Python Code

Pytorch CRNN 모델 (ResNet + LSTM)

Kimhj 2024. 3. 20. 15:50
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

# BLOCK
class CNNLSTMBlock(nn.Module):
    def __init__(self, input_size, num_filters, kernel_size=3, units=128, dropout_rate=0.5):
        super(CNNLSTMBlock, self).__init__()

        # CNN Layer
        self.cnn1 = nn.Sequential(
            nn.Conv1d(input_size, num_filters, kernel_size),
            nn.BatchNorm1d(num_filters),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
        )

        # LSTM Layer 1
        self.lstm1 = nn.LSTM(num_filters, units,
                             bidirectional=True, batch_first=True)
        # LSTM Layer 2
        self.lstm2 = nn.LSTM(
            units*2, units//2, bidirectional=True, batch_first=True)
        # Dropout Layer
        self.dropout1 = nn.Dropout(dropout_rate)
        # Dropout Layer 2
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.cnn1(x)
        x = x.permute(0, 2, 1)      # (B, Time-steps, Filters)
        x, _ = self.lstm1(x)
        x = self.dropout1(x)
        x, _ = self.lstm2(x)
        x = x[:, -1, :] 
        x = self.dropout2(x)

        return x

class CNNLSTMModel(nn.Module):
    def __init__(self, demo_size, input_sizes, num_filters, kernel_sizes, units=128, dropout_rate=0.5):
        super(CNNLSTMModel, self).__init__()
        self.cnn_lstm_blocks = nn.ModuleList([CNNLSTMBlock(input_size, num_filters, kernel_size, units, dropout_rate)
                                             for input_size, num_filters, kernel_size in zip(input_sizes, num_filters, kernel_sizes)])
        self.fc1 = nn.Linear(demo_size + len(input_sizes) * units, units)
        self.batch_norm = nn.BatchNorm1d(units)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(units, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, demo, inputs):
        hs = [cnn_lstm_block(x) for cnn_lstm_block,
              x in zip(self.cnn_lstm_blocks, inputs)]
        h = torch.cat([demo] + hs, dim=-1)
        h = self.fc1(h)
        h = self.batch_norm(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = self.sigmoid(h)

        return h