Medical AI & Article Review

Medical Image SOTA 모델 (DUCK-NET) 리뷰 및 torch 코드 구현

Kimhj 2024. 1. 26. 14:21
  • 원논문 링크 : https://arxiv.org/pdf/2311.02239v1.pdf
  • 의료 이미지 segmentation 에서 가장 많이 활용되는 모델은 UNET으로, 주로 인코더에서 CNN 계열 모델을 붙이고 디코더에 UNET으로 De-convoluioning 하는 Up-sampling 과정을 통해 학습시킨다.
  • UNET 후속 모델인 UNET++, UNET3+ 도 있으나, 동일한 데이터로 실험해본 결과 성능의 큰 차이는 없었다. (다만, 모델 사이즈가 조금 작아져서 light 해지는 결과는 있었음.)
  • UNET은 2015년에 소개된 이후 약 10년간 많이 활용되었지만, 최근 제안되는 모델에 비해 성능이 많이 떨어지는 결과가 많았음.

UNET 구조

  • semantic segmantation SOTA 모델 성능비교

 

  • Dataset 별 성능 비교 결과

 

  • 모델 결과비교

  • 전체 모델 구조
    • In our study, we utilized a parameter, F (filter size), to modify the depth of convolutional layers. Through comprehensive experimentation, we determined that a model incorporating 17 filters serves as an optimal representation of a smaller model, while a model incorporating 34 filters represents a larger model effectively.

  • Res Block
    • The Residual block , first introduced in ResUNet++ paper [10], is the first component in our novel DUCK. Its purpose is to understand the small details that make a polyp. While using multiple small convolutions is usually a good idea, having too many can mean that the network has difficulty training and understanding what features to look for. We use combinations of one, two, and three Residual blocks to simulate kernel sizes of 5x5, 9x9, and 13x13.

 

  • Midscope Block

  • Widescope Block

  • Our novel Midscope and Widescope blocks use dilated convolutions to reduce the parameters needed to simulate larger kernels while allowing the network to understand higher-level features better. They work by spreading the nine cells that would typically be in a 3x3 kernel over a larger area. These two blocks aim to learn prominent features that only require a little attention to detail, as the dilation effect has the side effect of losing information. The Midscope cell simulates a kernel size of 7x7, and the Widescope simulates a kernel size of 15x15.

 

  • Seperated Block
    • The Separated block (Figure 5) is our third way of simulating big kernels. The main idea behind it is that combining a 1xN kernel with an Nx1 kernel results in a behavior similar to an NxN kernel. However, this method encounters a drawback related to the concept known as "diagonality". Essentially, diagonality implies the capacity of a convolutional layer to capture and sustain spatial details linked to diagonal patterns in an image, a feature intrinsic to the structure of a conventional NxN convolutional kernel. It retains these diagonal elements owing to its bi-dimensional characteristics, enabling it to capture spatial connections in both vertical and horizontal directions, which also encompasses diagonal aspects.
    • Yet, the distinctive processing approach of separable convolutions (1xN followed by Nx1), where filters operate on one dimension at a time, potentially obstructs their capacity to efficiently encode diagonal features. This leads to the so-called "loss of diagonality". Such diagonal relationships can prove useful for detecting specific intricate patterns or shapes within an image, hence the other blocks are designed to compensate.
    • class separated_conv2D_block(nn.Module): def __init__(self, filters, half_channel, size=3, padding='same'): super(separated_conv2D_block, self).__init__() if half_channel: self.out_channel = int(filters/2) else: self.out_channel = filters self.conv_block = nn.Sequential( nn.Conv2d(in_channels=filters, out_channels=self.out_channel, kernel_size=(1, size), padding=padding), nn.BatchNorm2d(self.out_channel), nn.ReLU(), nn.Conv2d(in_channels=self.out_channel, out_channels=self.out_channel, kernel_size=(size,1), padding=padding), nn.BatchNorm2d(self.out_channel), nn.ReLU() ) def forward(self, x): output = self.conv_block(x) return output

 

  • Duck Block (conv/midscope/widescope/seperated Block을 조합한 block)
    • DUCK is our novel convolutional block that combines the previously mentioned blocks, all used in parallel so that the network can use the behavior it deems best at each step. The idea behind it is that it has a wide variety of kernel sizes simulated in three different ways. This means that the network can decide how to compensate for the drawbacks of one way to simulate a kernel over another.
    • Having a variety of kernel sizes means it can find the general area of the target while also finding the edges correctly. We incorporated a one-two-three combination of residual blocks based on empirical observations suggesting no significant performance gains from multiple instances of Midscope, Widescope, and Separable blocks. Essentially, the computational resources required for these additions did not justify the marginal improvements in results. The result is a novel block that searches for low-level and high-level features simultaneously with auspicious results.

  • Double conv with BatchNorm block
    • 2개의 Convolution Layer와 2개의 BN Layer로 이루어진 Block
  • Convolution Block (위 모든 block 을 간편하게 사용하도록 구현)
