본문 바로가기
AIML 분야/Generative Model과 GAN

GAN Cocktail: 학습된 GAN 모델을 합친다고?

by 포숑은 맛있어 2021. 7. 2.
반응형
GAN Cocktail: mixing GANs without dataset access

arxiv sanity top recent에 랭크된 논문. 일단 제목이 시선강탈이기는 하다.

https://arxiv.org/pdf/2106.03847v1.pdf     

 

저자들이 예루살렘에 있는 대학이라는데.

이스라엘의 딥러닝을 맛볼 수 있는건가요?

슬픈 DF-VO 논문이 넘 안읽히는 관계로 다른걸 기웃거리고 있다. 좋아.

 

주제

일단 주제는 흔하지 않다. 누가 GAN을 합치긴 합쳐?

조건으로는,

  • 모델 크기가 커지지 않을 것. 모델 두개를 대충 붙이는 게 아니다.
  • 각각의 모델들을 학습시켰던 데이터에 접근할 수 었는 상황을 가정한다.

요런 셋팅에서 시도하는건 본인들이 최초라고 한다.

 

 

다시 말해, data를 모르고 그냥 다른 도메인에서 학습된 여러 GAN 모델들을 한 모델로 merging하는 것!

GAN이 N개가 있다고 하면, 하나의 Union GAN을 만드는 것이다.

이 GAN은 Conditional GAN인데, 그 condition 'c'는 어느 도메인 (어떤 GAN)에서 왔는지를 의미하는 condition이다.

예를 들자면, 모델 하나는 개를 만든다. 다른 모델은 고양이를 만드는 Generator이다. 두 도메인을 합치기 때문에 고양이도 강아지도 만들 수 있는 GAN을 만드는 것이다!

 

키워드로는 continual learning, GAN, transfer learning을 걸어놨다.

 

 

들어가기에 앞서 (내가) 궁금한 것

  • 그럼 여러개로 합치면 더 잘 만드나?
    예를들어 원래 개를 만드는 단독모델의 FID랑, 합친 모델이 만드는 FID 비교하면 뭐가 더 좋지?

 

아래에 나올 3가지는 걍 베이스라인이다.

이분들이 만든 GAN 칵테일이 아니다.

 

Baseline A: Training From Scratch

그냥 간단하게 생각할 수 있는 것.

아 그냥 GAN 학습 잘 된거면, 각각의 generator로 왕창 생성해서 학습하면 되겠네.

= 여러 GAN의 Data를 활용하는 방법

일반 GAN의 Discriminator랑 좀 다른게, class c랑 latent code z도 같이 받는다. 

하지만 실험 결과, GAN으로 만든건 데이터가 무한정 많겠지만 걍 Real data 사용한게 성능이 더 좋다고 한다. 뒤에서 실험란에 설명 나올 예정.

대략 이렇게.

 

Baseline B: TransferGAN

왜 데이터만 쓰나? 모델들의 weights도 쓸 수 있는거 아닌가?

= 모델 Weights도 쓰자

 

간단하다. 그냥 우리 학습할 모델의 Initializer로 후보 GAN들을 쓰자는거다.

TransferGAN이라는 논문의 방법을 썼다.

 

Baseline C: Elastic Weight Consolidation

catastropic forgetting이 발생하는걸 해결하려고 Elastic Weight Consolidation (EWC)라는 논문의 방법도 실험했다고 한다.

하지만 별로 좋지 않았다는 것 같음.

 

* catastropic forgetting 문제를 해결하기 위해 고안됨. 설명 참조

https://adioshun.gitbook.io/deep-learning/online-learning/2018-a-study-on-sequential-iterative-learning-for-overcoming-catastrophic-forgetting-phenomenon-of-a

 

Fisher information을 계산해서, 모델의 파라미터의 성능을 위한 중요성을 평가한다. pretrained model 파라미터(theta)들에 대한 empirical fisher information을 계산하기 위해서, L(x|theta) log likelihood를 가지고 F_i값을 계산한다. log likelihoodd의 편미분의 제곱의 평균이다. 이건 Discriminator output의 BCE loss와 동치다.

그래서 generator의 아웃풋을 discriminator에 같이 넣어가지고 데이터를 generator에서 왕창 생산하고, BCE loss를 구하며, backprop을 통해 미분을 계산할 수 있다.

그리고 아래와 같은 로스를 추가. 

 

 

 

여기서부터는 진짜 GAN 칵테일 레시피를 알 수 있다.

 

