본문 바로가기
Artificial Intelligence/21. PyTorch

[PYTORCH] Dataset 클래스의 __len__과 __getitem__ 구현 방법 및 효율적 데이터 로딩 해결 가이드 7가지

by Papa Martino V 2026. 3. 25.
728x90

데이터 로딩 아키텍처
데이터 로딩 아키텍처

 

파이토치(PyTorch)를 활용한 딥러닝 프로젝트에서 성능의 병목 현상은 모델의 아키텍처보다 데이터 로딩 아키텍처에서 발생하는 경우가 많습니다. 본 가이드에서는 torch.utils.data.Dataset 커스텀 클래스를 통해 대규모 데이터를 효율적으로 관리하는 전문적인 방법론을 제시합니다.


1. PyTorch 데이터 파이프라인의 핵심: 추상화와 인터페이스

PyTorch의 데이터 관리 체계는 크게 DatasetDataLoader로 나뉩니다. Dataset은 데이터셋의 구조를 정의하고 개별 샘플을 가져오는 역할을 하며, DataLoader는 이를 병렬로 로드하고 셔플링, 배치 생성을 담당합니다.

커스텀 Dataset을 구축하기 위해서는 반드시 torch.utils.data.Dataset을 상속받아 다음 두 가지 매직 메서드를 구현해야 합니다.

  • __len__(self): 데이터셋의 전체 크기를 반환합니다.
  • __getitem__(self, idx): 인덱스(idx)에 해당하는 단일 샘플(입력 데이터 및 레이블)을 반환합니다.

2. Dataset 메서드의 역할 및 구현 차이 분석

단순한 구현을 넘어, 효율적인 리소스 관리를 위해 각 메서드가 갖는 기술적 차이점과 최적화 포인트를 아래 표로 정리하였습니다.

구분 __len__ 메서드 __getitem__ 메서드
주요 목적 전체 데이터 샘플 수 정의 인덱스 기반 데이터 추출 및 변환(Transform)
호출 시점 DataLoader 시작 시, 인덱스 샘플링 시 DataLoader가 다음 배치를 준비할 때 반복 호출
성능 영향도 무시할 수 있을 정도로 낮음 (O(1)) 매우 높음 (I/O 작업, 데이터 전처리 포함)
반환 타입 Integer (정수) Tensor, Dict, Tuple 등 유연한 형태
구현 핵심 인덱스 범위 초과 방지 Lazy Loading(지연 로딩) 적용 및 메모리 관리

3. 실무 중심의 커스텀 Dataset 구현 Example 7가지

실제 현업 개발 환경에서 바로 적용 가능한 시나리오별 코드 예제입니다.

Example 1: CSV 파일을 활용한 정형 데이터 Dataset

import pandas as pd
import torch
from torch.utils.data import Dataset

class CSVDataset(Dataset):
    def __init__(self, file_path):
        self.df = pd.read_csv(file_path)
        self.x = self.df.iloc[:, :-1].values
        self.y = self.df.iloc[:, -1].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 텐서 변환 후 반환
        return torch.tensor(self.x[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.long)
        

Example 2: 이미지 폴더 구조를 이용한 Vision Dataset

import os
from PIL import Image

class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, 0  # 임시 레이블
        

Example 3: 대용량 데이터를 위한 Lazy Loading 방식

메모리에 모든 데이터를 올릴 수 없는 경우, 인덱스가 호출될 때 파일에서 로드합니다.

class LazyTextDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # 호출 시점에 파일을 읽어 메모리 과부하 방지
        with open(self.file_list[idx], 'r') as f:
            data = f.read()
        return data
        

Example 4: 다중 모달(Multi-modal) 데이터 처리

class MultiModalDataset(Dataset):
    def __init__(self, text_data, image_paths):
        self.text_data = text_data
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        text = self.text_data[idx]
        image = Image.open(self.image_paths[idx])
        return {"text": text, "image": image}
        

Example 5: 시계열(Time-series) 슬라이딩 윈도우 데이터셋

import numpy as np

class TimeSeriesDataset(Dataset):
    def __init__(self, data, window_size):
        self.data = data
        self.window_size = window_size

    def __len__(self):
        return len(self.data) - self.window_size

    def __getitem__(self, idx):
        x = self.data[idx:idx+self.window_size]
        y = self.data[idx+self.window_size]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
        

Example 6: 분할 로딩을 위한 파티셔닝 데이터셋

class PartitionDataset(Dataset):
    def __init__(self, full_dataset, partition_idx):
        self.data = full_dataset[partition_idx]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
        

Example 7: Augmentation이 포함된 실무형 Dataset

from torchvision import transforms

class AugmentDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        return self.transform(img)
        

4. 데이터 로딩 성능 해결을 위한 3가지 골든 룰

전문적인 AI 엔지니어라면 다음 사항을 반드시 고려하여 __getitem__을 설계해야 합니다.

  1. I/O 병목 방지: 가능하다면 SSD 환경에서 데이터를 읽고, 데이터가 수백만 개일 경우 LMDB나 HDF5 같은 포맷을 사용하십시오.
  2. NumPy 활용: 파이썬 리스트보다 NumPy 배열이 인덱싱 속도에서 압도적입니다.
  3. Worker 프로세스 튜닝: DataLoader의 num_workers를 CPU 코어 수에 맞춰 조정하여 __getitem__의 호출을 병렬화하십시오.

5. 결론 및 요약

PyTorch의 Dataset 클래스 구현은 단순한 코드 작성을 넘어 효율적인 데이터 흐름을 설계하는 작업입니다. __len__으로 정체성을 부여하고, __getitem__으로 유연한 데이터 변환을 수행함으로써 모델이 학습에만 집중할 수 있는 최적의 환경을 제공할 수 있습니다.

내용 출처 및 참고 자료

  • PyTorch Official Documentation: `torch.utils.data` Module
  • Deep Learning with PyTorch (Manning Publications)
  • PyTorch Tutorials: "Writing Custom Datasets, DataLoaders and Transforms"
728x90