Python Code

Pytorch 이미지 분할(Image Split) 방법

Kimhj 2023. 11. 6. 14:25
  • random split
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

image_transform = {
    'train': transforms.Compose([
        transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        # transforms.CenterCrop(IMAGE_SIZE),
        # transforms.RandomHorizontalFlip(),
        # transforms.RandomAffine(degrees=10, translate=(10/224, 0)),
        # transforms.ColorJitter(contrast=0.5, brightness=0.5),
    ]),
    'val': transforms.Compose([
        transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Create Dataset
trainset = ImageDataset(train_meta, image_transform=image_transform['train'])
testset = ImageDataset(test_meta, image_transform=image_transform['val'])
# Data Split
trainset, validset = torch.utils.data.random_split(trainset, [0.8, 0.2])
# DataLoader
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(validset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
print(f"Trainset: { trainset.__len__() }, Validset: { validset.__len__() }, Testset: { testset.__len__() }")

 

 

  • 클래스 비율 유지하면서 random split
    • sk-learn의 train_test_split 사용하고, 클래스 비율 유지하면서 idx 번호로 split 
import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms from sklearn.model_selection import train_test_split image_transform = { 'train': transforms.Compose([ transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # transforms.CenterCrop(IMAGE_SIZE), # transforms.RandomHorizontalFlip(), # transforms.RandomAffine(degrees=10, translate=(10/224, 0)), # transforms.ColorJitter(contrast=0.5, brightness=0.5), ]), 'val': transforms.Compose([ transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # Stratified Sampling for train and val train_idx, validation_idx = train_test_split(np.arange(len(train_meta)), test_size=0.1, random_state=999, shuffle=True, stratify=train_meta.label) train_df = train_meta[train_meta.index.isin(train_idx)].reset_index(drop=True) valid_df = train_meta[train_meta.index.isin(validation_idx)].reset_index(drop=True) # Create Dataset trainset = ImageDataset(train_df, image_transform=image_transform['train']) validset = ImageDataset(valid_df, image_transform=image_transform['val']) testset = ImageDataset(test_meta, image_transform=image_transform['val']) # Dataloader for train and val train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(validset, batch_size=BATCH_SIZE, shuffle=False) test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False) print(f"Trainset: { trainset.__len__() }, Validset: { validset.__len__() }, Testset: { testset.__len__() }")