본문 바로가기
딥러닝 어쩌구/Challenge 기록

[연구일지] ALBEF 공식 코드를 뜯어요 (VLP, text-to-image retrieval)

by 포숑은 맛있어 2022. 7. 26.
반응형

 

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

 

공식 코드

https://github.com/salesforce/ALBEF

 

GitHub - salesforce/ALBEF: Code for ALBEF: a new vision-language pre-training method

Code for ALBEF: a new vision-language pre-training method - GitHub - salesforce/ALBEF: Code for ALBEF: a new vision-language pre-training method

github.com

논문

https://arxiv.org/pdf/2107.07651.pdf

 

 

[내가 중점을 둔 것]

  • Retrieval.py
    : image text retrieval 문제만 풀거고, 특정 데이터셋에서 잘 되게 하는 게 목표이다. 그래서 딱히 pretrain을 돌릴 이유는 없다.
  • bert encoder, image encoder
    : 바꾸려고요
  • 몇가지 디버깅

 

[설치]
데이터는 논문에서 얘기하는거 말고 내가 사용하려는 것으로 변경하였음. image, caption pair가 json file에 정의되어있다.

  • 깃허브 requirements 맞춰주기, transformers 4.8.1 / timm 0.4.9
  • 특히 transformer 버전 반드시 맞춰야함. 안 그러면 에러난다

 

데이터셋 관련 얘기는 챌린지 끝나고 기회가 되면 써야지

이 분야 처음 하는 주제에 일단 설렁설렁 하는중이다

 

코드 바꾼 것들이랑 자세한 설명 (technical report?) 그런건 챌린지 끝날 때 기회가 되면 올려보겠다. 


Issues

issue 1: evaluation()에서 out of memory

[문제 상황]

이제 셋업을 끝냈으니 상쾌한 마음으로 학습을 돌려봤다.
엥? 얍삽하게 1일 뒤에 에러뜸. 심지어 checkpoint 저장이 evaluation 코드 다음에 있어서 몽땅 날렸습니다. 우효!
text, image를 모두 불러온 후 similarity 계산을 하는데 이때 (텍스트는 괜찮지만) image 불러오는데서 gpu 메모리가 터진다. 계산기 때려보면 대충 82GB정도 소요되기 때문.

울면서 htop으로 메모리 찍어보면서 확인했다.

 

[해결방법]
Retrieval.py 파일 고치기.

이미지 읽어올때 .cpu()로 옮겨준다. similarity 계산할때만 .cuda()에 올려서 처리 하였음.

 

 

issue 2: 학습 중에 자꾸 메모리가 증가한다

[문제 상황]

메모리가 야금야금 증가한다. 자라나라 모리모리

 

-> 정상이랍니다. (깃허브 이슈에 저자 왈.)

고려 해서 적절한 batch size를 선택하였음

MoCo queue 때문인 것 같은데 이거 미리 메모리 할당하지 않나.. 음

 

 

배치 사이즈 관련
train batch size=32 하면 GPU당 35GB 소모 → 펑!
그런데 배치 사이즈가 queue size를 나눌때 숫자가 떨어져야 해서, 2의 제곱이어야함.
→ 배치 사이즈를 울면서 16으로 수정....... (configs/Retrieval.yml)

 

 

pth 저장할때...

main process에서 하도록 하자. 그냥 저장하면 나중에 load시에 serialization error가 뜬다. 못 쓰는 파일인거.

 


어떤 코드를 돌려야 합니까?

무작정 시작했으면 가장 먼저 고민해야할 사항이다. VLP분야는 downstream task들도 많고 그러니깐 더 헷갈린다.

 

1) Pretraining

  • 원래 시나리오라고 하면 먼저 Pretrain.py를 돌려야함. 그 다음에 Downstream task에서 돌린다. (Retrieval.py, VQZ.py 등)
  • ALBEF 공식
    • text, image encoder 각각에는 pretrained BERT와 ImageNet-1K를 사용
    • 그리고 모델 전체 pretrain은 몇가지 Image-text pair 데이터셋들을 사용한다. (논문 및 공식코드 pretrain.yml 참고)
    • 위에서 말한 pair dataset들에 pretrain한 것을 공식 깃허브에서 체크포인트로 지원한다.
  • 내 경우
    • 챌린지에 규칙에 의하면 image encoder는 ImageNet을 포함하여 몇가지 지정해준 데이터셋에서만 pretrain 가능하다고 한다. text encoder에는 제한이 없다. 따라서 우리는 공식 코드에서 지원하는 checkpoint 파일을 사용해서는 안된다.
    • 의문: 그러면 Pretrain.py 따로 돌려야하나?
      -> 팀원 왈: 굳이?
    • 그래도 혹시 모르니까 pretrain과 finetuning setting이 다른지 확인해보는걸로.

