본문 바로가기
Artificial Intelligence/21. PyTorch

[PYTORCH] 체크포인트(Checkpoint) 저장 및 불러오기 방법 7가지와 state_dict 차이 해결

by Papa Martino V 2026. 4. 4.
728x90

체크포인트(Checkpoint) 저장 및 불러오기
체크포인트(Checkpoint) 저장 및 불러오기

 

 

딥러닝 모델 학습은 수 시간에서 수일, 길게는 수주까지 소요되는 고된 작업입니다. 학습 도중 예상치 못한 서버 다운, 전원 공급 중단, 혹은 중간 성능 확인을 위해 반드시 마스터해야 하는 기술이 바로 체크포인트(Checkpoint) 관리입니다. PyTorch에서는 모델 전체를 저장하는 방식보다 state_dict를 활용한 가중치 저장 방식이 표준으로 권장됩니다. 본 가이드에서는 실무에서 즉시 활용 가능한 7가지 핵심 예제와 함께, 입문자들이 흔히 겪는 저장 방식 간의 차이와 오류 해결 방안을 심도 있게 다룹니다.


1. 왜 전체 모델이 아닌 state_dict를 저장해야 하는가?

PyTorch에서 모델을 저장하는 방법은 크게 두 가지입니다. 모델 객체 자체를 직렬화(Serialization)하는 방식과 모델의 매개변수(Weights & Biases)만을 딕셔너리 형태로 저장하는 state_dict 방식입니다. 파이토치 공식 문서와 전문가들은 state_dict 사용을 강력히 권장합니다.

모델 저장 방식에 따른 장점과 차이점 비교

비교 항목 model.state_dict() (권장) torch.save(model) (전체 저장)
저장 내용 레이어별 가중치 매개변수 딕셔너리 모델 클래스 구조 + 가중치 전체
유연성 매우 높음 (다른 구조의 모델에도 이식 가능) 낮음 (동일한 클래스 코드가 반드시 존재해야 함)
파일 크기 상대적으로 가벼움 상대적으로 무거움
오류 발생률 낮음 (Python 버전, 경로 문제에 강함) 높음 (디렉토리 구조 변경 시 로드 실패)
추천 용도 실무 배포, 전이 학습, 지속 학습 매우 간단한 개인 테스트용

2. 실무 개발자를 위한 7가지 핵심 Checkpoint 구현 Example

단순한 저장부터 옵티마이저 상태를 포함한 완전한 체크포인트 관리까지, 현업에서 바로 사용하는 코드 패턴입니다.

Example 1: 가장 기본적인 가중치 저장 및 불러오기

학습이 완료된 모델의 '뇌'에 해당하는 가중치만 추출하여 저장합니다.

import torch

# 저장하기
torch.save(model.state_dict(), 'model_weights.pth')

# 불러오기
model = MyModel() # 저장할 때와 동일한 클래스 인스턴스 생성
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 추론 모드 전환 필수

Example 2: 학습 재개를 위한 완전한 체크포인트 (Optimizer 포함)

서버가 꺼졌을 때 에포크 수와 옵티마이저의 모멘텀 상태까지 복구해야 진정한 학습 재개가 가능합니다.

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}
torch.save(checkpoint, 'checkpoint.tar')

# 로드 후 학습 재개
checkpoint = torch.load('checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

Example 3: CPU 환경에서 GPU 모델 불러오기 (Device Mapping)

GPU(CUDA)에서 학습된 모델을 일반 CPU 서버에서 서빙할 때 필수적인 방법입니다.

device = torch.device('cpu')
model.load_state_dict(torch.load('model.pth', map_location=device))

Example 4: DataParallel(Multi-GPU) 모델 로드 해결

DataParallel로 학습하면 가중치 키값에 module. 접두어가 붙어 일반 모델에서 로드가 안 됩니다. 이를 해결하는 팁입니다.

state_dict = torch.load('multigpu_model.pth')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 'module.' 제거
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

Example 5: 특정 레이어만 제외하고 로드 (Partial Load)

전이 학습(Transfer Learning) 시 마지막 분류기 레이어만 바꿀 때 유용합니다.

pretrained_dict = torch.load('pretrained.pth')
model_dict = model.state_dict()

# 현재 모델 구조에 존재하는 키값만 필터링
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Example 6: Best Model 선별 저장 로직

모든 에포크를 저장하지 않고, 검증 손실이 최소일 때만 갱신합니다.

if val_loss < best_loss:
    best_loss = val_loss
    torch.save(model.state_dict(), 'best_model.pth')

Example 7: 확장자 선택 및 경로 관리 (Best Practice)

가중치는 .pth 혹은 .pt를 사용하고, 정보가 포함된 체크포인트는 .tar를 사용하여 구분하는 것이 업계 관례입니다.

import os
save_path = os.path.join('checkpoints', f'model_epoch_{epoch}.pth')
torch.save(model.state_dict(), save_path)

3. Checkpoint 활용 시 발생하는 3가지 흔한 오류와 해결 방법

  1. RuntimeError: Error(s) in loading state_dict: 모델의 레이어 이름이나 개수가 저장 시점과 다를 때 발생합니다. load_state_dict의 옵션 중 strict=False를 주면 일치하는 부분만 로드할 수 있지만, 신중해야 합니다.
  2. Inplace 업데이트 오류: 불러온 직후 모델 구조를 변경하면 가중치 매핑이 깨질 수 있습니다. 반드시 구조를 먼저 정의한 뒤 load를 수행하십시오.
  3. eval() 모드 누락: 모델을 불러온 뒤 model.eval()을 호출하지 않으면 Dropout이나 Batch Normalization 레이어가 학습 모드로 작동하여 추론 결과가 일관되지 않게 나옵니다.

4. 결론 및 요약

PyTorch의 체크포인트 시스템은 유연하지만 명확한 규칙을 따를 때 가장 안전합니다. 가중치만 저장할 때는 .pth와 state_dict를, 학습 전체를 백업할 때는 딕셔너리 패키징을 기억하세요. 이러한 정교한 저장 전략은 협업 효율을 높이고 모델 배포 시 발생할 수 있는 환경 문제를 사전에 차단해 줍니다.

상황 해결 및 적용 방법
단순 모델 배포 model.state_dict() 저장 -> load_state_dict()
장기 학습 중 백업 에포크, 옵티마이저, 스케줄러 상태를 하나의 딕셔너리로 저장
기기(Device) 변경 torch.load 시 map_location 파라미터 활용

 

내용 출처 및 참고 문헌:
1. PyTorch Official Tutorials: "Saving and Loading Models".
2. Paszke, A., et al. (2019). "PyTorch: An Imperative Style, High-Performance Deep Learning Library".
3. Deep Learning with PyTorch (Manning Publications).
4. PyTorch Forum: Best practices for saving checkpoints.

728x90