# Convoultion Block
class conv_block_2D(nn.Module):
    def __init__(self, filters, block_type, repeat=1, dilation_rate=1, size=3, padding='same', half_channel=False):
        super(conv_block_2D, self).__init__()

        self.conv_block = nn.ModuleList()
        self.half_channel = half_channel
        self.filters = filters
        
        for i in range(0, repeat):
            if i == 1:
                if self.half_channel:
                    self.filters = int(self.filters/2)
                    
                self.half_channel = False

            if block_type == 'separated':
                self.conv_block.append(
                    separated_conv2D_block(self.filters, size=size, padding=padding, half_channel=self.half_channel)
                )
            elif block_type == 'duckv2':
                self.conv_block.append(
                    duckv2_conv2D_block(self.filters, size=size, half_channel=self.half_channel)
                    )
                
            elif block_type == 'midscope':
                self.conv_block.append(
                    midscope_conv2D_block(self.filters, half_channel=self.half_channel)
                )
            elif block_type == 'widescope':
                self.conv_block.append(
                    widescope_conv2D_block(self.filters, half_channel=self.half_channel)
                )           
            elif block_type == 'resnet':
                self.conv_block.append(
                    resnet_conv2D_block(self.filters, self.half_channel, dilation_rate)
                )
            elif block_type == 'conv':
                self.conv_block.append(
                    nn.Conv2d(in_channels=self.filters, out_channels=self.filters, kernel_size=size, padding=padding)
                )
                self.conv_block.append(nn.ReLU())
            elif block_type == 'double_convolution':
                self.conv_block.append(
                    double_convolution_with_batch_normalization(self.filters, dilation_rate)
                )
            else:
                print('HERE')
                return None
    def forward(self, x):
        for i in range(len(self.conv_block)):
            x = self.conv_block[i](x)
        
        return x
  • DuckNet 모델 구현
