티스토리 뷰

반응형

Balanced-Wav2Vec: Enhancing Stability and Robustness of Representation Learning through Sample Reweighting Techniques


  • Self-Supervised Learning model은 mode collapse, dimension collapse로 인해 expressiveness가 떨어짐
  • Balanced-Wav2Vec
    • Over-represented mode의 emergence를 suppress 하는 balanced-infoNCE loss를 도입
    • Wav2Vec 2.0의 highly-skewed codebook distribution을 방지하고 stable convergence를 지원
  • 논문 (INTERSPEECH 2024) : Paper Link

1. Introduction

  • Self-Supervised Learning (SSL)은 unlabeled data를 pre-training에 활용하여 labeled data에 대한 의존성을 줄이고 small labeled data 만으로도 fine-tuning이 가능함 
    • 특히 speech SSL의 경우, Wav2Vec 2.0, VQ-Wav2Vec과 같은 contrastive learning, HuBERT, WavLM과 같은 discrete hidden unit prediction, Data2Vec과 같은 continuous representation prediction을 고려할 수 있음
    • 한편으로 SSL model의 성능을 향상하기 위해서는 다음의 condition을 만족해야 함:
      1. Alignment
        - Input space의 slight perturbation이 output space에 대한 dynamic change를 발생시키지 않아야 함
        - 즉, representation model은 positive pair를 feature space의 proximate point로 mapping 해야 함
      2. Diversity/Uniformity
        - Feature vector는 feature space에 evenly distribute 되어야 함
    • BUT, 대부분의 SSL representation model은 feature space에서 mode의 subeset에만 fit 하는 mode collapse, dimension collapse 문제를 가짐
      - 이로 인해 model의 expressiveness, computational efficiency, generalization이 저하됨

-> 그래서 SSL의 mode collapse 문제를 해결한 Balanced-Wav2Vec을 제안

 

  • Balanced-Wav2Vec
    • Codebook distribution의 high-skewness가 mode collapse에 관여한다는 것을 검증
    • Over-represented mode의 emergence를 suppress 하기 위해 balanced-infoNCE loss를 도입

< Overall of Balanced-Wav2Vec >

  • Balanced-infoNCE loss를 Wav2Vec 2.0에 적용한 speech SSL model 
  • 결과적으로 기존보다 뛰어난 성능을 달성

2. Background

- Model Structure of Wav2Vec 2.0

  • Wav2Vec 2.0은 raw waveform $X\in\mathcal{X}$를 input으로 사용함
    • 이를 기반으로 length $T$의 latent speech representation $Z=[z_{1},...,z_{T}]\in\mathcal{Z}$, quantized representation $Q=[q_{1},...,q_{t}]\in\mathcal{Q}$, contextual representation $C=[c_{1},...,c_{T}]\in\mathcal{C}$를 생성함
    • 구조적으로는 3가지 module로 구성됨:
      1. Feature encoder $f:\mathcal{X}\rightarrow \mathcal{Z}$는 convolutional neural network로 구성되어 raw waveform의 length를 reduce 하기 위해 subsampling을 수행하고 local feature를 추출함
      2. Quantization module $h:\mathcal{Z}\rightarrow \mathcal{Q}$는 Gumbel softmax를 사용하여 latent speech representation을 quantize 함
      3. Context network $g:\mathcal{Z}\rightarrow \mathcal{C}$는 Transformer를 사용하여 latent speech representation의 sequential dependency를 고려한 contextual representation을 생성함

