티스토리 뷰
Paper/Representation
[Paper 리뷰] STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models
feVeRin 2025. 8. 27. 17:00반응형
STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models
- Transformer-based Speech Self-Supervised Learning model은 large parameter size와 computational cost를 가짐
- STaR
- Speech temporal relation을 distilling 하여 Speech Self-Supervised Learning model을 compress
- 특히 speech frame 간의 temporal relation을 transfer 하여 lightweight student를 얻음
- 논문 (ICASSP 2024) : Paper Link
1. Introduction
- Transformer-based Speech Self-Supervised Learning (SSL) model은 상당한 computational resource를 요구함
- 이를 해결하기 위해 DistilHuBERT, FitHuBERT, ARMHuBERT, LightHuBERT 등은 Knowledge Distillation을 활용해 SSL model을 compress 함
- BUT, 해당 방식들은 다음의 한계점이 있음:
- Student의 weak representation capability를 neglect 하고 additional linear head를 도입해 complex teacher representation과 directly match 함
- 따라서 student에 적합한 alternative distillation objective가 필요함 - 기존 방식은 parameter size 외의 computaitonal cost를 결정할 수 없음
- 즉, computational overhead가 더 높아질 수 있음
- Student의 weak representation capability를 neglect 하고 additional linear head를 도입해 complex teacher representation과 directly match 함
-> 그래서 더 나은 SSL model compression을 위한 distillation method인 STaR를 제안
- STaR
- Lightweight student model을 얻기 위해 speech frame 간의 temporal relation을 capture 하는 distillation objective를 도입
- Computationally efficient 한 model을 구축하기 위해 distillation 시 additional parameter를 사용하지 않음
< Overall of STaR >
- Speech temporal relation을 distillation 하는 SSL compression strategy
- 결과적으로 기존 lightweight SSL model보다 우수한 성능을 달성
2. Method
- Speech SSL model은 speech frame을 각 time step에 대한 specific acoustic unit의 feature로 represent 할 수 있음
- 이는 SSL model이 pre-training에서 각각의 masked frame에 대한 cluster를 predict 하도록 학습되기 때문
- 즉, speech SSL model은 specific acoustic unit과 closely tie 된 frame 별 representation을 생성함 - BUT, 해당 tearcher representation을 directly learning 하는 것은 limited representation capacity를 가진 student에게는 over-constraint 함
- 따라서 논문은 speech frame 간의 relation에 focus 하여 teahcer knowledge를 flexible manner로 distill 하기 위해, 다음 3가지의 Speech Temporal Relation (STaR) distillation objective를 고려함
- Temporal relation을 distill 하는 Average Attention Map Distillation
- Layer-wise Temporal Gram Matrix (TGM) Distillation
- Intra-layer TGM Distillation
- 이는 SSL model이 pre-training에서 각각의 masked frame에 대한 cluster를 predict 하도록 학습되기 때문

