치즈의 AI 녹이기

nn.DataParallel 사용하기 본문

인공지능 대학원생의 생활/동료 코드 따라잡기

nn.DataParallel 사용하기

개발자 치즈 2021. 7. 16. 19:21

오늘은 nn.DataParallel를 사용하여 데이터 병렬처리하는 코드를 가져왔습니다. 

데이터 병렬처리란, 다수의 GPU에서 모델을 병렬로 실행하여 작업할 수 있도록 처리하는 것입니다.

 

먼저 torch.cuda.device_count()를 이용하여 갖고 있는 GPU 개수를 num_gpus에 저장합니다. 

 

그 다음 원하는 모델을 불러 온 후, 갖고 있는 GPU 개수가 2개 이상이면 nn.DataParallel을 사용하여 모델을 wrapping할 수 있습니다. 이렇게만 해주면 알아서 모델 내에 들어오는 데이터를 각 GPU에 할당하여 처리하도록 해줍니다. 

 

특정 모델의 학습 매개변수를 불러와 inference를 하고 싶은 경우, 

다음과 같이 isinstance(model, nn.DataParallel)을 활용하여 불러올 수 있습니다.  

 

 

참고 링크 : https://tutorials.pytorch.kr/beginner/blitz/data_parallel_tutorial.html

 

선택 사항: 데이터 병렬 처리 (Data Parallelism) — PyTorch Tutorials 1.9.0+cu102 documentation

Note Click here to download the full example code 선택 사항: 데이터 병렬 처리 (Data Parallelism) 글쓴이: Sung Kim and Jenny Kang 번역: ‘정아진 ’ 이 튜토리얼에서는 DataParallel (데이터 병렬)

tutorials.pytorch.kr

 

코드 출처는 앞선 글에서 언급하였습니다.