티스토리 뷰
Paper/ASR
[Paper 리뷰] Whisper in Medusa's Ear: Multi-head Efficient Decoding for Transformer-based ASR
feVeRin 2025. 5. 22. 17:45반응형
Whisper in Medusa's Ear: Multi-head Efficient Decoding for Transformer-based ASR
- Large Transformer-based model은 self-attention mechanism으로 인해 computationally intensive 함
- Whisper-Medusa
- Whisper architecture를 extend 하여 iteration 마다 multiple token을 predict
- Word Error Rate에 대한 영향을 최소화하면서 latency를 50% 절감
- 논문 (ICASSP 2025) : Paper Link
1. Introduction
- Whisper와 같은 Transformer-based supervised model은 Automatic Speech Recognition (ASR)에서 우수한 transcription accuracy를 달성함
- BUT, large version 기준 1.5B의 parameter를 가지므로 slow inference speed의 한계가 있음
- 이를 해결하기 위해 knowledge distillation나 quantization을 적용할 수 있지만 Word Error Rate (WER)의 저하가 크게 발생함 - 한편으로 natural language processing에서 large model은 inference speed를 향상하기 위해 Speculative Decoding을 주로 활용함
- 이때 Speculative Decoding은 multiple decoding step을 parallel 하게 수행하여 computationally expansive operation의 수를 줄일 수 있음
- BUT, large version 기준 1.5B의 parameter를 가지므로 slow inference speed의 한계가 있음
-> 그래서 speculative decoding을 ASR에 적용한 Whisper-Meduas를 제안
- Whisper-Medusa
- Encoder-Decoder Transformer-based ASR architecture에서 Speculative Decoding을 extand 하여 decoding step 당 multiple token을 predict
- Encoder로부터 entire speech를 처리하는 Transformer decoder를 활용하여 decoder-only archcitecture 보다 더 나은 efficiency를 달성
< Overall of Whisper-Medusa >
- Speculative Decoding을 Whisper에 적용한 ASR model
- 결과적으로 ASR task에서 기존보다 빠른 추론 속도를 달성
2. Method
- Transformer-based ASR architecture는 encoder, decoder로 구성됨
- Encoder는 input audio waveform을 high-dimensional embedding sequence로 변환하고 decoder는 해당 embedding을 기반으로 sub-word unit의 token sequence를 생성함
- 이때 decoder는 autoregressive 하게 동작하고 한 번에 하나의 token을 predict 함
- 즉, 각 step에서 possible token의 entire set에 대한 probability distribution을 estimate 하고 most likely token을 select 함
- 해당 process는 end-of-sequence token이 생성될 때까지 반복됨
- 여기서 논문은 해당 decoder behavior를 modify 하는 것을 목표로, 각 iteration에서 single token을 predict 하지 않고 decoder가 $K+1$ token을 simultaneously predict 하도록 함
- 이를 통해 efficiency를 향상하고 output sequence에서 long-range dependency를 potentially capture 함
- Encoder는 input audio waveform을 high-dimensional embedding sequence로 변환하고 decoder는 해당 embedding을 기반으로 sub-word unit의 token sequence를 생성함
- $y\in\mathcal{Y}$를 token set $\mathcal{Y}$에서의 token, $\mathbf{y}_{<t}=(y_{0},y_{1},...,y_{t-1})$을 $0$-th token $y_{0}$부터 $t-1$-th token $y_{t-1}$ 까지의 token sequence라고 하자
- Decoder는 predicted token sequence $\hat{\mathbf{y}}_{<t}$와 encoder embedding $\mathbf{z}$를 고려하여 next token $y_{t}$의 probability distribution을 estimate 함
- 즉, $p(y_{t}|\hat{\mathbf{y}}_{<t},\mathbf{z})$ - 이때 논문은 각 token이 distribution에서 highest probability를 가지는 token을 select 하여 predict 하는 greedy decoding을 고려함:
(Eq. 1) $ \hat{y}_{t}=\arg\max_{y_{t}}p(y_{t}|\hat{\mathbf{y}}_{<t},\mathbf{z})$ - $K+1$ prediction head가 있다고 하면 base head는 original decoder prediction head에 해당함
- Base head는 preceding predicted token sequence $\hat{\mathbf{y}}_{<t}$를 condition으로하여 $t$-th token에 대한 token set의 probability distribution $p_{0}$를 생성함
- $k$-th head에 대한 probability distribution은 previous token sequence $\hat{\mathbf{y}}_{<t}$를 condition으로 하는 $y_{t+k}$ token에 해당함
- 즉, $p_{k}(y_{t+k}|\hat{\mathbf{y}}_{<t},\mathbf{z})$와 같음
- $K+1$ head를 사용하는 inference process는 token prediction과 verification의 two-phase로 구성됨
- First phase에서는 모든 $K+1$ head의 distribution을 estimate 함:
(Eq. 2) $p_{0}(y_{t}|\hat{\mathbf{y}}_{<t},\mathbf{z}),p_{k}(y_{t+k}|\hat{\mathbf{y}}_{<t},\mathbf{z})\,\,\, \text{for}\,\, 1\leq k\leq K$
- 이후 해당 probability를 maximize 하는 token을 select 하여 subsequent $K+1$ token set $\{\hat{y}_{t},\hat{y}_{t+1},...,\hat{y}_{t+K}\}$을 identify 함 - Second phase에서는 first phase에서 predict 된 token을 base head에 전달하고, resulted probability가 threshold 이상인 모든 head $0\leq k\leq K$를 select 함:
(Eq. 3) $p_{0}(y_{t+k}|\hat{\mathbf{y}}_{<t},\hat{y}_{t},...,\hat{y}_{t+k-1},\mathbf{z})>\min\{\epsilon,\alpha\tilde{p}_{\max}\}$
- $\epsilon=0.09, \alpha=0.3$, $\tilde{p}_{\max}$ : distribution $p_{0}$의 entropy function에 대한 exponent
- First phase에서는 모든 $K+1$ head의 distribution을 estimate 함:
- Decoder는 predicted token sequence $\hat{\mathbf{y}}_{<t}$와 encoder embedding $\mathbf{z}$를 고려하여 next token $y_{t}$의 probability distribution을 estimate 함
- Architecture 측면에서 논문은 다음 2가지를 고려함
- $K$ head를 가지는 Medusa-Linear
- 각 head는 residual connection이 있는 single linear layer와 shared vocabulary projection layer를 포함함
- 이때 논문은 final decoder layer와 head만 update 하고, cross-entropy (CE) loss를 각 head에 적용하여 base head와 additional $K$ head를 training 함
- Total loss는 individual loss의 average로 얻어지고, original ASR model의 probability distribution에 대한 weighted KL-divergence loss와 함께 combine 됨
- 모든 $K$ head에 대해 share 되는 additional decoder block을 포함하는 Medusa-Block
- Block 다음에는 single linear layer와 각 head에 대한 residual connection이 추가되고 output은 shared vocabulary projection layer로 전달됨
- 해당 architecture는 모든 ASR model weight가 froze 되고 Medusa weight만 update 됨
- 즉, base head는 train 되지 않고 KL loss 없이 $K$ weighted CE loss만을 사용하여 training 됨
- $K$ head를 가지는 Medusa-Linear
3. Experiments
- Settings
- Dataset : LibriSpeech, Voxpopuli
- Comparisons : Whisper
- Results
- Whisper-Medusa는 최대 1.48배의 speed up이 가능함
- LibriSpeech에 대해서도 1.40배 이상의 speed up이 가능함
- Self-Supervised
- Self-Supervised setup에서 Czech, Finnish, Dutch dataset에 대해 우수한 성능을 보임
- 특히 target sequence length가 길어질수록 속도가 향상됨
- Ablation
- Head 수가 많아질수록 Whisper-Medusa 성능은 저하됨
반응형
'Paper > ASR' 카테고리의 다른 글
댓글