본문 바로가기
Artificial Intelligence/21. PyTorch

[PYTORCH] WeightedRandomSampler를 이용한 불균형 데이터 처리 방법 및 7가지 실무 해결 가이드

by Papa Martino V 2026. 3. 25.
728x90

WeightedRandomSampler를 이용한 불균형 데이터 처리
WeightedRandomSampler를 이용한 불균형 데이터 처리

 

현실 세계의 데이터는 결코 공평하지 않습니다. 암 진단 데이터셋에서 정상 데이터가 99%이고 암 데이터가 1%인 상황은 매우 흔합니다. 이러한 **클래스 불균형(Class Imbalance)** 상황에서 일반적인 무작위 샘플링을 사용하면, 모델은 단순히 다수 클래스만 맞추도록 학습되어 정작 중요한 소수 클래스를 식별하지 못하게 됩니다. PyTorch의 WeightedRandomSampler는 이러한 통계적 편향을 학습 파이프라인 레벨에서 우아하게 해결할 수 있는 강력한 도구입니다. 본 가이드에서는 가중치 계산의 수학적 원리부터 7가지 실무 시나리오별 구현 예제까지 상세히 다룹니다.


1. 불균형 데이터 문제와 WeightedRandomSampler의 필요성

모델 학습 시 DataLoader에서 shuffle=True를 설정하면 데이터셋 전체에서 균일한 확률로 샘플을 추출합니다. 불균형 데이터셋에서는 다수 클래스가 배치(Batch)를 장악하게 되며, 소수 클래스의 Gradient는 무시됩니다. 이는 결과적으로 소수 클래스에 대한 낮은 Recall(재현율)을 초래합니다.

WeightedRandomSampler는 각 데이터 샘플에 **추출 확률(Weight)**을 부여합니다. 소수 클래스 샘플에 높은 가중치를 부여함으로써, 미니배치 내에서 클래스 비율이 대략적으로 1:1이 되도록 조정(Over-sampling)합니다. 이는 손실 함수(Loss Function)에 가중치를 주는 방식보다 학습 안정성 면에서 유리한 경우가 많습니다.

2. 불균형 데이터 해결 전략 비교: 샘플링 vs 손실 함수

클래스 불균형을 해결하는 대표적인 두 가지 접근 방법의 차이를 표로 정리하였습니다.