2) Image-text Retrieval

  • Retrieval.py를 돌리면 된다. checkpoint 안 불러오고 그냥 from scratch로.
  • 공식 코드: MSCOCO, Flickr30k를 사용
  • 우리는 챌린지 데이터에서 학습해야함. config의 Retrieval.yml에 json 파일 갈아끼워주자

 

그러면 돌려야하는 Retrieval.py 생김새 대충 보기

main()

  • 분산처리 해주고
  • create_dataset(), create_loader()
  • tokenizer 정의, ALBEF 모델 정의
  • checkpoint 있으면 visual_encoder, visual_encoder_m (momentum model) positional embedding 학습했던걸 interpolation함. input image 달라질 수 있어서.
  • epoch 단위로 train() 돌리고, 다 끝나면 evaluation()하여 score 계산.
    • 그런데 우리 데이터셋은 val, test 구분이 없어 똑같기 때문에 하나 지워야함
  • 이걸로 itm_eval()하여 val_result를 뽑으며, 이 dict를 출력하는게 log_stats. 이게 log.txt에 append mode로 json dump
  • r_mean 기준으로 best보다 값이 클 경우 checkpoint를 새롭게 저장하도록 구현되어있음

evaluation()

  • 모델이야 뭐 알아서 잘 돌아갈테니 eval을 봐주자.
  • text_encoder를 통해 text_embeds, feats, atts를 뽑음
  • visual_encoder를 통해 img_feats, embeds (projection 거친 것. 논문에서 설명한 linear transform)를 뽑음.
  • 데이터가 크다보니 여기서 에러가 발생했던 것. 텍스트는 괜찮지만 Image feature가 문제다.
    특히나 torch.cat()을 할때는 붙인 사이즈만큼의 메모리를 재할당 해야하기 때문에 필요 이상으로 많이 잡아먹는다.

  • 이를 해결하기 위해 위에서 말한 것 처럼 .cpu() 해서 gpu 메모리에 저장하지 않게 바꿨었고, 그래도 사이즈가 크기 때문에 미리 배열 정의해서 채워 넣도록 인덱싱 해버림. validation set 사이즈는 이미 알고 있으니깐.
  • 그리고 text to image retrieval만 필요하고 image-to-text는 안 할거니까 eval 코드에서 빼줬다.

  • 그러면 이제 대충 괜찮아지긴 한데 또 문제는 distributed로 돌리는데 GPU 개수가 늘어나면 또 메모리를 더 씀.
    GPU 4개는 램 메모리 터지고 2개로 하면 멀쩡함. (음?) 설마 공간을 따로 잡는건지 뭔지...
    -> 팀원 분이 해결해주심

해야할 것 & 성능 개선 방향

코드를 어떻게 고쳐야 성능이 오를까?

의견을 나누다보니 나는 주로 모델 내부를 건들려고 하고, 팀원들은 모델 외부에서 해결을 하려는 경향이 있었다.

모두들 이 분야가 처음이라 짧은 기간동안 다양하게 많은 시도를 해야했기 때문에 다같이 나눠서 진행하게되어 좋았다.

아무튼 그래서 나는 주로 모델 내부 관련된 얘기를 할거다.

 

[Dataset 관련, 학습 방법론적인 것] -> 이건 팀원들이 진행함

데이터 규모가 커서 에폭을 많이 못 돌린다.

continual learning이라든가 데이터 전처리 등 모델 외적인 방법으로 성능향상을 시도하고 계신다.

 


[Optimizer와 Scheduler]

그냥 내가 꼼수를 부렸다. 에폭을 적게 돌리는게 왜 문제인건가. 어차피 데이터는 왕창 큰데. 스케줄러만 잘 쓰면 되는거 아닌가.

  • optimizer: AdamW
    • 딱히 안 바꾸는 게 나을 듯. ViT에서는 SGD 대신에 AdamW 쓰는게 국룰이다.
  • scheduler가 epoch based로 구현되어있음
    • 현재: cosine scheduler이며 학습 에폭이 5로 설정되어있음
    • lr을 점점 줄이거나, SGDR처럼 여러번 튕겨서 앙상블 효과를 주는 등의 기법으로 성능을 올릴 수도 있는데 지금 데이터는 너무 커서 에폭을 여러번 못돌리므로 그런 효과를 주기 어려움
    • 일정 Iter마다 scheduler step 하나씩 해주면 데이터가 많아 여러 에폭을 돌릴 수 없는 문제가 어느정도 해결될 것임
  • 기타 코드 수정: 에폭, 혹은 일정 iteration마다 pth 저장하도록 변경 필요. 데이터셋이 너무 커서 날리면 아깝다.

 

