치즈의 AI 녹이기

How to Detach specific components in the loss? 본문

인공지능 대학원생의 생활/구글링

How to Detach specific components in the loss?

개발자 치즈 2022. 4. 3. 03:16

loss를 구할 때  detach()된 텐서가 어떻게 작용하는지 궁금해서 구글링 해봤다.

 

질문자는 다음과 같은 상황을 가정한다. 

1. input x에 대하여 3개의 모델(A/ B/ C), 3개의 loss(L1 / L2 / L3)를 구한다. 

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)

x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a.detach())
b.mean().backward()
print(modelA.weight.grad) #1
print(modelB.weight.grad) #2
print(modelC.weight.grad) #3

c = modelC(a)
c.mean().backward()
print(modelA.weight.grad) #4
print(modelB.weight.grad) #5
print(modelC.weight.grad) #6

 

위 코드 예시에 따르면 결과는 다음과 같다. 

#1. A는 학습되지 않는다.

#2. B는 학습된다.

#3. C는 학습되지 않는다. 

 

#4. A는 학습된다.

#5. B는 학습되지 않는다.

#6. C는 학습된다. 

 

나의 상황은 모델 A, B 둘 다 학습을 시키고 싶었는데, 대충 B(A(x)).detach()+A(x)가 되는 상황이었다. 

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)

x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a)
(b.detach()+a).mean().backward()
print(1, modelA.weight.grad) #1
print(2, modelB.weight.grad) #2
print(3, modelC.weight.grad) #3

 

즉, 다음과 같은 코드의 결과는 

 

#1. A는 학습된다.

#2. B는 학습되지 않는다.

#3. C는 학습되지 않는다. 

 

결국 이 지점에서 B가 학습이 안되고 있다는 문제를 찾았다. 

 

 

참고 링크 : https://discuss.pytorch.org/t/how-to-detach-specific-components-in-the-loss/13983

 

How to Detach specific components in the loss?

I’m a little confused how to detach certain model from the loss computation graph. If I have 3 models that generate an output: A, B and C, given an input. And, 3 optimizers for those 3 models: Oa, Ob and Oc, for A, B and C resp. Let’s assume I have an

discuss.pytorch.org