에러 삽질

model load missing keys and unexpected keys

ylab 2022. 4. 2. 14:28

모델로드시 생긴 에러

파이토치 만드신분들 존경하며, 차근차근 에러를 고쳐보자

무슨뜻일까? 

네트워크는 그대로인데 학습시킨 weight 와 bias를 각 레이어에 맞게 불러오지 못했다는 소리

 

 

RuntimeError: Error(s) in loading state_dict for DataParallel:
        Missing key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.conv2.weight", "module.bn2.weight", "module.bn2.bias", "module.bn2.runnin
g_mean", "module.bn2.running_var", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.conv2.weight", "module
.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", 

#생략


        Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "conv2.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.
num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "la
yer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias",

#생략

운이 좋게도(학습단계에서 생긴 문제는 아니라는 소리) 학습은 되었고, 에러를 침착하게 보면 요소들은 다 있으나, module. 차이가 있는 것을 알 수 있다.

 

 

<에러 발생한 코드>

    if "state_dict" in state_dict.keys():
       	state_dict = state_dict["state_dict"]
       	model.load_state_dict(state_dict.module.state_dict())
    else:
      	model.module.load_state_dict(state_dict)

보면 모델에서 부를 때에 module을 거쳐서 부르는 것을 알 수 있다.

그러면 module을 안거쳐서 부르면 될것 아닌가?! (그러면 module도 딸려 올것이기 때문에)

 

    if "state_dict" in state_dict.keys():
        state_dict = state_dict["state_dict"]
        model.load_state_dict(state_dict.state_dict())
    else:
        model.load_state_dict(state_dict)

해결!

 

비슷한 문제를 다르게 해결한 분도 계시다. 같이 참고하면 좋을 것 같다.

https://mostar39.tistory.com/30?category=915889 

 

[Python Torch] 얼렁뚱땅 load_state_dict 에러 잡기

많은 분들이 학습 시킬 때, 특정 epoch마다 혹은 특정 iter마다 모델을 저장하는데 저장한 모델을 다시 불러올 때가 있죠.... 그 때 종종 에러가 뜨는거에요 .. 슬프게 ... ㅜ 저는 CcGAN(https://github.com/U

mostar39.tistory.com