치즈의 AI 녹이기

[pytorch] collate_fn에 arg 추가하기 본문

인공지능 대학원생의 생활/구글링

[pytorch] collate_fn에 arg 추가하기

개발자 치즈 2022. 4. 8. 14:12

data.DataLoader에서 사용하는 collate_fn은

일반적으로 사용자 정의 함수에 의한 batch 단위의 데이터를 처리를 할 때 이용한다. 

따라서 사용자 정의 함수의 argument는 batch 단위의 데이터 하나만 받는데

 

나의 경우, 특정 조건에 따라 전처리를 다르게 하기 위해 추가 인자를 넣어줘야 하는 상황이었다. 

해결 방법은 간단하게 collate_fn을 위한 새로운 클래스를 생성하여, 추가 인자를 넣어주면 되었다. 

# 예시코드
class MyCollator(object):
    def __init__(self, *params):
        self.params = params
    def __call__(self, batch):
        # do something with batch and self.params
        
.
.
.
# feeding to the dataloader
my_collator = MyCollator(param1, param2, ...)
data_loader = torch.utils.data.DataLoader(..., collate_fn=my_collator)

 

참고 링크: https://intrepidgeeks.com/tutorial/add-parameter-to-torch-collate-fn

 

[torch] collate_fn에 arguments 추가하기

pytorch Dataloader의 collate_fn 매개변수를 조작하면서 얻은 간단한 해결법에 대해 적는다. pytorch는 torch.utils.data.Dataset과 torch.utils.data.DataLoader의 두 가지 도구를 제공한다. Dataset은 input feature x와 label y

intrepidgeeks.com