- 학습 코드, 평가 코드
import torch
import os
import random
import numpy as np
import pandas as pd
from copy import deepcopy
from model import RFModel
from dataset import RFDataset
from torchmetrics.classification import BinaryAUROC
def train_one_epoch(model, train_dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
for batch_idx, (x, y) in enumerate(train_dataloader):
x_emr = x['emr'].to(device)
x_vital = x['vital'].to(device)
x_ecg = x['ecg'].to(device)
y = y.to(device)
# 기울기 초기화
optimizer.zero_grad()
# 모델을 사용하여 예측값 계산
outputs = model(x_emr, x_vital, x_ecg)
# 손실 계산
loss = criterion(outputs, y)
# 기울기 계산 및 가중치 업데이트
loss.backward()
optimizer.step()
running_loss += loss.item()
return running_loss / (batch_idx + 1)
def validate_one_epoch(model, valid_dataloader, criterion, metric, device):
model.eval()
val_running_loss = 0.0
val_outputs = []
val_targets = []
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_dataloader):
x_emr = x['emr'].to(device)
x_vital = x['vital'].to(device)
x_ecg = x['ecg'].to(device)
y = y.to(device)
outputs = model(x_emr, x_vital, x_ecg)
loss = criterion(outputs, y)
val_running_loss += loss.item()
val_outputs.extend(outputs.cpu().numpy())
val_targets.extend(y.cpu().numpy())
val_epoch_loss = val_running_loss / (batch_idx + 1)
val_outputs = torch.tensor(val_outputs)
val_targets = torch.tensor(val_targets)
val_auroc = metric(val_outputs, val_targets)
return val_epoch_loss, val_auroc
- 데이터셋 구축 (예시)
import torch
import numpy as np
from torch.utils.data import Dataset
class RFDataset(Dataset):
def __init__(self, data):
self.emr = data["emr"]
self.vital = data["vital"]
self.ecg = data["ecg"]
self.label = data["label"]
def __len__(self):
return len(self.label)
@staticmethod
def normalize_data(data, mean_val, scale_val):
return (data - mean_val) / scale_val
def preprocess_ecg(self, idx):
ecg_mean = np.mean(self.ecg[idx], axis=1, keepdims=True)
ecg_std = np.std(self.ecg[idx], axis=1, keepdims=True)
return (self.ecg[idx] - ecg_mean) / (ecg_std + 1e-10)
def preprocess_vital(self, idx):
PR = self.normalize_data(self.vital[idx, :, 0], 50, 50)
RR = self.normalize_data(self.vital[idx, :, 1], 12, 8)
SPO2 = self.normalize_data(self.vital[idx, :, 2], 95, 5)
DBP = self.normalize_data(self.vital[idx, :, 3], 60, 120)
SBP = self.normalize_data(self.vital[idx, :, 4], 60, 120)
return np.stack([PR, RR, SPO2, DBP, SBP], axis=1)
def preprocess_emr(self, idx):
age = self.normalize_data(self.emr[idx, 0], 18, 83)
temp_AVPU = min(int(self.emr[idx, 2]), 4)
AVPU = [int(i == temp_AVPU) for i in range(5)]
BT = self.normalize_data(self.emr[idx, 5], 35, 10)
pH = self.normalize_data(self.emr[idx, 6], 6.5, 1.2)
pCO2 = self.normalize_data(self.emr[idx, 7], 180, 10)
pO2 = self.normalize_data(self.emr[idx, 8], 570, 40)
BE = self.normalize_data(self.emr[idx, 9], -30, 50)
HCO3 = self.normalize_data(self.emr[idx, 10], 0, 55)
FIO2_corrected_imputed = self.normalize_data(self.emr[idx, 11], 20, 80)
return np.array([
age, self.emr[idx, 1], *AVPU, self.emr[idx, 3], self.emr[idx, 4], BT, pH, pCO2, pO2, BE, HCO3, FIO2_corrected_imputed
], dtype=np.float32)
def __getitem__(self, idx):
x = {
"ecg": self.preprocess_ecg(idx),
"vital": self.preprocess_vital(idx),
"emr": self.preprocess_emr(idx)
}
y = self.label[idx]
return x, y
- 활용
# Hyper parameters
SEED = 42
BATCH_SIZE = 32
NUM_EPOCHS = 300
LEARNING_RATE = 0.001
MODEL_PATH = ''
# 1. train, valid 세트를 각각 dictionary 형태로 변환
train_dict = {
'emr': train_data[EMR_COLUMNS].to_numpy().astype(np.float32),
'vital': np.array(train_data.numerics.tolist()),
'ecg': np.array(train_data.ecg.tolist()),
'label': train_data[LABEL_COLUMN].to_numpy().astype(np.float32)
}
valid_dict = {
'emr': valid_data[EMR_COLUMNS].to_numpy().astype(np.float32),
'vital': np.array(valid_data.numerics.tolist()),
'ecg': np.array(valid_data.ecg.tolist()),
'label': valid_data[LABEL_COLUMN].to_numpy().astype(np.float32)
}
test_dict = {
'emr': test_data[EMR_COLUMNS].to_numpy().astype(np.float32),
'vital': np.array(test_data.numerics.tolist()),
'ecg': np.array(test_data.ecg.tolist()),
'label': test_data[LABEL_COLUMN].to_numpy().astype(np.float32)
}
# 2. train, valid, test 세트를 각각 RFDataset으로 변환
train_dataset = RFDataset(train_dict)
valid_dataset = RFDataset(valid_dict)
test_dataset = RFDataset(test_dict)
# 3. train, valid, test 세트를 각각 DataLoader로 변환
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 4. 모델, optimizer, scheduler, criterion, metric 정의
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RFModel().to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=LEARNING_RATE, weight_decay=0.0003)
# scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.25, mode='max')
criterion = torch.nn.BCELoss()
metric = BinaryAUROC()
# 5. Early stopping과 Best model save를 위한 변수 선언
best_val_auroc = float("-inf")
best_model = None
patience = 50
no_improvement_epochs = 0
# 6. 모델 학습 및 검증
for epoch in range(NUM_EPOCHS):
epoch_loss = train_one_epoch(model, train_dataloader, criterion, optimizer, device)
val_epoch_loss, val_auroc = validate_one_epoch(model, valid_dataloader, criterion, metric, device)
test_epoch_loss, test_auroc = test_one_epoch(model, test_dataloader, criterion, metric, device)
print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Train Loss: {epoch_loss:.4f}, Val Loss: {val_epoch_loss:.4f}, Val AUROC: {val_auroc:.4f}, Test Loss: {test_epoch_loss:.4f}, Test AUROC: {test_auroc}")
# scheduler.step(val_auroc)
# Check for improvement
if val_auroc > best_val_auroc:
best_val_auroc = val_auroc
best_model = deepcopy(model.state_dict())
no_improvement_epochs = 0
else:
no_improvement_epochs += 1
# Early stopping
if no_improvement_epochs >= patience:
print(f"Early stopping at epoch {epoch + 1}")
break
# 7. Save the best model
model.load_state_dict(best_model)
torch.save(model.state_dict(), f"{MODEL_PATH}/best_model-{best_val_auroc}.pth")
'Python Code' 카테고리의 다른 글
Config 파일 관리 (argparse) (0) | 2023.10.10 |
---|---|
Cut-off 별 Sensitivity(민감도), Specificity(특이도) 성능 측정 코드 (0) | 2023.10.10 |
Docker 자주 쓰는 명령어 정리 (0) | 2023.09.23 |
PDF 파일 Concat (0) | 2023.09.20 |
Config 파일 관리 (yaml) (0) | 2023.09.19 |