Python Code

torch 모델 save, load

Kimhj 2024. 1. 22. 09:29
  • Pytorch 에서 모델을 pth 파일로 저장 후, 가중치를 불러와서 prediction 시에 발생하는 오류가 있어서 정리
  • 모델 파일에서 load_state_dict 로 정보들을 가져온 뒤 파라미터에 weight를 적용해도 성능이 0.0 으로 나오는 경우가 있음. 이런 경우는 모델 저장 시에 weight와 optimizer 정보까지 같이 저장한 경우여서, 두개 전부 load 해주어야함.

 

# Torch model Save

def save(ckpt_dir, model, optim, epoch, iou, f1):
    """
    Save torch model in check point directory.
    Args: 
        ckpt_dir (str): checkpoint directory path
        model (torch model) : Pytorch model weights
        optim (torch.optimizer): pytorch optimizer
        epoch (int): train epoch        
    """
    os.makedirs(ckpt_dir, exist_ok=True)

    torch.save({'model': model.state_dict(),
                'optim': optim.state_dict()},
                f"{ckpt_dir}/epoch_{epoch}_iou{round(iou,2)}_f1{round(f1,2)}.pth")


# Torch model Load

def load(ckpt_dir, model, optim):
    """
        Load torch model in check point directory.
    """
    if os.path.exists(ckpt_dir):
        epoch = 0
        return model, optim, epoch
    
    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst.sort(key=lambda f: int("".join(filter(str.isdigit, f))))

    dict_model = torch.load(f"{ckpt_dir}/{ckpt_lst[-1]}")   # load laetset model
    model.load_state_dict(dict_model['model'])
    optim.load_state_dict(dict_model['optim'])
    epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])

    return model, optim, epoch