[Language Model 관련]

이미지는 특정 도메인이 아니고서야 거기서 거기지만 텍스트는 경우가 달랐다.

이 챌린지 성능향상의 핵심은 언어처리라고 생각했다.

 

ALBEF에서 쓰는게 text encoder가 그냥 pretrained BERT base (uncased) 버전이다.

그런데 우리는 알파벳이 포함되어있긴 하지만 순수하게 '영어'라고 보기 어렵다.

 

그렇다면 먼저 BERT-base-uncased의 능력이 어느정도인지 이해할 필요가 있다.

학습한 데이터가 무엇인지, 이 모델의 빈칸 추론 실력은 어떠한지 등등 살펴봤는데 챌린지 데이터셋에 사용하기에 무리가 있는 것으로 판단하였다.

참고로 이거 찾아보는 와중에 hugging face라는 사이트를 알게 되었는데, nlp쪽은 이런데서 pretrain 모델들을 땡겨서 쓰기 좋게 되어있구나 싶었다. 웹에서 인퍼런스도 가능하다.

https://huggingface.co/bert-base-uncased?text=I+wanna+go+%5BMASK%5D+right+now.+I%27m+so+tired 

 

bert-base-uncased · Hugging Face

👁 multimodalart/latentdiffusion 🤌 clip-italian/clip-italian-demo 👁️‍🗨️ flax-community/Multilingual-VQA 🤯 flax-community/clip-reply-demo 🌍 Gladiator/Text-Summarizer 🚀 erc/entity-referring-classifier ⚡ ysharma/text-to-image-to-vi

huggingface.co

바로 이렇게.

아무튼 조사 결과로는,

1) 우리 도메인에 맞는 pretrained model은 없었고,

2) 그나마 vocab.txt 리스트 참고해서 가장 많이 포함하는걸 (+vocab 전체 사이즈도 고려해서 적당히) 고르긴 했다.

3) 그런데 그렇게 해도 우리 데이터에서 필요한 vocab들의 상당수가 unknown token으로 뭉뚱그려지는 문제가 있다.

 

그래서 생각한 해결 방법으로는,

1) 우리 데이터에 필요한 vocab들을 적당히 정의해본다. 다 넣진 않았고 몇백개 정도 추가해줬고, tokenizer에 반영하도록 코딩했다.

2) Retrieval.py 돌릴 때 MLM loss를 추가하여 챌린지 데이터셋의 언어 또한 bert가 잘 커버해주도록 바꾼다.

 

1만 진행했을때는 성능 향상 잘 모르겠는데, 2를 하니까 확실히 오른 게 확인 되었다.

 

 

[Image Encoder 관련]

ALBEF의 loss는 3가지다. ITC, ITM, MLM. 심지어 downstream task에서는 MLM을 뺀 2가지다. (물론 위에서 언급했듯이 retrieval.py에도 MLM 추가했다.)

image encoder의 domain transfer가 완전히 잘 되지 않을 수도 있지 않나 싶다. generalization이 잘 되는 모델임에는 틀림 없으나 image representation을 새 데이터에도 잘 되게 하는 것은 아니라 생각한다.

 

그리고 구현 코드를 봤는데 너무 익숙한데요. 그냥 초창기 16*16 뭐시기 vit 원 논문 유행하기 시작할때 돌던 코드다. 그냥 vanilla ViT를 (ViT-B 12 layers) 가져다가 DeiT pretrained parameters (ImageNet 1K 사용하여 학습됨) 불러와 사용한다.

 

위 두가지 이유 때문에 image encoder를 개선해서 성능을 더 올릴 여지가 있다고 판단하였다.

1) swin transformer와 같이 좀더 연산이 줄거나 개선된 vision transformer 아키텍쳐를 사용하거나,

2) masked auto encoder 논문처럼 1 layer vit decoder를 붙여서 patch를 맞추도록 하여 (마치 NLP의 MLM같은) MSE loss로 self-supervised learning을 할 수 있겠다.

연산량이 늘어날 수도 있겠지만 더 좋은 정보를 준다는건 수렴 속도가 빨라진다는 의미도 되기 때문에 이렇게 두가지를 고려했다. 실제로 MAE 논문에서 학습속도가 3배 정도 빨라졌다는걸 본 것 같다.

 


