- Pytorch 에서 모델을 pth 파일로 저장 후, 가중치를 불러와서 prediction 시에 발생하는 오류가 있어서 정리
- 모델 파일에서 load_state_dict 로 정보들을 가져온 뒤 파라미터에 weight를 적용해도 성능이 0.0 으로 나오는 경우가 있음. 이런 경우는 모델 저장 시에 weight와 optimizer 정보까지 같이 저장한 경우여서, 두개 전부 load 해주어야함.
# Torch model Save
def save(ckpt_dir, model, optim, epoch, iou, f1):
"""
Save torch model in check point directory.
Args:
ckpt_dir (str): checkpoint directory path
model (torch model) : Pytorch model weights
optim (torch.optimizer): pytorch optimizer
epoch (int): train epoch
"""
os.makedirs(ckpt_dir, exist_ok=True)
torch.save({'model': model.state_dict(),
'optim': optim.state_dict()},
f"{ckpt_dir}/epoch_{epoch}_iou{round(iou,2)}_f1{round(f1,2)}.pth")
# Torch model Load
def load(ckpt_dir, model, optim):
"""
Load torch model in check point directory.
"""
if os.path.exists(ckpt_dir):
epoch = 0
return model, optim, epoch
ckpt_lst = os.listdir(ckpt_dir)
ckpt_lst.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
dict_model = torch.load(f"{ckpt_dir}/{ckpt_lst[-1]}") # load laetset model
model.load_state_dict(dict_model['model'])
optim.load_state_dict(dict_model['optim'])
epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])
return model, optim, epoch
'Python Code' 카테고리의 다른 글
Pytorch training continue 코드 (0) | 2024.01.31 |
---|---|
Attention UNET 모델 구조 파이토치(Pytorch) 코드 (1) | 2024.01.23 |
Nibabel 사용법 (0) | 2024.01.19 |
DICOM 파일 전처리 (0) | 2024.01.19 |
Glob 이용해서 특정폴더의 이미지 경로들을 DataFrame으로 만들기 (0) | 2024.01.11 |