티스토리 뷰

반응형

W2V-BERT: Combining Contrastive Learning and Masked Language Modeling for Self-Supervised Speech Pre-Training


  • Masked Language Modeling을 self-supervised speech representation learning에 적용할 수 있음
  • W2V-BERT
    • Contrastive learning과 masked language modeling을 combine
    • 2가지의 self-supervised task를 end-to-end fashion으로 optimize
  • 논문 (ASRU 2021) : Paper Link

1. Introduction

  • Large-scale unannotated speech를 사용하여 Automatic Speech Recognition (ASR) task를 향상할 수 있음
    • 먼저 pseudo-labeling의 경우 available labeled data를 사용하여 teacher model을 training 한 다음, teacher model을 사용하여 unlabeled data를 labeling 함
    • Wav2Vec과 같은 Unsupervised/Self-Supervised pre-training의 경우, unlabeled data만 사용하는 proxy task를 complete 하도록 model을 training 함
      1. 해당 proxy task는 supervised data에 대해 training 되기 전에 good starting point를 제공함
      2. 특히 Unsupervised/Self-Supervised 방식은 downstream ASR task에 대해 우수한 성능을 달성할 수 있음

-> 그래서 ASR task의 unsupervised pre-training을 더욱 향상할 수 있는 W2V-BERT를 제안

 

  • W2V-BERT
    • Wav2Vec 2.0contrastive task를 사용하여 discriminative, discretized speech unit에 대한 finite set을 생성
    • 이후 BERT의 Masked Language Modeling (MLM)과 같은 방식으로 contextualized speech representation learning을 수행
      - 이때 contrastive task와 masked prediction task를 simultaneously optimize 함

< Overall of W2V-BERT >

  • Wav2Vec 2.0과 BERT를 combine 한 end-to-end self-supervised speech representation learning framework
  • 결과적으로 기존보다 뛰어난 성능을 달성

2. Method

- Model Architecture

  • W2V-BERT는 다음과 같이 구성됨
    • Raw acoustic input에서 latent speech representation을 추출하는 feature encoder
    • Wav2Vec 2.0의 contrastive task를 solve 하여 discretized speech token set을 얻는 contrastive module
    • Contextualized speech representation을 학습하기 위한 masked prediction module
  • Feature Encoder
    • Feature encoder는 2개의 2D convolution layer로 구성된 convolutional sub-sampling block을 활용함
      - 이때 두 layer 모두 stride가 $(2,2)$로 설정되므로 acoustic input sequence length는 $4\times$ reduce 됨
    • Log-mel spectrogram을 input으로 주어지면, feature encoder는 subsequent contrastive module에 input 될 latent speech representation을 추출함 
  • Contrastive Module
    • Contrastive module은 linear projection layer, conformer block stack으로 구성됨
      - 각각은 multi-head self-attention, depth-wise convolution, feed-forward layer를 가짐
    • Contrastive module은 feature encoder output을 representative speech unit의 finite set으로 discretize 하기 위해 quantization mechanism을 활용함 
      1. 결과적으로 feature encoder output은 linear projection layer로 전달되고, masking 이후 conformer block stack을 통과하여 context vector를 생성한 다음,
      2. Masking 없이 quantizer에 전달하여 quantized vector와 assigned token ID를 생성함
    • Quantized vector는 masked position에 해당하는 context vector와 함께 Wav2Vec 2.0의 contrastive task를 solve 하고 contrastive module을 optimize 함
      - Assigned token ID는 subsequent masked prediction module에서 prediction target으로 사용됨
  • Masked Prediction Module
    • Masked prediction module은 각 block이 contrastive module의 configuration과 identical 한 conformer block stack으로 구성됨
    • 해당 module은 contrastive module에 의해 생성된 contrastive vector를 사용하여 high-level contextualized speech representation을 추출함 

