Python Code

CRNN multi-modal 모델

Kimhj 2024. 3. 21. 11:20
  • EMR(Static Variables) + Vital sign(Time Series Variables)

 

class CNNLSTMBlock(nn.Module):
    def __init__(self, input_size, num_filters, kernel_size, units, dropout):
        super(CNNLSTMBlock, self).__init__()

        self.cnn = nn.Sequential(
            # 1
            nn.Conv1d(in_channels=input_size, out_channels=num_filters, kernel_size=kernel_size, padding=1),
            nn.BatchNorm1d(num_filters),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            # 2
            nn.Conv1d(in_channels=num_filters, out_channels=num_filters*2, kernel_size=kernel_size, padding=1),
            nn.BatchNorm1d(num_filters*2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            # 3
            nn.Conv1d(in_channels=num_filters*2, out_channels=num_filters*4, kernel_size=kernel_size, padding=1),
            nn.BatchNorm1d(num_filters*4),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )

        self.lstm1 = nn.LSTM(num_filters*4, units, num_layers=2, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(units*2, units//2, num_layers=2, bidirectional=True, batch_first=True)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = x.permute(0, 2, 1)      # (Batch, Time-steps, features)   ==> (32, 288, 6)
        x = self.cnn(x)             # (Batch, features, Time-steps)   ==> (32, 6, 288)
        x = x.permute(0, 2, 1)      # (Batch, filters, Time-steps)    ==> (32, 64, 288)
        x, _ = self.lstm1(x)        # (Batch, Time-steps, filters)    ==> (32, 288, 64)
        x = self.dropout1(x)        # (Batch, Time-steps, filters)    ==> (32, 288, 256)
        x, _ = self.lstm2(x)        # (Batch, Time-steps, filters)    ==> (32, 288, 256)
        x = x[:, -1, :]             # (Batch, Time-steps, filters)    ==> (32, 288, 128)
        # x = self.dropout2(x)        # (Batch, filters)                ==> (32, 128)
        return x
    

class CNNLSTMModel(nn.Module):
    def __init__(self, emr_size, vitals_size, num_filters=128, kernel_size=7, units=256, dropout=0.4):
        super(CNNLSTMModel, self).__init__()
        # CRNN Blocks
        self.cnn_lstm_block = nn.ModuleList(
        [CNNLSTMBlock(vitals_size, num_filters, kernel_size, units, dropout)]
        )
        # FC layers
        self.fc1 = nn.Linear((emr_size + units), units)
        self.bn = nn.BatchNorm1d(units)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(units, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_emr, x_vitals):
        x_emr = x_emr.float()  # Convert input EMR data to Double type
        x_vitals = x_vitals.float()
        hs = [cnn_lstm_block(x_vitals) for cnn_lstm_block in self.cnn_lstm_block]
        h = torch.cat([x_emr] + hs, dim=-1) # torch.Size([32, 148])
        h = self.fc1(h)                     # torch.Size([32, 128])
        h = self.bn(h)
        h = self.dropout(h)
        h = self.fc2(h)                     # torch.Size([32, 1])
        h = self.sigmoid(h)                 # torch.Size([32, 1])

        return h.squeeze()