본문 바로가기
Artificial Intelligence/60. Python

[PYTHON] GAN Mode Collapse 감지 방법 3가지와 구조적 해결 로직 7가지

by Papa Martino V 2026. 4. 17.
728x90

GAN Mode Collapse
GAN Mode Collapse

 

생성적 적대 신경망(Generative Adversarial Networks, GAN)은 데이터의 분포를 학습하여 새로운 샘플을 생성하는 혁신적인 모델이지만, 학습 과정에서 마주하는 Mode Collapse(모드 붕괴)는 수많은 연구자들을 고뇌에 빠뜨리는 난제입니다. 모드 붕괴란 생성자(Generator)가 판별자(Discriminator)를 속이기 쉬운 몇 가지 특정 형태의 샘플(Mode)만을 반복해서 생성하여, 결과물의 다양성을 완전히 상실하는 현상을 말합니다. 2026년 최신 딥러닝 실무 관점에서 볼 때, 단순한 시각적 확인만으로는 모드 붕괴를 사전에 차단하기 어렵습니다. 본 포스팅에서는 파이썬(Python) 기반의 통계적 감지 로직과 생성 품질을 보호하기 위한 7가지 고도화된 아키텍처 해결 패턴을 심층적으로 다룹니다.


1. Mode Collapse의 원인과 감지 지표의 구조적 차이

모드 붕괴는 생성자가 전체 데이터 분포 $P_{data}$를 학습하지 못하고 특정 $P_{model}$의 국소 영역에 갇힐 때 발생합니다. 이를 감지하기 위해 통계적 거리 측정과 엔트로피 분석이 동원됩니다.

감지 메커니즘 주요 지표 (Metric) 측정 원리 실무 해결 포인트
분포 다양성 측정 Inception Score (IS) 생성물 클래스의 엔트로피 분석 생성된 이미지의 선명도와 다양성 평가
거리 기반 평가 FID (Fréchet Inception Distance) 실제 데이터와 생성 데이터의 가우시안 거리 특징 벡터 분포의 통계적 일치성 확인
픽셀 중복성 분석 Standard Deviation of Batches 배치 내 샘플 간 유사도 계산 생성 샘플의 수치적 고착화 실시간 감지
판별자 상태 추적 Discriminator Loss Saturation 판별자 손실값이 0에 가깝게 수렴 생성자가 한 패턴에 고정되었는지 확인

2. Mode Collapse 해결을 위한 7가지 실무 파이썬 로직 (Examples)

실제 PyTorch 환경에서 학습 루프에 즉시 이식 가능한 7가지 해결 시나리오와 로직 예시입니다.

Example 1: Minibatch Discrimination 로직 구현

배치 내의 샘플들 사이의 거리를 계산하여 생성자가 너무 유사한 샘플만 만들지 않도록 판별자에게 추가 정보를 제공하는 방법입니다.

import torch
import torch.nn as nn

class MinibatchDiscrimination(nn.Module):
    def __init__(self, in_features, out_features, kernel_dims):
        super().__init__()
        self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims))
        nn.init.normal_(self.T)

    def forward(self, x):
        # x shape: [batch, in_features]
        matrices = x.mm(self.T.view(self.T.shape[0], -1))
        matrices = matrices.view(-1, self.T.shape[1], self.T.shape[2])
        
        # 샘플 간 L1 거리 계산
        M = matrices.unsqueeze(0) # [1, batch, out, kernel]
        M_t = matrices.unsqueeze(1) # [batch, 1, out, kernel]
        diffs = torch.exp(-torch.abs(M - M_t).sum(3)) # [batch, batch, out]
        
        out = diffs.sum(0) - 1 # 자기 자신 제외
        return torch.cat([x, out], 1) # 특징 결합

Example 2: Unrolled GAN을 이용한 판별자 선행 예측 해결

생성자가 현재의 판별자뿐만 아니라 미래의 K-단계 업데이트된 판별자까지 고려하여 학습하게 함으로써 모드 고착화를 방지하는 방법입니다.

# 로직 개념: 생성자 업데이트 시 판별자의 가중치를 복사하여 K번 가상 업데이트 수행
# 이를 통해 생성자는 판별자가 미래에 취할 방어 전략을 예측하여 '속이기 쉬운 쉬운 길'을 거부함
def unrolled_loss(generator, discriminator, real_data, noise, k_steps=5):
    # 1. 판별자 현재 상태 백업
    backup = copy.deepcopy(discriminator.state_dict())
    
    # 2. 판별자 K번 가상 학습
    for _ in range(k_steps):
        d_optimizer.step(real_data, generator(noise))
        
    # 3. K번 후의 판별자로 생성자 손실 계산
    g_loss = criterion(discriminator(generator(noise)), target_real)
    
    # 4. 판별자 상태 원복
    discriminator.load_state_dict(backup)
    return g_loss

