본문 바로가기
Artificial Intelligence/60. Python

[PYTHON] Transformer Attention Masking 구현 방법 3가지와 성능 병목 해결책 7가지

by Papa Martino V 2026. 4. 18.
728x90

Transformer Attention Masking 구현 방법
Transformer Attention Masking 구현 방법

 

트랜스포머(Transformer) 아키텍처가 자연어 처리(NLP)를 넘어 컴퓨터 비전(Vision Transformer)과 멀티모달 학습의 표준이 된 핵심 비결은 모든 토큰 간의 관계를 한 번에 계산하는 셀프 어텐션(Self-Attention) 메커니즘에 있습니다. 하지만 모든 관계를 허용하는 것이 항상 정답은 아닙니다. 문장의 길이를 맞추기 위한 패딩(Padding)을 연산에서 제외하거나, 생성 모델에서 미래의 정보를 미리 보지 못하게 차단하는 어텐션 마스킹(Attention Masking)은 모델의 무결성과 성능을 결정짓는 결정적인 디테일입니다.

본 가이드에서는 파이썬(Python) 환경에서 마스킹이 수학적으로 어떻게 소프트맥스(Softmax) 결과에 영향을 미치는지 분석하고, 실무에서 마주하는 가변 길이 데이터 처리와 메모리 부족 문제를 해결하는 7가지 고도화된 구현 패턴을 제안합니다.


1. 마스킹 유형별 구조적 차이와 수치적 해결 원리

어텐션 마스킹은 특정 에너지 값(Attention Score)에 매우 작은 음수(대개 $-\infty$에 가까운 $-1e9$)를 더해 소프트맥스 통과 후 해당 확률을 0으로 만드는 원리로 작동합니다.

마스킹 유형 주요 목적 구현 형태 실전 해결 포인트
Padding Mask 배치 내 가변 길이 대응 [Batch, 1, 1, Seq_len] 텐서 의미 없는 [PAD] 토큰 정보 차단
Causal (Look-ahead) Mask 미래 정보 유출 방지 Lower Triangular Matrix 디코더의 자기회귀(AR) 속성 유지
Memory Mask Cross-Attention 제어 Key-Value 시퀀스 제어 Encoder-Decoder 간 정보 흐름 조절
Local/Window Mask 메모리 사용량 최적화 Band Matrix (대각선 집중) 긴 시퀀스 처리 시 연산 복잡도 해결

2. 실무 트랜스포머 설계를 위한 7가지 마스킹 해결 패턴 (Examples)

실제 PyTorch 실무에서 즉시 활용 가능한, 수치적으로 안정적이고 효율적인 마스킹 구현 예시입니다.

Example 1: Boolean Mask를 이용한 기초적인 Padding Mask 생성

가변 길이 시퀀스에서 패딩된 부분에만 마스크를 씌워 어텐션 점수를 무효화하는 기본 해결책입니다.

import torch

def create_padding_mask(seq):
    # seq shape: (batch_size, seq_len)
    # 패딩 토큰(0)인 부분을 1로 표시
    mask = (seq == 0).float()
    # 어텐션 스코어에 더해줄 수 있도록 차원 확장 (batch, heads, q_len, k_len)
    return mask[:, None, None, :] * -1e9

# 예시 데이터 (0은 패딩)
input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
pad_mask = create_padding_mask(input_ids)

Example 2: 생성 모델을 위한 삼각 행렬 기반 Causal Mask 해결

디코더가 현재 시점 이전의 토큰만 참조하도록 상삼각 행렬(Upper Triangle)을 마스킹하는 방법입니다.

def create_causal_mask(size):
    # (size, size) 크기의 하삼각 행렬은 0, 나머지는 1인 마스크 생성
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask.bool() # True인 부분이 마스킹됨

# 어텐션 내부에서: energy.masked_fill_(causal_mask, -1e9)

Example 3: Scaled Dot-Product Attention 내 마스크 통합 적용

패딩 마스크와 인과 마스크를 결합하여 실제 어텐션 연산에 적용하는 통합 해결 패턴입니다.