비교 항목 샘플링 방식 (WeightedRandomSampler) 손실 함수 가중치 방식 (Weighted Loss)
작동 레벨 DataLoader (데이터 입력 단계) Criterion (Criterion (손실 계산 단계)
주요 메커니즘 소수 클래스의 배치 등장 빈도를 높임 (Over-sampling) 소수 클래스 오차에 대한 페널티를 높임
데이터 접근 모든 데이터에 접근하지만 소수 데이터가 반복됨 모든 데이터를 한 번씩만 접근함 (에포크당 속도 빠름)
장점 학습 Gradient가 안정적이며, 특정 클래스 과적합 방지에 유리 구현이 매우 간단하며, 대규모 데이터셋 학습에 유리
단점 에포크당 학습 데이터 수가 늘어나 학습 시간이 증가할 수 있음 Learning Rate 튜닝이 민감해질 수 있으며 최적화가 어려움
추천 상황 불균형이 심하며(1:100 이상), 데이터 소실 없이 패턴을 학습해야 할 때 데이터셋이 매우 커서 샘플링 오버헤드가 크거나, 불균형이 적당할 때

3. 실무 즉시 적용 가능한 WeightedRandomSampler Example 7가지

개발자가 실무에서 불균형 데이터를 처리할 때 마주치는 시나리오별 구현 예제입니다. copy & paste 하여 바로 사용 가능하도록 작성되었습니다.

Example 1: 표준적인 2클래스(Binary) 가중치 계산 및 적용 방법

가장 기본적이고 흔한 시나리오입니다. 클래스 빈도의 역수를 가중치로 사용합니다.

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
import numpy as np

# 1. 가상의 불균형 데이터셋 생성 (0: 900개, 1: 100개)
num_samples = 1000
num_class_0 = 900
num_class_1 = 100

# 타겟 레이블 생성 (NumPy 사용이 빈도 계산에 편리)
target_np = np.concatenate([np.zeros(num_class_0), np.ones(num_class_1)])
np.random.shuffle(target_np)
targets = torch.from_numpy(target_np).long()

# 2. 클래스별 빈도수 계산
class_sample_count = np.array([num_class_0, num_class_1]) # [900, 100]

# 3. 클래스별 가중치 계산 (빈도의 역수) -> [1/900, 1/100]
weight = 1. / class_sample_count

# 4. 각 데이터 샘플에 해당하는 가중치 부여 (가장 중요한 단계)
samples_weight = torch.from_numpy(np.array([weight[t] for t in targets]))

# 5. Sampler 정의 (num_samples는 보통 전체 데이터 수와 동일하게 설정)
sampler = WeightedRandomSampler(weights=samples_weight, num_samples=len(samples_weight), replacement=True)

# 6. DataLoader 연동 (shuffle=True와 상호 배타적이므로 꺼야 함)
dataset = TensorDataset(torch.randn(num_samples, 10), targets) # 임의의 피처
train_loader = DataLoader(dataset, batch_size=32, sampler=sampler)

# 배치 내 클래스 비율 확인 (대략 1:1 확인)
for inputs, labels in train_loader:
    print(f"Batch labels: {labels.bincount(minlength=2)}")
    break # 첫 배치만 확인
        

Example 2: 다중 클래스(Multi-class) 불균형 처리 해결 예제

3개 이상의 클래스가 있을 때 각 클래스의 비율을 맞춥니다.

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
from collections import Counter

# 1. 가상의 다중 불균형 데이터 (Class 0: 80%, 1: 15%, 2: 5%)
targets = torch.cat([torch.zeros(800), torch.ones(150), torch.full((50,), 2)]).long()

# 2. Counter를 이용한 빈도 계산 (NumPy 없는 환경)
target_list = targets.tolist()
count_dict = Counter(target_list)
class_count = [count_dict[0], count_dict[1], count_dict[2]]

# 3. 가중치 계산 및 부여
weights = 1. / torch.tensor(class_count, dtype=torch.float)
samples_weight = weights[targets]

# 4. Sampler 및 Loader 설정
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
dataset = TensorDataset(torch.randn(1000, 10), targets)
train_loader = DataLoader(dataset, batch_size=32, sampler=sampler)

# 배치 비율 확인
for _, labels in train_loader:
    print(f"Batch counts: {labels.bincount(minlength=3)}")
    break
        

Example 3: num_samples 조정을 통한 데이터 증강(Over-sampling) 효과 구현

전체 데이터 수보다 num_samples를 크게 설정하여 소수 클래스를 더 많이 반복 학습시킵니다.

# Example 1의 데이터를 그대로 사용한다고 가정
# len(samples_weight)가 1000일 때, num_samples를 2000으로 설정
oversampled_sampler = WeightedRandomSampler(samples_weight, num_samples=2000, replacement=True)

# DataLoader는 한 에포크당 2000개의 샘플을 추출하게 됨 (대부분 소수 클래스의 반복)
train_loader = DataLoader(dataset, batch_size=32, sampler=oversampled_sampler)
print(f"Loader length (batches): {len(train_loader)}") # 2000/32 대략 63배치
        

Example 4: replacement=False를 이용한 가중치 기반 언더샘플링(Under-sampling)

데이터 복원 추출을 하지 않고 가중치가 높은 순서대로 데이터를 뽑습니다. 다수 데이터를 버리고 소수 데이터를 유지할 때 사용합니다.

# replacement=False일 때는 반드시 num_samples가 데이터 수보다 작아야 함
# 예: 0: 900개, 1: 100개 데이터셋에서 총 200개만 추출 (0: 100개, 1: 100개 목표)
undersampled_sampler = WeightedRandomSampler(samples_weight, num_samples=200, replacement=False)

train_loader = DataLoader(dataset, batch_size=20, sampler=undersampled_sampler)

# 에포크당 총 10배치(20*10=200샘플)만 학습
total_labels = []
for _, labels in train_loader:
    total_labels.append(labels)
print(f"Total extracted counts: {torch.cat(total_labels).bincount(minlength=2)}")
        

Example 5: 대용량 이미지 폴더 데이터셋(ImageFolder)에 가중치 적용 해결

폴더 구조로 된 이미지 데이터셋에서 레이블 빈도를 계산하는 실무 방법입니다.

from torchvision.datasets import ImageFolder
import os

# 이미지 폴더 경로 (예시)
data_dir = 'path/to/data' 

# 1. 데이터셋 로드 (transforms는 임의)
# dataset.targets에 각 이미지의 레이블 인덱스가 리스트로 담겨 있음
image_dataset = ImageFolder(root=data_dir)

# 2. 레이블 빈도 계산 (NumPy bincount가 가장 빠름)
train_targets = np.array(image_dataset.targets)
class_count = np.bincount(train_targets)

# 3. 가중치 부여
class_weights = 1. / class_count
samples_weight = torch.from_numpy(class_weights[train_targets]).float()

# 4. Sampler 및 DataLoader 설정
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
image_loader = DataLoader(image_dataset, batch_size=64, sampler=sampler, num_workers=4)
        

Example 6: Custom Dataset에서의 가중치 사전 계산 및 로딩 전략

커스텀 데이터셋 내부에서 가중치 정보를 미리 가지고 있어 초기화 시간을 줄이는 해결책입니다.

import pandas as pd
from torch.utils.data import Dataset

class CustomCSVDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.targets = torch.from_numpy(self.df['label'].values).long()
        
        # 데이터셋 생성 시 가중치를 미리 계산하여 속성으로 저장
        label_counts = self.df['label'].value_counts().sort_index().values
        class_weights = 1. / label_counts
        self.samples_weight = torch.from_numpy(class_weights[self.targets.numpy()]).float()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 데이터 로직 (생략)
        return torch.randn(10), self.targets[idx]

# 사용 예시
csv_dataset = CustomCSVDataset('data.csv')
# 외부에서 다시 계산할 필요 없이 속성을 가져옴
sampler = WeightedRandomSampler(csv_dataset.samples_weight, len(csv_dataset))
csv_loader = DataLoader(csv_dataset, batch_size=32, sampler=sampler)
        

Example 7: Weighted Loss(손실 함수 가중치)와 Sampler의 하이브리드 적용 해결

극도로 심한 불균형(예: 1:1000) 상황에서 두 가지 방식을 모두 사용하여 시너지를 냅니다.

# 1. DataLoader에서는 Sampler로 배치 내 클래스 비율을 1:1로 조정
train_loader = DataLoader(dataset, batch_size=32, sampler=sampler) # Example 1의 sampler

# 2. Criterion(Loss)에서는 여전히 소수 클래스 오차에 가중치를 부여
# 배치 비율이 조정되었더라도, 소수 데이터의 변별력을 높이기 위해 가중 손실 사용
# 비율이 1:1이므로 weights는 [1.0, 1.0]에 가깝게 설정하거나 소수 클래스에 약간만 더 부여
# (1. / class_sample_count 로 계산한 극단적 가중치는 사용 금지)
# 하이브리드 상황에서는 튜닝이 매우 중요함
loss_weights = torch.tensor([1.0, 2.0]) # Class 1 오차에 2배 페널티
criterion = torch.nn.CrossEntropyLoss(weight=loss_weights)

# 학습 루프 (생략)
# outputs = model(inputs)
# loss = criterion(outputs, labels)
        

4. WeightedRandomSampler 사용 시 범하는 치명적인 실수 3가지 해결

많은 개발자들이 이 도구를 사용할 때 흔히 저지르는 실수와 그 해결책을 전문 엔지니어 관점에서 정리했습니다.

  1. shuffle=True와 sampler 동시 사용: PyTorch DataLoader에서 sampler 옵션을 설정하면 shuffle 옵션은 내부적으로 True로 간주되거나 충돌을 일으킵니다. 반드시 DataLoader 정의 시 shuffle=True는 제거하거나 False로 설정해야 합니다. (Example 1 참조)
  2. 가중치 Normalize 누락: WeightedRandomSampler에 전달하는 weights 인자는 확률값처럼 0~1 사이로 정규화될 필요는 없습니다. 내부적으로 sum으로 나누어 상대적 확률을 계산합니다. 하지만 극단적인 가중치 값(예: 0.0000001 vs 100)은 부동소수점 정밀도 문제를 유발할 수 있으므로, 1 클래스 수의 역수로 가중치를 주는 방식이 가장 안전합니다.
  3. replacement=False 시 num_samples 초과: 복원 추출을 하지 않는 replacement=False 모드일 때, num_samples 인자가 전체 데이터 가중치 리스트의 길이보다 크면 데이터가 부족하여 에러가 발생합니다. 이 모드는 오직 언더샘플링(다수 클래스 폐기) 용도로만 사용해야 합니다. (Example 4 참조)

5. 결론: 불균형 데이터를 넘어 일반화 성능으로

PyTorch의 WeightedRandomSampler는 학습 파이프라인의 데이터 입력 단계에서 클래스 불균형 문제를 수학적으로 우아하게 해결합니다. 미니배치 내의 클래스 분포를 인위적으로 조정함으로써 모델이 소수 클래스의 고유한 패턴을 학습할 수 있는 공정한 기회를 제공합니다. 이는 손실 함수 가중치 방식보다 학습 과정의 Gradient를 안정적으로 유지시켜 특정 클래스에 대한 과적합을 방지하고 일반화 성능을 높이는 데 기여합니다. 본 가이드의 7가지 실무 예제를 바탕으로 귀하의 딥러닝 프로젝트가 통계적 불균형에 흔들리지 않는 견고한 성능을 달성하길 바랍니다.

내용 출처 및 기술 참조

  • PyTorch Official Documentation: `torch.utils.data.WeightedRandomSampler` API
  • Deep Learning Book (Ian Goodfellow et al.) - Optimization for Training Deep Models Section
  • PyTorch Community Forum: Best practices for imbalanced data handling
728x90