Example 3: WGAN-GP (Wasserstein GAN with Gradient Penalty) 적용

손실 함수를 Earth Mover's Distance로 변경하고 기울기 패널티를 부여하여 학습의 수치적 안정성을 해결하는 패턴입니다.

def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

Example 4: Experience Replay (Buffer)를 통한 모드 기억 해결

생성자가 이전에 만들었던 다양한 모드들을 버퍼에 저장해두고 판별자에게 주기적으로 보여줌으로써, 판별자가 특정 모드만 기억하지 못하게 하는 해결책입니다.

import random

class ImageBuffer:
    def __init__(self, buffer_size=50):
        self.buffer_size = buffer_size
        self.images = []

    def query(self, images):
        if self.buffer_size == 0: return images
        return_images = []
        for image in images:
            if len(self.images) < self.buffer_size:
                self.images.append(image)
                return_images.append(image)
            else:
                if random.uniform(0, 1) > 0.5:
                    idx = random.randint(0, self.buffer_size - 1)
                    return_images.append(self.images[idx].clone())
                    self.images[idx] = image
                else:
                    return_images.append(image)
        return torch.stack(return_images)

Example 5: Multi-Generator GAN 아키텍처 구현

여러 개의 생성자를 병렬로 두고 각각 다른 데이터 모드를 담당하게 하여 전체 분포의 다양성을 확보하는 구조적 해결 방법입니다.

class MultiGen(nn.Module):
    def __init__(self, num_gen=3):
        super().__init__()
        self.generators = nn.ModuleList([Generator() for _ in range(num_gen)])
        
    def forward(self, z):
        # z를 분할하거나 각 생성자에게 다른 시드를 주어 다양한 결과 유도
        outputs = [gen(z) for gen in self.generators]
        return torch.cat(outputs, dim=0)

Example 6: Spectral Normalization을 통한 판별자 립시츠 연속성 해결

판별자의 가중치를 최대 고윳값으로 나누어 급격한 기울기 변화를 막고 모드 붕괴를 억제하는 실무 패턴입니다.

from torch.nn.utils import spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 모든 Conv 레이어에 Spectral Norm 적용
        self.conv1 = spectral_norm(nn.Conv2d(3, 64, 4, 2, 1))
        self.conv2 = spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
        # ...

Example 7: 생성 샘플 간 Pairwise Distance 모니터링 로직

학습 도중 배치 내 샘플 간 거리가 일정 임계치 이하로 떨어지면 경고를 보내고 학습 파라미터를 조정하는 감지 로직입니다.

def detect_mode_collapse(fake_images, threshold=0.1):
    # fake_images: [batch, c, h, w]
    flat_images = fake_images.view(fake_images.size(0), -1)
    dist_matrix = torch.pdist(flat_images) # 모든 쌍의 유클리드 거리
    mean_dist = dist_matrix.mean().item()
    
    if mean_dist < threshold:
        print(f"Warning: Mode Collapse suspected. Mean Dist: {mean_dist:.4f}")
        return True
    return False

3. GAN 모드 붕괴 방지를 위한 3대 황금 규칙

  • 균형 잡힌 학습(Balancing): 생성자와 판별자의 학습 속도 차이를 조절하십시오. 판별자가 너무 강력해지면 생성자가 포기(기울기 소실)하거나 꼼수(모드 붕괴)를 쓰게 됩니다.
  • 다양한 하이퍼파라미터 실험: Batch Size를 키우는 것이 모드 다양성 확보에 유리합니다. 큰 배치는 생성자가 한 번에 더 넓은 데이터 분포를 볼 수 있게 합니다.
  • 정교한 손실 함수 선택: 단순히 Binary Cross Entropy를 쓰기보다는 Hinge LossWasserstein Loss처럼 기울기가 유연하게 흐르는 함수를 선택하여 해결하십시오.

4. 결론 및 향후 전망

2026년 기준으로 GAN은 Diffusion 모델과 경쟁하며 더욱 정교해지고 있습니다. 특히 StyleGAN-VVQ-GAN 아키텍처는 내부적으로 강력한 정규화 기법을 도입하여 모드 붕괴 문제를 상당 부분 해결했습니다. 하지만 커스텀 데이터셋을 다루는 실무 엔지니어에게는 여전히 통계적 감지와 적응형 학습 로직이 필수적입니다. 본 가이드의 7가지 파이썬 로직을 통해 생성 인공지능의 품질과 다양성을 동시에 확보하시기 바랍니다.

 

전문 지식 출처 및 참조:

  • Salimans et al. (2016), "Improved Techniques for Training GANs" (Minibatch Discrimination 제안)
  • Metz et al. (2017), "Unrolled Generative Adversarial Networks"
  • Gulrajani et al. (2017), "Improved Training of Wasserstein GANs"
  • PyTorch Official Examples: "DCGAN and WGAN-GP Implementation"
728x90