본문 바로가기
Artificial Intelligence/21. PyTorch

[PYTORCH] Custom collate_fn 구현 방법 및 7가지 가변 길이 시퀀스 해결 가이드

by Papa Martino V 2026. 3. 25.
728x90

Custom collate_fn 구현 방법
Custom collate_fn 구현 방법

 

PyTorch의 DataLoader는 기본적으로 모든 데이터 샘플의 크기가 동일하다고 가정하고 이를 단순히 스택(Stack)하여 배치를 만듭니다. 하지만 자연어 처리(NLP)나 오디오 분석과 같이 가변 길이 시퀀스(Variable-length sequences)를 다룰 때는 기본 방식이 에러를 발생시킵니다. 본 가이드에서는 collate_fn을 커스텀하여 복잡한 데이터 구조를 효율적인 텐서 배치로 변환하는 전문적인 해결책을 제시합니다.


1. Default Collate와 Custom Collate의 핵심 차이 및 해결 과제

기본 default_collate는 리스트 형태의 샘플을 받아 torch.stack()을 호출합니다. 만약 샘플들의 shape가 하나라도 다르면 학습은 중단됩니다. 이를 해결하기 위해 개발자는 데이터 로딩 파이프라인의 최종 단계인 collate_fn에서 패딩(Padding)이나 데이터 재구조화를 수행해야 합니다.

2. Custom collate_fn이 반드시 필요한 3가지 시나리오 비교

실무에서 기본 로더 대신 커스텀 함수를 작성해야 하는 대표적인 상황을 정리했습니다.

시나리오 데이터 특성 기본 로더의 문제점 Custom collate의 해결 방법
NLP / 가변 시퀀스 문장마다 단어 수가 다름 Stack 시 차원 불일치 에러 최대 길이에 맞춰 Zero-padding 수행
멀티모달 (Multi-modal) 텍스트, 이미지, 메타데이터 혼합 Dict/List 복합 구조 처리 불가 데이터 타입별 개별 텐서화 및 묶음
객체 탐지 (Detection) 이미지당 바운딩 박스 개수 다름 가변 길이 Target 텐서 스택 불가 List of Tensors 형태로 유지 또는 Packing
희소 데이터 (Sparse) 대부분이 0인 거대 행렬 메모리 효율성 급감 Sparse Tensor 포맷으로 직접 변환

3. 실무 즉시 적용 가능한 collate_fn Example 7가지

다양한 도메인의 가변 데이터를 처리하기 위한 실무용 코드 스니펫입니다.

Example 1: 텍스트 데이터 패딩 (Zero-Padding) 해결 방법

가장 흔한 사례로, 배차 내 가장 긴 문장을 기준으로 나머지 문장에 패딩을 채웁니다.

import torch
from torch.nn.utils.rnn import pad_sequence

def collate_text_padding(batch):
    # batch: [(text_tensor, label), (text_tensor, label), ...]
    texts = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    
    # 가변 길이 텐서들을 패딩하여 (Batch, Max_Len)으로 변환
    texts_padded = pad_sequence(texts, batch_first=True, padding_value=0)
    labels = torch.tensor(labels, dtype=torch.long)
    
    return texts_padded, labels
        

Example 2: 실제 시퀀스 길이(Lengths)를 포함한 배치 구성

RNN의 pack_padded_sequence를 사용하기 위해 원본 길이를 함께 반환합니다.

def collate_with_lengths(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True) # 길이순 정렬
    texts, labels = zip(*batch)
    
    lengths = torch.tensor([len(x) for x in texts])
    texts_padded = pad_sequence(texts, batch_first=True)
    
    return texts_padded, torch.tensor(labels), lengths
        

Example 3: 딕셔너리 형태의 멀티모달 데이터 처리

def collate_dict_data(batch):
    # batch: [{'img': T, 'cap': T, 'id': int}, ...]
    imgs = torch.stack([item['img'] for item in batch])
    caps = pad_sequence([item['cap'] for item in batch], batch_first=True)
    ids = [item['id'] for item in batch] # 숫자 리스트는 그대로 유지
    
    return {'images': imgs, 'captions': caps, 'ids': ids}
        

Example 4: 객체 탐지를 위한 가변 Target 처리 (Bbox)

이미지당 박스 개수가 다를 때 Target을 리스트로 유지하는 방법입니다.

def collate_detection(batch):
    images = torch.stack([item[0] for item in batch])
    targets = [item[1] for item in batch] # 리스트로 묶어서 반환
    
    return images, targets
        

Example 5: 특정 데이터 샘플링 오류(None) 제외 방법

데이터 로딩 중 손상된 파일(None)이 섞여 있을 때 배치를 깨뜨리지 않고 건너뛰는 해결책입니다.

def collate_skip_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0: return torch.tensor([]), torch.tensor([])
    return torch.utils.data.dataloader.default_collate(batch)
        

Example 6: 오디오 스펙트로그램 가변 시간축 패딩

def collate_audio(batch):
    # 오디오는 (Channel, Time) 구조인 경우가 많음
    waveforms = [item[0].t() for item in batch] # (Time, Channel)로 전치
    waveforms_padded = pad_sequence(waveforms, batch_first=True)
    return waveforms_padded.permute(0, 2, 1) # 다시 (Batch, Channel, Time)으로
        

Example 7: 정답 레이블의 원-핫 인코딩(One-hot) 동적 변환

def collate_one_hot(batch, num_classes=10):
    images, labels = zip(*batch)
    images = torch.stack(images)
    labels = torch.tensor(labels)
    one_hot = torch.nn.functional.one_hot(labels, num_classes=num_classes)
    return images, one_hot
        

4. Custom collate_fn 사용 시 주의할 점 3가지

  1. Worker 프로세스 오버헤드: 커스텀 함수 내에서 너무 무거운 연산(예: 복잡한 이미지 증강)을 수행하면 num_workers가 많아도 병목이 발생할 수 있습니다.
  2. 메모리 복사: 리스트를 텐서로 변환할 때 torch.as_tensor를 활용하여 불필요한 복사를 방지하십시오.
  3. Pinned Memory 호환성: pin_memory=True를 사용할 경우, 커스텀 함수가 반환하는 객체가 텐서이거나 텐서를 포함한 구조여야 정상 작동합니다.

5. 결론: 유연한 데이터 공급이 모델의 지능을 결정한다

PyTorch의 collate_fn은 원시 데이터셋과 신경망 연산 사이의 가교 역할을 합니다. 가변 길이 시퀀스나 복잡한 멀티모달 데이터를 정제된 텐서 묶음으로 변환하는 능력은 전문 딥러닝 엔지니어의 핵심 역량입니다. 본 가이드의 7가지 패턴을 활용하여 어떤 비정형 데이터라도 막힘없이 처리하는 파이프라인을 구축해 보시기 바랍니다.

참고 문헌 및 기술 출처

  • PyTorch Documentation: `torch.utils.data.DataLoader` - collate_fn 섹션
  • "Advanced PyTorch Data Pipelines" (NVIDIA Developer Blog, 2025)
  • PyTorch Forum: "How to create a dataloader with variable-size input"
728x90