
딥러닝 모델의 학습 과정에서 loss.backward()는 마법의 주문과 같습니다. 이 함수 한 줄로 수만 개의 파라미터에 대한 미분값이 계산되지만, 그 내부에서 어떤 일이 벌어지는지 정확히 이해하는 개발자는 드뭅니다. 단순히 "역전파(Backpropagation)가 일어난다"는 설명만으로는 실무에서 마주치는 RuntimeError: Trying to backward through the graph a second time 같은 문제를 해결할 수 없습니다. 본 포스팅에서는 파이토치(PyTorch)의 심장부인 Autograd Engine이 loss.backward()를 만났을 때 텐서 엔진 내부에서 수행하는 연쇄적인 하부 로직을 분석하고, 실무 최적화를 위한 7가지 구체적인 해결 예제를 제시합니다.
1. loss.backward() 호출 시 발생하는 내부 메커니즘의 차이
파이토치는 Dynamic Computational Graph (DCG) 방식을 채택합니다. 즉, 순전파(Forward Pass) 과정에서 연산의 순서를 기록하고, backward()가 호출되는 순간 기록된 경로를 역추적합니다. 이 과정에서 발생하는 핵심적인 차이점들을 표로 정리했습니다.
| 내부 단계 | 주요 작업 내용 | 비고 (Side Effects) |
|---|---|---|
| 1. 그래프 역추적 | Loss 텐서부터 Leaf Node까지 DAG(유향 비순환 그래프)를 거슬러 올라감 | 연산 그래프가 동적으로 탐색됨 |
| 2. 연쇄 법칙 적용 | Chain Rule을 사용하여 상위 그래디언트를 하위 텐서로 전달 및 곱셈 수행 | Vector-Jacobian Product 계산 |
| 3. .grad 필드 업데이트 | 계산된 미분값을 각 파라미터 텐서의 .grad 속성에 누적(Add)함 |
덮어쓰기가 아닌 덧셈 방식임 |
| 4. 그래프 버퍼 해제 | 특별한 설정이 없다면 계산 완료 후 중간 연산 결과(Buffer)를 메모리에서 삭제 | 메모리 효율화 및 재사용 방지 |
2. 실무 개발자를 위한 해결 중심의 Sample Examples (7가지)
실제 딥러닝 프로젝트 아키텍처를 설계할 때 loss.backward()와 관련하여 즉시 적용 가능한 실무 코드들입니다.
Example 1: 기본 역전파와 그래디언트 확인 방법
import torch
x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
# 역전파 수행
out.backward()
# x의 grad 필드에 저장된 d(out)/dx 확인
print(f"Gradients of x:\n{x.grad}")
Example 2: 그래디언트 누적 문제 해결 (zero_grad 활용)
파이토치는 기본적으로 그래디언트를 더하기 때문에, 루프마다 초기화하지 않으면 값이 무한히 커집니다.
import torch.optim as optim
model = torch.nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)
for data, target in dataset:
# 필수: 이전 배치의 미분값을 비워줌
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
Example 3: 동일한 그래프에서 두 번 backward 수행하는 방법
GAN이나 Multi-task Learning에서 하나의 그래프로 여러 Loss를 역전파해야 할 때 사용합니다.
loss1 = criterion1(pred1, target1)
loss2 = criterion2(pred2, target2)
# retain_graph=True를 통해 중간 버퍼를 유지함
loss1.backward(retain_graph=True)
loss2.backward() # 이제 에러 없이 수행 가능
Example 4: 특정 레이어만 역전파 차단하기 (Stop Gradient)
# 연산 중간에 detach()를 사용하여 그래프를 끊음
intermediate = model_part1(input)
detached_val = intermediate.detach()
output = model_part2(detached_val)
loss = criterion(output, target)
loss.backward()
# model_part1의 파라미터는 업데이트되지 않음
Example 5: 스칼라가 아닌 텐서에 대한 backward 해결
Loss가 단일 값이 아닐 경우, 동일한 형태의 그래디언트 가중치를 인자로 넘겨야 합니다.
x = torch.randn(3, requires_grad=True)
y = x * 2
# y는 [3] 크기의 텐서이므로 직접 backward 불가
# gradient 인자를 통해 가중치 합산 방식 지정
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)
print(x.grad)
Example 6: 그래디언트 클리핑을 통한 폭주 해결
역전파 중 미분값이 너무 커져 모델이 붕괴되는 현상을 막는 실무 기법입니다.
loss.backward()
# backward 호출 직후, step 호출 직전에 수행
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Example 7: 사용자 정의 Autograd Function 구현
특정 연산의 미분 로직을 직접 설계해야 할 때 사용합니다.
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# 사용 예시
custom_relu = MyReLU.apply
output = custom_relu(torch.randn(5, requires_grad=True))
3. 독창적인 분석: 왜 .grad는 누적(Accumulate)되는가?
초보자들은 optimizer.zero_grad()를 매번 써야 하는 번거로움에 의문을 갖습니다. 하지만 파이토치가 이를 '누적' 방식으로 설계한 데에는 메모리 한계를 극복하기 위한 고도의 전략이 숨어 있습니다. 만약 GPU 메모리가 부족하여 큰 배치(Batch)를 한 번에 올릴 수 없다면, 작은 배치를 여러 번 backward() 시킨 뒤 마지막에 step()을 한 번만 호출함으로써 '가상 대형 배치 학습(Gradient Accumulation)'을 구현할 수 있기 때문입니다.
4. 결론 및 요약
loss.backward()는 단순히 미분을 수행하는 단계를 넘어, 메모리 버퍼 관리와 동적 그래프의 수명 주기를 결정하는 핵심 함수입니다. 실무에서는 다음 세 가지만 기억하십시오.
- 미분값은 기본적으로 **누적**되므로 반드시 초기화가 필요하다.
- 연산 그래프는 효율성을 위해 한 번 쓰면 **파괴**된다 (재사용 시 retain_graph 필요).
- 계산된 결과는 텐서의 **.grad** 속성에 정밀하게 저장된다.
내용 출처 및 참고 문헌 (Sources)
- PyTorch Official Documentation: Autograd - Automatic Differentiation
- "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann.
- PyTorch GitHub Repository: aten/src/ATen/native/autograd/ 소스 코드 분석.
'Artificial Intelligence > 21. PyTorch' 카테고리의 다른 글
| [PYTORCH] 커스텀 레이어(Custom Layer)를 정의하는 3가지 방법과 성능 최적화 해결 가이드 (0) | 2026.03.24 |
|---|---|
| [PYTORCH] requires_grad=True 설정의 3가지 핵심 의미와 역전파 문제 해결 방법 7가지 (0) | 2026.03.23 |
| [PYTORCH] optimizer.zero_grad() 호출 이유 2가지와 누적 그래디언트 해결 방법 7가지 (0) | 2026.03.23 |
| [PYTORCH] with torch.no_grad() 사용 방법 2가지와 메모리 부족 해결 방법 7가지 (0) | 2026.03.23 |
| [PYTORCH] detach()와 clone()의 치명적 차이점 3가지와 메모리 누수 해결 방법 7가지 (0) | 2026.03.23 |