def train_model(model, ealry_stop, n_epochs, progress_interval):
train_losses, valid_losses, lowest_losse = list(), list(), np.inf
for epoch in range(n_epochs):
model.train()
for x, y in train_batches:
x = x.reshape(-1, sequence_length, feature_size)
outputs = model(x)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
model.eval()
with torch.no_grad():
for x, y in val_batches:
x = x.reshape(-1, sequence_length, feature_size)
outputs = model(x)
loss = criterion(outputs, y)
valid_losses.append(loss.item())
if valid_losses[-1] < lowest_losse:
lowest_losse = valid_losses[-1]
lowest_epoch = epoch
best_model = deepcopy(model.state_dict())
break
else:
if ealry_stop > 0 and (lowest_epoch + ealry_stop) < epoch:
print(f'Ealry Stopped at { epoch } Epoch')
model.load_state_dict(best_model)
break
if (epoch % progress_interval) == 0:
print(f"Epoch: {epoch} / {n_epochs}, Train Loss: { train_losses[-1] }, Valid Loss: { valid_losses[-1] }, Lowest Loss: { lowest_losse }")