본문 바로가기
AIML 분야/Generative Model과 GAN

[논문리뷰] Transformer + GAN에 관한 논문 리뷰

by 포숑은 맛있어 2021. 6. 15.
반응형

논문 두개 대충 봐야지

 

Transformer-based GAN 논문이 있고, Transformer Generator + CNN Discriminator 논문이 있다.

후자가 더 나중에 나왔다.

 

1.

"TransGAN: Two Transformers Can Make One Strong GAN"
[요약]

transformer만 활용한 아키텍쳐에 대한 고찰이 주된 주제.
Transformer 기반의 Discriptor는 그냥은 안좋다. 여러 문제점이 있다.
CNN Discriptor를 사용한 두 경우가 성능이 좋았는데, 이걸 끌어올리려는 노력보다는 그냥 transformer의 한계를 보고싶었던 듯.
그래서 여러 기법으로 영끌하면 AutoGAN 정도의 성능은 transformer만으로도 달성 가능하다는 것을 입증.
따라서 GAN 성능을 끌어 올리고 싶거나, transformer가 가지는 장점을 극대화하여 뭔가 활용하려는 목적으로 읽는거라면 비추천.

github: https://github.com/VITA-Group/TransGAN 

무소속 간지~

 

 

말그대로 트랜스포머만 사용해서 GAN 아키텍쳐를 만들었다는 것 같은데.

이 논문은 3가지 contribution이 있다.

  • Model Architecture: CNN없이 transformer만 사용한 GAN은 이 논문이 최초라고 밝혔다.
  • Training Technique: 어떻게 해야 TransGAN을 더 잘 학습시킬 수 있는지에 대해 말한다.
    augmentation, multitask co-training for generator + Self-Supervised auxiliary loss, self-attention을 위한 localized initialization이 있다고 한다.
  • Performance: CNN기반 SOTA와 유사한 성능이 나왔다고 한다.

 

# 그러면 Image Generation에서 이전에 트랜스포머가 사용된 적이 있을까?

18년, 19년도에 있긴 했는데 GAN은 아니었다.

20년도 논문중 하나는 CNN GAN을 사용해서 context정보가 많은 코드북을 만들고, 이후에 composition하는 부분에 autoregressive transformer 아키텍쳐를 사용했다.

 

 

# 그러면 본론으로 들어가자. 어떻게 만들었나?

 

일단 GAN은 Generator, Discriminator가 있다.

GAN 설명은 검색하면 위조지폐와 경찰 어쩌구가 무진장 많이 나오므로 설명 스킵.

 

 

Transformer Encoder는?

Vaswani et al., 2017의 Transformer encoder를 basic block으로 사용하였다.
이따 그림 보면 알겠지만 "input - LN - (멀티헤드 어텐션) - LN- (MLP - GELU) " 구조이며,
(~~~) 뒤에 residual connection이 추가되어있다.

이게 모든 Generator, Discriminator의 기본 단위라고 생각하면 된다.

 

Generator는? (feat. 메모리!!)

메모리가 왕창 든다. 생각해보자. 32*32 이미지를 만들려고 해도 1024라는 겁나 짱긴 시퀀스를 다뤄야한다.
self-attention연산... 시퀀스 길이의 제곱에 비례할텐데?
-> CNN과 비슷하게, 멀티 스테이지로 만들면 되겠구나!
-> embedding dimension을 줄이며 seq len은 점진적으로 늘리는 방식을 택함
-> 이걸 여기서 memory-friendly transformer라고 부름.

 

