카테고리 없음

딥러닝 Early Stopping 코드

Kimhj 2023. 10. 25. 10:04
def train_model(model, ealry_stop, n_epochs, progress_interval):

    train_losses, valid_losses, lowest_losse = list(), list(), np.inf
    
    for epoch in range(n_epochs):
        model.train()
        for x, y in train_batches:
            x = x.reshape(-1, sequence_length, feature_size)
            outputs = model(x)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
    
        model.eval()
        with torch.no_grad():
            for x, y in val_batches:
                x = x.reshape(-1, sequence_length, feature_size)
                outputs = model(x)
                loss = criterion(outputs, y)

                valid_losses.append(loss.item())

        if valid_losses[-1] < lowest_losse:
            lowest_losse = valid_losses[-1]
            lowest_epoch = epoch
            best_model = deepcopy(model.state_dict())
            break
    
        else:
            if ealry_stop > 0 and (lowest_epoch + ealry_stop) < epoch:
                print(f'Ealry Stopped at { epoch } Epoch')
                model.load_state_dict(best_model)
                break
            
        if (epoch % progress_interval) == 0:
            print(f"Epoch: {epoch} / {n_epochs}, Train Loss: { train_losses[-1] }, Valid Loss: { valid_losses[-1] }, Lowest Loss: { lowest_losse }")