[논문 리뷰] TPT :: Test-Time Prompt Tuning for Zero-Shot Generalization in Vision-Language Models
논문 Test-Time Prompt Tuning for Zero-Shot Generalization in Vision-Language Models을 읽고, 최대한 다른 분들이 보셨을 때 이해가 잘 되도록 요약 및 핵심 내용 설명 위주로 Paper Review를 진행할 예정입니다.
논문에서의 순서를 일정하게 지키는 것보다, 각 목차에서 필요할 것 같다 생각하는 부분을 먼저 언급하거나, 뒷 부분에서 언급할 수 있습니다. 보시고 피드백도 자유롭게 주셨으면 좋겠습니다!
⏰ Abstract & Introduction
CLIP과 같은 VL pre-training 기술의 발전은 image-text 쌍에 대한 훈련을 통해 visual concept을 인코딩할 수 있다. 그렇기에 특정 작업 데이터 없이도 downstream task에 zero shot으로 적용할 수 있었다.
그러나 이때 사용되는 hand-crafted prompts는 도메인에 특화된 경험이 필요하며, 특히 이렇게 만든 prompt가 최적(optimal)이 아닐 수도 있다는 문제점이 있다.
그래서 요즘은 prompt를 직접 학습시키는 prompt tuning 방법을 많이 제안하고 있다. (CoOp, CoCoOp)
TPT가 나오기 이전의 Prompt Tuning
prompt embedding을 모델 입력의 한 부분으로 사용하고, 모델의 parameter를 조정하는 것처럼, prompt를 fine-tuning하는 방식이다. Hand-crafted prompt 방식보다 좋은 성능을 보였다. 그러나 이 방법들에도 한계점이 있다.
- prompt는 training data의 분포와 task에 국한되어 있기 때문에 일반화에 한계가 있다.
- training data에 annotation이 필요하다. 이는 비용이 들고, zero-shot task에도 적절하지 않다.
Their Approach & Result Preview 🌟핵심🌟
주어진 test sample(단일 샘플임)만을 사용하여, 즉석에서 prompt를 조정하고, 각 작업에 적응되어 어떠한 task-specific training dataset이나 annotation 없이도 zero-shot 일반화를 하도록 한다.
논문의 전반적인 내용이 궁금하거나, 이해하기 어려울 때 이렇게 approach 부분을 보면 도움이 많이 되는 것 같다.이 부분은 Sorn Chottananurak의 유튜브 영상을 참고하면 아주 좋을 것 같다. (2:22 ~) 본 논문의 TPT 적용 내용을 augmented view를 생성하는 것부터 confidence selection을 하는 부분까지 시각적으로 매우 설명이 잘 되어 있는 것 같다.
➡️ Image Classification Task에서의 TPT zero-shot 일반화
핵심은 confidence selection이다. 단일 샘플에 대해 여러 무작위 증강 뷰(random augmented view)를 생성하고, 다양한 증강 뷰에 걸쳐 일관된 예측을 할 수 있도록 text prompt를 최적화한다. 이때 증강 뷰 중 high prediction entropy인 view들을 지워 confidence selection을 진행한다.
high prediction entropy : entropy는 불명확함을 나타내는 단어로, high prediction entropy는 말그대로 ‘높은 불명확성 예측’, 즉 low confidence를 의미한다. low confidence인 view는 image의 중요한 정보가 부족하기 때문에 entropy를 증가시킨다. 즉 이들을 제거하여 예측 확률을 높이는 것이다.
본 Task는 [1] Natural Distribution Shift와 [2] Cross Dataset Generalization에 대해 실험된다.
- Natural Distribution Shift (데이터 분포 변화 상황)
- zero-shot을 사용했을 때 최대 3.6%, few-shot을 사용했을 때 최대 6.9%까지 성능이 향상되었다.
- 추가적인 data와 annotation 없이 SOTA를 달성하였다.
- Unseen Class에서의 Cross Dataset Generalization (데이터 교차 일반화)
- 역시 추가적인 data와 annotation 없이 SOTA를 달성하였다.
➡️ Context-dependent Visual Reasoning Task에서의 TPT
test sample : support image 두 세트와 평가를 위한 query image로 이루어져 있다.
두 세트의 support image는 HOI 개념의 존재(presence)와 부재(absence)를 의미하는 것이다. 이때 support image는 추가적인 데이터로 간주되지 않는다. (training data가 아니기 때문이다.)
HOI(Human-Object Interaction) : 사람과 물체 사이의 상호작용을 의미한다. 에시로 ‘자전거 타기’가 있다. 존재의 경우, 사람이 자전거를 타고 있어 인간과 객체(자전거)의 상호작용이 보이는 사진이다. 부재의 경우, 자전거만 있는 사진이나 사람이 자전거 옆에 서 있지만 타지 않는 모습을 예시로 들 수 있다.
이러한 query image를 더 잘 분류하는 데에 TPT가 사용된다.
- SOTA보다 4.1% 능가하는 성능을 보였다.
- 역시 추가적인 data와 annotation을 사용하지 않았다.
Contributions
- 어떠한 training data와 annotation를 사용하지 않고, prompt를 최적화하는 TPT를 제안하였다. 이는 zero-shot 방식으로 단일 test sample에 prompt tuning을 수행하는 최초의 연구이다.
- TPT 방식으로 confidence selection 방식을 채택하였다. low-confidence의 예측을 유발하는 증강 뷰를 필터링 하여 ‘엔트로피 최소화’를 향상시켰다.
- natural distribution shift, cross-dataset generalization, context-dependent visual reasoning 같은 경우에서 광범위한 실험을 진행했고, 기존의 zero-shot CLIP의 성능을 뛰어 넘었다.
⏰ Related Works
Prompt Tuning의 기존 연구들
- CoOp : training data에서 prompt를 tuning하면서, downstream task에 대해 좋은 성능을 보였다.
- CoCoOp : CoOp의 문제점이었던
out-of-distribution data
에 대해 일반화 성능이 부족한 점을 보완하였다. 모델 input에 따라 조건화하여 문제를 해결하였다.
그러나 이들은 모두 training data와 annotation의 접근을 요구하여, zero-shot 지식 전달을 제한한다.
Generalization under data distribution shifts
신뢰할 수 있는 ML 모델들은 real-world의 데이터 분포 변화에서도 잘 수행되어야 한다. Pre-trained VL model의 경우, zero-shot 방식에서 다양한 데이터 분포 변화가 있는 downstream task에 잘 일반화된다. TPT는 이에 대해 다음과 같이 작용한다.
- TPT는 downstream task 또는 target dataset에 맞게 적용시킨다는 개념보다, CLIP 기본 모델 자체를 더 좋게 향상시키는 것이 목표이다.
- 네트워크의 출력이 작은 변동(augmentation)에도 불변하도록 만든다. (Consistency regularization based methods)
Test-time Optimization
말그대로 test time에 대한 최적화를 계산한다. TENT와 같이 배치별 예측 확률 분포의 entropy를 최소화하면서 특정 훈련 과정에 의존하지 않고 다양한 모델에 적용할 수 있다. 그러나 이의 문제점은 한 개 이상의 test sample이 필요하다는 것이다.
또 다른 challenge는 올바른 parameter group을 선택해야 한다는 것이다. BN Layer는 이미지 데이터에서 도메인 차이를 잡지만, 이를 통해 모델에 국한되어 좋지 않은 방안이 된다.
그러나 TPT는 이를 모두 해결한다.
- 모델을 그대로 유지(freeze)하면서, text prompt를 최적화한다. 즉, 사전 훈련된 feature를 왜곡하지 않고, pre-trained의 zero-shot generalization 능력을 보존한다.
- confidence selection 방식을 선택하여 단일 sample에 대한 해결을 하였다.
⏰ TPT: Test-Time Prompt Tuning
기존의 Prompt Tuning (using downstream training data)
fine tuning의 문제점 : domain-specific behaviors(도메인에 특호되어 있는) 결과를 초래하고, generalization 성능을 더 잃게 하여 견고성도 떨어진다.
그러나 prompt를 사용하면, 모델의 외부에서 작용하기 때문에 pre-trained feature를 왜곡하지 않아 성능을 유지할 수 있게 한다.
TPT의 Formula
위의 연산과 다르게, 어떠한 annotation과 training data를 사용하지 않고, prompt를 최적화하는 것을 볼 수 있다.
TPT for image classification
- Loss : 비지도 학습 손실 사용 (label이 없기 때문)
- 다양한 증강 뷰에 걸쳐 모델의 예측 일관성을 증진
- 방법
- 무작위 augmentation 집합 A를 활용하여 무작위 증강 뷰 생성
- entropy(불명확성)을 최소화 ⇒ confidence selection
high-entropy(low confidence) prediction을 생성하는 view들은 중요한 정보들이 부족하기 때문에 걸러내야 한다. 이때 prediction entropy < τ인 sample에 대해 살려둔다.
- 엔트로피를 기준으로 신뢰도가 높은 sample을 선택하는 방법: cutoff percentile ρ를 사용한다. 이는 정렬 후 상위 ρ만큼의 샘플들을 사용하는 것이다.
TPT for context-dependent visual reasoning
binary label로 결과가 도출되고, 수동적으로 label을 직접 붙여줬어야 하지만 TPT는 최적의 label token cls
를 직접 학습할 수 있다. 아래는 M개의 support image가 있을 때 context-dependent reasoning의 prompt p이다.
방법과 구조는 다음과 같다.
- positive에
1
, negative에0
label을 부여한다. - binary label token
cls={cls1, cls2}
와 prompt p를 동시에 조정한다. - 각 이미지마다 CLIP의 text input으로
T={T1, T2 | Ti={p, cls_i}}
를 조립하여 사용한다.
⏰ Experiment
[1] Robustness to Natural Distribution Shifts
- Dataset : 분포 외 데이터(OOD - out of distribution)로 간주되었던 ImageNet 변형체 사용
- prompt ensemble과 기존의 few-shot prompt tuning(CoOp, CoCoOp)들의 성능을 능가하는 것을 확인할 수 있다. 초록색
- CoOp과 CoCoOp에 TPT를 적용한 것이 더 높은 성능을 보여주었다. 노란색
- ❓ CoCoOp + TPT보다 CoOp + TPT의 성능이 더 좋은 이유에 대한 생각 : CoCoOp은 CoOp의 일반화 능력 부족을 보완하기 위해 고안된 것으로, 일반 class에 대한 성능을 분석해보면 CoOp > CoCoOp인 것을 볼 수 있다.
- 이때, TPT가 CoOp의 일반화 성능을 높여주기 때문에, CoCoOp보다 높았던 기본 class에서의 CoOp 성능을 유지할 수 있기 때문이라고 생각한다.
- 🗣️ 선배 분께서 공유해주신 의견
- 실험적으로 16개의 샘플만 뽑아서 활용하기 때문에, 정확하게 분석하기는 어려운 것 같다. 뽑힌 샘플의 이미지가 얼마나 유사했는지에 따라 이미지 기반으로 prompt를 뽑아내는 CoOpOp의 경우 성능 등락이 많이 변할 것 같다.
[2] Cross-Datasets Generalization
➡️ 실험 1. source는 ImageNet, target은 fine-grained dataset
- TPT는 zero-shot 방식이기 때문에 ImageNet을 학습하지 않는다.
- 그런데도 ImageNet으로 학습한 CoCoOp과 비슷한 성능을 내었다.
➡️ 실험 2[more challenging]. source는 specialized fine-grained dataset, target도 fine-grained dataset
source와 target의 중복이 없다. 다음은 baseline CLIP에 대해 얼마나 개선이 되었는지를 표시한 그래프이다.
- TPT를 제외하고는 CoOp과 CoCoOp 모두 성능이 오히려 감소한 것을 확인할 수 있다. (Cross-Dataset에서)
- 즉, TPT의 Cross-Dataset에 대한 일반화 능력이 뛰어난 것을 볼 수 있다. (양수)
[3] Context-dependent visual reasoning on Bongard-HOI
비교를 위한 세가지 방법
- CNN-baseline : Bongard-HOI training data로 훈련된 단순한 분류기
- 전체 sample들을 이진 output과 매핑하여 query image가 적합한 concept을 포함하고 있는지 확인
- Meta-baseline : Bongard-HOI의 샘플들을 few-shot task로 간주하고, meta 목표와 함께 훈련 데이터로 훈련 시켜, 새로운 task에 빠르게 모델이 적응할 수 있도록
- HOITrans : 이전 최고 방법
- 다양한 HOI 검출 벤치마크에서 최신 정확도를 달성하는 트랜스포머 기반 HOI 검출 모델
- 쿼리 이미지의 검출된 HOI를 support image와 비교함으로써 Bongard-HOI를 해결
TPT의 성능이 기존 가장 성능을 좋게 보였던 HOITrans보다 평균 4.13% 증가한 것을 볼 수 있다. 특히 unseen act / unseen obj에서도 65.48%의 성능을 보이는 것이 일반화 능력이 뛰어난 것을 확인할 수 있다.
⏰ Ablation Study
[1] Test-time Optimization
parameter group을 대상으로 test-time optimization 연구를 진행한 결과, Prompt를 최적화 하는 것이 가장 높은 성능 향상을 가져오는 것을 볼 수 있다.
- ❓ 반면, Visaul Encoder 최적화가 가장 낮은 성능을 보였는데, 이는 pre-trained feature가 왜곡되기 때문이라고 생각한다.
- 🗣️ 선배 분께서 공유해주신 의견
- text의 경우는 test-time시에 N개로 고정되어있어 variation이 크지 않지만, 이미지의 경우 샘플하나하나가 모두 달라 variation이 text에 비하여 매우 크기 때문에 학습이 어려운 경우라고 생각
[2] Confidence Selection의 효과
- Confidence selection의 사용 여부에 따라 성능을 확인
- Cutoff ρ-percentile에 따라 성능을 확인
[3] 효율성과 성능(정확도) trade off
- TPT의 효율성에 영향을 미치는 요소 : augmented view 수, 역전파에 의해 유발되는 runtime과 memory 사용량을 증가시키는 optimization step 수
- augmented view가 8 개일 때도, CLIP에 2% 이상의 성능 향상을 보인다.
- TPT step을 1에서 2로 증가시키면 0.4% 정도 증가를 하지만, 그 이상을 수행해도 유의미한 성능 향상은 없었다.
- 따라서 step을 1로 설정하여 위 실험들을 진행하였다.
⏰ Conclusion
Contribution
- 적응형 prompt를 단일 test sample로 즉석에서 학습할 수 있는새로운 prompt tuning 방법인 TPT(Test-time Prompt Tuning)을 개발하였다.
- Natural Distribution Shift에 대한 강인성과 Cross-Dataset의 일반화, Context-dependent visual reasoning에 대한 성능 향상을 입증하였다.
- Zero-shot generalization 능력 향상을 training data와 annotation 없이 이뤘다.
Limitation
test time의 prompt를 최적화할 때, 한 단계의 역전파가 필요하다. 또한 TPT는 단일 test sample에 대해 여러 augmented view를 생성하기 때문에 메모리 비용을 증가시킨다.
⏰ 본 논문에 대한 나의 생각
prompt tuning에 대한 논문들을 읽어왔지만, test-time에 zero-shot으로 즉각 처리할 수 있는 TPT는 너무 새로웠다. 인상 깊었던 부분은 Related Work 부분이었는데, zero-shot에 대한 고민과 test-time optimization에 대한 아이디어들을 합쳐 더 좋은 성능을 낼 수 있는 prompt tuning을 고안해낸 것이 존경스러웠다. 데이터 전처리로만 사용될 줄 알았던 augmentation이 test-time에서 단일 test sample을 여러 sample을 받아온 것처럼 사용할 수 있다는 것이 가장 신기하였다. 무언가를 새롭게 만들어내기 위해 단순히 있는 방법들을 합친 것이 아니라, 각각의 파트에 대해 새로운 관점으로 접근하는 것이 중요한 것 같다.
Leave a comment