Python Code

Pytorch training continue 코드

Kimhj 2024. 1. 31. 09:21
  • 파이토치로 딥러닝 모델을 학습시키다가 환경적인 요인으로 학습이 중단될 경우, 처음부터 다시 학습시키면 시간과 자원 낭비가 심하게 된다.
  • 따라서, 현재까지 학습되어 있는 모델을 지속적으로 저장해주고, 마지막 모델로부터 학습을 이어서 할 수 있는 코드 작성이 필요하다.

 

  • Training 을 이어서 하기 위해 아래 함수로 모델 가중치와 optimizer load가 필요
def continue_training(model, model_path, optimizer, device):
    
    saved_checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(saved_checkpoint['model'])
    optimizer.load_state_dict(saved_checkpoint['optim'])
    model = model.to(device)

    return model

 

  • def train() 함수에서 train_continue argument가 필요
# parser에 train_continue 추가
parser.add_argument("--train_continue", default="off", type=str, dest="train_continue")

 

  • train.py 파일 내의 training main 함수에서 아래처럼 train_continue='on' 인 경우, 모델 파일명과 epoch, 최종 성능을 작성해주어야 함. (파일명으로 오름차순 되어있으므로, string 타입의 파일명을 파싱&인덱싱해서 가장 마지막 인덱스 [-1] 모델을 자동으로 불러오는 것이 좋음.)
def main():
	...
	...
    if mode == 'train': # TRAIN MODE
        ST_EPOCH = 0
        global_f1 = 0.0

        if train_continue == 'on':
            ST_EPOCH = 17
            model_path = os.path.join(ckpt_dir, "epoch_17_iou0.7854_f10.8716.pth")
            model = continue_training(model, model_path, optimizer, device)
            global_f1 = 0.8716
            print(f"Model Loaded from [ { model_path } ]... Completed !")
            
            
	...
    ...

'Python Code' 카테고리의 다른 글

psycopg2 로 python 에서 postgresql 활용하는 법  (0) 2024.03.20
glob 사용법  (0) 2024.02.08
Attention UNET 모델 구조 파이토치(Pytorch) 코드  (1) 2024.01.23
torch 모델 save, load  (0) 2024.01.22
Nibabel 사용법  (0) 2024.01.19