본문 바로가기
AIML 분야/KD, MultiTask, Foundation Model, Fusion등

Born-Again Neural Network 코드 조지기

by 포숑은 맛있어 2021. 1. 4.
반응형

의식의 흐름대로 아래 코드를 보는 중.

재택근무라 사람이 없어 외로운 나머지 여기에 주절거리며 코드를 리뷰하는 중이다.

 


 

사실 아직 시작도 안했다.

이 글을 다 쓸때 쯤이면 코드를 이해했겠지?

 

난 이미지 도메인을 다룰 게 아니어서 그냥 적당히 구경하고 넘어가서 바로 짤거다.

 

github.com/nocotan/born_again_neuralnet.git

 

nocotan/born_again_neuralnet

Unofficial pytorch implementation of Born-Again Neural Networks. - nocotan/born_again_neuralnet

github.com

 

Readme를 보니까 train.py, inference.py 를 돌리는구나. train.py를 보러가자.

 

train.py

updater를 봐야하는구나. 깔끔해서 딴건 어려운 게 없어보인다!

for문에 gen 변수값을 늘려가며 루프를 돌리는걸 봐서, 이게 student 만드는 횟수인 것 같다.

 

updater.py

파라미터는 이렇게 세개.

모델, last 모델, optimizer가 있다. BAN에서 teacher-student 모델은 동일하다. 그래서 모델 하나 불러와서 파라미터로 전달.

그리고 n_gen과 gen이 있는데 이건 코드를 더 봐야겠다. n_gen은 main 실행할 때 디폴트가 3. 3번 student를 만드나?

update()

  1. 처음에는 따라할 teacher가 없으니 target만 사용. 이때 모델의 학습이 teacher가 되는 시나리오다. 걍 스킵하고 pretrained 불러와도 되겠지.
  2. 그 다음 사이클부터는 teacher를 사용하여 kd_loss()를 넣고 backprop한다.
    kd_loss()는 nn.KLDivLoss() + student의 CE loss로 정의되어있다. 난 딴거 쓸거다.

 

정리하면, (student) model 업뎃을 위한 로스.

그리고 teacher model은 updater.last_model이다.

register_last_model()

get_model()로 새로 만들어서 weight 저장한다.

이 코드에서는 걍 MLP를 인스턴스화해서 넣어놨다.

 

register_last_model()를 이 클래스 함수에서는 안 건드리니 다시 train.py를 보러가자.

student model의 에폭이 전부 돌고나면 아래 코드가 실행된다.

  • 기존 모델은 (loss가 가장 적은 에폭의 모델) last_model로 지정된다.
    이건 updater에서 모델 인스턴스 새로 만들어 학습된 weight을 할당해준다.
  • 모델을 새로 인스턴스화하여 student 모델로 지정.

 

뭐야 괜히 쫄았잖아? 왕간단하다.

 

그냥 원래 내꺼 비디오 코드에다가 updater만 하나 얹어야겠다. train 코드에 for 루프 하나 추가하고.

loss는 label 여부에 따라 여러가지 로스 적용해서 실험 해봐야겠다.

그리고 난 pretrained model 쓸거라 첨부터 트레이닝할필욘 없음. resume_gen = 1로 시작하면 됨.

 

이제 딱 1사이클만 돌려서 테이블을 뽑아보러 가자.

반응형

댓글