def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = k.size(-1)
    scaled_attention_logits = matmul_qk / (dk ** 0.5)

    if mask is not None:
        # 마스크 위치에 매우 작은 음수를 더해 Softmax 후 0이 되게 함
        scaled_attention_logits += mask

    attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
    return torch.matmul(attention_weights, v)

Example 4: 2D 데이터(이미지 패치)를 위한 시각적 어텐션 마스킹

Vision Transformer에서 특정 패치 영역만 참조하거나 제외할 때 사용하는 해결 기법입니다.

def create_window_mask(grid_size, window_size):
    # 특정 윈도우 내의 패치들끼리만 어텐션을 수행하도록 설계
    mask = torch.zeros((grid_size**2, grid_size**2))
    # ... 윈도우 인덱스 계산 로직 ...
    return mask.masked_fill(condition, -1e9)

Example 5: Flash Attention을 활용한 메모리 효율적 마스킹 해결

시퀀스가 매우 길 때(Long-context), 메모리 사용량을 줄이기 위해 커스텀 커널 수준의 마스킹을 적용하는 방법입니다.

from flash_attn import flash_attn_func

# Flash Attention은 내부적으로 causal 옵션을 지원하여 메모리를 절약함
output = flash_attn_func(q, k, v, dropout_p=0.1, causal=True)

Example 6: 가변 길이 배치 처리를 위한 `PackedSequence`와 마스크 호환

RNN 스타일의 패킹된 시퀀스를 트랜스포머용 마스크 텐서로 변환하는 해결 프로세스입니다.

def lengths_to_mask(lengths, max_len):
    # 각 샘플의 실제 길이를 바탕으로 마스크 생성
    mask = torch.arange(max_len)[None, :] < lengths[:, None]
    return (~mask).float() * -1e9

Example 7: Multi-Head Attention에서의 브로드캐스팅(Broadcasting) 에러 해결

헤드 수에 맞춰 마스크 차원을 정확히 맞추지 않아 발생하는 런타임 에러를 방지하는 코드입니다.

# mask shape: (batch, seq_len) -> (batch, 1, 1, seq_len)
# 4차원으로 확장해야 Multi-head attention(batch, heads, q_len, k_len)과 연산 가능
attn_mask = mask.unsqueeze(1).unsqueeze(2) 
attn_mask = attn_mask.expand(-1, num_heads, -1, -1)

3. 마스킹 설계 시 반드시 고려해야 할 3가지 성능 원칙

  • 수치적 안정성(Numerical Stability): $-\infty$를 직접 사용하기보다 torch.finfo(dtype).min이나 -1e9를 사용하여 Softmax 계산 시 NaN이 발생하는 문제를 해결하십시오.
  • 메모리 오버헤드: 마스크 텐서는 [Batch, Heads, Seq, Seq] 크기로 커질 수 있습니다. 가급적 브로드캐스팅을 활용하여 [Batch, 1, 1, Seq] 형태로 메모리 점유를 최소화하십시오.
  • 학습과 추론의 일치: 추론 시 KV-Cache를 사용한다면 마스크의 모양이 실시간으로 변해야 합니다. 증분 디코딩(Incremental Decoding) 시의 마스킹 로직을 별도로 구축하십시오.

4. 결론 및 향후 전망

2026년 기준, 대규모 언어 모델(LLM)은 1M 토큰 이상의 긴 문맥(Long-context)을 처리하는 방향으로 진화하고 있습니다. 이에 따라 단순한 Full-masking보다는 Sliding Window Attention이나 Block-wise Masking 같은 기술이 병목 해결의 핵심이 되고 있습니다. 마스킹은 단순한 차단 도구가 아니라, 모델이 데이터의 인과 관계와 구조를 배우게 하는 학습의 나침반입니다.

 

내용 출처 및 참조:

  • Vaswani et al. (2017), "Attention Is All You Need"
  • PyTorch Documentation: "Transformer Layers and Masking Modules"
  • NVIDIA Technical Blog: "Optimizing Transformer Models with FlashAttention"
728x90