일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- 오블완
- 데드리프트
- 운동
- 프로그래머스
- 바디프로필
- 코드
- 암풀다운
- 바프준비
- 라섹 수술 후기
- 체스트프레스
- 영화 비평
- Knowledge Tracing
- 디버깅
- 코딩테스트
- 개인 피티
- PT 운동
- 연구 시작
- 개인 PT
- 개인 운동
- 덤벨운동
- github
- 티스토리챌린지
- 코테준비
- 개발자
- 코테 공부
- pytorch
- 논문 리뷰
- 다이어트
- 건강
- 하체운동
- Today
- Total
치즈의 AI 녹이기
state_dict()로 best 모델 저장 및 불러오기 본문
state_dict란 이름 그대로 Python 사전(dict) 객체 입니다.
오늘은 state_dict에 어떤 것을 저장할 수 있을 지 알아보겠습니다.
Best 모델 저장하기
학습 도중에 가장 좋은 성능을 가지는 모델을 저장 한다면,
학습을 중단한 뒤에도 해당 모델을 불러와 학습을 재개할 수 있습니다.
저장하는 지점
아래의 동료 코드를 보면, 현재 사용하고 있는 metric인 'ndcg 5'를 기준으로 best 수치를 갱신할 때마다
새롭게 정의한 함수 save_checkpoint를 통해 여러 매개변수 및 하이퍼 파라미터들을 저장하고 있습니다.
저장하는 방법
save_checkpoint 함수를 보면, state이라는 이름으로 dict 타입의 변수를 하나 만들어
인자에서 전달받은 매개변수 및 하이퍼 파라미터들을 차례대로 저장하는 것을 볼 수 있습니다.
그 다음, torch.save(state, ckpt_path)를 통해 state를 한번에 저장합니다.
특히 pytorch, np.random, random의 seed를 저장하기 위한 "rng_state" 항목이 중요해 보입니다.
모델 파라미터, 옵티마이져 뿐만 아니라 dict 타입의 기타 정보 또한 저장할 수 있다는 점이 새롭습니다.
pytorch tutorial에 나와있는 노트 중에서 중요한 내용인 것 같아 첨부합니다.
결과 확인하기
굉장히 다양한 정보들이 torch.save()를 통해 저장될 수 있음을 알 수 있습니다.
Best 모델 불러오기
이제 저장해놓은 모델의 state를 다시 불러오겠습니다.
불러오는 지점
불러오는 지점은 이미 저장된 모델의 경로 net_t가 존재 할 때와, 학습이 끝난 후 바로 검증할 때
새롭게 정의한 함수 load_checkpoint를 통해 불러오고 있습니다.
불러오는 방법
load_checkpoint 함수를 보면, save_checkpoint 함수에서 .pt 형식으로 저장했던 state를 torch.load()로 불러오는 것을 볼 수 있습니다.
그 다음 차례대로 load_state_dict()를 통해 state에 저장되어 있던 정보를 불러옵니다.
참고 링크 :
https://tutorials.pytorch.kr/beginner/saving_loading_models.html
https://tutorials.pytorch.kr/recipes/recipes/what_is_state_dict.html
코드 출처는 앞선 글에서 언급하였습니다.
'인공지능 대학원생의 생활 > 동료 코드 따라잡기' 카테고리의 다른 글
코딜리티를 이용한 코딩테스트 준비 (lesson 1~4) (1) | 2023.03.12 |
---|---|
kwargs 이용해서 효율적으로 인자 관리하기 (0) | 2022.09.19 |
nn.DataParallel 사용하기 (0) | 2021.07.16 |
Huggingface Bert 모델 커스터마이징 하기 (0) | 2021.07.16 |
nn.Embedding sparse 파라미터 (0) | 2021.06.25 |