- random split
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
image_transform = {
'train': transforms.Compose([
transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# transforms.CenterCrop(IMAGE_SIZE),
# transforms.RandomHorizontalFlip(),
# transforms.RandomAffine(degrees=10, translate=(10/224, 0)),
# transforms.ColorJitter(contrast=0.5, brightness=0.5),
]),
'val': transforms.Compose([
transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Create Dataset
trainset = ImageDataset(train_meta, image_transform=image_transform['train'])
testset = ImageDataset(test_meta, image_transform=image_transform['val'])
# Data Split
trainset, validset = torch.utils.data.random_split(trainset, [0.8, 0.2])
# DataLoader
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(validset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
print(f"Trainset: { trainset.__len__() }, Validset: { validset.__len__() }, Testset: { testset.__len__() }")
- 클래스 비율 유지하면서 random split
- sk-learn의 train_test_split 사용하고, 클래스 비율 유지하면서 idx 번호로 split
'Python Code' 카테고리의 다른 글
json 파일을 읽고 DataFrame으로 바꾸는 코드 (0) | 2023.11.21 |
---|---|
Streamlit (파이썬 웹기반 시각화툴) (0) | 2023.11.14 |
matplotlib subplot 으로 이미지 여러개 확인하기 (1) | 2023.11.02 |
ViT(Vision Transformer) 비전 트랜스포머 Pytorch 코드 (0) | 2023.11.01 |
Linux 우분투 cuda 버전/GPU 사용량 확인 (0) | 2023.10.29 |