
현실 세계의 데이터는 결코 공평하지 않습니다. 암 진단 데이터셋에서 정상 데이터가 99%이고 암 데이터가 1%인 상황은 매우 흔합니다. 이러한 **클래스 불균형(Class Imbalance)** 상황에서 일반적인 무작위 샘플링을 사용하면, 모델은 단순히 다수 클래스만 맞추도록 학습되어 정작 중요한 소수 클래스를 식별하지 못하게 됩니다. PyTorch의 WeightedRandomSampler는 이러한 통계적 편향을 학습 파이프라인 레벨에서 우아하게 해결할 수 있는 강력한 도구입니다. 본 가이드에서는 가중치 계산의 수학적 원리부터 7가지 실무 시나리오별 구현 예제까지 상세히 다룹니다.
1. 불균형 데이터 문제와 WeightedRandomSampler의 필요성
모델 학습 시 DataLoader에서 shuffle=True를 설정하면 데이터셋 전체에서 균일한 확률로 샘플을 추출합니다. 불균형 데이터셋에서는 다수 클래스가 배치(Batch)를 장악하게 되며, 소수 클래스의 Gradient는 무시됩니다. 이는 결과적으로 소수 클래스에 대한 낮은 Recall(재현율)을 초래합니다.
WeightedRandomSampler는 각 데이터 샘플에 **추출 확률(Weight)**을 부여합니다. 소수 클래스 샘플에 높은 가중치를 부여함으로써, 미니배치 내에서 클래스 비율이 대략적으로 1:1이 되도록 조정(Over-sampling)합니다. 이는 손실 함수(Loss Function)에 가중치를 주는 방식보다 학습 안정성 면에서 유리한 경우가 많습니다.
2. 불균형 데이터 해결 전략 비교: 샘플링 vs 손실 함수
클래스 불균형을 해결하는 대표적인 두 가지 접근 방법의 차이를 표로 정리하였습니다.
| 비교 항목 | 샘플링 방식 (WeightedRandomSampler) | 손실 함수 가중치 방식 (Weighted Loss) |
|---|---|---|
| 작동 레벨 | DataLoader (데이터 입력 단계) | Criterion (Criterion (손실 계산 단계) |
| 주요 메커니즘 | 소수 클래스의 배치 등장 빈도를 높임 (Over-sampling) | 소수 클래스 오차에 대한 페널티를 높임 |
| 데이터 접근 | 모든 데이터에 접근하지만 소수 데이터가 반복됨 | 모든 데이터를 한 번씩만 접근함 (에포크당 속도 빠름) |
| 장점 | 학습 Gradient가 안정적이며, 특정 클래스 과적합 방지에 유리 | 구현이 매우 간단하며, 대규모 데이터셋 학습에 유리 |
| 단점 | 에포크당 학습 데이터 수가 늘어나 학습 시간이 증가할 수 있음 | Learning Rate 튜닝이 민감해질 수 있으며 최적화가 어려움 |
| 추천 상황 | 불균형이 심하며(1:100 이상), 데이터 소실 없이 패턴을 학습해야 할 때 | 데이터셋이 매우 커서 샘플링 오버헤드가 크거나, 불균형이 적당할 때 |
3. 실무 즉시 적용 가능한 WeightedRandomSampler Example 7가지
개발자가 실무에서 불균형 데이터를 처리할 때 마주치는 시나리오별 구현 예제입니다. copy & paste 하여 바로 사용 가능하도록 작성되었습니다.
Example 1: 표준적인 2클래스(Binary) 가중치 계산 및 적용 방법
가장 기본적이고 흔한 시나리오입니다. 클래스 빈도의 역수를 가중치로 사용합니다.
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
import numpy as np
# 1. 가상의 불균형 데이터셋 생성 (0: 900개, 1: 100개)
num_samples = 1000
num_class_0 = 900
num_class_1 = 100
# 타겟 레이블 생성 (NumPy 사용이 빈도 계산에 편리)
target_np = np.concatenate([np.zeros(num_class_0), np.ones(num_class_1)])
np.random.shuffle(target_np)
targets = torch.from_numpy(target_np).long()
# 2. 클래스별 빈도수 계산
class_sample_count = np.array([num_class_0, num_class_1]) # [900, 100]
# 3. 클래스별 가중치 계산 (빈도의 역수) -> [1/900, 1/100]
weight = 1. / class_sample_count
# 4. 각 데이터 샘플에 해당하는 가중치 부여 (가장 중요한 단계)
samples_weight = torch.from_numpy(np.array([weight[t] for t in targets]))
# 5. Sampler 정의 (num_samples는 보통 전체 데이터 수와 동일하게 설정)
sampler = WeightedRandomSampler(weights=samples_weight, num_samples=len(samples_weight), replacement=True)
# 6. DataLoader 연동 (shuffle=True와 상호 배타적이므로 꺼야 함)
dataset = TensorDataset(torch.randn(num_samples, 10), targets) # 임의의 피처
train_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# 배치 내 클래스 비율 확인 (대략 1:1 확인)
for inputs, labels in train_loader:
print(f"Batch labels: {labels.bincount(minlength=2)}")
break # 첫 배치만 확인
Example 2: 다중 클래스(Multi-class) 불균형 처리 해결 예제
3개 이상의 클래스가 있을 때 각 클래스의 비율을 맞춥니다.
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
from collections import Counter
# 1. 가상의 다중 불균형 데이터 (Class 0: 80%, 1: 15%, 2: 5%)
targets = torch.cat([torch.zeros(800), torch.ones(150), torch.full((50,), 2)]).long()
# 2. Counter를 이용한 빈도 계산 (NumPy 없는 환경)
target_list = targets.tolist()
count_dict = Counter(target_list)
class_count = [count_dict[0], count_dict[1], count_dict[2]]
# 3. 가중치 계산 및 부여
weights = 1. / torch.tensor(class_count, dtype=torch.float)
samples_weight = weights[targets]
# 4. Sampler 및 Loader 설정
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
dataset = TensorDataset(torch.randn(1000, 10), targets)
train_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# 배치 비율 확인
for _, labels in train_loader:
print(f"Batch counts: {labels.bincount(minlength=3)}")
break
Example 3: num_samples 조정을 통한 데이터 증강(Over-sampling) 효과 구현
전체 데이터 수보다 num_samples를 크게 설정하여 소수 클래스를 더 많이 반복 학습시킵니다.
# Example 1의 데이터를 그대로 사용한다고 가정
# len(samples_weight)가 1000일 때, num_samples를 2000으로 설정
oversampled_sampler = WeightedRandomSampler(samples_weight, num_samples=2000, replacement=True)
# DataLoader는 한 에포크당 2000개의 샘플을 추출하게 됨 (대부분 소수 클래스의 반복)
train_loader = DataLoader(dataset, batch_size=32, sampler=oversampled_sampler)
print(f"Loader length (batches): {len(train_loader)}") # 2000/32 대략 63배치
Example 4: replacement=False를 이용한 가중치 기반 언더샘플링(Under-sampling)
데이터 복원 추출을 하지 않고 가중치가 높은 순서대로 데이터를 뽑습니다. 다수 데이터를 버리고 소수 데이터를 유지할 때 사용합니다.
# replacement=False일 때는 반드시 num_samples가 데이터 수보다 작아야 함
# 예: 0: 900개, 1: 100개 데이터셋에서 총 200개만 추출 (0: 100개, 1: 100개 목표)
undersampled_sampler = WeightedRandomSampler(samples_weight, num_samples=200, replacement=False)
train_loader = DataLoader(dataset, batch_size=20, sampler=undersampled_sampler)
# 에포크당 총 10배치(20*10=200샘플)만 학습
total_labels = []
for _, labels in train_loader:
total_labels.append(labels)
print(f"Total extracted counts: {torch.cat(total_labels).bincount(minlength=2)}")
Example 5: 대용량 이미지 폴더 데이터셋(ImageFolder)에 가중치 적용 해결
폴더 구조로 된 이미지 데이터셋에서 레이블 빈도를 계산하는 실무 방법입니다.
from torchvision.datasets import ImageFolder
import os
# 이미지 폴더 경로 (예시)
data_dir = 'path/to/data'
# 1. 데이터셋 로드 (transforms는 임의)
# dataset.targets에 각 이미지의 레이블 인덱스가 리스트로 담겨 있음
image_dataset = ImageFolder(root=data_dir)
# 2. 레이블 빈도 계산 (NumPy bincount가 가장 빠름)
train_targets = np.array(image_dataset.targets)
class_count = np.bincount(train_targets)
# 3. 가중치 부여
class_weights = 1. / class_count
samples_weight = torch.from_numpy(class_weights[train_targets]).float()
# 4. Sampler 및 DataLoader 설정
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
image_loader = DataLoader(image_dataset, batch_size=64, sampler=sampler, num_workers=4)
Example 6: Custom Dataset에서의 가중치 사전 계산 및 로딩 전략
커스텀 데이터셋 내부에서 가중치 정보를 미리 가지고 있어 초기화 시간을 줄이는 해결책입니다.
import pandas as pd
from torch.utils.data import Dataset
class CustomCSVDataset(Dataset):
def __init__(self, csv_file):
self.df = pd.read_csv(csv_file)
self.targets = torch.from_numpy(self.df['label'].values).long()
# 데이터셋 생성 시 가중치를 미리 계산하여 속성으로 저장
label_counts = self.df['label'].value_counts().sort_index().values
class_weights = 1. / label_counts
self.samples_weight = torch.from_numpy(class_weights[self.targets.numpy()]).float()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 데이터 로직 (생략)
return torch.randn(10), self.targets[idx]
# 사용 예시
csv_dataset = CustomCSVDataset('data.csv')
# 외부에서 다시 계산할 필요 없이 속성을 가져옴
sampler = WeightedRandomSampler(csv_dataset.samples_weight, len(csv_dataset))
csv_loader = DataLoader(csv_dataset, batch_size=32, sampler=sampler)
Example 7: Weighted Loss(손실 함수 가중치)와 Sampler의 하이브리드 적용 해결
극도로 심한 불균형(예: 1:1000) 상황에서 두 가지 방식을 모두 사용하여 시너지를 냅니다.
# 1. DataLoader에서는 Sampler로 배치 내 클래스 비율을 1:1로 조정
train_loader = DataLoader(dataset, batch_size=32, sampler=sampler) # Example 1의 sampler
# 2. Criterion(Loss)에서는 여전히 소수 클래스 오차에 가중치를 부여
# 배치 비율이 조정되었더라도, 소수 데이터의 변별력을 높이기 위해 가중 손실 사용
# 비율이 1:1이므로 weights는 [1.0, 1.0]에 가깝게 설정하거나 소수 클래스에 약간만 더 부여
# (1. / class_sample_count 로 계산한 극단적 가중치는 사용 금지)
# 하이브리드 상황에서는 튜닝이 매우 중요함
loss_weights = torch.tensor([1.0, 2.0]) # Class 1 오차에 2배 페널티
criterion = torch.nn.CrossEntropyLoss(weight=loss_weights)
# 학습 루프 (생략)
# outputs = model(inputs)
# loss = criterion(outputs, labels)
4. WeightedRandomSampler 사용 시 범하는 치명적인 실수 3가지 해결
많은 개발자들이 이 도구를 사용할 때 흔히 저지르는 실수와 그 해결책을 전문 엔지니어 관점에서 정리했습니다.
- shuffle=True와 sampler 동시 사용: PyTorch DataLoader에서 sampler 옵션을 설정하면 shuffle 옵션은 내부적으로 True로 간주되거나 충돌을 일으킵니다. 반드시 DataLoader 정의 시
shuffle=True는 제거하거나 False로 설정해야 합니다. (Example 1참조) - 가중치 Normalize 누락:
WeightedRandomSampler에 전달하는weights인자는 확률값처럼 0~1 사이로 정규화될 필요는 없습니다. 내부적으로 sum으로 나누어 상대적 확률을 계산합니다. 하지만 극단적인 가중치 값(예: 0.0000001 vs 100)은 부동소수점 정밀도 문제를 유발할 수 있으므로, 1 클래스 수의 역수로 가중치를 주는 방식이 가장 안전합니다. - replacement=False 시 num_samples 초과: 복원 추출을 하지 않는
replacement=False모드일 때,num_samples인자가 전체 데이터 가중치 리스트의 길이보다 크면 데이터가 부족하여 에러가 발생합니다. 이 모드는 오직 언더샘플링(다수 클래스 폐기) 용도로만 사용해야 합니다. (Example 4참조)
5. 결론: 불균형 데이터를 넘어 일반화 성능으로
PyTorch의 WeightedRandomSampler는 학습 파이프라인의 데이터 입력 단계에서 클래스 불균형 문제를 수학적으로 우아하게 해결합니다. 미니배치 내의 클래스 분포를 인위적으로 조정함으로써 모델이 소수 클래스의 고유한 패턴을 학습할 수 있는 공정한 기회를 제공합니다. 이는 손실 함수 가중치 방식보다 학습 과정의 Gradient를 안정적으로 유지시켜 특정 클래스에 대한 과적합을 방지하고 일반화 성능을 높이는 데 기여합니다. 본 가이드의 7가지 실무 예제를 바탕으로 귀하의 딥러닝 프로젝트가 통계적 불균형에 흔들리지 않는 견고한 성능을 달성하길 바랍니다.
'Artificial Intelligence > 21. PyTorch' 카테고리의 다른 글
| [PYTORCH] 데이터 증강(Data Augmentation) 기법 적용 방법 및 7가지 성능 차이 해결 가이드 (0) | 2026.03.25 |
|---|---|
| [PYTORCH] CSV 파일을 읽어 데이터셋으로 만드는 7가지 방법과 성능 해결 가이드 (0) | 2026.03.25 |
| [PYTORCH] 사전 학습된(Pre-trained) 모델의 데이터 전처리 일치 방법 및 7가지 성능 저하 해결 가이드 (0) | 2026.03.25 |
| [PYTORCH] Hook 기능을 활용한 모델 디버깅 방법 3가지와 에러 해결 전략 7가지 (0) | 2026.03.24 |
| [PYTORCH] nn.Module 상속 시 super().__init__() 호출 필수 이유 2가지와 속성 에러 해결 방법 7가지 (0) | 2026.03.24 |