Deep Learning

UNET 구조, 설명

Kimhj 2023. 12. 12. 09:27
  • UNET(United Encoder-Decoder Networks)은 주로 이미지 분할, 세그멘테이션, 객체인식 등에 사용되는 딥러닝 모델중 하나임.
  • 이미지의 각 픽셀을 해당하는 클래스나 세그멘테이션 마스크로 할당하는 작업에 적합하며, 2015년 소개된 이후 이미지 세그맨테이션 분야에서 많이 활용되고 있음.
  • UNET 특징은 아래와 같음.
    • Encoder-Decoder 구조
      • UNET은 대칭구조의 Encoder와 Decoder로 구성
      • Encoder : 이미지의 공간적 계층 구조를 추출하여 저차원 임베딩으로 변환
      • Decoder : 저차원 임베딩을 사용하여 입력 이미지의 공간적 구조를 복원하고, 원래 입력의 크기와 동일한 크기의 세그멘테이션 맵 생성
    • Skip Connections
      • UNET은 각 인코더 레이어와 디코더 레이어간 스킵 커넥션 사용
      • 스킵 커넥션은 원본 입력값 정보를 Conv layer를 통과한 feature와 결합하여 층이 깊어짐으로 인한 정보손실을 방지하고 더 정확한 세그멘테이션을 가능하게 함.
    • Valid Convolution(패딩없는 합성곱)
      • UNET은 합성곱 연산 시, 패딩이 없는 형태를 사용
      • 입력과 출력 크기를 정확히 일치시켜줘서 세그멘테이션 맵을 정확하게 복원하는 데 도움을 줌.
    • Activation Function
      • Encoder 부분에서는 주로 ReLU(Rectified Linear Unit) 활성화함수를 사용하고, Decoder에서는 Sigmoid 함수를 사용
      • Sigmoid 함수는 각 픽셀에 대한 이진 분류를 수행하므로, 픽셀이 특정 클래스에 속하는지 여부를 결정하는데 사용됨

딥러닝 태스크 종류

 

 

 

UNET 구조

 

# UNET Model
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, n_filters):
        super(ConvBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_filters, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm2d(n_filters)
        self.bn2 = nn.BatchNorm2d(n_filters)

        self.activation = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)

        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, n_filters):
        super(EncoderBlock, self).__init__()

        self.conv_blk = ConvBlock(n_filters)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, inputs):
        x = self.conv_blk(inputs)
        p = self.pool(x)
        return x, p

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, n_filters):
        super(DecoderBlock, self).__init__()

        self.up = nn.ConvTranspose2d(in_channels, n_filters, kernel_size=2, stride=2, padding=0)
        self.conv_blk = ConvBlock(n_filters)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.conv_blk(x)

        return x

class UNET(nn.Module):
    def __init__(self, in_channels, n_classes):
        super(UNET, self).__init__()

        # Encoder
        self.e1 = EncoderBlock(in_channels, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)

        # Bridge
        self.b = ConvBlock(512)

        # Decoder
        self.d1 = DecoderBlock(512, 256)
        self.d2 = DecoderBlock(256, 128)
        self.d3 = DecoderBlock(128, 64)
        self.d4 = DecoderBlock(64, in_channels)

        # Outputs
        if n_classes == 1:
            activation = nn.Sigmoid()
        else:
            activation = nn.Softmax(dim=1)

        self.outputs = nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
        self.activation = activation

    def forward(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)
        outputs = self.activation(outputs)

        return outputs