Generator 과정

  • Input: Random Noise
  • MLP를 거쳐서 H*W*C 벡터가 나오는데, 이걸 reshape하여 (HW)*C 이렇게 C채널의 HW길이 시퀀스로 만듦.
    8*8을 사용했다.
  • 우리가 원래 아는 그 transformer encoder를 거침.
  • 하지만 이때, 각 스테이지별로 Upsampling 모듈을 적용. high resolution image를 만들기 위함이라고함.
    이는 reshape, pixelshuffle 모듈로 나눠짐.
    • reshape: 1D -> 2D image 형태로 바꿈
    • Pixelshuffle: 원래 16년도에 나온 논문의 방법을 사용. 이걸 거치면 H,W는 각각 2배이고 채널은 1/4로 줄어든다.
    • reshape: 다시 1D로 바꾸기
  • 따라서, 인코더 스테이지가 올라갈수록 이미지는 커지고 채널은 줄어든다. 이렇게 해서 computation과 메모리 문제를 해결.
    그렇게 원래 목표 이미지 크기가 되면 projection을 통해 마지막엔 RGB 3채널이 된다.

 

Discriminator

real/fake 탐지만 하면 그만!

BERT와 똑같이 [CLS] token를 쓰는데 이게 real/fake 용도.

-> 똑같이 이미지를 8*8 토큰으로 쪼갬. 그러면 H*W*C = 8*8*C (Generator의 것과 같은 사이즈)

 

Evaluation

19년도의 AutoGAN의 G, D와 비교했다.

각각 G, D에 transformer를 쓴게 좋을지, 안쓰는게 좋을지 총 4가지 combination이 나올 것이다.

 

# 1. 일단 이건 베이직한 실험. 실험을 통해 발견한 사실을 정리하면,

  1. Transformer G는 굉장한 capacity를 가진다!
  2. 하지만 Transformer D는 구리다. AutoGAN의 D가 나음.

Inception Score와 FID를 보니까 D는 그냥 CNN 기반이 낫다.

# 2. Data Augmentation

Discriminator가 학습이 안된다는걸 위 실험으로 알았으니, 개선해보자.

transformer는 data-hungry한 모델이라고 알려져있다 (Dosovitskiy et al., 2020). 그러면 Augmentation을 활용해볼까?

 

Data Augmentation: DiffAug(20년) 방법을 사용

비교대상: 원래 CNN기반 다른 GAN {WGAN-GP(17년), AutoGAN, StyleGANv2(20년)}, transGAN

 

실험결과를 보면, Augmentation을 하면 TransGAN은 성능 차이가 심하다.

물론 이 성능은 테이블1의 Discriminator에 AutoGAN을 사용한 것보다는 안좋다.

하지만 Discriminator가 학습할 수 있다는걸 의미하지 않을까 싶다. 역시 데이터 양이 중요한 transformer.

 

 

# 3. Co-Training with Self-Supervised Auxiliary Task

transformer G와 super resolution task를 같이 학습시켰다.

아래 그림에서 LR: low resolution, SR: high resolution을 의미한다.

 

이전 연구에서도 일반적인 GAN에서, self-supervised auxiliary task (rotation pred)가 GAN 학습 안정화에 도움이 된다고 했다.

이에따라 transGAN에서도 도입했다는데, Generator loss에 MSE를 추가한 것이다.

제일 위에서 만들어진게 high resolution이고, 초반에 만들어진건 low resolution이었으니까. (처음에 아키텍처 그림 참고)

이렇게

# 4. Locality-Aware Initialization for self-attention

일단 CNN은 natural image smoothness라는  built-in-prior를 가지고 있도록 설계 되었음. 근데 트랜스포머는 그런 bias가 적은 아키텍처란말임.

그런데 ViT같은데서 요즘 나오는 결과를 보면, 트랜스포머가 convolutional structure를 여전히 배우고 있다는 것임. 그럼 그냥 역시 이미지 도메인을 다룰때에는 로컬리티가 중요하단 소리 아님?

저번에 gMLP 논문 읽을때도 비슷한 생각을 했음. self-attention 없는 MLP를 가지고 학습할때도 2D로 visualization 하고나면 locality가 강하다는 느낌을 받았다.

 

그렇다면 이제 고민이 되는거다. inductive bias가 강한 CNN을 택하냐, 어느정도의 유동성을 위해 transformer를 택해야하나?

Esser et al., 2020에서는 low level image strueture를 위해서는 CNN을 유지해야한다고 주장했다.

