-
model load missing keys and unexpected keys에러 삽질 2022. 4. 2. 14:28
모델로드시 생긴 에러
파이토치 만드신분들 존경하며, 차근차근 에러를 고쳐보자
무슨뜻일까?
네트워크는 그대로인데 학습시킨 weight 와 bias를 각 레이어에 맞게 불러오지 못했다는 소리
RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.conv2.weight", "module.bn2.weight", "module.bn2.bias", "module.bn2.runnin g_mean", "module.bn2.running_var", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.conv2.weight", "module .layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", #생략 Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "conv2.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2. num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "la yer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", #생략
운이 좋게도(학습단계에서 생긴 문제는 아니라는 소리) 학습은 되었고, 에러를 침착하게 보면 요소들은 다 있으나, module. 차이가 있는 것을 알 수 있다.
<에러 발생한 코드>
if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] model.load_state_dict(state_dict.module.state_dict()) else: model.module.load_state_dict(state_dict)
보면 모델에서 부를 때에 module을 거쳐서 부르는 것을 알 수 있다.
그러면 module을 안거쳐서 부르면 될것 아닌가?! (그러면 module도 딸려 올것이기 때문에)
if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] model.load_state_dict(state_dict.state_dict()) else: model.load_state_dict(state_dict)
해결!
비슷한 문제를 다르게 해결한 분도 계시다. 같이 참고하면 좋을 것 같다.
https://mostar39.tistory.com/30?category=915889
[Python Torch] 얼렁뚱땅 load_state_dict 에러 잡기
많은 분들이 학습 시킬 때, 특정 epoch마다 혹은 특정 iter마다 모델을 저장하는데 저장한 모델을 다시 불러올 때가 있죠.... 그 때 종종 에러가 뜨는거에요 .. 슬프게 ... ㅜ 저는 CcGAN(https://github.com/U
mostar39.tistory.com
'에러 삽질' 카테고리의 다른 글
데이터셋 간에 차이 (mesh 편) (1) 2022.06.06 ps -af 우분투 실행창 관리 (0) 2022.06.06 if 문 괄호 생활화 (1) 2022.04.20 어느 카테고리에도 속하지 못한 depth rendering (0) 2022.03.15 opecv (0) 2022.03.06