치즈의 AI 녹이기

state_dict()로 best 모델 저장 및 불러오기 본문

인공지능 대학원생의 생활/동료 코드 따라잡기

state_dict()로 best 모델 저장 및 불러오기

개발자 치즈 2021. 7. 19. 18:29

state_dict란 이름 그대로 Python 사전(dict) 객체 입니다.

오늘은 state_dict에 어떤 것을 저장할 수 있을 지 알아보겠습니다. 

 

Best 모델 저장하기

학습 도중에 가장 좋은 성능을 가지는 모델을 저장 한다면,

학습을 중단한 뒤에도 해당 모델을 불러와 학습을 재개할 수 있습니다. 

 

저장하는 지점

아래의 동료 코드를 보면, 현재 사용하고 있는 metric인 'ndcg 5'를 기준으로 best 수치를 갱신할 때마다 

새롭게 정의한 함수 save_checkpoint를 통해 여러 매개변수 및 하이퍼 파라미터들을 저장하고 있습니다.  

 

 

저장하는 방법

save_checkpoint 함수를 보면, state이라는 이름으로 dict 타입의 변수를 하나 만들어  

인자에서 전달받은 매개변수 및 하이퍼 파라미터들을 차례대로 저장하는 것을 볼 수 있습니다.

그 다음, torch.save(state, ckpt_path)를 통해 state를 한번에 저장합니다. 

특히 pytorch, np.random, random의 seed를 저장하기 위한 "rng_state" 항목이 중요해 보입니다.

모델 파라미터, 옵티마이져 뿐만 아니라 dict 타입의 기타 정보 또한 저장할 수 있다는 점이 새롭습니다.

 

save_checkpoint function

 

pytorch tutorial에 나와있는 노트 중에서 중요한 내용인 것 같아 첨부합니다. 

https://tutorials.pytorch.kr/beginner/saving_loading_models.html

결과 확인하기

굉장히 다양한 정보들이 torch.save()를 통해 저장될 수 있음을 알 수 있습니다.

model state_dict()
other_states

 

 

Best 모델 불러오기

이제 저장해놓은 모델의 state를 다시 불러오겠습니다.

 

불러오는 지점

불러오는 지점은 이미 저장된 모델의 경로 net_t가 존재 할 때와, 학습이 끝난 후 바로 검증할 때

새롭게 정의한 함수 load_checkpoint를 통해 불러오고 있습니다. 

불러오는 지점

 

불러오는 방법

load_checkpoint 함수를 보면, save_checkpoint 함수에서 .pt 형식으로 저장했던 state를 torch.load()로 불러오는 것을 볼 수 있습니다.

그 다음 차례대로 load_state_dict()를 통해 state에 저장되어 있던 정보를 불러옵니다. 

 

 

 

참고 링크 :

https://tutorials.pytorch.kr/beginner/saving_loading_models.html

 

모델 저장하기 & 불러오기 — PyTorch Tutorials 1.9.0+cu102 documentation

Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich 번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다

tutorials.pytorch.kr

https://tutorials.pytorch.kr/recipes/recipes/what_is_state_dict.html

 

PyTorch에서 state_dict란 무엇인가요? — PyTorch Tutorials 1.9.0+cu102 documentation

Note Click here to download the full example code PyTorch에서 state_dict란 무엇인가요? PyTorch에서 torch.nn.Module 모델의 학습 가능한 매개변수(예. 가중치와 편향)들은 모델의 매개변수에 포함되어 있습니다. (model

tutorials.pytorch.kr

 

코드 출처는 앞선 글에서 언급하였습니다.