- 이미지 데이터 수가 많지 않을 때, 적은 데이터로 모델을 학습시키기 어렵기 때문에, ImageNet과 같은 대량의 데이터로 사전학습 되어있는 모델을 import 해서 모델을 학습시키는 경우가 많다.
- Torchvision 의 models 내에 있는 사전학습 모델들을 활용해도 좋지만, 생각보다 모델 종류가 많지 않다.
- Torchvision models list (링크 : https://pytorch.org/vision/0.9/models.html)
- AlexNet, VGG, ResNet, SqueezeNet, DenseNet, Inception-v3, GoogLeNet, ShuffleNet-v2, MobileNet-V2, MobileNet-V3, ResNeXt, Wide ResNet, MNASNet
- 최근에 알게된 Timm 이라는 패키지는 Vision Transformer 계열뿐 아니라 여러 모델을 조합한 모델 종류들이 많고, 이미지 사이즈나 Layer Depth별로 Pretrain된 모델들이 많다.
- Timm Docs : https://timm.fast.ai/
- 모델 종류는 아래 코드로 간단하게 확인할 수 있다.
import timm
for model in timm.list_models():
print(model)
- 모델 사용법은 아래와 같이 create_model 을 사용해서 pretrain 모델을 불러오고, fc layer 부분을 현재 데이터셋에 맞게 수정해서 사용하면 된다.
import torch.nn as nn
import timm
class ConViT_Model(nn.Module):
def __init__(self):
super(ConViT_Model, self).__init__()
self.model = timm.create_model('convit_small', pretrained=True)
self.model.head_drop = nn.Dropout(0.5)
self.model.head = nn.Linear(self.model.head.in_features, 5) # num of class=5
def forward(self, x):
x = self.model(x)
return x
'Deep Learning' 카테고리의 다른 글
편향(Bias)와 분산(Variance) 개념 정리 (0) | 2023.12.11 |
---|---|
Representation Learning (1) | 2023.12.01 |
딥러닝 서버 세팅-우분투 설치 후 gpu 환경 설치 (0) | 2023.10.31 |
Tensor를 detach() 하는 이유 (0) | 2023.10.19 |
비전 트랜스포머(Vision Transformer, ViT) 개념, 설명 (0) | 2023.10.10 |