Pretrain 거르고 Retrieval 코드만 본다.

 

Pretrain.py와 Retrieval.py의 차이?

  • pretrain에서는 from models.xbert import BertConfig, BertForMaskedLM의 BertforLM을 쓴다.
    BertforLM.bert 이런식으로 text_encoder를 불러와서 씀. 아래에서 말할 BertModel이랑 head를 묶어서 정의한게 BertforLM이기 때문이다.
  • retrieval에서는 from models.xbert import BertConfig, BertModel의 BertModel을 사용.
    .bert 할 필요 없다.

 

LM전용과 그냥 인코더를 이렇게 나눠놔서 그렇다.

@add_start_docstrings(
    """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

 

 

아무튼 Retrieval.py를 보기 위해 model_retrieval.py의 ALBEF 클래스 본다.

-image encoder: vit

-text encoder: bert

-위에 것들 각각을 정해진 embedding size로 맞추기 위한 proj. (nn.Linear) 논문에서 linear transform 한다는게 이걸 의미함

-그리고 위에 4개의 것들에 대해 각각 momentum model을 가지므로 총 8개.

- 추가로 image queue, text queuee도 정의되어있다.

 

참고로 text encoder + fusion model이 합쳐서 오리지널 BERT_base 인코더이다.

6 레이어씩 절반 잘라서 앞부분은 text encoder로 사용, 뒷부분은 fusion용으로 사용한다. 그러다보니 cross modal attention을 계산하려고 image encoder 부분 것을 fusion 중간에 넣어주는거다.

그러므로 text encoder를 정의한 이후에, forward 할때에는 mode='text' or 'fusion'으로 구분해서 쓴다.

 

 

forward()image, text, alpha, idx를 받는다.

 

1.

먼저 unimodal loss. momentum model을 가지고 distillation loss 발생시켜주고 (i->t, t->i)

 

2.

그 다음에 fusion 부분.

positive sample에 대해 output 만들고

negative sample에 대해 output 만든다.

negative는 batch size만큼 돌아가면서 1개씩 뽑을건데, b번째 데이터 입장에서 얘랑 가장 헷갈렸던 것을 더 높은 확률로 뽑는다. weight들이 multinomial 확률값이다. 각 이미지에 대해 negative text를 뽑을때는 weights_i2t를 참고하며, 텍스트에 대한 네거티브 이미지를 뽑을때는 반대이다.

아무튼 pos, neg 샘플들의 last hidden state에서 임베딩의 0번째 부분을 뽑아 itm head에 넣어 itm loss를 발생시킨다. image-caption pair가 positive인지 negative인지 맞추는 classification loss이다. 당연히 Pos sample들은 1, neg는 0이라고 맞춰야한다.

 

이렇게 itc, itm loss만 발생시켜 학습한다.

mlm은 pretrain시에만 쓴다. retrieval.py 돌릴때는 text encoder finetuning 되는건 맞는데 MLM을 안 쓰는 것 뿐인 듯.

(나는 사용하도록 바꿨지만 아무튼.)

 

    def forward(self, image, text, alpha, idx):
        
        image_embeds = self.visual_encoder(image) 
        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

        image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 
        text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
                                        return_dict = True, mode = 'text')            
        text_embeds = text_output.last_hidden_state
        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 

        idx = idx.view(-1,1)
        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)  
        pos_idx = torch.eq(idx, idx_all).float()       
        sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)     

        with torch.no_grad():
            self._momentum_update()
            image_embeds_m = self.visual_encoder_m(image) 
            image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
            image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
            text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,             
                                                return_dict = True, mode = 'text')    
            text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
            text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)

            if self.distill:               
                sim_i2t_m = image_feat_m @ text_feat_all / self.temp 
                sim_t2i_m = text_feat_m @ image_feat_all / self.temp   

                sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
                sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 

        sim_i2t = image_feat @ text_feat_all / self.temp 
        sim_t2i = text_feat @ image_feat_all / self.temp           

        if self.distill:
            loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
            loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
        else:
            loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()
            loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()   

        loss_ita = (loss_i2t+loss_t2i)/2

        self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)

        ###=================================###
        # forward the positve image-text pair
        output_pos = self.text_encoder(encoder_embeds = text_embeds, 
                                        attention_mask = text.attention_mask,
                                        encoder_hidden_states = image_embeds,
                                        encoder_attention_mask = image_atts,      
                                        return_dict = True,
                                        mode = 'fusion',
                                       )            
        with torch.no_grad():
            bs = image.size(0)      
            weights_i2t = F.softmax(sim_i2t[:,:bs]+1e-4,dim=1)
            weights_t2i = F.softmax(sim_t2i[:,:bs]+1e-4,dim=1)

            mask = torch.eq(idx, idx.T)
            weights_i2t.masked_fill_(mask, 0)
            weights_t2i.masked_fill_(mask, 0) 

        # select a negative image for each text
        image_embeds_neg = []    
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2i[b], 1).item()
            image_embeds_neg.append(image_embeds[neg_idx])
        image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   

        # select a negative text for each image
        text_embeds_neg = []
        text_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
            text_embeds_neg.append(text_embeds[neg_idx])
            text_atts_neg.append(text.attention_mask[neg_idx])
        text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   
        text_atts_neg = torch.stack(text_atts_neg,dim=0)      

        text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     
        text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)     

        image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
        image_atts_all = torch.cat([image_atts,image_atts],dim=0)

        output_neg = self.text_encoder(encoder_embeds = text_embeds_all, 
                                        attention_mask = text_atts_all,
                                        encoder_hidden_states = image_embeds_all,
                                        encoder_attention_mask = image_atts_all,      
                                        return_dict = True,
                                        mode = 'fusion',
                                       )                         

        vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
        vl_output = self.itm_head(vl_embeddings)            

        itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
                               dim=0).to(image.device)
        loss_itm = F.cross_entropy(vl_output, itm_labels)     

        return loss_ita, loss_itm

 

 

음.

암튼 ALBEF retrieval 모델 동작 원리 대략 알았고, 수정도 해놔서 이제 Image encoder를 건드리고 싶다.

정말 다행스럽게도 image encoder 바꾸기 매우 쉬워보인다.

 

이미지 인코더는 정말 딱 image embed 만드는 용도로만 쓴다.

중간 레이어를 끌어와서 쓰고 그런건 아니니까 feature embedding output size만 맞춰주면 끝.

vit.py만 갈아끼우면 된다.

 

참고로 pretrain의 경우 DeiT Pretrain을 가져오지만 ViT는 그렇지 않다.

 

 

-

 

text encoder 보자.

 

_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"

 

 

    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )

 

token 추가

 

add tokens: 이거 참고

https://github.com/huggingface/transformers/issues/1413

 

Adding New Vocabulary Tokens to the Models · Issue #1413 · huggingface/transformers

❓ Questions & Help Hi, How could I extend the vocabulary of the pre-trained models, e.g. by adding new tokens to the lookup table? Any examples demonstrating this?

github.com

 

기존의 30522개 토큰에다가 새로운 토큰을 추가했고, 늘어난만큼 뒷부분에 임베딩 추가 생성하는데 random init한 값이 채워진다.

NLP가 처음이라 모르겠는데, 새로 추가한 단어들에 대해서는 zero init을 해야하는지 random init을 해야하는지 의문이긴 하다.

 

인터넷에서 병음 긁어다가 407가지 있는거 추가했더니 토큰 개수가 30522 -> 30696가 된다. 중복 걸러져서.

 

 


인퍼런스 얘기 잠깐 추가.

i2t, t2i를 하도록 구현 되어있는데 난 text to image만 쓴다.

텍스트를 보고 가장 적합해보이는 이미지를 어떻게 뽑아올까?

이걸 다시 풀어서 쓰자면, 모든 텍스트에 대해 인퍼런스를 해야하는데, 각 텍스트마다 매칠될법한 topk개의 이미지가 있다는거다.

 

ALBEF 논문에서 itm loss 발생시키는 head 부분을 생각해보자.

hard example mining 하여 가장 헷갈릴 것들을 골라 그중에서 positive인지 negative match인지 판단했다. 여기서 hard example은 image, text 각각의 unimodal embedding을 가지고 similarity를 구하여 가장 유사한 것들 topk를 고르는 방식이었다. (ITC loss 발생시키는 구간에서 계산)

유사도는 similarity matrix t2i = matmul(text embedding, image embedding)으로 구했으며 여기서 topk similarity를 가지는 샘플들의 인덱스를 구한다. (k 개수는 config에) 이걸 topk image라 생각하는거다.

그러니 fusion파트(BERT 뒷단)에서 cross attention 할때는 해당 topk image의 feature들을 긁어와 사용하는거다.

오.. 하나 맞추는데 연산이 많겠네요? 네 많네요...

반응형

댓글