본문 바로가기
Artificial Intelligence/60. Python

[PYTHON] 그래프 신경망(GNN) 구현을 위한 PyTorch Geometric 활용 방법과 데이터 구조 해결 7가지 전략

by Papa Martino V 2026. 4. 13.
728x90

PyTorch Geometric(PyG)
PyTorch Geometric (PyG)

1. 관계의 미학: 그래프 신경망(GNN)과 PyTorch Geometric의 등장

우리가 살아가는 세상의 데이터는 단순히 격자 구조(이미지)나 시퀀스(텍스트)로만 이루어져 있지 않습니다. 소셜 네트워크의 사용자 관계, 단백질 분자의 결합 구조, 금융 거래의 흐름 등은 모두 노드(Node)와 엣지(Edge)로 연결된 그래프(Graph) 형태를 띱니다. 이러한 비정형 관계 데이터를 딥러닝으로 해석하기 위해 탄생한 것이 바로 그래프 신경망(GNN)입니다. Python 생태계에서 GNN을 구현할 때 가장 강력한 라이브러리는 단연 PyTorch Geometric (PyG)입니다. PyG는 그래프 데이터의 희소성(Sparsity)을 효율적으로 처리하며, 최신 GNN 논문의 핵심 알고리즘들을 직관적인 API로 제공합니다. 본 포스팅에서는 PyG의 독창적인 데이터 구조를 파헤치고, 실무에서 즉시 활용 가능한 7가지 핵심 구현 패턴을 제시합니다.


2. 기존 CNN과 GNN의 데이터 처리 방식 차이 분석

GNN을 이해하기 위해서는 기존의 합성곱 신경망(CNN)과 데이터 구조적 측면에서 어떤 차이가 있는지 명확히 알아야 합니다.

비교 항목 Convolutional Neural Network (CNN) Graph Neural Network (GNN)
데이터 도메인 Euclidean Space (이미지, 격자) Non-Euclidean Space (그래프, 네트워크)
이웃의 정의 고정된 픽셀 주변 (3x3, 5x5 등) 가변적인 연결 관계 (Degree에 따라 다름)
순서 불변성 위치 정보가 중요함 노드 순서가 바뀌어도 구조적 의미 유지
데이터 구조 Dense Tensor (Tensor) Sparse Matrix / Edge List (Data Object)
주요 연산 Sliding Window Convolution Message Passing (Aggregate & Update)

3. PyTorch Geometric(PyG) 실전 활용 및 데이터 설계 Example 7선

그래프 데이터를 다루는 개발자가 실무 프로젝트의 초기 설계부터 모델 배포까지 참고할 수 있는 7가지 단계별 Python 예제입니다.

Example 1: PyG의 핵심 'Data' 객체 설계 방법

그래프의 노드 특징(x)과 연결성(edge_index)을 정의하는 가장 기초적이고 중요한 단계입니다.

import torch
from torch_geometric.data import Data

# 노드 특징 행렬 (4개 노드, 각 16차원 특징)
x = torch.randn(4, 16)

# 엣지 연결 리스트 (COO 형식: [Source, Target])
# 0->1, 1->0, 1->2, 2->3 방향의 연결
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 3]], dtype=torch.long)

# 데이터 객체 생성
data = Data(x=x, edge_index=edge_index)

print(f"그래프 노드 수: {data.num_nodes}")
print(f"그래프 엣지 수: {data.num_edges}")
        

Example 2: 대규모 그래프를 위한 Neighbor Sampling 워크플로우

수백만 개의 노드를 가진 그래프를 한 번에 GPU에 올릴 수 없을 때 사용하는 '메모리 부족 해결' 기법입니다.

from torch_geometric.loader import NeighborLoader

# 대규모 그래프 data가 있다고 가정
loader = NeighborLoader(
    data,
    num_neighbors=[10, 5], # 1계층에서 10개, 2계층에서 5개 노드 샘플링
    batch_size=128,
    input_nodes=data.train_mask, # 학습 노드 기준
)

for batch in loader:
    # 전체 그래프가 아닌 샘플링된 서브그래프만 처리
    out = model(batch.x, batch.edge_index)
        

Example 3: GCN(Graph Convolutional Network) 레이어 커스텀 구현

기존 레이어를 활용하는 것을 넘어, 자신만의 Message Passing 로직을 설계하는 방법입니다.

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class CustomGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add') # "Add" aggregation 사용
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # 1. Self-loop 추가
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # 2. 선형 변환
        x = self.lin(x)
        # 3. 메시지 전달 시작
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # x_j는 이웃 노드들의 특징값
        return x_j
        

