본문 바로가기
Artificial Intelligence/60. Python

[PYTHON] 모델 안정성 해결을 위한 Stochastic Weight Averaging (SWA) 적용 시점과 7가지 활용 방법

by Papa Martino V 2026. 4. 25.
728x90

Stochastic Weight Averaging (SWA)
Stochastic Weight Averaging (SWA)

 

딥러닝 모델을 학습시킬 때 가장 허무한 순간은 검증 데이터셋(Validation Set)에서는 최고의 성능을 보였으나, 실제 배포 환경(In-the-wild)에서 성능이 급격히 하락하는 경우입니다. 이는 모델이 가파른 손실 함수 곡면(Sharp Minima)에 빠졌기 때문일 가능성이 큽니다. 본 포스팅에서는 이를 해결하기 위해 Stochastic Weight Averaging (SWA)를 활용하여 더 넓고 평평한 곡면(Flat Minima)을 찾아 모델의 일반화 성능을 극대화하는 실무적인 전략을 다룹니다.


1. SWA의 개념과 왜 평평한 곡면(Flat Minima)이 중요한가?

전통적인 SGD(Stochastic Gradient Descent)는 학습 종료 시점의 가중치($w$) 하나만을 사용합니다. 하지만 학습 후반부의 가중치들은 최적해 주변을 진동하며 이동합니다. SWA는 이 진동하는 가중치들을 평균 내어, 기하학적으로 더 중앙에 위치하고 평평한 지점을 찾습니다. 평평한 지점은 입력 데이터에 약간의 노이즈가 섞여도 출력값이 크게 변하지 않으므로 모델 안정성일반화 성능이 비약적으로 향상됩니다.

2. 기존 학습 방식과 SWA의 핵심 차이 분석

단순한 앙상블(Ensemble)과 SWA는 비슷해 보이지만, 연산 비용과 구조적 측면에서 큰 차이가 있습니다.

구분 항목 표준 SGD/Adam 학습 Deep Ensemble (일반 앙상블) Stochastic Weight Averaging (SWA)
가중치 선택 최종 Epoch 가중치 여러 모델의 가중치 각각 저장 학습 후반 가중치들의 평균치
추론 비용 $1 \times$ (표준) $N \times$ (모델 수만큼 증가) $1 \times$ (단일 모델과 동일)
학습 시간 표준 매우 높음 (N번 개별 학습) 낮음 (표준 학습 + 약간의 오버헤드)
해결 목표 로컬 최적화 다양성 확보 일반화(Flat Minima) 달성
실무 적용성 기본 리소스 풍부할 때만 가능 비용 효율적 모델 개선 시 필수

3. SWA를 적용해야 하는 결정적 시점(When to Apply)

SWA는 학습의 처음부터 적용하는 것이 아닙니다. 일반적으로 다음과 같은 3가지 시나리오에서 적용을 고려해야 합니다.

  • 수렴 지점 도달 직후: 학습 곡선이 평탄해지기 시작할 때(보통 전체 Epoch의 75%~80% 지점) 시작합니다.
  • 스케줄러 교체 시점: Cyclic Learning Rate나 Cosine Annealing 스케줄러를 사용하여 학습률이 낮아지는 주기에 가중치를 수집합니다.
  • Fine-tuning 단계: 사전 학습된 모델을 새로운 도메인에 맞출 때, 안정적인 전이를 위해 마지막에 SWA를 수행합니다.

4. 실무 적용을 위한 SWA 구현 예제 (Python/PyTorch)

PyTorch의 torch.optim.swa_utils를 활용하여 현업 프로젝트에 즉시 반영할 수 있는 7가지 핵심 구현 예제입니다.

Example 1: 기본 SWA 모델 생성 및 스케줄러 정의

import torch
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

# 1. 모델과 AveragedModel 초기화
model = MyNetwork()
swa_model = AveragedModel(model)

# 2. 옵티마이저와 표준 스케줄러
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = CosineAnnealingLR(optimizer, T_max=100)

# 3. SWA 전용 스케줄러 (높은 고정 학습률로 최적해 주변 탐색)
swa_start = 75
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
        

Example 2: 학습 루프 내 SWA 가중치 업데이트 해결 방법

for epoch in range(100):
    train(model, loader, optimizer)
    
    if epoch > swa_start:
        # 가중치 누적 시작
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

# 학습 종료 후 BN 통계치 업데이트 (중요!)
torch.optim.swa_utils.update_bn(loader, swa_model)
        

Example 3: BatchNorm 통계치 재계산 (update_bn)의 필요성

가중치가 평균화되면 기존의 Batch Normalization의 running mean/var가 맞지 않게 됩니다. 이를 해결하기 위해 학습 데이터를 한 번 더 통과시켜 통계치를 맞춰줘야 합니다.

# SWA 모델은 반드시 이 과정을 거쳐야 평가 모드에서 정상 작동합니다.
device = torch.device("cuda")
swa_model.to(device)
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
        

Example 4: 커스텀 가중치 평균화 (Exponential Moving Average)

최근의 가중치에 더 높은 가중치를 두는 EMA 방식의 변형입니다.

def ema_avg(averaged_model_parameter, model_parameter, num_averaged):
    return 0.9 * averaged_model_parameter + 0.1 * model_parameter

# update_fn 파라미터를 통해 커스텀 averaging 적용
swa_model_ema = AveragedModel(model, avg_fn=ema_avg)
        

Example 5: 검증 데이터 성능 기반의 조건부 SWA 업데이트

best_acc = 0
for epoch in range(100):
    acc = validate(model, val_loader)
    if epoch > swa_start and acc > threshold:
        # 성능이 일정 수준 이상일 때만 SWA 수집에 참여시켜 노이즈 제거
        swa_model.update_parameters(model)
        

Example 6: 분산 학습(DDP) 환경에서의 SWA 적용 차이 해결

# DistributedDataParallel 환경에서는 모델 자체가 래핑되어 있으므로 .module 접근 필요
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
    swa_model.update_parameters(model.module)
        

Example 7: 추론(Inference) 시 SWA 모델 로드 및 활용

# 저장된 가중치 불러오기
checkpoint = torch.load("swa_model.pth")
swa_model.load_state_dict(checkpoint['state_dict'])

swa_model.eval()
with torch.no_grad():
    output = swa_model(input_tensor)
        

5. SWA 도입 시 주의해야 할 실패 요인 (Solution)

단순히 평균을 낸다고 성능이 오르지 않는 경우가 있습니다. 다음 해결책을 점검하세요.

  • 너무 높은 학습률: swa_lr이 너무 높으면 최적 지점을 이탈합니다. 보통 학습 초기 LR의 1/10 정도로 설정하세요.
  • 너무 이른 시작: 모델이 아직 특정 해(Solution)로 수렴하지 않았는데 평균을 내면 성능이 하락합니다.
  • Weight Decay 충돌: SWA 단계에서도 적절한 Weight Decay를 유지해야 가중치가 폭발하지 않습니다.

6. 결론

SWA는 추가적인 추론 비용 없이 딥러닝 모델의 일반화 능력을 개선할 수 있는 가장 우아한 방법 중 하나입니다. 특히 모델이 과적합(Overfitting)되기 쉬운 소규모 데이터셋이나, 배포 환경의 변동성이 큰 실무 프로젝트에서 SWA 적용 시점의 조절은 최고의 가성비 전략이 될 것입니다.


내용 출처 및 전문 문헌

  • Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. (2018). "Averaging Weights Leads to Wider Optima and Better Generalization." UAI.
  • PyTorch Official Documentation 
  • Seong, H., et al. "Stochastic Weight Averaging Revisited." (2023).
728x90