- 파이토치로 딥러닝 모델을 학습시키다가 환경적인 요인으로 학습이 중단될 경우, 처음부터 다시 학습시키면 시간과 자원 낭비가 심하게 된다.
- 따라서, 현재까지 학습되어 있는 모델을 지속적으로 저장해주고, 마지막 모델로부터 학습을 이어서 할 수 있는 코드 작성이 필요하다.
- Training 을 이어서 하기 위해 아래 함수로 모델 가중치와 optimizer load가 필요
def continue_training(model, model_path, optimizer, device):
saved_checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(saved_checkpoint['model'])
optimizer.load_state_dict(saved_checkpoint['optim'])
model = model.to(device)
return model
- def train() 함수에서 train_continue argument가 필요
# parser에 train_continue 추가
parser.add_argument("--train_continue", default="off", type=str, dest="train_continue")
- train.py 파일 내의 training main 함수에서 아래처럼 train_continue='on' 인 경우, 모델 파일명과 epoch, 최종 성능을 작성해주어야 함. (파일명으로 오름차순 되어있으므로, string 타입의 파일명을 파싱&인덱싱해서 가장 마지막 인덱스 [-1] 모델을 자동으로 불러오는 것이 좋음.)
def main():
...
...
if mode == 'train': # TRAIN MODE
ST_EPOCH = 0
global_f1 = 0.0
if train_continue == 'on':
ST_EPOCH = 17
model_path = os.path.join(ckpt_dir, "epoch_17_iou0.7854_f10.8716.pth")
model = continue_training(model, model_path, optimizer, device)
global_f1 = 0.8716
print(f"Model Loaded from [ { model_path } ]... Completed !")
...
...
'Python Code' 카테고리의 다른 글
psycopg2 로 python 에서 postgresql 활용하는 법 (0) | 2024.03.20 |
---|---|
glob 사용법 (0) | 2024.02.08 |
Attention UNET 모델 구조 파이토치(Pytorch) 코드 (1) | 2024.01.23 |
torch 모델 save, load (0) | 2024.01.22 |
Nibabel 사용법 (0) | 2024.01.19 |