
그래프 신경망(GNN)은 데이터 간의 관계를 학습하는 데 탁월한 성능을 발휘합니다. 하지만 층(Layer)이 깊어질수록, 즉 Message Passing(메시지 전파) 횟수가 늘어날수록 모든 노드의 임베딩 벡터가 서로 유사해지는 Over-smoothing(과도한 평활화) 문제에 직면하게 됩니다. 이는 결국 모델이 노드 간의 변별력을 잃게 만들어 성능을 급격히 저하시킵니다. 본 포스팅에서는 Python 환경에서 PyTorch Geometric(PyG)을 활용하여 Over-smoothing의 원인을 심층 분석하고, 이를 극복하기 위한 7가지 핵심 해결 방법과 실무 코드를 상세히 다룹니다.
1. Over-smoothing 현상의 이해와 성능 차이
Over-smoothing은 GNN이 고유하게 가지는 특성인 '이웃 정보의 평균화'에서 기인합니다. 층이 깊어질수록 각 노드는 그래프 전체의 정보를 수용하게 되며, 결과적으로 모든 노드가 동일한 정보(평균값)를 갖게 됩니다. 이는 수학적으로 고유값 분해(Eigen-decomposition)를 통해 증명되기도 합니다.
레이어 깊이에 따른 성능 및 지표 비교
| 비교 항목 | Shallow GNN (2-3 Layers) | Deep GNN (10+ Layers) | Optimized Deep GNN (해결책 적용) |
|---|---|---|---|
| 노드 임베딩 거리 | 멀고 뚜렷함 | 매우 가까움 (수렴) | 적절한 거리 유지 |
| 수용장 (Receptive Field) | 국소적 (Local) | 전역적 (Global) | 계층적 (Hierarchical) |
| 성능 (Accuracy) | 높음 | 매우 낮음 | 최적화 및 향상 |
| 학습 안정성 | 안정적 | Gradient Vanishing 발생 | Residual 연결로 보완 |
2. Over-smoothing 해결을 위한 7가지 실전 Python 예제
개발자가 즉시 실무에 적용할 수 있도록 PyTorch 기반의 구현 코드를 제공합니다.
Example 1: Residual Connections (잔차 연결) 적용
이전 층의 정보를 직접 전달하여 고유한 노드 특징을 유지합니다.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class ResidualGNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
identity = x
x = F.relu(conv(x, edge_index))
if identity.size() == x.size():
x = x + identity # Residual Addition
return self.convs[-1](x, edge_index)
Example 2: Jumping Knowledge (JK) Networks 구현
마지막 층에서 모든 중간 층의 출력을 결합하여 최적의 표현력을 선택합니다.
from torch_geometric.nn import JumpingKnowledge
class JKGNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList([GCNConv(in_channels, hidden_channels)])
for _ in range(num_layers - 1):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.jk = JumpingKnowledge(mode='cat') # Concatenation 방식
self.lin = torch.nn.Linear(num_layers * hidden_channels, out_channels)
def forward(self, x, edge_index):
xs = []
for conv in self.convs:
x = F.relu(conv(x, edge_index))
xs.append(x)
x = self.jk(xs)
return self.lin(x)
Example 3: DropEdge 전략 적용
학습 시 무작위로 엣지를 제거하여 메시지 전파의 과도한 밀집을 방지합니다.
from torch_geometric.utils import dropout_edge
def train_step_with_dropedge(model, data, p=0.2):
model.train()
# 학습 시에만 무작위로 엣지 제거
edge_index, _ = dropout_edge(data.edge_index, p=p)
out = model(data.x, edge_index)
# ... 후속 loss 계산 및 최적화
Example 4: Initial Connection (GCNII 스타일)
초기 입력 특징($x_0$)을 모든 층에 직접 연결하여 노드의 본질적 정보를 보존합니다.
class GCNII_Simple(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, alpha=0.1):
super().__init__()
self.conv = GCNConv(hidden_channels, hidden_channels)
self.alpha = alpha
def forward(self, x, x_0, edge_index):
# x_0는 초기 입력 임베딩
hi = F.relu(self.conv(x, edge_index))
x = (1 - self.alpha) * hi + self.alpha * x_0
return x
Example 5: PairNorm (Normalization 계층)
노드 간의 총 거리를 일정하게 유지하도록 정규화하여 뭉침 현상을 억제합니다.
class PairNorm(torch.nn.Module):
def __init__(self, scale=1.0):
super().__init__()
self.scale = scale
def forward(self, x):
col_mean = x.mean(dim=0)
x = x - col_mean
dist = torch.norm(x, p=2, dim=1).mean()
return self.scale * x / dist
Example 6: Differentiable Group Normalization (DGN)
노드들을 그룹화하여 그룹 내/그룹 간 정규화를 다르게 적용합니다.
# PyG에서 제공하는 정규화 모듈 활용 예시
from torch_geometric.nn import DiffGroupNorm
class DGNGNN(torch.nn.Module):
def __init__(self, channels, clusters=10):
super().__init__()
self.norm = DiffGroupNorm(channels, groups=clusters)
def forward(self, x):
return self.norm(x)
Example 7: 가변적인 Message Passing 강도 조절
학습 가능한 파라미터를 통해 이웃 정보 수용량을 동적으로 조절합니다.
class AdaptiveGNN(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = GCNConv(in_channels, out_channels)
self.beta = torch.nn.Parameter(torch.Tensor([0.5])) # 학습 가능한 가중치
def forward(self, x, edge_index):
x_new = self.conv(x, edge_index)
return (1 - self.beta) * x + self.beta * x_new
3. 결론 및 실무 가이드
GNN에서 레이어를 깊게 쌓는 것은 단순히 수용장을 넓히는 것 이상의 의미를 갖습니다. Over-smoothing은 기술적인 결함이라기보다 그래프 구조가 가지는 고유한 수학적 성질입니다. 따라서 실무에서는 다음과 같은 순서로 접근하는 것을 추천합니다.
- 먼저 2-3층의 얕은 모델로 기준 성능(Baseline)을 잡습니다.
- 더 넓은 컨텍스트가 필요하다면 Residual Connection을 가장 먼저 고려하십시오.
- 데이터의 노이즈가 많다면 DropEdge를 통해 모델의 강건함을 높이십시오.
- 매우 깊은 구조(10층 이상)가 필수적이라면 GCNII나 PairNorm과 같은 전문적인 정규화 기법을 도입해야 합니다.
이러한 기법들을 통해 여러분의 GNN 모델은 깊어지면서도 강력한 변별력을 유지할 수 있을 것입니다.
참고 문헌 (Sources)
- Li, Q., Han, Z., & Wu, X. M. (2018). "Deeper Insights into Graph Convolutional Networks for Semi-Supervised Learning." AAAI.
- Rong, Y., et al. (2020). "DropEdge: Towards Deep Graph Convolutional Networks on Node Classification." ICLR.
- Chen, M., et al. (2020). "Simple and Deep Graph Convolutional Networks." ICML.
- Xu, K., et al. (2018). "Representation Learning on Graphs with Jumping Knowledge Networks." ICML.