이 TransGAN 논문에서는 CNN을 도입하는 대신에, 그냥 transformer 기반을 유지하기는 하지만 저 주장을 참고하여 'self-attention의 warm-starting'을 함으로써 보완하려고 한다.

이게 어떤건지는 아래 그림.

 

나머지를 다 마스크 씌워서 allowable region을 조절하는거다.

이전에도 그런식으로 마스킹되지 않은 local neighbor를 더 중점적으로 하게 하는 연구가 있었지만, 여기서는 방법이 다르다고 한다.

이 논문에서는 마스크를 점진적으로 약화시켜서 최종적으로는 self-attention이 global하도록 만들었다.

 

방금 3,4에서 언급한 auxiliary task를 추가하는 것, attention mask initialization 기법의 실험결과는 다음과 같다.

성능 개선이 있다!

# 5. 더 큰 모델로 스케일 업!

아래 결과를 보자.

1. 원래 것

2, 3. embedding dimension 늘리기

4. depth 늘리기 (transformer encoder) + embedding dim 같이 늘리기

 

이렇게 방법을 총동원하여 영혼까지 끌어올리면 transformer만 사용한게 AutoGAN 정도의 성능까지 드디어 올라왔다!

StyleGAN v2가 여전히 막강하긴 하다.  

 

암튼 이런 과정을 통해서 든든한 모델 TransGAN-XL을 얻었으니, SOTA들과 비교하러 가자.

여기부터는 각종 모델들과 성능 비교에 관한 내용이다.

CIFAR10, STL10, CelebA Dataset에서 한다. 각각 32*32, 64*64

ㅇㅕ기 StyleGAN v2가 없는건 함정

 

 

 

 

Conclusion.

pure transformer는 취약점이 있다. 왜냐면 데이터 헝그리!

게다가 트랜스포머를 GAN에 활용할때는, image generation이라는 태스크의 어려움도 있는데다가 GAN training이 stable하지 못하다는 난이도 요소가 존재한다.

그럼에도 저자의 연구를 통해 어느정도 강력한 CNN기반 GAN모델에 필적하는 성능으로 가능성을 보였고, 이를 개선 해볼만한 포인트를 제안한다.

 

  • tokenizing을 걍 8*8 14*14 이런식으로 했는데, 좀 더 잘 할수는 없을까?
  • pretraining을 잘 하면?
  • attention을 좀 더 효율적으로 잘 주는 법? 특히 memory 성능 trade-off 때문에.
  • conditional image generation에서는 어떻게 하지

 

 


그렇다.

이제 다음 논문으로 넘어가자.

arxiv sanity에 한달 기준 top recent에 있길래 업어온 논문이다.

 

얘가 그 Discriminator는 CNN, Generator를 transformer를 사용한 연구다.

대놓고 abstract에서부터 transGAN을 언급하고 있다.

저렇게 조합을 했다니까 뭔가 성능을 잘 올렸을 것 같은 기분이 팍팍 들지 않는가?

 

아카이브 https://arxiv.org/pdf/2105.10189v1.pdf   

저자는 독일인들 (=transGAN과 다른 저자)

연구를 빨리 했구만

 

Combining Transformer Generators with Convolutional Discriminators

이 논문 주제가 Discriminator CNN, Generator Transformer이다.

그럼 Discriminator는 SOTA 모델들 것으로 고정하고, Generator를 transformer vs SOTA 모델의 것을 비교하는게 핵심이겠구나!

 

근데 논문 abstract부터 어조가 그렇게 신나보이지는 않는게, 결과가 그저 그럴 것 같다는 느낌이 드는건 기분탓일까

 

일단 본문 걍 뻔한 얘기 다 스킵함

아키텍쳐

- Generator 똑같음 스킵

- Discriminator는 CNN을 쓰겠지? SNGAN 사용함.

 

여러 Discriminator를 실험했다고함.

StyleGANv2같은것도 실험 했던데 18년도 논문인 SNGAN이 제일 조합이 좋았다고함.

아래는 SNGAN 설명.

"It consists of residual blocks (ResBlock) followed by down-sampling layers using average pooling. The ResBlock itself consists of multiple convolutional layers stacked successively with residual connections, spectral layer normalization [27] and ReLU non-linearity [28]."

 

