- Pix2pix 모델은 A-B처럼 1:1 대응되는 paired한 이미지를 입력으로 해서 특정 입력데이터가 들어왔을 때, 어떤 이미지가 나올지를 예측하여 생성하는 생성 모델이다.
- 예를 들면, 건물 segmentation을 입력으로 했을 때, 건물을 생성하는 이미지가 출력되는 것이 가장 대표적이다.
- 기존에 활용되던 CGAN 모델과는 다르게 Pix2Pix 에서는 Noise vector 와 Class vector를 입력으로 받지 않고, 오직 이미지만을 입력으로 한다.
- Discriminator는 두 가지의 입력데이터를 받는데, Generator에 넣었던 원본 이미지와 짝지어졌던 실제이미지, 그리고 생성된 이미지를 각각 입력받는다.
- 그리고 discriminator는 pair를 비교해서 생성된 이미지인지, 혹은 진짜 이미지인지를 구분하도록 학습을 시키는 것이다.
- 아래 사진에서 보다시피, Generator는 UNET 구조를 적용하여 skip-connection을 적용했고, Discriminator는 PatchGAN을 이용했다.
- 모델 코드구현은 아래와 같이 layer.py 와 model.py 로 구분하여 작성했다.
- layer.py
import os
import numpy as np
import torch
import torch.nn as nn
class DECBR2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True, norm="bnorm", relu=0.0):
super().__init__()
layers = []
# layers += [nn.ReflectionPad2d(padding=padding)]
layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding,
bias=bias)]
if not norm is None:
if norm == "bnorm":
layers += [nn.BatchNorm2d(num_features=out_channels)]
elif norm == "inorm":
layers += [nn.InstanceNorm2d(num_features=out_channels)]
if not relu is None and relu >= 0.0:
layers += [nn.ReLU() if relu == 0 else nn.LeakyReLU(relu)]
self.cbr = nn.Sequential(*layers)
def forward(self, x):
return self.cbr(x)
class CBR2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflection', bias=True, norm="bnorm", relu=0.0):
super().__init__()
layers = []
if padding_mode == 'reflection':
layers += [nn.ReflectionPad2d(padding)]
elif padding_mode == 'replication':
layers += [nn.ReplicationPad2d(padding)]
elif padding_mode == 'constant':
value = 0
layers += [nn.ConstantPad2d(padding, value)]
elif padding_mode == 'zeros':
layers += [nn.ZeroPad2d(padding)]
layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=0,
bias=bias)]
if not norm is None:
if norm == "bnorm":
layers += [nn.BatchNorm2d(num_features=out_channels)]
elif norm == "inorm":
layers += [nn.InstanceNorm2d(num_features=out_channels)]
if not relu is None and relu >= 0.0:
layers += [nn.ReLU() if relu == 0 else nn.LeakyReLU(relu)]
self.cbr = nn.Sequential(*layers)
def forward(self, x):
return self.cbr(x)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, norm="bnorm", relu=0.0):
super().__init__()
layers = []
# 1st conv
layers += [CBR2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, norm=norm, relu=relu)]
# 2nd conv
layers += [CBR2d(in_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, norm=norm, relu=None)]
self.resblk = nn.Sequential(*layers)
def forward(self, x):
return x + self.resblk(x)
- model.py
# Pix2Pix
# https://arxiv.org/pdf/1611.07004.pdf
class Pix2Pix(nn.Module):
def __init__(self, in_channels, out_channels, nker=64, norm="bnorm"):
super(Pix2Pix, self).__init__()
# Encoder
self.enc1 = CBR2d(in_channels, 1 * nker, kernel_size=4, padding=1,
norm=None, relu=0.2, stride=2)
self.enc2 = CBR2d(1 * nker, 2 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc3 = CBR2d(2 * nker, 4 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc4 = CBR2d(4 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc5 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc6 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc7 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc8 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
# Decoder
self.dec1 = DECBR2d(in_channels=8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.drop1 = nn.Dropout2d(0.5)
self.dec2 = DECBR2d(in_channels=2*8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.drop2 = nn.Dropout2d(0.5)
self.dec3 = DECBR2d(in_channels=2*8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.drop3 = nn.Dropout2d(0.5)
self.dec4 = DECBR2d(in_channels=2*8*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.dec5 = DECBR2d(in_channels=2*8*nker, out_channels=4*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.dec6 = DECBR2d(in_channels=2*4*nker, out_channels=2*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.dec7 = DECBR2d(in_channels=2*2*nker, out_channels=1*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.0)
self.dec8 = DECBR2d(in_channels=2*1*nker, out_channels=out_channels, kernel_size=4, stride=2, padding=1, norm=None, relu=None)
def forward(self, x):
# Encoder
enc1 = self.enc1(x)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
enc5 = self.enc5(enc4)
enc6 = self.enc6(enc5)
enc7 = self.enc7(enc6)
enc8 = self.enc8(enc7)
# Decoder
dec1 = self.dec1(enc8)
drop1 = self.drop1(dec1)
cat2 = torch.cat((drop1, enc7), dim=1)
dec2 = self.dec2(cat2)
drop2 = self.drop2(dec2)
cat3 = torch.cat((drop2, enc6), dim=1)
dec3 = self.dec3(cat3)
drop3 = self.drop3(dec3)
cat4 = torch.cat((drop3, enc5), dim=1)
dec4 = self.dec4(cat4)
cat5 = torch.cat((dec4, enc4), dim=1)
dec5 = self.dec5(cat5)
cat6 = torch.cat((dec5, enc3), dim=1)
dec6 = self.dec6(cat6)
cat7 = torch.cat((dec6, enc2), dim=1)
dec7 = self.dec7(cat7)
cat8 = torch.cat((dec7, enc1), dim=1)
dec8 = self.dec8(cat8)
x = torch.tanh(dec8)
class Discriminator(nn.Module):
def __init__(self, in_channels, out_channels, nker=64, norm='bonrm'):
super(Discriminator, self).__init__()
self.enc1 = CBR2d(in_channels=1*in_channels, out_channels=1*nker, kernel_size=4, stride=2, padding=1, norm=None, relu=0.2, bias=False)
self.enc2 = CBR2d(in_channels=1*nker, out_channels=2*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.2, bias=False)
self.enc3 = CBR2d(in_channels=2*nker, out_channels=4*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.2, bias=False)
self.enc4 = CBR2d(in_channels=4*nker, out_channels=8*nker, kernel_size=4, stride=2, padding=1, norm=norm, relu=0.2, bias=False)
self.enc5 = CBR2d(in_channels=8*nker, out_channels=out_channels, kernel_size=4, stride=2, padding=1, norm=None, relu=None, bias=False)
def forward(self, x):
x = self.enc1(x)
x = self.enc2(x)
x = self.enc3(x)
x = self.enc4(x)
x = self.enc5(x)
x = torch.sigmoid(x)
return x
return x
'Deep Learning' 카테고리의 다른 글
PGGAN, StyleGAN 설명, Torch 코드 구현 (0) | 2024.02.19 |
---|---|
[TORCH_UECUDA_DSA] CUDA 오류 (0) | 2024.02.05 |
conda env 실행 시 오류 해결 (0) | 2024.01.31 |
pytorch gpu(cuda) 정보 확인 (0) | 2024.01.17 |
anaconda 다중 사용자(multi-user) 환경 세팅 (0) | 2024.01.15 |