본문 바로가기
Artificial Intelligence/21. PyTorch

[PYTORCH] 커스텀 Autograd 함수 구현 방법 2가지와 미분 비연속성 해결 방법 7가지

by Papa Martino V 2026. 3. 23.
728x90

Autograd 함수
Autograd 함수

 

 

파이토치(PyTorch)의 가장 강력한 무기는 자동 미분(Autograd) 엔진입니다. 하지만 딥러닝 연구나 실무 프로젝트를 진행하다 보면, 파이토치가 기본적으로 제공하지 않는 특수한 연산을 수행하거나 미분 불가능한 함수(예: Step Function)를 근사 미분해야 하는 상황이 발생합니다. 이때 필요한 기술이 바로 커스텀 Autograd 함수(Custom Autograd Function)를 설계하는 것입니다. 단순히 nn.Module을 만드는 것과는 차원이 다른 이 기법은 연산 그래프의 심장부에 직접 개입하여 ForwardBackward 로직을 정밀하게 제어할 수 있게 해줍니다. 본 가이드에서는 시니어 딥러닝 엔지니어의 관점에서 torch.autograd.Function을 활용한 독창적인 연산 설계 방법과 실무에서 즉시 적용 가능한 7가지 고급 예제를 심층적으로 다룹니다.


1. nn.Module과 torch.autograd.Function의 결정적 차이

대부분의 초보 개발자는 새로운 연산을 만들 때 nn.Module만을 떠올립니다. 하지만 내부 메커니즘을 들여다보면 두 클래스는 존재 목적 자체가 다릅니다. 이 차이를 이해하는 것이 커스텀 엔진 설계의 시작입니다.

비교 항목 nn.Module (신경망 레이어) torch.autograd.Function (연산 엔진)
상속 대상 가중치(Parameter)를 관리하는 객체 입/출력 간의 수학적 관계를 정의하는 클래스
미분 정의 파이토치가 기존 연산을 조합해 자동 계산 사용자가 직접 수동으로 미분 정의 (Static)
상태 저장 가중치와 버퍼를 인스턴스에 저장 ctx(Context) 객체를 통해 역전파용 데이터 저장
주요 목적 모델 구조 설계 및 학습 루프 구성 새로운 수학 연산 추가 및 미분값 조작
사용 방법 model(x) 인스턴스 호출 MyFunction.apply(x) 정적 메서드 호출

2. 커스텀 Autograd 함수의 핵심 원리와 설계 방법

커스텀 함수를 만들 때는 반드시 두 개의 정적 메서드(Static Method)를 구현해야 합니다.

  • forward(ctx, input, ...): 순전파 연산을 수행합니다. ctx(Context) 객체에 역전파 시 필요한 텐서를 save_for_backward로 저장합니다.
  • backward(ctx, grad_output): 체인 룰(Chain Rule)에 따라 역전파를 수행합니다. grad_output(상위 레이어에서 온 미분값)을 입력받아 입력 텐서들에 대한 미분값을 반환합니다.

3. 실무 해결을 위한 핵심 Sample Examples (7가지)

실제 연구 논문을 구현하거나 하드웨어 가속기용 커스텀 연산을 만들 때 즉시 적용 가능한 7가지 코드 시나리오입니다.

Example 1: 근사 미분을 이용한 Sign(부호) 함수 구현

Sign 함수는 미분값이 0이지만, Straight-Through Estimator(STE)를 사용하여 학습이 가능하게 만듭니다.

import torch

class STE_Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        # 입력이 -1과 1 사이일 때만 미분값을 1로 간주하여 통과시킴
        return grad_output.clone()

# 사용법
sign_op = STE_Sign.apply
x = torch.randn(5, requires_grad=True)
out = sign_op(x)
out.sum().backward()
    

Example 2: 그래디언트 리버설 레이어 (GRL) 해결

도메인 적응(Domain Adaptation) 모델에서 적대적 학습을 위해 미분값의 부호를 반전시킵니다.

class GRL(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        # 역전파 시 alpha를 곱하고 부호를 반전
        return grad_output.neg() * ctx.alpha, None

grl_op = GRL.apply
    

Example 3: 가중치 정규화를 위한 커스텀 스케일링

class ScaledExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, scale):
        ctx.save_for_backward(input)
        ctx.scale = scale
        return torch.exp(input) * scale

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # exp(x)의 미분은 exp(x)임을 활용
        grad_input = grad_output * torch.exp(input) * ctx.scale
        return grad_input, None # scale 인자에 대한 미분은 None
    

Example 4: 다중 입력과 다중 출력을 가진 연산 해결

class MultiOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        return a + b, a * b

    @staticmethod
    def backward(ctx, grad_sum, grad_mul):
        a, b = ctx.saved_tensors
        # 각 입력에 대해 두 출력으로부터 오는 미분값을 합산
        grad_a = grad_sum + grad_mul * b
        grad_b = grad_sum + grad_mul * a
        return grad_a, grad_b
    

Example 5: 특정 범위 밖의 기울기 차단 (Gradient Clipping)

class SafeSqrt(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        result = torch.sqrt(input.clamp(min=1e-6))
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        # sqrt 미분 시 0으로 나누는 문제 방지
        return grad_output / (2 * result)
    

Example 6: 비연속적인 양자화(Quantization) 레이어 구현

class Quantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, levels):
        return torch.round(input * levels) / levels

    @staticmethod
    def backward(ctx, grad_output):
        # Round 함수의 미분은 0이나 1로 우회(Identity)
        return grad_output, None
    

Example 7: 하드웨어 가속을 위한 메모리 절약형 연산

class MemoryEfficientAdd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        # 연산 결과만 저장하고 x, y는 저장하지 않음 (Storage 공유 시 유용)
        ctx.mark_dirty(x)
        return x + y

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, grad_output
    

4. 성능 최적화와 안정성을 위한 시니어의 조언

커스텀 Autograd 함수를 만들 때 가장 흔히 하는 실수는 ctx.save_for_backward에 텐서가 아닌 일반 파이썬 객체를 넣는 것입니다. 일반 객체는 ctx.any_variable = val 형식으로 저장해야 하며, 텐서는 반드시 전용 메서드를 사용해야 메모리 누수와 그래프 파손을 방지할 수 있습니다. 또한, 역전파 로직이 수학적으로 정확한지 확인하기 위해 torch.autograd.gradcheck 함수를 사용하여 수치 미분값과 비교하는 검증 과정을 반드시 거치십시오.


5. 결론 및 요약

커스텀 Autograd 함수는 파이토치를 단순한 도구에서 수학적 연구 플랫폼으로 격상시키는 핵심 기술입니다.

  • 연산 그래프의 미분 흐름을 수동으로 제어할 때 torch.autograd.Function을 사용한다.
  • forward에서는 ctx에 데이터를 저장하고, backward에서는 grad_output과 체인 룰을 결합한다.
  • 미분 불가능한 지점은 STE(Straight-Through Estimator)와 같은 기법으로 우회하여 학습을 가능케 한다.

참조 및 출처 (Sources)

  • PyTorch Docs: Extending PyTorch.
  • "Deep Learning Research Patterns" - Differentiable Programming with Custom Ops
  • GitHub: pytorch/pytorch/torch/autograd/function.py 
728x90