Python Code

ViT(Vision Transformer) 비전 트랜스포머 Pytorch 코드

Kimhj 2023. 11. 1. 14:38

 

  • 이미지 데이터셋 생성
from PIL import Image as Image
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

class ImageDataset(Dataset):
    '''
    dataset class overloads the __init__, __len__, __getitem__ methods of the Dataset class. 
    
    Parameters :
        df:  DataFrame object for the csv file.
        data_path: Location of the dataset.
        image_transform: Transformations to apply to the image.
        train: A boolean indicating whether it is a training_set or not.
    '''
    def __init__(self,df, data_path, image_transform=True, train=True):
        super(Dataset, self).__init__()

        self.df = df
        self.data_path = data_path
        self.image_transform = image_transform
        self.train = train
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_id = self.df['id_code'][idx]
        image = Image.open(f"{self.data_path}/{image_id}.png")
        if self.image_transform:
            image = self.image_transform(image) # Applies transformation to the image.
        if self.train:  # train mode
            label = self.df['diagnosis'][idx]
            return image, label
        else:   # eval mode
            return image

 

  • 이미지 전처리(Pre-proccessing)
# define transform
train_path = "./train_images"

# 아래 전처리 부분은 모델이나 목적에 따라 custom 하게 수정 필요
image_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Create train Dataset
data_set = ImageDataset(train_meta, train_path, image_transform=image_transform)

# Split
train_set,valid_set = torch.utils.data.random_split(data_set,[0.9, 0.1])

# Create DataLoader
train_dataloader = DataLoader(train_set, batch_size=32, shuffle=True)
val_dataloader = DataLoader(valid_set, batch_size=32, shuffle=False)
print(f"Trainset: { train_set.__len__() }, Validset: { valid_set.__len__() }")

 

  • ViT 모델 클래스 생성
    • transformer 모델부분 주의해서 설계할 것 : 인코더만 있는 모델인지, 혹은 인코더 디코더 전부 사용하는 모델인지 구분 필요
# ViT model

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, num_heads, num_layers):
        super(VisionTransformer, self).__init__()

        self.num_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, padding=0)
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = nn.Transformer(d_model=dim, nhead=num_heads, num_encoder_layers=num_layers)
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embedding(x)
        x = x.view(B, self.num_patches, -1)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.positional_embedding  # torch.Size([32, 197, 768])
        tgt = torch.empty((x.size(0), x.size(1), x.size(2)), dtype=x.dtype, device=x.device)  # 'tgt' 파라미터를 생성 --> NLP에서 사용하는 디코더 없이 이미지에 인코더만 사용하므로, Empty 한 벡터 넣어준다. 
        x = self.transformer(x, tgt)

        x = x[:, 0, :]  # first token은 분류 토큰
        x = self.fc(x)

        return x

 

  • ViT 모델 정의
import warnings
warnings.filterwarnings('ignore')

# Define parameters
image_size=224
patch_size=16
num_classes=5
dim=768
num_heads=12
num_layers=12

model = VisionTransformer(image_size, patch_size, num_classes, dim, num_heads, num_layers).to(device)
model

 

  • 학습/평가 함수 생성
def train_one_epoch(dataloader, model, loss_fn, optimizer, device):
    total = 0
    correct = 0
    running_loss = 0.0
    
    model.train()   # train mode
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        outputs = model(x)  # 목표 데이터를 모델에 전달
        print(outputs.shape)
        print(outputs)
        # print(outputs.shape, y.shape)       # torch.Size([197, 5]) torch.Size([32])
        loss = loss_fn(outputs, y)
        
        # Calculate Results
        running_loss += loss.item()
        total += y.size(0)
        preds = outputs.argmax(dim=1).cpu().detach() 
        correct += (preds==y.cpu().detach()).sum().item()

        optimizer.zero_grad()   # Gradient values are set to zero
        loss.backward()         # Back-propagation
        optimizer.step()        # Update weights
    
    # Add results
    epoch_loss = running_loss/len(dataloader)
    acc = (correct / total) * 100

    print(f"Train Loss: { epoch_loss }, Train Acc: { acc }, Correct: [{ correct }/{ total }]")

    return epoch_loss, acc
    
    
   def valid_one_epoh(dataloader, model, loss_fn, device):
    total_valid = 0
    correct_valid = 0
    running_loss_valid = 0.0
    
    model.eval()   # train mode
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            outputs = model(x)  # 목표 데이터를 모델에 전달
            loss = loss_fn(outputs, y)
        
            # Calculate Results
            total_valid += y.size(0)
            preds = outputs.argmax(dim=1).cpu().detach() # Index for the highest score for all the samples in the batch.
            correct_valid += (preds==y.cpu().detach()).sum().item()
            running_loss_valid += loss.item()

        # Add results
        epoch_loss_valid = running_loss_valid / len(dataloader)
        acc_valid = (correct_valid/total_valid) * 100

    print(f"Valid Loss: { epoch_loss_valid }, Valid Acc: { acc_valid }, Correct: [{ correct_valid }/{ total_valid }]")

    return epoch_loss_valid, acc_valid

 

  • 학습 및 모델 저장 코드
loss_fn   = nn.CrossEntropyLoss(weight=class_weights) #CrossEntropyLoss with class_weights.
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)               # ADAM

n_epochs = 60
#Call the optimize function.
#Lists to store losses for all the epochs.
train_losses = []
train_acc_list = []
valid_losses = []
valid_acc_list = []

for epoch in range(n_epochs):
    print(f'\nEpoch {epoch+1}/{n_epochs}')
    print('-------------------------------')
    # train
    train_loss, train_acc = train_one_epoch(train_dataloader,model,loss_fn,optimizer,device) #Calls the train function.
    train_losses.append(train_loss)
    train_acc_list.append(train_acc)
    # valid
    valid_loss, valid_acc = valid_one_epoh(val_dataloader,model,loss_fn,device) #Calls the validate function.
    valid_losses.append(valid_loss)
    valid_acc_list.append(valid_acc)

print('\nTraining has completed!')

# save model
dat_path = './ViT.pth'
torch.save(model.state_dict(), dat_path)

 

  • 학습결과 시각화 (Loss Plot / Accuracy Plot)
# Loss Curve
epochs = range(n_epochs)
plt.plot(epochs, train_losses, 'g', label='Training loss')
plt.plot(epochs, valid_losses, 'b', label='validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()


# Acc Curve
epochs = range(n_epochs)
plt.plot(epochs, train_acc_list, label='Training Acc')
plt.plot(epochs, valid_acc_list, label='validation Acc')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()