- Training Strategy of Wav2Vec 2.0

  • Wav2Vec 2.0은 contrastive loss를 사용하는 mask prediction method로 training 됨
    • 이때 다음 objective를 minimize 함:
      (Eq. 1) $\mathcal{L}_{W2V2}=\mathcal{L}_{info}+\alpha_{1}\mathcal{L}_{div}+\alpha_{2}\mathcal{L}_{L2}$
      - $\mathcal{L}_{info}$ : contrastive loss에 해당하는 infoNCE loss
      - $\mathcal{L}_{div}$ : auxiliary diversity loss
      - $\mathcal{L}_{L2}$ : $L2$ loss, $(\alpha_{1},\alpha_{2})$ : tuning factor
    • Wav2Vec 2.0의 objective function은 infoNCE loss를 사용하여 positive pair $(q_{t},c_{t})$의 mutual information을 maximize 하고 diversity loss를 사용하여 codebook diversity를 increase 함
  • InfoNCE Loss
    • Speech signal의 continuous nature와 adjacent sample 간의 correlation으로 인해 mode collapse에 susceptible 함
    • 이때 infoNCE loss는 same utterance 내에서 negative, positive sample을 모두 sampling 하여 in-utterance diversity를 향상함:
      (Eq. 2) $\mathcal{L}_{info}(q_{t},c_{t})=-\log \frac{\exp\left(\text{sim}(q_{t},c_{t})/\kappa\right)}{\sum_{\tilde{q}\in Q_{t}}\exp\left(\text{sim}(\tilde{q}_{t},c_{t})/\kappa\right)} $
      - $\text{sim}(\cdot, \cdot)$ : cosine-similarity, $\kappa$ : tuning factor
      - $Q_{t}$ : 하나의 positive vector $q_{t}$와 $K$ negative quantized vector로 구성된 set
    • Negative quantized vector는 utterance 내의 masked time step에서 uniformly sample 됨
  • Diversity Loss and Others
    • Contrastive loss 만으로는 sufficient diversity를 달성하기 어려움
    • 따라서 codebook distribution의 negative entropy를 나타내는 diversity loss를 contrastive loss와 함께 사용하여 codebook diversity를 향상함:
      (Eq. 3) $\mathcal{L}_{div}=\frac{1}{GV}\sum_{i=1}^{G}\sum_{j=1}^{V}\hat{p}_{i,j}\log \hat{p}_{i,j}$
      - $G$ : group 수, $V$ : codebook entry 수
    • $\alpha_{1}$을 증가시키면 codebook entropy가 증가하지만, Wav2Vec 2.0의 overall loss landscape 측면에서 diversity loss의 dominance가 증가하면 stability가 감소함

3. Method

- Problem Setup

  • Stochastic gradient descent는 $\mathcal{L}_{W2V2}$를 minimize 하는 parameter를 찾기 위해 사용됨
    • 여기서 model은 training set $\mathcal{D}=\{(X_{i})\}_{i=1}^{M}$을 사용하여 empirical risk $\hat{R}(\mathcal{D})$를 minimize 하도록 training 됨
      - $T_{i}$ : $X_{i}$에서 derive 된 quantized vector sequence $Q_{i}$의 length, $M$ : mini-batch size
    • BUT, Empirical risck는 over-represented mode에 higher weight를 assign 하므로 codebook diversity가 감소할 수 있음
      1. Mini-batch에서 얻어진 empirical risk는 $\mathcal{V}$ 내의 individual codebook entry에 대한 expectation으로 주어짐:
        (Eq. 4) $\hat{R}(\mathcal{D})=\mathbb{E}_{\mathcal{D}}\left[\mathcal{L}_{info}(q_{i,t},c_{i,t})\right] = \mathbb{E}_{v\in\mathcal{V}}\mathbb{E}_{(i,t)\in\mathcal{D}_{v}}\left[\mathcal{L}_{info}(q_{i,i},c_{i,t})\right]$
        $\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,=\sum_{v=1}^{V}p(v)\sum_{(i,t)\in\mathcal{D}_{v}}p(i,t|q_{i,t}=v)\mathcal{L}_{info}(q_{i,t},c_{i,t})$
        - $\mathcal{D}_{v}=\cup_{i=1}^{M}\left\{(i,t)|q_{i,t}=v,\forall t\in[T_{i}] \right\}$
        - $[T_{i}]=\{1,2,...,T_{i}\}, [M]=\{1,2,...,M\}$
      2. Wav2Vec 2.0의 training scenario에서 $\hat{p}(v)$와 $\hat{p}(i,t|q_{i,t}=v)$는 frequentist approach를 통해 mini-batch level에서 estimate 됨:
        (Eq. 5) $\hat{p}(v)=\frac{N_{v}}{N}\,\,\,\text{and}\,\,\,\hat{p}(i,t|q_{i,t}=v)=\frac{1}{N_{v}}$
        - $N_{v}=\sum_{i\in[M]}\sum_{t\in[T_{i}]}\mathbf{1}(q_{i,t}=v),N=\sum_{i\in[M]}T_{i}$
        - $\mathbf{1}(\cdot)$ : indicator function
      3. (Eq. 4)의 $\hat{p}(v)$와 $\hat{p}(i,t|q_{i,t}=v)$를 substitute 하면:
        (Eq. 6) $\hat{R}(\mathcal{D})=\sum_{v=1}^{V}\frac{N_{v}}{N}\sum_{(i,t)\in\mathcal{D}_{v}}\frac{1}{N_{v}} \mathcal{L}_{info}(q_{i,t},c_{i,t})=\frac{1}{N}\sum_{(i,t)\in\mathcal{D}}\mathcal{L}_{info}(q_{i,t},c_{i,t})$
    • Original Wav2Vec은 (Eq. 6)과 같이 각 iteration에서 model parameter를 update 하여 estimated empirical risk를 estimate 하지만, 다음의 한계점이 존재함:
      1. 아래 그림과 같이 over-represented mode는 $\hat{p}(v)$가 highly skew 되게 만듦
      2. $\hat{p}(v)$는 in-utterance sample 간의 correlation이 높은 mini-batch의 limited size로 estimate 됨
    • $\hat{p}(v)$의 high variance는 model training의 stability를 저해하고 high skewness는 codebook distribution에서 over-represented mode로 probability mass를 shift 함
      - 결과적으로 $\hat{p}(v)$의 high skewness는 iterative parameter update를 통해 mode collapse를 intensify 하므로 논문은 이를 방지하기 위해 $\hat{p}(v)$를 smoothing함