- Average Attention Map Distillation
- Transformer의 attention map은 key, query 간의 temporal relation을 capture 함
- 이때 attention map의 각 entry는 sequence의 두 frame 간의 relationship level을 나타내므로 해당 map을 distill 하여 teacher로부터 temporal relation을 transfer 할 수 있음
- Multi-Head Self-Attention (MHSA)의 head $h$에 대한 key, query matrix $\mathbf{K}_{h},\mathbf{Q}_{h}\in\mathbb{R}^{d_{h}\times N}$이 주어졌을 때, 해당 attention map $\mathbf{A}_{h}\in\mathbb{R}^{N\times N}$은 $\mathbf{A}_{h}=\text{softmax}\left(\mathbf{Q}_{h}^{\top}\mathbf{K}_{h}/\sqrt{d_{h}}\right)$와 같음
- $d_{h}$ : head $h$에 대한 key matrix의 width, $N$ : sequence length - BUT, 기존에는 knowledge distillation을 위해 last Transformer layer의 모든 head를 활용하므로, speech SSL model에서 computational overhead가 발생함
- 따라서 논문은 각 Transformer layer의 모든 head에 대한 Averaged Attention Map을 사용함 - 그러면 각 Transformer layer에 대한 loss는 teacher $T$와 student $S$의 average attention map에 대한 Kullback-Leibler (KL) divergence로 얻어짐:
(Eq. 1) $ \mathcal{L}_{avg\text{-}attn}=\sum_{\ell=1}^{L}\sum_{t=1}^{N}D_{KL}\left( \left.\left.\frac{1}{H^{T}}\sum_{h=1}^{H^{T}}\mathbf{A}_{h,t}^{\ell,T}\right|\right| \frac{1}{H^{S}}\sum_{h=1}^{H^{S}}\mathbf{A}_{h,t}^{\ell,S}\right)$
- $H, L$ : total attention head 수, Transformer layer 수
- $\mathbf{A}_{h,t}^{\ell}$ : time step $t$에서 Transformer layer $\ell$의 attention head $h$에 대한 attention distribution
- Temporal Gram Matrix Distillation
- 논문은 attention map 외에도 stronger hint를 제공하기 위해, 각 Transformer layer output의 temporal relation에 대한 distillation objective를 고려함
- 이때 sample 간의 correlation은 sample representation 간의 inner product로 정의되는 Gram matrix로 나타낼 수 있음
- $\mathbf{F}\in\mathbb{R}^{d\times N}$을 channel width $d$의 representation이라고 하면, Gram matrix $\tilde{\mathbf{G}}$는:
(Eq. 2) $\tilde{G}_{i,j}=\sum_{k}F_{ik}F_{jk},\,\,\,\tilde{\mathbf{G}}\in\mathbb{R}^{d\times d}$
- $F_{\cdot k}$ : $k$-th time step의 representation - Gram matrix는 주로 data property를 represent 하기 위해 사용됨
- 대표적으로 style loss에서는 (Eq. 2)와 같은 sample 내의 positional information을 aggregate 하여 channel 간의 correlation을 minimize 함
- 그 외에도 Gram matrix는 Knowledge Distillation을 위한 objective로써 사용할 수 있음
- 따라서 논문은 speech frame 간의 temporal relation을 반영하기 위해, 2개의 time step에서 channel information을 aggregate 하는 Temporal Gram Matrix (TGM)을 도입함 - 이때 TGM $\mathbf{G}$는:
(Eq. 3) $G_{ij}=\sum_{k}F_{ki}F_{kj},\,\,\, \mathbf{G}\in\mathbb{R}^{N\times N}$
- Layer-wise TGM Distillation
- Layer-wise TGM distillation은 각 Transformer layer output에 대한 TGM을 사용함
- 이때 front-end convolutional layer를 directly distill 하기 위해 first Transformer layer의 input을 포함함 - 해당 approach는 SSL model의 각 layer 내에 encode 된 information을 student의 해당 layer로 transfer 함
- First Transformer layer input을 zero-th output으로 두면, layer-wise TGM distillation loss는 모든 layer에서 teacher, student TGM 간의 Mean Squared Error (MSE)와 같음:
(Eq. 4) $\mathcal{L}_{layer\text{-}wise}=\sum_{\ell=0}^{L}\left|\left| \mathbf{G}^{\ell,T}-\mathbf{G}^{\ell,S}\right|\right|_{2}^{2}$
- Layer-wise TGM distillation은 각 Transformer layer output에 대한 TGM을 사용함
- Intra-layer TGM Distillation
- Student model의 parameter size를 더욱 줄이기 위해 논문은 intra-layer TGM distillation을 도입함
- 이를 위해 solution procedure matrix를 따라 TGM을 single Transformer layer의 input, output 간의 temporal relation을 compute 하는 것으로 modify 함
- 이때 matrix는 각 individual layer 내에서 speech representation의 progression을 capture 하고, 2개의 서로 다른 representation을 기반으로 flexible objective를 제공함
- 결과적으로 얻어지는 Transformer layer $\ell$의 modified TGM $\breve{\mathbf{G}}$과 intra-layer TGM distillation loss는:
(Eq. 5) $\breve{G}_{ij}^{\ell}=\sum_{k}F_{ki}^{\ell-1}F_{kj}^{\ell},\,\,\,\breve{\mathbf{G}}^{\ell} \in\mathbb{R}^{N\times N}$
(Eq. 6) $\mathcal{L}_{intra\text{-}layer}=\sum_{\ell=1}^{L}\left|\left| \breve{\mathbf{G}}^{\ell,T}-\breve{\mathbf{G}}^{\ell,S}\right|\right|_{2}^{2}$ - 해당 objective는 distillation을 위한 additional parameter를 요구하지 않음
- Student model의 channel width가 teacher model과 다르더라도 time length $N$이 동일하다면, 해당 loss를 additional linear projection 없이 formulate 할 수 있음
- 따라서 이를 통해 compact student를 얻을 수 있고, teacher의 temporal relation knowledge를 fully transfer 할 수 있음
- Student model의 parameter size를 더욱 줄이기 위해 논문은 intra-layer TGM distillation을 도입함
3. Experiments
- Settings
- Dataset : LibriSpeech
- Comparisons : HuBERT, LightHuBERT, DPHuBERT, ARMHuBERT, FitHuBERT
- Results
- 전체적으로 STaRHuBERT의 성능이 가장 우수함

- Selection of STaR Loss
- $\mathcal{L}_{layer\text{-}wise}+\mathcal{L}_{intra\text{-}layer}$의 조합이 가장 적합함

- Examination on Universality
- STaR는 Wav2Vec 2.0, WavLM에도 효과적으로 적용될 수 있음

- Parameter Size
- STaR를 사용하면 작은 parameter에서도 성능을 안정적인 성능을 얻을 수 있음

반응형
'Paper > Representation' 카테고리의 다른 글
댓글
