
현업 딥러닝 엔지니어의 관점에서 분석한 멀티 태스크 학습(Multi-task Learning) 시 손실 함수 결합 및 그래디언트 불균형 해결 가이드
1. 다중 손실 함수(Multi-loss) 결합의 핵심 개념
딥러닝 모델이 복잡해짐에 따라 하나의 모델이 여러 개의 태스크를 동시에 수행해야 하는 경우가 많아졌습니다. 예를 들어, 자율 주행 시스템에서는 단일 백본 네트워크를 통해 객체 검출(Object Detection), 세그멘테이션(Segmentation), 그리고 깊이 추정(Depth Estimation)을 동시에 수행합니다. 이때 각 태스크는 고유의 손실 함수($L_1, L_2, ..., L_n$)를 가지며, 이를 최적화하기 위해 하나로 합치는 과정이 필요합니다. 단순히 모든 손실을 더하는 방식($L_{total} = \sum L_i$)은 구현이 쉽지만, 각 손실의 스케일(Scale) 차이나 수렴 속도 차이로 인해 특정 태스크에만 모델이 편향되는 그래디언트 지배(Gradient Dominance) 현상이 발생할 수 있습니다. 본 가이드에서는 이를 전문적으로 해결하는 방법론을 제시합니다.
2. 손실 함수 결합 방식 및 장단점 비교
| 결합 방식 | 주요 특징 | 장점 | 단점 및 한계 |
|---|---|---|---|
| Simple Summation | 모든 손실을 가중치 없이 합산 | 구현이 매우 간단함 | 손실 값의 스케일에 민감함 |
| Weighted Linear Combination | 각 손실에 고정된 하이퍼파라미터 $\lambda$ 적용 | 중요도에 따른 수동 조절 가능 | 최적의 $\lambda$를 찾는 튜닝 비용 발생 |
| Uncertainty Weighting | 태스크별 불확실성을 학습하여 가중치 조절 | 자동으로 스케일 밸런싱 수행 | 학습 파라미터가 소폭 증가함 |
| Gradient Normalization (GradNorm) | 그래디언트의 크기를 기준으로 가중치 동적 변경 | 안정적인 멀티 태스크 학습 보장 | 구현 복잡도가 높음 |
3. 실무 적용 가능한 PyTorch 코드 Example (7가지)
실제 프로젝트에서 즉시 활용할 수 있는 다양한 수준의 구현 예시입니다.
Example 1: 기본적인 가중 합산 (Weighted Sum)
가장 범용적으로 사용되는 방식으로, 수동으로 가중치를 부여합니다.
import torch
# 예측값과 실제값 가정
pred_class = torch.randn(8, 10, requires_grad=True)
target_class = torch.randint(0, 10, (8,))
pred_bbox = torch.randn(8, 4, requires_grad=True)
target_bbox = torch.randn(8, 4)
# 개별 손실 함수 정의
criterion_cls = torch.nn.CrossEntropyLoss()
criterion_reg = torch.nn.MSELoss()
# 가중치 설정 (하이퍼파라미터)
w1, w2 = 1.0, 0.5
loss_cls = criterion_cls(pred_class, target_class)
loss_reg = criterion_reg(pred_bbox, target_bbox)
# 최종 손실 합산
total_loss = (w1 * loss_cls) + (w2 * loss_reg)
total_loss.backward()
Example 2: 딕셔너리 구조를 이용한 확장형 손실 관리
태스크가 많아질 때 유지보수가 용이한 구조입니다.
losses = {
'classification': loss_cls,
'regression': loss_reg,
'auxiliary': loss_aux
}
weights = {
'classification': 1.0,
'regression': 2.5,
'auxiliary': 0.1
}
total_loss = sum(weights[k] * losses[k] for k in losses.keys())
total_loss.backward()
Example 3: 동적 손실 가중치 (Uncertainty-based Weighting)
Kendall et al. (CVPR 2018)의 논문을 기반으로 한 자동 가중치 조절 방식입니다.
class MultiLossLayer(torch.nn.Module):
def __init__(self, num_tasks):
super(MultiLossLayer, self).__init__()
# log(sigma^2)를 학습 파라미터로 설정
self.log_vars = torch.nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses):
weighted_losses = []
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i])
weighted_losses.append(precision * loss + self.log_vars[i])
return torch.sum(torch.stack(weighted_losses))
# 사용 예시
multi_loss_fn = MultiLossLayer(num_tasks=2)
total_loss = multi_loss_fn([loss_cls, loss_reg])
total_loss.backward()
Example 4: 개별 최적화(Accumulation) 방식
메모리가 부족하거나 특정 태스크의 그래디언트만 따로 처리해야 할 때 유용합니다.
optimizer.zero_grad()
# Task 1 역전파
loss_cls = criterion_cls(pred_cls, target_cls)
loss_cls.backward(retain_graph=True) # 그래프 유지 필수
# Task 2 역전파
loss_reg = criterion_reg(pred_reg, target_reg)
loss_reg.backward()
optimizer.step()
Example 5: 그래디언트 클리핑(Clipping)과 병합
다중 손실 합산 시 폭주하는 그래디언트를 방지합니다.
total_loss = loss_cls + loss_reg
total_loss.backward()
# 역전파 후 step 전 클리핑 수행
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Example 6: PCGrad (Projected Conflicting Gradients) 컨셉 구현
서로 충돌하는 그래디언트 방향을 투영하여 상쇄하는 기법의 핵심 로직입니다.
# 간단 개념 예시: g1, g2가 태스크별 그래디언트일 때
# dot_product = torch.dot(g1, g2)
# if dot_product < 0:
# g1 = g1 - (dot_product / (g2.norm()**2)) * g2
Example 7: 커스텀 Loss 클래스 캡슐화
프로덕션 환경에서 깔끔한 코드를 유지하는 객체지향적 접근법입니다.
class CombinedLoss(torch.nn.Module):
def __init__(self, cls_weight=1.0, reg_weight=1.0):
super().__init__()
self.cls_weight = cls_weight
self.reg_weight = reg_weight
self.ce = torch.nn.CrossEntropyLoss()
self.mse = torch.nn.MSELoss()
def forward(self, outputs, targets):
l1 = self.ce(outputs['cls'], targets['cls'])
l2 = self.mse(outputs['reg'], targets['reg'])
return self.cls_weight * l1 + self.reg_weight * l2
4. 다중 손실 최적화 시 주의사항 및 해결 전략
- Scale Normalization: L1 Loss는 수백 단위인데 CrossEntropy는 0~2 단위라면, L1이 학습을 지배합니다. 각 손실의 초기값을 측정하여 비슷한 스케일이 되도록 가중치를 조정하십시오.
- Learning Rate Scheduling: 멀티 태스크 학습 시에는 `ReduceLROnPlateau` 보다는 `CosineAnnealing` 계열이 전체적인 밸런스를 잡기에 유리할 수 있습니다.
- Gradient Conflict: 한 태스크의 성능이 오를 때 다른 태스크의 성능이 떨어진다면, Shared Encoder 이후의 Head 부분을 더 깊게 설계하여 태스크 특화 정보를 분리해야 합니다.
5. 결론
PyTorch에서 다중 손실 함수를 다루는 것은 단순히 더하기 연산을 수행하는 것을 넘어, 모델이 학습하고자 하는 방향성(Objective Landscape)을 정의하는 매우 정교한 작업입니다. 실무에서는 우선 Weighted Linear Combination으로 베이스라인을 잡고, 성능 정체기(Plateau)가 오면 Uncertainty Weighting과 같은 동적 기법 도입을 고려하는 것이 가장 효율적인 해결 순서입니다.
'Artificial Intelligence > 21. PyTorch' 카테고리의 다른 글
| [PYTORCH] 모델 학습 중 Loss NaN 발생 시 7가지 체크리스트와 즉시 해결 방법 (0) | 2026.04.04 |
|---|---|
| [PYTORCH] 오버피팅(Overfitting) 확인 및 해결을 위한 7가지 방지 방법과 차이 분석 (0) | 2026.04.04 |
| [PYTORCH] Warmup Step이 학습 안정성에 미치는 5가지 영향과 해결 방법 (0) | 2026.04.04 |
| [PYTORCH] DistributedDataParallel (DDP) 기본 개념과 DataParallel의 3가지 차이 및 성능 해결 방법 (0) | 2026.04.04 |
| [PYTORCH] 딥러닝 모델의 7가지 파라미터 수 계산 방법과 최적화 해결 가이드 (0) | 2026.03.25 |