치즈의 AI 녹이기

GAN 이해하기 본문

인공지능 대학원생의 생활/구글링

GAN 이해하기

개발자 치즈 2022. 3. 22. 17:46

GAN의 구조 

  • Discriminator 입장 : 들어오는 데이터(fake or real)에 대하여 fake/real을 구분하는 task 수행.
  • Generator 입장 : 최대한 real에 가까운 데이터 생성. 

 

GAN Loss 이해하기 

GAN Loss function

 

  1. real을 real이라고 하는 경우 - 0
  2. real을 fake라고 하는 경우 - log0으로 -무한대가 됨. 
  3. fake를 real이라고 하는 경우 - log0으로 -무한대가 됨
  4. fake를 fake라고 하는 경우 - 0

결국, Discriminator 입장에서는 최대값인 0으로 수렴하는 것이 목표이고, 

Generator 입장에서는 3, 4번 경우에 따라 Discriminator를 최대한 속여 최소값인 -무한대로 수렴하는 것이 목표가 된다. 

 

수렴포인트 : generator가 real data의 분포를 학습하고, discriminator는 들어오는 real data에 대해서 real인지 fake인지 1/2확률로 찍는(구별을 못하는) 목표를 가짐. 

 

학습과정 : discriminator가 우선적으로 input data에 대한 정보를 가지기 위해 k번 학습하고 generator가 1번 학습하는 방식으로 번갈아

 

Generator 학습 시 문제 해결

초기 Discriminator을 학습할 때, 기존 loss를 이용하면 작은 gradient에서 시작하기 때문에 학습이 잘 안되는 문제가 있어 loss를 수정함. 

바뀐 Generator loss

 

  1. fake를 real이라고 하는 경우 - 0
  2. fake를 fake라고 하는 경우 - -log0으로 무한대가 됨. 

결국, Discriminator 입장에서는 이전과 동일하게 최대값인 무한대로 수렴하는 것이 목표이고, 

Generator 입장에서는 이전과 동일하게 Discriminator를 최대한 속여 최소값인 0으로 수렴하는 것이 목표가 된다. 

 

GAN 학습 시 문제 해결

Mode Collapse 

여러개의 data class 분포 각각을 mode라고 한다. 

이 mode들 중 generator가 하나만 집중해서 잘 그리는 현상을 collapse되었다고 한다. 

기본 GAN Loss는 Generator 입장에서 "진짜 같은 그림을 그려라"가 목표가 아닌, "Discriminator를 속여라"라는 목표를 가지고 있기 때문이다. 

 

WGAN

Wasserstein GAN은 discriminator한테 잘 구분하라는 목적보다 generator한테 잘 생성해내라는 목적을 주는 아이디어를 가짐. 즉, discriminator 대신 생성한 이미지의 feature 평균과 실제 이미지의 feature 평균이 유사하도록(차이를 좁히도록) 만듦. 

여기서 f란, feature 평균을 계산하는 계산 함수로 보면 된다.

 

바뀐 Generator loss

 

Conditional GAN (CGAN)

해당 모델은 discriminator 입장에서 단순히 real/fake를 판별하는 것이 아닌,

클래스 정보를 받아 어떤 클래스를 판별해야 하는지 미리 알 수 있음.