
PyTorch의 DataLoader는 기본적으로 모든 데이터 샘플의 크기가 동일하다고 가정하고 이를 단순히 스택(Stack)하여 배치를 만듭니다. 하지만 자연어 처리(NLP)나 오디오 분석과 같이 가변 길이 시퀀스(Variable-length sequences)를 다룰 때는 기본 방식이 에러를 발생시킵니다. 본 가이드에서는 collate_fn을 커스텀하여 복잡한 데이터 구조를 효율적인 텐서 배치로 변환하는 전문적인 해결책을 제시합니다.
1. Default Collate와 Custom Collate의 핵심 차이 및 해결 과제
기본 default_collate는 리스트 형태의 샘플을 받아 torch.stack()을 호출합니다. 만약 샘플들의 shape가 하나라도 다르면 학습은 중단됩니다. 이를 해결하기 위해 개발자는 데이터 로딩 파이프라인의 최종 단계인 collate_fn에서 패딩(Padding)이나 데이터 재구조화를 수행해야 합니다.
2. Custom collate_fn이 반드시 필요한 3가지 시나리오 비교
실무에서 기본 로더 대신 커스텀 함수를 작성해야 하는 대표적인 상황을 정리했습니다.
| 시나리오 | 데이터 특성 | 기본 로더의 문제점 | Custom collate의 해결 방법 |
|---|---|---|---|
| NLP / 가변 시퀀스 | 문장마다 단어 수가 다름 | Stack 시 차원 불일치 에러 | 최대 길이에 맞춰 Zero-padding 수행 |
| 멀티모달 (Multi-modal) | 텍스트, 이미지, 메타데이터 혼합 | Dict/List 복합 구조 처리 불가 | 데이터 타입별 개별 텐서화 및 묶음 |
| 객체 탐지 (Detection) | 이미지당 바운딩 박스 개수 다름 | 가변 길이 Target 텐서 스택 불가 | List of Tensors 형태로 유지 또는 Packing |
| 희소 데이터 (Sparse) | 대부분이 0인 거대 행렬 | 메모리 효율성 급감 | Sparse Tensor 포맷으로 직접 변환 |
3. 실무 즉시 적용 가능한 collate_fn Example 7가지
다양한 도메인의 가변 데이터를 처리하기 위한 실무용 코드 스니펫입니다.
Example 1: 텍스트 데이터 패딩 (Zero-Padding) 해결 방법
가장 흔한 사례로, 배차 내 가장 긴 문장을 기준으로 나머지 문장에 패딩을 채웁니다.
import torch
from torch.nn.utils.rnn import pad_sequence
def collate_text_padding(batch):
# batch: [(text_tensor, label), (text_tensor, label), ...]
texts = [item[0] for item in batch]
labels = [item[1] for item in batch]
# 가변 길이 텐서들을 패딩하여 (Batch, Max_Len)으로 변환
texts_padded = pad_sequence(texts, batch_first=True, padding_value=0)
labels = torch.tensor(labels, dtype=torch.long)
return texts_padded, labels
Example 2: 실제 시퀀스 길이(Lengths)를 포함한 배치 구성
RNN의 pack_padded_sequence를 사용하기 위해 원본 길이를 함께 반환합니다.
def collate_with_lengths(batch):
batch.sort(key=lambda x: len(x[0]), reverse=True) # 길이순 정렬
texts, labels = zip(*batch)
lengths = torch.tensor([len(x) for x in texts])
texts_padded = pad_sequence(texts, batch_first=True)
return texts_padded, torch.tensor(labels), lengths
Example 3: 딕셔너리 형태의 멀티모달 데이터 처리
def collate_dict_data(batch):
# batch: [{'img': T, 'cap': T, 'id': int}, ...]
imgs = torch.stack([item['img'] for item in batch])
caps = pad_sequence([item['cap'] for item in batch], batch_first=True)
ids = [item['id'] for item in batch] # 숫자 리스트는 그대로 유지
return {'images': imgs, 'captions': caps, 'ids': ids}
Example 4: 객체 탐지를 위한 가변 Target 처리 (Bbox)
이미지당 박스 개수가 다를 때 Target을 리스트로 유지하는 방법입니다.
def collate_detection(batch):
images = torch.stack([item[0] for item in batch])
targets = [item[1] for item in batch] # 리스트로 묶어서 반환
return images, targets
Example 5: 특정 데이터 샘플링 오류(None) 제외 방법
데이터 로딩 중 손상된 파일(None)이 섞여 있을 때 배치를 깨뜨리지 않고 건너뛰는 해결책입니다.
def collate_skip_none(batch):
batch = list(filter(lambda x: x is not None, batch))
if len(batch) == 0: return torch.tensor([]), torch.tensor([])
return torch.utils.data.dataloader.default_collate(batch)
Example 6: 오디오 스펙트로그램 가변 시간축 패딩
def collate_audio(batch):
# 오디오는 (Channel, Time) 구조인 경우가 많음
waveforms = [item[0].t() for item in batch] # (Time, Channel)로 전치
waveforms_padded = pad_sequence(waveforms, batch_first=True)
return waveforms_padded.permute(0, 2, 1) # 다시 (Batch, Channel, Time)으로
Example 7: 정답 레이블의 원-핫 인코딩(One-hot) 동적 변환
def collate_one_hot(batch, num_classes=10):
images, labels = zip(*batch)
images = torch.stack(images)
labels = torch.tensor(labels)
one_hot = torch.nn.functional.one_hot(labels, num_classes=num_classes)
return images, one_hot
4. Custom collate_fn 사용 시 주의할 점 3가지
- Worker 프로세스 오버헤드: 커스텀 함수 내에서 너무 무거운 연산(예: 복잡한 이미지 증강)을 수행하면
num_workers가 많아도 병목이 발생할 수 있습니다. - 메모리 복사: 리스트를 텐서로 변환할 때
torch.as_tensor를 활용하여 불필요한 복사를 방지하십시오. - Pinned Memory 호환성:
pin_memory=True를 사용할 경우, 커스텀 함수가 반환하는 객체가 텐서이거나 텐서를 포함한 구조여야 정상 작동합니다.
5. 결론: 유연한 데이터 공급이 모델의 지능을 결정한다
PyTorch의 collate_fn은 원시 데이터셋과 신경망 연산 사이의 가교 역할을 합니다. 가변 길이 시퀀스나 복잡한 멀티모달 데이터를 정제된 텐서 묶음으로 변환하는 능력은 전문 딥러닝 엔지니어의 핵심 역량입니다. 본 가이드의 7가지 패턴을 활용하여 어떤 비정형 데이터라도 막힘없이 처리하는 파이프라인을 구축해 보시기 바랍니다.
'Artificial Intelligence > 21. PyTorch' 카테고리의 다른 글
| [PYTORCH] 비정형 데이터를 텐서로 변환하는 7가지 방법과 데이터 손실 해결 가이드 (0) | 2026.03.25 |
|---|---|
| [PYTORCH] Subset을 이용해 학습/검증 데이터를 나누는 3가지 방법과 데이터 누수 해결 가이드 (0) | 2026.03.25 |
| [PYTORCH] 거대한 데이터셋을 메모리 부족 없이 로드하는 7가지 전략 및 성능 해결 방법 (0) | 2026.03.25 |
| [PYTORCH] 데이터 증강(Data Augmentation) 기법 적용 방법 및 7가지 성능 차이 해결 가이드 (0) | 2026.03.25 |
| [PYTORCH] CSV 파일을 읽어 데이터셋으로 만드는 7가지 방법과 성능 해결 가이드 (0) | 2026.03.25 |