- Pre-Training

  • Pre-training 시에는 unlabeled data만 사용됨
  • Contrastive Loss
    • Contrastive loss는 quantizer와 함께 contrastive module을 train 하는 데 사용됨
      1. 이를 통해 contrastive module은 subsequent masked prediction module의 input인 adequate context vector를 생성하고,
      2. Quantizer는 masked prediction module의 prediction target으로 사용될 discriminative discretized speech token을 생성함
        - 이때 논문은 Wav2Vec 2.0의 contrastive task와 quantization mechanism을 채택함
    • Feature encoder가 raw acoustic input을 latent speech representation으로 변환한 다음, 논문은 some time step을 randomly select 하여 mask 함
      1. Wav2Vec 2.0에서는 masked position의 latent vector가 shared learnable feature vector로 replace 되지만, W2V-BERT에서는 random vector로 replace 함
      2. Masked feature encoder output은 contrastive module로 전달되어 context vector를 생성하고, feature encoder output은 masking 없이 quantizer에 전달되어 quantized vector를 생성함
      3. 즉, masked time step $t$의 context vector $c_{t}$에 대해, model은 same utterance의 other masked time step의 uniformly sampled $K$ distractor set $\{\tilde{q}_{1},\tilde{q}_{2},...,\tilde{q}_{K}\}$ 중에서 true quantized vector $q_{t}$를 identify 함 
    • 결과적으로 얻어지는 final contrastive loss는:
      (Eq. 1) $\mathcal{L}_{c}=\mathcal{L}_{w}+\alpha\cdot \mathcal{L}_{d}$
      - $\mathcal{L}_{w}$ : contrastive loss, $\mathcal{L}_{d}$ : codebook diversity loss, $\alpha=0.1$
  • Masked Prediction Loss
    • Contrastive module에서 생성된 context vector는 masked prediction task를 수행하는 데 사용되는 final context vector를 생성하기 위해 masked prediction module에 directly pass 됨
      - 이때 module의 last conformer block에는 softmax layer가 append 됨
    • Final layer의 context vector가 masked position에 해당하는 경우, softmax layer는 context vector를 input으로 하여 해당하는 token ID를 predict 함
      - 해당 token ID는 contrastive module에서 quantizer에 의해 assign 됨
    • Masked prediction task에 대한 cross-entropy loss를 $\mathcal{L}_{m}$이라고 하면, final training loss는:
      (Eq. 2) $\mathcal{L}_{p}=\beta\cdot \mathcal{L}_{c}+\gamma \cdot\mathcal{L}_{m}$
      - $\beta=\gamma=1$

W2V-BERT Pre-Training

- Fine-Tuning

  • Pre-trained W2V-BERT를 labeled data를 통해 fine-tuning 한 다음, LibriSpeech, Voice Search task에 적용함
    • 먼저 ASR network는 pre-trained W2V-BERT model과 LSTM decoder로 구성된 sequence transducer와 같음
    • 이때 pre-trained W2V-BERT model과 LSTM decoder 사이에 Swish activation과 batch normalization을 포함한 linear layer를 insert 하여 projection block으로 사용함

W2V-BERT Model Parameter

3. Experiments

- Settings

- Results

  • 전체적으로 W2V-BERT의 성능이 가장 우수함

Model 성능 비교

  • Necessity of Contrastive Module
    • Contrastive module이 없는 경우 feature encoder output이 masked prediction module에 directly fed 됨
      - 이때 해당 module은 useful representation을 학습하지 않고도 masked prediction task를 수행할 수 있음
    • 실제로 아래 그림의 (a)와 같이 contrastive module이 없는 W2V-BERT의 masked prediction loss는 early stage에서 $0$으로 quickly decrease 함
      - 관련하여 (b)의 해당 stage에서 model은 $100\%$의 prediction accuracy를 달성함
    • 한편으로 (c)의 diversity loss는 $1$에 quickly increase 하므로 code collapse가 나타남

Training Loss 비교

  • Impact of Contrastive Module Capacity
    • Contrastive module을 enlarging 할수록 better representation을 얻을 수 있음

Module Capacity 비교

  • Voice Search Traffic
    • Real-world audio traffic 측면에서도 W2V-BERT는 뛰어난 성능을 달성함

Voice Search Data에 대한 성능

 

반응형
댓글
최근에 올라온 글
최근에 달린 댓글
«   2025/09   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
Total
Today
Yesterday