class DuckNet(nn.Module):
    def __init__(self, starting_filters):
        super(DuckNet, self).__init__()

        # Down Sampling
        self.change_channels = nn.Conv2d(in_channels=1, out_channels=starting_filters, kernel_size=1, stride=1)
        self.downsample_1 = nn.Conv2d(in_channels=starting_filters, out_channels=starting_filters*2, kernel_size=2, stride=2)
        self.downsample_2 = nn.Conv2d(in_channels=starting_filters*2, out_channels=starting_filters*4, kernel_size=2, stride=2)
        self.downsample_3 = nn.Conv2d(in_channels=starting_filters*4, out_channels=starting_filters*8, kernel_size=2, stride=2)
        self.downsample_4 = nn.Conv2d(in_channels=starting_filters*8, out_channels=starting_filters*16, kernel_size=2, stride=2)
        self.downsample_5 = nn.Conv2d(in_channels=starting_filters*16, out_channels=starting_filters*32, kernel_size=2, stride=2)
        
        # Down Duck Block
        self.downduck_0 = conv_block_2D(starting_filters, block_type='duckv2')
        self.downduck_1 = conv_block_2D(starting_filters*2, block_type='duckv2')
        self.downduck_2 = conv_block_2D(starting_filters*4, block_type='duckv2')
        self.downduck_3 = conv_block_2D(starting_filters*8, block_type='duckv2')
        self.downduck_4 = conv_block_2D(starting_filters*16, block_type='duckv2')

        # Down Sampling (Duck block)
        self.downsample_d_1 = nn.Conv2d(in_channels=starting_filters, out_channels=starting_filters*2, kernel_size=2, stride=2)
        self.downsample_d_2 = nn.Conv2d(in_channels=starting_filters*2, out_channels=starting_filters*4, kernel_size=2, stride=2)
        self.downsample_d_3 = nn.Conv2d(in_channels=starting_filters*4, out_channels=starting_filters*8, kernel_size=2, stride=2)
        self.downsample_d_4 = nn.Conv2d(in_channels=starting_filters*8, out_channels=starting_filters*16, kernel_size=2, stride=2)
        self.downsample_d_5 = nn.Conv2d(in_channels=starting_filters*16, out_channels=starting_filters*32, kernel_size=2, stride=2)
        
        # Res Block
        self.res_1 = conv_block_2D(starting_filters*32, block_type='resnet', repeat=2)
        self.res_2 = conv_block_2D(starting_filters*32, block_type='resnet', repeat=2, half_channel=True)

        # Up Duck Block
        self.upduck_0 = conv_block_2D(starting_filters, 'duckv2', repeat=1)
        self.upduck_1 = conv_block_2D(starting_filters*2, 'duckv2', repeat=1, half_channel=True)
        self.upduck_2 = conv_block_2D(starting_filters*4, 'duckv2', repeat=1, half_channel=True)
        self.upduck_3 = conv_block_2D(starting_filters*8, 'duckv2', repeat=1, half_channel=True)
        self.upduck_4 = conv_block_2D(starting_filters*16, 'duckv2', repeat=1, half_channel=True)

        self.output = nn.Conv2d(in_channels=starting_filters, out_channels=32, kernel_size=1, stride=1, padding='same')
        self.softmax = nn.Softmax2d()

    def forward(self, x):                   # Input shape: (1, 1, 256, 256) ==> (Batch, Channels, Height, Weight)
        x = self.change_channels(x)

    # Down sampling (Origin Image)
        down_1 = self.downsample_1(x)
        down_2 = self.downsample_2(down_1)
        down_3 = self.downsample_3(down_2)
        down_4 = self.downsample_4(down_3)
        down_5 = self.downsample_5(down_4)

    # Down sampling (Origin Image + Duck Block)
        duck_0 = self.downduck_0(x)
        # duck block down 1 (blue arrow in article)
        down_d_1 = self.downsample_d_1(duck_0)
        downadd_1 = down_1 + down_d_1  # down + duck down summation
        downduck_1 = self.downduck_1(downadd_1)

        # duck block down 2
        down_d_2 = self.downsample_d_2(downduck_1)
        downadd_2 = down_2 + down_d_2
        downduck_2 = self.downduck_2(downadd_2)

        # duck block down 3
        down_d_3 = self.downsample_d_3(downduck_2)
        downadd_3 = down_3 + down_d_3
        downduck_3 = self.downduck_3(downadd_3)

        # duck block down 4
        down_d_4 = self.downsample_d_4(downduck_3)
        downadd_4 = down_4 + down_d_4
        downduck_4 = self.downduck_4(downadd_4)

        # duck block down 5
        down_d_5 = self.downsample_d_5(downduck_4)
        downadd_5 = down_5 + down_d_5

        # 2 res blocks (yellow arrow in article)
        res_1 = self.res_1(downadd_5)
        res_2 = self.res_2(res_1)

    # Up sampling (Origin Image + Duck Block / Neareset)
        up_4 = nn.Upsample(scale_factor=2, mode='nearest')(res_2)
        upadd_4 = downduck_4 + up_4
        upduck_4 = self.upduck_4(upadd_4)

        up_3 = nn.Upsample(scale_factor=2, mode='nearest')(upduck_4)
        upadd_3 = downduck_3 + up_3
        upduck_3 = self.upduck_3(upadd_3)

        up_2 = nn.Upsample(scale_factor=2, mode='nearest')(upduck_3)
        upadd_2 = downduck_2 + up_2
        upduck_2 = self.upduck_2(upadd_2)

        up_1 = nn.Upsample(scale_factor=2, mode='nearest')(upduck_2)
        upadd_1 = downduck_1 + up_1
        upduck_1 = self.upduck_1(upadd_1)

        up_0 = nn.Upsample(scale_factor=2, mode='nearest')(upduck_1)
        upadd_0 = duck_0 + up_0
        upduck_0 = self.upduck_0(upadd_0)

        output = self.output(upduck_0)
        output = self.softmax(output)
        return output