Deep Learning

이미지 전이학습(Image Transfer learning) 패키지 Timm

Kimhj 2023. 11. 3. 13:55
  • 이미지 데이터 수가 많지 않을 때, 적은 데이터로 모델을 학습시키기 어렵기 때문에, ImageNet과 같은 대량의 데이터로 사전학습 되어있는 모델을 import 해서 모델을 학습시키는 경우가 많다.
  • Torchvision 의 models 내에 있는 사전학습 모델들을 활용해도 좋지만, 생각보다 모델 종류가 많지 않다.
  • Torchvision models list (링크 : https://pytorch.org/vision/0.9/models.html)
    • AlexNet, VGG, ResNet, SqueezeNet, DenseNet, Inception-v3, GoogLeNet, ShuffleNet-v2, MobileNet-V2, MobileNet-V3, ResNeXt, Wide ResNet, MNASNet
  • 최근에 알게된 Timm 이라는 패키지는 Vision Transformer 계열뿐 아니라 여러 모델을 조합한 모델 종류들이 많고, 이미지 사이즈나 Layer Depth별로 Pretrain된 모델들이 많다.
  • Timm Docs : https://timm.fast.ai/
  • 모델 종류는 아래 코드로 간단하게 확인할 수 있다.
import timm

for model in timm.list_models():
    print(model)

 

  • 모델 사용법은 아래와 같이 create_model 을 사용해서 pretrain 모델을 불러오고, fc layer 부분을 현재 데이터셋에 맞게 수정해서 사용하면 된다.
import torch.nn as nn
import timm

class ConViT_Model(nn.Module):
    def __init__(self):
        super(ConViT_Model, self).__init__()

        self.model = timm.create_model('convit_small', pretrained=True)
        self.model.head_drop = nn.Dropout(0.5)
        self.model.head = nn.Linear(self.model.head.in_features, 5) # num of class=5 

    def forward(self, x):
        x = self.model(x)
        return x