Codebook Distribution

- Balanced-infoNCE Loss

  • Mini-batch에서 estimate 된 codebook distribution $\hat{p}(v)$를 smooth 하기 위해 $0\leq \tau \leq 1$의 smoothing factor를 사용함
    • 이를 통해 over-represented mode의 emergence를 suppress 할 수 있음:
      (Eq. 7) $\hat{p}(v;\tau)=\left(\frac{N_{v}}{N}\right)^{\tau}$
    • (Eq. 4)의 $\hat{p}(v;\tau)$와 $\hat{p}(i,t|q_{i,t}=v)=\frac{1}{N_{v}}$를 substituting 함으로써 empirical risk를 recalculate 할 수 있음:
      (Eq. 8) $\hat{R}(\mathcal{D})=\sum_{v=1}^{V}\left(\frac{N_{v}}{N}\right)^{\tau}\sum_{(i,t)\in\mathcal{D}_{v}} \frac{1}{N_{v}}\mathcal{L}_{info}(q_{i,t},c_{i,t})$
      $\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,=\frac{1}{N} \sum_{(i,t)\in\mathcal{D}}\left(\frac{N_{v}}{N}\right)^{\tau-1}\mathcal{L}_{info}(q_{i,t},c_{i,t})=\frac{1}{N}\sum_{(i,t)\in\mathcal{D}}\mathcal{L}_{B}$
    • Balanced-infoNCE loss $\mathcal{L}_{B}$는 infoNCE loss $\mathcal{L}_{info}$에 sample weight $\left(\frac{N_{v}}{N}\right)^{\tau-1}$을 multiply 한 것과 동일함
      1. $\tau=1$일 때 balanced-infoNCE loss는 original infoNCE loss와 동일함
      2. $\tau=0$일 때 sample weight는 mini-batch 내에서 class를 share 하는 sample frequency에 대해 inversely decrease 하므로 over-represented model에 smaller weight가 assign 됨
    • 여기서 논문은 group 수를 1로 설정하고, 2개 이상인 경우 group의 average sample weight를 사용함

- Interpretation

  • Balanced-infoNCE loss는 sparsely occurring mode에 higher weight를 assign 함
    • 특히 balanced-infoNCE loss는 codebook distribution에 대해 non-informative prior (uniform distirbution)을 가정하는 Bayesian approach로 볼 수 있음
    • 이때 codebook distribution에 대한 prior knowledge가 부족하므로 insufficient reason principle에 따라 uniform distribution으로 설정하는 것은 reasonable 함

4. Experiments

- Settings

  • Dataset : LibriSpeech, CommonVoice 5.1
  • Comparisons : Wav2Vec 2.0

- Results

  • In-domain dataset (LibriSpeech)에 대해 Balanced-Wav2Vec의 성능이 가장 우수함

In-Domain Dataset 성능

  • Out-of-domain dataset에 대해서도 Balanced-Wav2Vec이 가장 뛰어난 성능을 달성함

Out-of-Domain Dataset 성능

 

반응형
댓글
최근에 올라온 글
최근에 달린 댓글
«   2025/07   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30 31
Total
Today
Yesterday