- EMR(Static Variables) + Vital sign(Time Series Variables)
class CNNLSTMBlock(nn.Module):
def __init__(self, input_size, num_filters, kernel_size, units, dropout):
super(CNNLSTMBlock, self).__init__()
self.cnn = nn.Sequential(
# 1
nn.Conv1d(in_channels=input_size, out_channels=num_filters, kernel_size=kernel_size, padding=1),
nn.BatchNorm1d(num_filters),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
# 2
nn.Conv1d(in_channels=num_filters, out_channels=num_filters*2, kernel_size=kernel_size, padding=1),
nn.BatchNorm1d(num_filters*2),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
# 3
nn.Conv1d(in_channels=num_filters*2, out_channels=num_filters*4, kernel_size=kernel_size, padding=1),
nn.BatchNorm1d(num_filters*4),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
self.lstm1 = nn.LSTM(num_filters*4, units, num_layers=2, bidirectional=True, batch_first=True)
self.lstm2 = nn.LSTM(units*2, units//2, num_layers=2, bidirectional=True, batch_first=True)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
x = x.permute(0, 2, 1) # (Batch, Time-steps, features) ==> (32, 288, 6)
x = self.cnn(x) # (Batch, features, Time-steps) ==> (32, 6, 288)
x = x.permute(0, 2, 1) # (Batch, filters, Time-steps) ==> (32, 64, 288)
x, _ = self.lstm1(x) # (Batch, Time-steps, filters) ==> (32, 288, 64)
x = self.dropout1(x) # (Batch, Time-steps, filters) ==> (32, 288, 256)
x, _ = self.lstm2(x) # (Batch, Time-steps, filters) ==> (32, 288, 256)
x = x[:, -1, :] # (Batch, Time-steps, filters) ==> (32, 288, 128)
# x = self.dropout2(x) # (Batch, filters) ==> (32, 128)
return x
class CNNLSTMModel(nn.Module):
def __init__(self, emr_size, vitals_size, num_filters=128, kernel_size=7, units=256, dropout=0.4):
super(CNNLSTMModel, self).__init__()
# CRNN Blocks
self.cnn_lstm_block = nn.ModuleList(
[CNNLSTMBlock(vitals_size, num_filters, kernel_size, units, dropout)]
)
# FC layers
self.fc1 = nn.Linear((emr_size + units), units)
self.bn = nn.BatchNorm1d(units)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(units, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x_emr, x_vitals):
x_emr = x_emr.float() # Convert input EMR data to Double type
x_vitals = x_vitals.float()
hs = [cnn_lstm_block(x_vitals) for cnn_lstm_block in self.cnn_lstm_block]
h = torch.cat([x_emr] + hs, dim=-1) # torch.Size([32, 148])
h = self.fc1(h) # torch.Size([32, 128])
h = self.bn(h)
h = self.dropout(h)
h = self.fc2(h) # torch.Size([32, 1])
h = self.sigmoid(h) # torch.Size([32, 1])
return h.squeeze()
'Python Code' 카테고리의 다른 글
Pytorch CRNN 모델 (ResNet + LSTM) (0) | 2024.03.20 |
---|---|
psycopg2 로 python 에서 postgresql 활용하는 법 (0) | 2024.03.20 |
glob 사용법 (0) | 2024.02.08 |
Pytorch training continue 코드 (0) | 2024.01.31 |
Attention UNET 모델 구조 파이토치(Pytorch) 코드 (1) | 2024.01.23 |