티스토리 뷰
Paper/Representation
[Paper 리뷰] W2V-BERT: Combining Contrastive Learning and Masked Language Modeling for Self-Supervised Speech Pre-Training
feVeRin 2025. 5. 26. 17:42반응형
W2V-BERT: Combining Contrastive Learning and Masked Language Modeling for Self-Supervised Speech Pre-Training
- Masked Language Modeling을 self-supervised speech representation learning에 적용할 수 있음
- W2V-BERT
- Contrastive learning과 masked language modeling을 combine
- 2가지의 self-supervised task를 end-to-end fashion으로 optimize
- 논문 (ASRU 2021) : Paper Link
1. Introduction
- Large-scale unannotated speech를 사용하여 Automatic Speech Recognition (ASR) task를 향상할 수 있음
- 먼저 pseudo-labeling의 경우 available labeled data를 사용하여 teacher model을 training 한 다음, teacher model을 사용하여 unlabeled data를 labeling 함
- Wav2Vec과 같은 Unsupervised/Self-Supervised pre-training의 경우, unlabeled data만 사용하는 proxy task를 complete 하도록 model을 training 함
- 해당 proxy task는 supervised data에 대해 training 되기 전에 good starting point를 제공함
- 특히 Unsupervised/Self-Supervised 방식은 downstream ASR task에 대해 우수한 성능을 달성할 수 있음
- 해당 proxy task는 supervised data에 대해 training 되기 전에 good starting point를 제공함
-> 그래서 ASR task의 unsupervised pre-training을 더욱 향상할 수 있는 W2V-BERT를 제안
- W2V-BERT
- Wav2Vec 2.0의 contrastive task를 사용하여 discriminative, discretized speech unit에 대한 finite set을 생성
- 이후 BERT의 Masked Language Modeling (MLM)과 같은 방식으로 contextualized speech representation learning을 수행
- 이때 contrastive task와 masked prediction task를 simultaneously optimize 함
< Overall of W2V-BERT >
- Wav2Vec 2.0과 BERT를 combine 한 end-to-end self-supervised speech representation learning framework
- 결과적으로 기존보다 뛰어난 성능을 달성
2. Method
- Model Architecture
- W2V-BERT는 다음과 같이 구성됨
- Raw acoustic input에서 latent speech representation을 추출하는 feature encoder
- Wav2Vec 2.0의 contrastive task를 solve 하여 discretized speech token set을 얻는 contrastive module
- Contextualized speech representation을 학습하기 위한 masked prediction module
- Feature Encoder
- Feature encoder는 2개의 2D convolution layer로 구성된 convolutional sub-sampling block을 활용함
- 이때 두 layer 모두 stride가 $(2,2)$로 설정되므로 acoustic input sequence length는 $4\times$ reduce 됨 - Log-mel spectrogram을 input으로 주어지면, feature encoder는 subsequent contrastive module에 input 될 latent speech representation을 추출함
- Feature encoder는 2개의 2D convolution layer로 구성된 convolutional sub-sampling block을 활용함
- Contrastive Module
- Contrastive module은 linear projection layer, conformer block stack으로 구성됨
- 각각은 multi-head self-attention, depth-wise convolution, feed-forward layer를 가짐 - Contrastive module은 feature encoder output을 representative speech unit의 finite set으로 discretize 하기 위해 quantization mechanism을 활용함
- 결과적으로 feature encoder output은 linear projection layer로 전달되고, masking 이후 conformer block stack을 통과하여 context vector를 생성한 다음,
- Masking 없이 quantizer에 전달하여 quantized vector와 assigned token ID를 생성함
- Quantized vector는 masked position에 해당하는 context vector와 함께 Wav2Vec 2.0의 contrastive task를 solve 하고 contrastive module을 optimize 함
- Assigned token ID는 subsequent masked prediction module에서 prediction target으로 사용됨
- Contrastive module은 linear projection layer, conformer block stack으로 구성됨
- Masked Prediction Module
- Masked prediction module은 각 block이 contrastive module의 configuration과 identical 한 conformer block stack으로 구성됨
- 해당 module은 contrastive module에 의해 생성된 contrastive vector를 사용하여 high-level contextualized speech representation을 추출함
- Pre-Training
- Pre-training 시에는 unlabeled data만 사용됨
- Contrastive Loss
- Contrastive loss는 quantizer와 함께 contrastive module을 train 하는 데 사용됨
- 이를 통해 contrastive module은 subsequent masked prediction module의 input인 adequate context vector를 생성하고,
- Quantizer는 masked prediction module의 prediction target으로 사용될 discriminative discretized speech token을 생성함
- 이때 논문은 Wav2Vec 2.0의 contrastive task와 quantization mechanism을 채택함
- Feature encoder가 raw acoustic input을 latent speech representation으로 변환한 다음, 논문은 some time step을 randomly select 하여 mask 함
- Wav2Vec 2.0에서는 masked position의 latent vector가 shared learnable feature vector로 replace 되지만, W2V-BERT에서는 random vector로 replace 함
- Masked feature encoder output은 contrastive module로 전달되어 context vector를 생성하고, feature encoder output은 masking 없이 quantizer에 전달되어 quantized vector를 생성함
- 즉, masked time step $t$의 context vector $c_{t}$에 대해, model은 same utterance의 other masked time step의 uniformly sampled $K$ distractor set $\{\tilde{q}_{1},\tilde{q}_{2},...,\tilde{q}_{K}\}$ 중에서 true quantized vector $q_{t}$를 identify 함
- 결과적으로 얻어지는 final contrastive loss는:
(Eq. 1) $\mathcal{L}_{c}=\mathcal{L}_{w}+\alpha\cdot \mathcal{L}_{d}$
- $\mathcal{L}_{w}$ : contrastive loss, $\mathcal{L}_{d}$ : codebook diversity loss, $\alpha=0.1$
- Contrastive loss는 quantizer와 함께 contrastive module을 train 하는 데 사용됨
- Masked Prediction Loss
- Contrastive module에서 생성된 context vector는 masked prediction task를 수행하는 데 사용되는 final context vector를 생성하기 위해 masked prediction module에 directly pass 됨
- 이때 module의 last conformer block에는 softmax layer가 append 됨 - Final layer의 context vector가 masked position에 해당하는 경우, softmax layer는 context vector를 input으로 하여 해당하는 token ID를 predict 함
- 해당 token ID는 contrastive module에서 quantizer에 의해 assign 됨 - Masked prediction task에 대한 cross-entropy loss를 $\mathcal{L}_{m}$이라고 하면, final training loss는:
(Eq. 2) $\mathcal{L}_{p}=\beta\cdot \mathcal{L}_{c}+\gamma \cdot\mathcal{L}_{m}$
- $\beta=\gamma=1$
- Contrastive module에서 생성된 context vector는 masked prediction task를 수행하는 데 사용되는 final context vector를 생성하기 위해 masked prediction module에 directly pass 됨
- Fine-Tuning
- Pre-trained W2V-BERT를 labeled data를 통해 fine-tuning 한 다음, LibriSpeech, Voice Search task에 적용함
- 먼저 ASR network는 pre-trained W2V-BERT model과 LSTM decoder로 구성된 sequence transducer와 같음
- 이때 pre-trained W2V-BERT model과 LSTM decoder 사이에 Swish activation과 batch normalization을 포함한 linear layer를 insert 하여 projection block으로 사용함
3. Experiments
- Settings
- Dataset : LibriLight
- Comparisons : Wav2Vec 2.0, HuBERT
- Results
- 전체적으로 W2V-BERT의 성능이 가장 우수함
- Necessity of Contrastive Module
- Contrastive module이 없는 경우 feature encoder output이 masked prediction module에 directly fed 됨
- 이때 해당 module은 useful representation을 학습하지 않고도 masked prediction task를 수행할 수 있음 - 실제로 아래 그림의 (a)와 같이 contrastive module이 없는 W2V-BERT의 masked prediction loss는 early stage에서 $0$으로 quickly decrease 함
- 관련하여 (b)의 해당 stage에서 model은 $100\%$의 prediction accuracy를 달성함 - 한편으로 (c)의 diversity loss는 $1$에 quickly increase 하므로 code collapse가 나타남
- Contrastive module이 없는 경우 feature encoder output이 masked prediction module에 directly fed 됨
- Impact of Contrastive Module Capacity
- Contrastive module을 enlarging 할수록 better representation을 얻을 수 있음
- Voice Search Traffic
- Real-world audio traffic 측면에서도 W2V-BERT는 뛰어난 성능을 달성함
반응형
'Paper > Representation' 카테고리의 다른 글
댓글