Our Approach: GAN Cocktail

transfer learning을 하다보면 하나의 모델 파라미터만 가져오게 된다. 그래서 여기선 모든 모델을 활용할 수 있게 한다.

  • first stage: model rooting for all the input GAN models
  • second stage: model merging by averaging the weights of the rooted models and then fine-tuning them using the original models.

 

1. Model Rooting

original model의 성능을 최대한 유지하면서 여러 GAN을 합치기?!

그러기 위해서, 모델의 weights를 어떻게든 합쳐야한다.

 

여러 모델을 합치는 방법에는 산술연산이 있다. 예를들면 EMA (Exponential Moving Average). 원래 있는 기법이며, 학습중에 checkpoints들을 가지고 모델 weights를 평균내는 방식이다. 이 얘기는, 평균내는 모델들이 전부 common ancestor를 가졌다는 것이다.

 

그런데 우린 지금 여러 모델을 합쳐야한다. 이런 방식을 쓰려고 하면 모델 weights가 다 같은 차원을 가지고 있어야하지 않는가? 그래서 일단 다를수도 있고 같을수도 있는 구조를 가진 N개의 모델들을,

 

1. N개의 모델을 다 같은 구조로 변환하는 작업이 필요하다. (N개의 모델구조가 각각이 다르다면 말이다)
convert 어떻게 하는건지 이따 나올까? 찾아보자.
뭔가 그냥 해당 모델의 output을 가지고 GAN을 학습했을 것 같다. (다른 도메인-페이크, 해당 도메인-리얼?)

그리고나서,
2. 모든 모델이 공유할 수 있는 common ancestor를 만든다. 이게 GAN_r으로, root model이 된다.
위에서 만든 N개의 통일된 구조를 가진 모델 중에서 하나를 그냥 고른 것이다.

3. 이제 GAN_{r->i}은 GAN_r을 가지고 initialization 해주고, GAN_i의 output을 가지고 학습할것!
그러면 source data r에 대해 catastropic forgetting이 필연적으로 발생할 모델이다.
그럼 GAN_r과 GAN_r->i는 아키텍쳐도 동일하고 common ancestor도 가진다.
그럼 이제 common ancestor를 가지기 때문에 EMA를 적용해도 semantically 의미있는 결과를 뽑을 수 있다는거다!

 

아래 그림을 보면, 똑같이 EMA를 가지고 평균을 낸 결과가 있다.

root의 유무를 가지고 EMA로 모델을 합친 것을 visualization 했다. 역시 common root가 중요한가보다.

 

2. Model Merging

이제 우리는 N개의 모델이 있다. 게다가 averaging까지 해줘서 뭔가 여러 클래스를 아우르는 정보가 있다.

그런데 묘하다... 각각 학습에 썼던 클래스들 사이 어중간한 곳에 있는 느낌...? (위에 기괴한 사진)

우리가 합침으로써 원하는건? 모든 데이터셋을 아우를 수 있는 filter와, class-specific한 filter를 구분하는 것.

이를 위해 averaged model들에 대해 adversarial training을 할건데, 처음에 original GAN 모델들을 데이터 소스로서 사용할 것이다.

 

이전 단계에서 N개의 rooted models가 있으면 (GANr, GANr->i), GAN_a라는 평균모델 하나를 만들 수 있다.  다른 GANr->i들의 모델 파라미터값과 GANr의 모델 파라미터를 합치면 총 N개의 모델의 파라미터 평균을 낼 수 있으니까.

물론 걍 평균내기보다도 fisher information matrix의 diagonal을 기반으로 하는 등 더 멋지게 합치는 방법이 있긴 하겠으나 딱히 성능이 오르진 않았기에, 단순 평균을 내는게 괜찮다고 한다.

그리고나서, 아까 맨 위에 나온 1번 식을 minimizing하도록 GAN_a를 (GAN_i를 사용하여) finetuning한다.

 

어우 말로 쓰니까 복잡하다. GANi, GANr, GANr->i, GANa... 그냥 그림으로 그려줬으면 좋을텐데 없다. 아쉽다.

 

 

Results

FID score로 evaluation 할거다. 50k의 generated 이미지를 쓸거다.

 

 

 

 

실험이랑 visualization 몇개 더 있는데, GAN 퀄리티가 그렇게 좋지는 않아보인다.

그래도 다른 transferGAN이나 EWC같은 것 보다는 좋으니 뭐...

 

여튼 신기한 논문이다.

반응형

댓글