실험

 

- Discriminator architectures: DCGAN, SNGAN, SAGAN, AutoGAN, SytleGANv2 and TransGAN

- Datasets: CIFAR-10, CIFAR-100, STL-10 resized to 48×48, tiny ImageNet [7] resized to 32×32.

- Evaluation Metric: FID, IS

 

와 사실 메트릭 뭐 계산하는지 몰랐는데 설명 나옴

* Inception Score (IS)

IS computes the KL divergence between the conditional class distribution and the marginal class distribution over the generated data

 

* Fr´echet Inception Distance (FID)

FID calculates Fr´echet distance between multivariate Gaussian fitted to the intermediate activations of the Inception-v3 network [38] of generated and real images.

 

 

 

읽다보니 지루하다.

그냥 있는것들 조합해서 돌려보고 분석했다는게 끝이지, 성능을 끌어올리기 위해 transformer Generator를 새로 고안했다거나 학습 기법을 새로 적용했다는 내용이 아닌 것으로 보임. 왜냐면 실험을 보면 이렇기 때문에.

 

 

실험1: Discriminator topology가 CIFAR10에서 어떤 역할과 영향이 있는지 알아보기 위해 실험을 함

음... 이미 Discriminator는 TransGAN에서 transformer 기반 애들이 안좋은게 밝혀지지 않았나?

그걸 해결하려고 열심히 연구한 논문이 TransGAN인데요

 

여기서는 data augmentation, auxiliary tasks, locality-aware initialization 다 안 썼다고 한다.

 

  • 만약 좋은 D(스타일겐v2 처럼)를 쓰면 Transformer G가 빨리 수렴하지 못해서 성능이 구리다.
  • 근데 D가 작으면 (DCGAN) Generator의 학습에 별로 도움이 안된다
  • SNGAN이 잘 나온댔는데, 거기서 Spectral Normalization을 D에 쓴게 성능에 있어서 큰 도움을 준건 아니다.
    (라고 주장하는데 별로 동의하진 않는다)
  • 모델 크기는 중요하다. 근데 어느정도 크기에 가면 수렴

 

 

뭐지...

아니 augmentation 등 다른거 안쓰면 트랜스포머 D가 구린거 이미 밝혀졌는데 뭘 실험하는거지...?

SNGAN이 좋으면 그걸로 Discriminator 고정하고 TransGAN을 어떻게 modification하는게 중요하지 않나 싶다

 

 

Frequency 분석도 했다고 한다.

합친게 제일 real에 가까웠다고 한다.

 

 

드는 생각

 

1. 제목은 저렇게 잡아놓고 딱히 Transformer에 대한 고찰을 한건 아님

2. 그렇다고 해서 SNGAN이 왜 좋은지 이유를 설명하는 부분도 없음 (내가 못본걸수도 있음. 논문에 하도 평이한 문장이 많아서 스킵하면서 읽음)

3. augmentation 등을 사용하지 않고 비슷하다는건 장점이 아님. 왜냐면 이건 트랜스포머 Discriminator의 단점보완용이었기 때문. vanilla도 성능수준 비슷한거야 이미 TransGAN에 나왔음. 차라리 (augmentation같은) 추가적인 것들을 연구하여 성능을 더 올리려는 시도를 통해 분석을 했어야함

4. 전반적으로, 여러 조합을 실험한 후 그 결과를 두고 1차원적인 분석이 끝인 것 같음

 

 

불만을 좀 많이 적긴 했는데 독일인이니까 못보겠지? 보고있으면 쏘리

근데 제목은 기대감 만땅이면서 별로 재미 없어서 어쩔 수 없었음.

abstract부터 쎄함을 느꼈어야했음..

대체 얘가 왜 top recent (심지어 1달기준)에 있는건지 이해할수가 없음

 

 

할말은 많지만 이만 줄이겠다

 

누가 이걸로 멋진 연구를 해서 논문 내줬으면 좋겠다!

반응형

댓글