- 기존 UNET 구조에서 skip connection 결과가 Decoder 에 추가될 때, Attention score를 적용해보기 위해 Attention Layer를 추가
- Codes
# 02. Attention UNET Model
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
output = self.conv(x)
return output
class UpConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
output = self.up(x)
return output
class AttentionBlock(nn.Module):
"""
Attention block with learnable parameters.
"""
def __init__(self, F_g, F_l, n_coefficients):
super(AttentionBlock, self).__init__()
self.W_gate = nn.Sequential(
nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(n_coefficients),
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(n_coefficients)
)
self.attn_score = nn.Sequential(
nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, gate, skip_connection):
g1 = self.W_gate(gate)
x1 = self.W_x(skip_connection)
attn_score = self.relu(g1 + x1)
attn_score = self.attn_score(attn_score)
out = skip_connection*attn_score
return out
class AttentionUNet(nn.Module):
def __init__(self, img_ch=1, output_ch=32):
super(AttentionUNet, self).__init__()
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = ConvBlock(img_ch, 64)
self.Conv2 = ConvBlock(64, 128)
self.Conv3 = ConvBlock(128, 256)
self.Conv4 = ConvBlock(256, 512)
self.Conv5 = ConvBlock(512, 1024)
self.Up5 = UpConv(1024, 512)
self.Att5 = AttentionBlock(F_g=512, F_l=512, n_coefficients=256)
self.UpConv5 = ConvBlock(1024, 512)
self.Up4 = UpConv(512, 256)
self.Att4 = AttentionBlock(F_g=256, F_l=256, n_coefficients=128)
self.UpConv4 = ConvBlock(512, 256)
self.Up3 = UpConv(256, 128)
self.Att3 = AttentionBlock(F_g=128, F_l=128, n_coefficients=64)
self.Upconv3 = ConvBlock(256, 128)
self.Up2 = UpConv(128, 64)
self.Att2 = AttentionBlock(F_g=64, F_l=64, n_coefficients=32)
self.UpConv2 = ConvBlock(128, 64)
self.Conv = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
self.softmax = nn.Softmax2d()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Encoder
e1 = self.Conv1(x)
e2 = self.MaxPool(e1)
e2 = self.Conv2(e2)
e3 = self.MaxPool(e2)
e3 = self.Conv3(e3)
e4 = self.MaxPool(e3)
e4 = self.Conv4(e4)
e5 = self.MaxPool(e4)
e5 = self.Conv5(e5)
# Decoder
d5 = self.Up5(e5)
s4 = self.Att5(gate=d5, skip_connection=e4) # skip connection
d5 = torch.cat((s4, d5), dim=1)
d5 = self.UpConv5(d5)
d4 = self.Up4(d5)
s3 = self.Att4(gate=d4, skip_connection=e3)
d4 = torch.cat((s3, d4), dim=1)
d4 = self.UpConv4(d4)
d3 = self.Up3(d4)
s2 = self.Att3(gate=d3, skip_connection=e2)
d3 = torch.cat((s2, d3), dim=1)
d3 = self.Upconv3(d3)
d2 = self.Up2(d3)
s1 = self.Att2(gate=d2, skip_connection=e1)
d2 = torch.cat((s1, d2), dim=1)
d2 = self.UpConv2(d2)
output = self.Conv(d2)
output = self.softmax(output)
# output = self.sigmoid(output)
return output