
현대 자연어 처리(NLP)와 컴퓨터 비전의 핵심인 트랜스포머 아키텍처는 강력하지만, 시퀀스 길이의 제곱에 비례하는 연산량이라는 치명적인 단점이 있습니다. 본 포스팅에서는 이 $O(n^2)$의 굴레를 벗어나기 위한 최신 선형 어텐션(Linear Attention) 기법들과 파이썬 구현 사례를 심도 있게 다룹니다.
1. 왜 $O(n^2)$이 문제인가? 복잡도 분석과 해결의 필요성
표준 셀프 어텐션(Self-Attention)은 쿼리(Query)와 키(Key) 행렬의 내적을 통해 유사도를 계산합니다. 시퀀스 길이가 $n$일 때, $n \times n$ 크기의 어텐션 맵이 생성됩니다. 이는 시퀀스가 길어질수록 메모리 점유율과 연산 시간이 기하급수적으로 증가함을 의미합니다. 특히 4K 이상의 긴 컨텍스트를 다루는 LLM이나 고해상도 이미지 처리에서 이는 하드웨어적인 한계로 작용합니다.
2. 표준 어텐션 vs 선형 어텐션 아키텍처 비교
선형 어텐션의 핵심 아이디어는 행렬 곱셈의 결합 법칙(Associative Property)을 활용하여 소프트맥스(Softmax) 연산을 커널 함수로 대체하는 것입니다.
| 비교 항목 | 표준 소프트맥스 어텐션 | 선형 어텐션 (Linear Attention) |
|---|---|---|
| 시간 복잡도 | $O(n^2 d)$ | $O(nd^2)$ |
| 공간 복잡도 | $O(n^2)$ | $O(nd)$ |
| 핵심 연산 | $\text{Softmax}(QK^T)V$ | $Q(K^TV)$ (연산 순서 변경) |
| 장점 | 높은 정확도, 풍부한 표현력 | 초장거리 시퀀스 처리 가능, 추론 속도 빠름 |
| 단점 | 메모리 병목, 긴 시퀀스 불가 | 근사화로 인한 일부 정보 손실 발생 가능 |
3. 실무 적용 가능한 선형 어텐션 구현 예제 (Python/PyTorch)
개발자가 현업 프로젝트에서 즉시 성능을 벤치마킹하고 적용해 볼 수 있는 7가지 파이썬 코드 예시를 제공합니다.
Example 1: Katharopoulos 식 표준 선형 어텐션 (기본 원리)
커널 함수 $\phi(x) = \text{elu}(x) + 1$을 사용하여 소프트맥스를 우회하는 가장 기본적인 방법입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearAttention(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, q, k, v):
# q, k, v shape: (batch, heads, seq_len, dim)
q = F.elu(q) + 1
k = F.elu(k) + 1
# k와 v를 먼저 곱하여 (dim, dim) 크기의 컨텍스트 행렬 생성
# O(n * d^2) 복잡도 달성
v_sum = v.sum(dim=-2, keepdim=True)
k_sum = k.sum(dim=-2, keepdim=True)
kv = torch.einsum("b h n d, b h n m -> b h d m", k, v)
num = torch.einsum("b h n d, b h d m -> b h n m", q, kv)
den = torch.einsum("b h n d, b h d -> b h n", q, k_sum.squeeze(-2))
return num / (den.unsqueeze(-1) + self.eps)
Example 2: Performer의 Fast Attention (FAVOR+)
Random Fourier Features를 사용하여 소프트맥스 어텐션을 수학적으로 근사하는 기법입니다.
import math
def favor_plus_kernel(x, projection_matrix):
# x: (batch, n, d), projection_matrix: (m, d)
# 정규화 및 지수 함수 근사
x_norm = (x**2).sum(dim=-1, keepdim=True) / 2
projection = torch.matmul(x, projection_matrix.t())
phi = torch.exp(projection - x_norm) / math.sqrt(projection_matrix.size(0))
return phi
# 실무 팁: projection_matrix는 Orthogonal하게 설계할 때 성능이 가장 좋음
Example 3: 효율적인 인과적(Causal) 선형 어텐션 구현
디코더 모델에서 과거의 정보만 참조하도록 누적 합(Cumsum)을 이용한 고속 연산 방법입니다.
def causal_linear_attention(q, k, v):
q = F.softplus(q)
k = F.softplus(k)
# KV의 누적합 계산
kv = torch.einsum("b h n d, b h n m -> b h n d m", k, v)
kv_cum = torch.cumsum(kv, dim=2)
# Q와 누적 KV 곱셈
out = torch.einsum("b h n d, b h n d m -> b h n m", q, kv_cum)
k_cum = torch.cumsum(k, dim=2)
den = torch.einsum("b h n d, b h n d -> b h n", q, k_cum)
return out / den.unsqueeze(-1)
Example 4: FLASH (Fast Linear Attention with Sliding Window)
전역적인 선형 어텐션과 지역적인 슬라이딩 윈도우를 결합하여 정확도를 해결하는 하이브리드 방식입니다.
class SlidingWindowMask:
@staticmethod
def apply(attn_weights, window_size):
mask = torch.ones_like(attn_weights)
mask = torch.tril(mask) * torch.triu(mask, diagonal=-window_size)
return attn_weights * mask
# 실무 적용: 선형 어텐션 결과에 로컬 윈도우 특징을 더해주면 성능이 크게 향상됨
Example 5: Cosine Attention (복잡도 해결을 위한 단순화)
복잡한 지수 함수 대신 코사인 유사도를 기반으로 시퀀스 길이에 선형적으로 대응하는 방식입니다.
def cosine_linear_attention(q, k, v):
# L2 Normalization을 통한 코사인 유사도 유도
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
# 일반적인 행렬곱 순서 변경 적용
kv = torch.matmul(k.transpose(-2, -1), v)
out = torch.matmul(q, kv)
return out
Example 6: Nyströmformer 기반의 랜드마크 선택 기법
시퀀스 전체가 아닌 일부 랜드마크 노드만 샘플링하여 어텐션을 계산하는 전략입니다.
def nystrom_attention(q, k, v, num_landmarks=64):
# Landmark 포인트 샘플링 (평균 풀링 활용)
n = q.size(-2)
landmark_k = F.adaptive_avg_pool1d(k.transpose(-2, -1), num_landmarks).transpose(-2, -1)
landmark_q = F.adaptive_avg_pool1d(q.transpose(-2, -1), num_landmarks).transpose(-2, -1)
# 랜드마크를 이용한 근사 행렬 계산
# A = Q * (K_langmark^T) * inverse(Q_landmark * K_landmark^T) * (Q_landmark * K^T)
# 실제 구현시 Moore-Penrose Pseudo-inverse 사용 권장
pass
Example 7: 효율적인 메모리 관리를 위한 'Chunked' Linear Attention
대규모 텐서 연산 시 발생하는 VRAM 부족 문제를 해결하기 위해 데이터를 청크(Chunk) 단위로 처리합니다.
def chunked_linear_attn(q, k, v, chunk_size=256):
chunks = q.split(chunk_size, dim=-2)
outputs = []
prev_kv = 0
for i, q_chunk in enumerate(chunks):
k_chunk = k.split(chunk_size, dim=-2)[i]
v_chunk = v.split(chunk_size, dim=-2)[i]
current_kv = torch.matmul(k_chunk.transpose(-2, -1), v_chunk)
# 이전 청크의 컨텍스트를 누적하여 선형성 유지
combined_kv = prev_kv + current_kv
out = torch.matmul(q_chunk, combined_kv)
outputs.append(out)
prev_kv = combined_kv
return torch.cat(outputs, dim=-2)
4. 현업에서 선형 어텐션을 선택할 때 고려해야 할 3가지 요소
- 시퀀스 길이 정체 구간: 시퀀스 길이가 512 이하인 경우 표준 소프트맥스 어텐션이 커널 연산 오버헤드가 적어 더 빠를 수 있습니다. 1024 이상의 환경에서 선형 어텐션의 도입을 검토하십시오.
- 훈련 vs 추론: 선형 어텐션은 추론 시에 이전 상태(State)를 캐싱하기 매우 유리하여 RNN처럼 동작할 수 있습니다. 실시간 서빙 시스템에서 매우 강력한 이점을 가집니다.
- 소프트맥스 특성 유지: 단순 선형화는 어텐션 맵의 '집중(Sparsity)' 특성을 약화시킬 수 있습니다. 이를 해결하기 위해 개별 모델에 맞는 활성화 함수(Feature Map) 선택이 중요합니다.
5. 결론 및 향후 전망
Attention의 $O(n^2)$ 복잡도를 해결하기 위한 노력은 Linear Attention을 넘어 FlashAttention 같은 하드웨어 가속 기법으로 진화하고 있습니다. 파이썬 개발자로서 이러한 알고리즘적 최적화를 이해하고 적용하는 것은 자원 효율적인 AI 모델 구축의 핵심 역량이 될 것입니다.