Example 4: Heterogeneous Graph(이종 그래프) 처리 전략

사용자와 상품, 논문과 저자 등 서로 다른 타입의 노드와 관계를 가진 복합 데이터 구조 해결법입니다.

from torch_geometric.data import HeteroData

data = HeteroData()

# 노드 타입별 특징 정의
data['user'].x = torch.randn(100, 64)
data['item'].x = torch.randn(500, 128)

# 관계 타입별 연결 정의
data['user', 'buys', 'item'].edge_index = torch.randint(0, 100, (2, 1000))

print(data.metadata()) # (['user', 'item'], [('user', 'buys', 'item')])
        

Example 5: 그래프 분류(Graph Classification)를 위한 Pooling 레이어 적용

개별 노드가 아닌 그래프 전체의 성질(예: 분자 독성 여부)을 예측하기 위한 데이터 집계 방법입니다.

from torch_geometric.nn import global_mean_pool

def forward(self, x, edge_index, batch):
    x = self.conv1(x, edge_index).relu()
    x = self.conv2(x, edge_index).relu()
    
    # batch 벡터를 기준으로 노드들을 그래프 단위로 묶어 평균 계산
    x = global_mean_pool(x, batch) 
    return self.lin(x)
        

Example 6: Edge Feature(엣지 특징값)를 모델에 반영하는 방법

노드뿐만 아니라 연결의 강도, 거리 등 엣지가 가진 정보를 학습에 포함하는 해결책입니다.

from torch_geometric.nn import GINEConv

# 엣지 특징(edge_attr)을 지원하는 GINEConv 사용
# nn은 특징 변환을 위한 MLP
conv = GINEConv(nn=torch.nn.Sequential(torch.nn.Linear(16, 32), torch.nn.ReLU()))

# forward 호출 시 edge_attr 전달
out = conv(x, edge_index, edge_attr=edge_attr)
        

Example 7: InMemoryDataset을 활용한 커스텀 데이터셋 구축

외부 raw 데이터를 PyG의 표준 데이터셋 형식으로 변환하여 관리 효율성을 높이는 방법입니다.

from torch_geometric.data import InMemoryDataset

class MyGraphDataset(InMemoryDataset):
    def __init__(self, root, transform=None):
        super().__init__(root, transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self): return ['raw_data.csv']
    
    @property
    def processed_file_names(self): return ['data.pt']

    def process(self):
        # 로직: CSV 읽기 -> Data 객체 리스트 생성 -> self.collate() -> 저장
        pass
        

4. GNN 성능 최적화를 위한 3가지 핵심 데이터 해결 가이드

실무에서 GNN 모델이 수렴하지 않거나 속도가 느릴 때 체크해야 할 전문적인 가이드라인입니다.

  • Over-smoothing 문제 해결: 레이어가 깊어질수록 모든 노드의 특징이 비슷해지는 현상을 막기 위해 Residual Connection이나 PairNorm 기법을 데이터 파이프라인에 추가하십시오.
  • Edge Index의 정렬: PyG의 edge_index는 가급적 정렬된(Sorted) 상태로 유지하는 것이 연산 속도 최적화에 유리합니다. torch_geometric.utils.sort_edge_index를 활용하세요.
  • Self-loop의 유무: 중심 노드 자신의 정보를 보존하려면 학습 시 반드시 Self-loop를 추가해야 합니다. 레이어 내부에서 자동으로 처리되는지 확인하십시오.

5. 결론: PyTorch Geometric으로 그리는 데이터의 미래

그래프 신경망은 단순한 유행을 넘어 검색 엔진, 추천 시스템, 신약 개발 등 현대 산업 전반의 난제를 해결하는 필살기가 되고 있습니다. PyTorch Geometric은 이러한 복잡한 그래프 이론을 Python이라는 유연한 언어 위에서 가장 효율적으로 구현할 수 있게 돕는 도구입니다. 본 포스팅에서 다룬 7가지 데이터 구조 설계와 활용 기법을 토대로, 보이지 않는 관계 속에서 가치 있는 인사이트를 추출해 보시기 바랍니다.

내용 출처

  • Fey, M., & Lenssen, J. E. (2019). "Fast Graph Representation Learning with PyTorch Geometric." ICLR Workshop.
  • Kipf, T. N., & Welling, M. (2017). "Semi-Supervised Classification with Graph Convolutional Networks." ICLR.
  • PyTorch Geometric (PyG) Official Documentation: "Introduction by Example".
  • Stanford CS224W: "Machine Learning with Graphs" Lecture Series.
728x90