티스토리 뷰
Paper/Representation
[Paper 리뷰] Balanced-Wav2Vec: Enhancing Stability and Robustness of Representation Learning through Sample Reweighting Techniques
feVeRin 2025. 6. 12. 17:27반응형
Balanced-Wav2Vec: Enhancing Stability and Robustness of Representation Learning through Sample Reweighting Techniques
- Self-Supervised Learning model은 mode collapse, dimension collapse로 인해 expressiveness가 떨어짐
- Balanced-Wav2Vec
- Over-represented mode의 emergence를 suppress 하는 balanced-infoNCE loss를 도입
- Wav2Vec 2.0의 highly-skewed codebook distribution을 방지하고 stable convergence를 지원
- 논문 (INTERSPEECH 2024) : Paper Link
1. Introduction
- Self-Supervised Learning (SSL)은 unlabeled data를 pre-training에 활용하여 labeled data에 대한 의존성을 줄이고 small labeled data 만으로도 fine-tuning이 가능함
- 특히 speech SSL의 경우, Wav2Vec 2.0, VQ-Wav2Vec과 같은 contrastive learning, HuBERT, WavLM과 같은 discrete hidden unit prediction, Data2Vec과 같은 continuous representation prediction을 고려할 수 있음
- 한편으로 SSL model의 성능을 향상하기 위해서는 다음의 condition을 만족해야 함:
- Alignment
- Input space의 slight perturbation이 output space에 대한 dynamic change를 발생시키지 않아야 함
- 즉, representation model은 positive pair를 feature space의 proximate point로 mapping 해야 함 - Diversity/Uniformity
- Feature vector는 feature space에 evenly distribute 되어야 함
- Alignment
- BUT, 대부분의 SSL representation model은 feature space에서 mode의 subeset에만 fit 하는 mode collapse, dimension collapse 문제를 가짐
- 이로 인해 model의 expressiveness, computational efficiency, generalization이 저하됨
-> 그래서 SSL의 mode collapse 문제를 해결한 Balanced-Wav2Vec을 제안
- Balanced-Wav2Vec
- Codebook distribution의 high-skewness가 mode collapse에 관여한다는 것을 검증
- Over-represented mode의 emergence를 suppress 하기 위해 balanced-infoNCE loss를 도입
< Overall of Balanced-Wav2Vec >
- Balanced-infoNCE loss를 Wav2Vec 2.0에 적용한 speech SSL model
- 결과적으로 기존보다 뛰어난 성능을 달성
2. Background
- Model Structure of Wav2Vec 2.0
- Wav2Vec 2.0은 raw waveform $X\in\mathcal{X}$를 input으로 사용함
- 이를 기반으로 length $T$의 latent speech representation $Z=[z_{1},...,z_{T}]\in\mathcal{Z}$, quantized representation $Q=[q_{1},...,q_{t}]\in\mathcal{Q}$, contextual representation $C=[c_{1},...,c_{T}]\in\mathcal{C}$를 생성함
- 구조적으로는 3가지 module로 구성됨:
- Feature encoder $f:\mathcal{X}\rightarrow \mathcal{Z}$는 convolutional neural network로 구성되어 raw waveform의 length를 reduce 하기 위해 subsampling을 수행하고 local feature를 추출함
- Quantization module $h:\mathcal{Z}\rightarrow \mathcal{Q}$는 Gumbel softmax를 사용하여 latent speech representation을 quantize 함
- Context network $g:\mathcal{Z}\rightarrow \mathcal{C}$는 Transformer를 사용하여 latent speech representation의 sequential dependency를 고려한 contextual representation을 생성함
- Training Strategy of Wav2Vec 2.0
- Wav2Vec 2.0은 contrastive loss를 사용하는 mask prediction method로 training 됨
- 이때 다음 objective를 minimize 함:
(Eq. 1) $\mathcal{L}_{W2V2}=\mathcal{L}_{info}+\alpha_{1}\mathcal{L}_{div}+\alpha_{2}\mathcal{L}_{L2}$
- $\mathcal{L}_{info}$ : contrastive loss에 해당하는 infoNCE loss
- $\mathcal{L}_{div}$ : auxiliary diversity loss
- $\mathcal{L}_{L2}$ : $L2$ loss, $(\alpha_{1},\alpha_{2})$ : tuning factor - Wav2Vec 2.0의 objective function은 infoNCE loss를 사용하여 positive pair $(q_{t},c_{t})$의 mutual information을 maximize 하고 diversity loss를 사용하여 codebook diversity를 increase 함
- 이때 다음 objective를 minimize 함:
- InfoNCE Loss
- Speech signal의 continuous nature와 adjacent sample 간의 correlation으로 인해 mode collapse에 susceptible 함
- 이때 infoNCE loss는 same utterance 내에서 negative, positive sample을 모두 sampling 하여 in-utterance diversity를 향상함:
(Eq. 2) $\mathcal{L}_{info}(q_{t},c_{t})=-\log \frac{\exp\left(\text{sim}(q_{t},c_{t})/\kappa\right)}{\sum_{\tilde{q}\in Q_{t}}\exp\left(\text{sim}(\tilde{q}_{t},c_{t})/\kappa\right)} $
- $\text{sim}(\cdot, \cdot)$ : cosine-similarity, $\kappa$ : tuning factor
- $Q_{t}$ : 하나의 positive vector $q_{t}$와 $K$ negative quantized vector로 구성된 set - Negative quantized vector는 utterance 내의 masked time step에서 uniformly sample 됨
- Diversity Loss and Others
- Contrastive loss 만으로는 sufficient diversity를 달성하기 어려움
- 따라서 codebook distribution의 negative entropy를 나타내는 diversity loss를 contrastive loss와 함께 사용하여 codebook diversity를 향상함:
(Eq. 3) $\mathcal{L}_{div}=\frac{1}{GV}\sum_{i=1}^{G}\sum_{j=1}^{V}\hat{p}_{i,j}\log \hat{p}_{i,j}$
- $G$ : group 수, $V$ : codebook entry 수 - $\alpha_{1}$을 증가시키면 codebook entropy가 증가하지만, Wav2Vec 2.0의 overall loss landscape 측면에서 diversity loss의 dominance가 증가하면 stability가 감소함
3. Method
- Problem Setup
- Stochastic gradient descent는 $\mathcal{L}_{W2V2}$를 minimize 하는 parameter를 찾기 위해 사용됨
- 여기서 model은 training set $\mathcal{D}=\{(X_{i})\}_{i=1}^{M}$을 사용하여 empirical risk $\hat{R}(\mathcal{D})$를 minimize 하도록 training 됨
- $T_{i}$ : $X_{i}$에서 derive 된 quantized vector sequence $Q_{i}$의 length, $M$ : mini-batch size - BUT, Empirical risck는 over-represented mode에 higher weight를 assign 하므로 codebook diversity가 감소할 수 있음
- Mini-batch에서 얻어진 empirical risk는 $\mathcal{V}$ 내의 individual codebook entry에 대한 expectation으로 주어짐:
(Eq. 4) $\hat{R}(\mathcal{D})=\mathbb{E}_{\mathcal{D}}\left[\mathcal{L}_{info}(q_{i,t},c_{i,t})\right] = \mathbb{E}_{v\in\mathcal{V}}\mathbb{E}_{(i,t)\in\mathcal{D}_{v}}\left[\mathcal{L}_{info}(q_{i,i},c_{i,t})\right]$
$\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,=\sum_{v=1}^{V}p(v)\sum_{(i,t)\in\mathcal{D}_{v}}p(i,t|q_{i,t}=v)\mathcal{L}_{info}(q_{i,t},c_{i,t})$
- $\mathcal{D}_{v}=\cup_{i=1}^{M}\left\{(i,t)|q_{i,t}=v,\forall t\in[T_{i}] \right\}$
- $[T_{i}]=\{1,2,...,T_{i}\}, [M]=\{1,2,...,M\}$ - Wav2Vec 2.0의 training scenario에서 $\hat{p}(v)$와 $\hat{p}(i,t|q_{i,t}=v)$는 frequentist approach를 통해 mini-batch level에서 estimate 됨:
(Eq. 5) $\hat{p}(v)=\frac{N_{v}}{N}\,\,\,\text{and}\,\,\,\hat{p}(i,t|q_{i,t}=v)=\frac{1}{N_{v}}$
- $N_{v}=\sum_{i\in[M]}\sum_{t\in[T_{i}]}\mathbf{1}(q_{i,t}=v),N=\sum_{i\in[M]}T_{i}$
- $\mathbf{1}(\cdot)$ : indicator function - (Eq. 4)의 $\hat{p}(v)$와 $\hat{p}(i,t|q_{i,t}=v)$를 substitute 하면:
(Eq. 6) $\hat{R}(\mathcal{D})=\sum_{v=1}^{V}\frac{N_{v}}{N}\sum_{(i,t)\in\mathcal{D}_{v}}\frac{1}{N_{v}} \mathcal{L}_{info}(q_{i,t},c_{i,t})=\frac{1}{N}\sum_{(i,t)\in\mathcal{D}}\mathcal{L}_{info}(q_{i,t},c_{i,t})$
- Mini-batch에서 얻어진 empirical risk는 $\mathcal{V}$ 내의 individual codebook entry에 대한 expectation으로 주어짐:
- Original Wav2Vec은 (Eq. 6)과 같이 각 iteration에서 model parameter를 update 하여 estimated empirical risk를 estimate 하지만, 다음의 한계점이 존재함:
- 아래 그림과 같이 over-represented mode는 $\hat{p}(v)$가 highly skew 되게 만듦
- $\hat{p}(v)$는 in-utterance sample 간의 correlation이 높은 mini-batch의 limited size로 estimate 됨
- $\hat{p}(v)$의 high variance는 model training의 stability를 저해하고 high skewness는 codebook distribution에서 over-represented mode로 probability mass를 shift 함
- 결과적으로 $\hat{p}(v)$의 high skewness는 iterative parameter update를 통해 mode collapse를 intensify 하므로 논문은 이를 방지하기 위해 $\hat{p}(v)$를 smoothing함
- 여기서 model은 training set $\mathcal{D}=\{(X_{i})\}_{i=1}^{M}$을 사용하여 empirical risk $\hat{R}(\mathcal{D})$를 minimize 하도록 training 됨
- Balanced-infoNCE Loss
- Mini-batch에서 estimate 된 codebook distribution $\hat{p}(v)$를 smooth 하기 위해 $0\leq \tau \leq 1$의 smoothing factor를 사용함
- 이를 통해 over-represented mode의 emergence를 suppress 할 수 있음:
(Eq. 7) $\hat{p}(v;\tau)=\left(\frac{N_{v}}{N}\right)^{\tau}$ - (Eq. 4)의 $\hat{p}(v;\tau)$와 $\hat{p}(i,t|q_{i,t}=v)=\frac{1}{N_{v}}$를 substituting 함으로써 empirical risk를 recalculate 할 수 있음:
(Eq. 8) $\hat{R}(\mathcal{D})=\sum_{v=1}^{V}\left(\frac{N_{v}}{N}\right)^{\tau}\sum_{(i,t)\in\mathcal{D}_{v}} \frac{1}{N_{v}}\mathcal{L}_{info}(q_{i,t},c_{i,t})$
$\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,=\frac{1}{N} \sum_{(i,t)\in\mathcal{D}}\left(\frac{N_{v}}{N}\right)^{\tau-1}\mathcal{L}_{info}(q_{i,t},c_{i,t})=\frac{1}{N}\sum_{(i,t)\in\mathcal{D}}\mathcal{L}_{B}$ - Balanced-infoNCE loss $\mathcal{L}_{B}$는 infoNCE loss $\mathcal{L}_{info}$에 sample weight $\left(\frac{N_{v}}{N}\right)^{\tau-1}$을 multiply 한 것과 동일함
- $\tau=1$일 때 balanced-infoNCE loss는 original infoNCE loss와 동일함
- $\tau=0$일 때 sample weight는 mini-batch 내에서 class를 share 하는 sample frequency에 대해 inversely decrease 하므로 over-represented model에 smaller weight가 assign 됨
- 여기서 논문은 group 수를 1로 설정하고, 2개 이상인 경우 group의 average sample weight를 사용함
- 이를 통해 over-represented mode의 emergence를 suppress 할 수 있음:
- Interpretation
- Balanced-infoNCE loss는 sparsely occurring mode에 higher weight를 assign 함
- 특히 balanced-infoNCE loss는 codebook distribution에 대해 non-informative prior (uniform distirbution)을 가정하는 Bayesian approach로 볼 수 있음
- 이때 codebook distribution에 대한 prior knowledge가 부족하므로 insufficient reason principle에 따라 uniform distribution으로 설정하는 것은 reasonable 함
4. Experiments
- Settings
- Dataset : LibriSpeech, CommonVoice 5.1
- Comparisons : Wav2Vec 2.0
- Results
- In-domain dataset (LibriSpeech)에 대해 Balanced-Wav2Vec의 성능이 가장 우수함
- Out-of-domain dataset에 대해서도 Balanced-Wav2Vec이 가장 뛰어난 성능을 달성함
반응형
'Paper > Representation' 카테고리의 다른 글
댓글