에러 삽질
model load missing keys and unexpected keys
ylab
2022. 4. 2. 14:28
모델로드시 생긴 에러
파이토치 만드신분들 존경하며, 차근차근 에러를 고쳐보자
무슨뜻일까?
네트워크는 그대로인데 학습시킨 weight 와 bias를 각 레이어에 맞게 불러오지 못했다는 소리
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.conv2.weight", "module.bn2.weight", "module.bn2.bias", "module.bn2.runnin
g_mean", "module.bn2.running_var", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.conv2.weight", "module
.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias",
#생략
Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "conv2.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.
num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "la
yer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias",
#생략
운이 좋게도(학습단계에서 생긴 문제는 아니라는 소리) 학습은 되었고, 에러를 침착하게 보면 요소들은 다 있으나, module. 차이가 있는 것을 알 수 있다.
<에러 발생한 코드>
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict.module.state_dict())
else:
model.module.load_state_dict(state_dict)
보면 모델에서 부를 때에 module을 거쳐서 부르는 것을 알 수 있다.
그러면 module을 안거쳐서 부르면 될것 아닌가?! (그러면 module도 딸려 올것이기 때문에)
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict.state_dict())
else:
model.load_state_dict(state_dict)
해결!
비슷한 문제를 다르게 해결한 분도 계시다. 같이 참고하면 좋을 것 같다.
https://mostar39.tistory.com/30?category=915889
[Python Torch] 얼렁뚱땅 load_state_dict 에러 잡기
많은 분들이 학습 시킬 때, 특정 epoch마다 혹은 특정 iter마다 모델을 저장하는데 저장한 모델을 다시 불러올 때가 있죠.... 그 때 종종 에러가 뜨는거에요 .. 슬프게 ... ㅜ 저는 CcGAN(https://github.com/U
mostar39.tistory.com