Deep Learning

Pytorch Pix2Pix 모델 구현

Kimhj 2024. 2. 16. 11:23
  • Pix2pix 모델은 A-B처럼 1:1 대응되는 paired한 이미지를 입력으로 해서 특정 입력데이터가 들어왔을 때, 어떤 이미지가 나올지를 예측하여 생성하는 생성 모델이다.
  • 예를 들면, 건물 segmentation을 입력으로 했을 때, 건물을 생성하는 이미지가 출력되는 것이 가장 대표적이다.

Pix2Pix

 

  • 기존에 활용되던 CGAN 모델과는 다르게 Pix2Pix 에서는 Noise vector 와  Class vector를 입력으로 받지 않고, 오직 이미지만을 입력으로 한다.

Pix2Pix

 

  • Discriminator는 두 가지의 입력데이터를 받는데, Generator에 넣었던 원본 이미지와 짝지어졌던 실제이미지, 그리고 생성된 이미지를 각각 입력받는다.
  • 그리고 discriminator는 pair를 비교해서 생성된 이미지인지, 혹은 진짜 이미지인지를 구분하도록 학습을 시키는 것이다.
  • 아래 사진에서 보다시피, Generator는 UNET 구조를 적용하여 skip-connection을 적용했고, Discriminator는 PatchGAN을 이용했다.

 

  • 모델 코드구현은 아래와 같이 layer.py 와 model.py 로 구분하여 작성했다.

 

  • layer.py
import os
import numpy as np

import torch
import torch.nn as nn

class DECBR2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True, norm="bnorm", relu=0.0):
        super().__init__()

        layers = []
        # layers += [nn.ReflectionPad2d(padding=padding)]
        layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding,
                                      bias=bias)]

        if not norm is None:
            if norm == "bnorm":
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            elif norm == "inorm":
                layers += [nn.InstanceNorm2d(num_features=out_channels)]

        if not relu is None and relu >= 0.0:
            layers += [nn.ReLU() if relu == 0 else nn.LeakyReLU(relu)]

        self.cbr = nn.Sequential(*layers)

    def forward(self, x):
        return self.cbr(x)

class CBR2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflection', bias=True, norm="bnorm", relu=0.0):
        super().__init__()

        layers = []

        if padding_mode == 'reflection':
            layers += [nn.ReflectionPad2d(padding)]
        elif padding_mode == 'replication':
            layers += [nn.ReplicationPad2d(padding)]
        elif padding_mode == 'constant':
            value = 0
            layers += [nn.ConstantPad2d(padding, value)]
        elif padding_mode == 'zeros':
            layers += [nn.ZeroPad2d(padding)]

        layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                             kernel_size=kernel_size, stride=stride, padding=0,
                             bias=bias)]

        if not norm is None:
            if norm == "bnorm":
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            elif norm == "inorm":
                layers += [nn.InstanceNorm2d(num_features=out_channels)]

        if not relu is None and relu >= 0.0:
            layers += [nn.ReLU() if relu == 0 else nn.LeakyReLU(relu)]

        self.cbr = nn.Sequential(*layers)

    def forward(self, x):
        return self.cbr(x)



class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, norm="bnorm", relu=0.0):
        super().__init__()

        layers = []

        # 1st conv
        layers += [CBR2d(in_channels=in_channels, out_channels=out_channels,
                         kernel_size=kernel_size, stride=stride, padding=padding,
                         bias=bias, norm=norm, relu=relu)]

        # 2nd conv
        layers += [CBR2d(in_channels=out_channels, out_channels=out_channels,
                         kernel_size=kernel_size, stride=stride, padding=padding,
                         bias=bias, norm=norm, relu=None)]

        self.resblk = nn.Sequential(*layers)

    def forward(self, x):
        return x + self.resblk(x)

 

  • model.py
# Pix2Pix
# https://arxiv.org/pdf/1611.07004.pdf
class Pix2Pix(nn.Module):
    def __init__(self, in_channels, out_channels, nker=64, norm="bnorm"):
        super(Pix2Pix, self).__init__()
        # Encoder  
        self.enc1 = CBR2d(in_channels, 1 * nker, kernel_size=4, padding=1,
                          norm=None, relu=0.2, stride=2)

        self.enc2 = CBR2d(1 * nker, 2 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc3 = CBR2d(2 * nker, 4 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc4 = CBR2d(4 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc5 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc6 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc7 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc8 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)
        # Decoder
        self.dec1 = DECBR2d(in_channels=8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.drop1 = nn.Dropout2d(0.5)
        
        self.dec2 = DECBR2d(in_channels=2*8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.drop2 = nn.Dropout2d(0.5)

        self.dec3 = DECBR2d(in_channels=2*8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.drop3 = nn.Dropout2d(0.5)

        self.dec4 = DECBR2d(in_channels=2*8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.dec5 = DECBR2d(in_channels=2*8*nker, out_channels=4*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.dec6 = DECBR2d(in_channels=2*4*nker, out_channels=2*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.dec7 = DECBR2d(in_channels=2*2*nker, out_channels=1*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0) 
        self.dec8 = DECBR2d(in_channels=2*1*nker, out_channels=out_channels, kernel_size=4, stride=2, padding=1, norm=None, relu=None) 
        
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)
        enc6 = self.enc6(enc5)
        enc7 = self.enc7(enc6)
        enc8 = self.enc8(enc7)

        # Decoder
        dec1 = self.dec1(enc8)
        drop1 = self.drop1(dec1)

        cat2 = torch.cat((drop1, enc7), dim=1)
        dec2 = self.dec2(cat2)
        drop2 = self.drop2(dec2)

        cat3 = torch.cat((drop2, enc6), dim=1)
        dec3 = self.dec3(cat3)
        drop3 = self.drop3(dec3)

        cat4 = torch.cat((drop3, enc5), dim=1)
        dec4 = self.dec4(cat4)

        cat5 = torch.cat((dec4, enc4), dim=1)
        dec5 = self.dec5(cat5)

        cat6 = torch.cat((dec5, enc3), dim=1)
        dec6 = self.dec6(cat6)

        cat7 = torch.cat((dec6, enc2), dim=1)
        dec7 = self.dec7(cat7)

        cat8 = torch.cat((dec7, enc1), dim=1)
        dec8 = self.dec8(cat8)

        x = torch.tanh(dec8)
        
   
class Discriminator(nn.Module):
    def __init__(self, in_channels, out_channels, nker=64, norm='bonrm'):
        super(Discriminator, self).__init__()

        self.enc1 = CBR2d(in_channels=1*in_channels, out_channels=1*nker, kernel_size=4, stride=2, padding=1, norm=None, relu=0.2, bias=False)
        self.enc2 = CBR2d(in_channels=1*nker, out_channels=2*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.2, bias=False)
        self.enc3 = CBR2d(in_channels=2*nker, out_channels=4*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.2, bias=False)
        self.enc4 = CBR2d(in_channels=4*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.2, bias=False)
        self.enc5 = CBR2d(in_channels=8*nker, out_channels=out_channels, kernel_size=4, stride=2, padding=1, norm=None, relu=None, bias=False)
        
    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)
        x = self.enc5(x)
        
        x = torch.sigmoid(x)

        return x
        
        return x