티스토리 뷰

반응형

STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models


  • Transformer-based Speech Self-Supervised Learning model은 large parameter size와 computational cost를 가짐
  • STaR
    • Speech temporal relation을 distilling 하여 Speech Self-Supervised Learning model을 compress
    • 특히 speech frame 간의 temporal relation을 transfer 하여 lightweight student를 얻음
  • 논문 (ICASSP 2024) : Paper Link

1. Introduction

  • Transformer-based Speech Self-Supervised Learning (SSL) model은 상당한 computational resource를 요구함
    • 이를 해결하기 위해 DistilHuBERT, FitHuBERT, ARMHuBERT, LightHuBERT 등은 Knowledge Distillation을 활용해 SSL model을 compress 함
    • BUT, 해당 방식들은 다음의 한계점이 있음:
      1. Student의 weak representation capability를 neglect 하고 additional linear head를 도입해 complex teacher representation과 directly match 함
        - 따라서 student에 적합한 alternative distillation objective가 필요함
      2. 기존 방식은 parameter size 외의 computaitonal cost를 결정할 수 없음
        - 즉, computational overhead가 더 높아질 수 있음

-> 그래서 더 나은 SSL model compression을 위한 distillation method인 STaR를 제안

 

  • STaR
    • Lightweight student model을 얻기 위해 speech frame 간의 temporal relation을 capture 하는 distillation objective를 도입
    • Computationally efficient 한 model을 구축하기 위해 distillation 시 additional parameter를 사용하지 않음

< Overall of STaR >

  • Speech temporal relation을 distillation 하는 SSL compression strategy
  • 결과적으로 기존 lightweight SSL model보다 우수한 성능을 달성

2. Method

  • Speech SSL model은 speech frame을 각 time step에 대한 specific acoustic unit의 feature로 represent 할 수 있음
    • 이는 SSL model이 pre-training에서 각각의 masked frame에 대한 cluster를 predict 하도록 학습되기 때문
      - 즉, speech SSL model은 specific acoustic unit과 closely tie 된 frame 별 representation을 생성함
    • BUT, 해당 tearcher representation을 directly learning 하는 것은 limited representation capacity를 가진 student에게는 over-constraint 함 
    • 따라서 논문은 speech frame 간의 relation에 focus 하여 teahcer knowledge를 flexible manner로 distill 하기 위해, 다음 3가지의 Speech Temporal Relation (STaR) distillation objective를 고려함 
      1. Temporal relation을 distill 하는 Average Attention Map Distillation
      2. Layer-wise Temporal Gram Matrix (TGM) Distillation
      3. Intra-layer TGM Distillation

Overview

- Average Attention Map Distillation

  • Transformer의 attention map은 key, query 간의 temporal relation을 capture 함
    • 이때 attention map의 각 entry는 sequence의 두 frame 간의 relationship level을 나타내므로 해당 map을 distill 하여 teacher로부터 temporal relation을 transfer 할 수 있음
    • Multi-Head Self-Attention (MHSA)의 head $h$에 대한 key, query matrix $\mathbf{K}_{h},\mathbf{Q}_{h}\in\mathbb{R}^{d_{h}\times N}$이 주어졌을 때, 해당 attention map $\mathbf{A}_{h}\in\mathbb{R}^{N\times N}$은 $\mathbf{A}_{h}=\text{softmax}\left(\mathbf{Q}_{h}^{\top}\mathbf{K}_{h}/\sqrt{d_{h}}\right)$와 같음
      - $d_{h}$ : head $h$에 대한 key matrix의 width, $N$ : sequence length
    • BUT, 기존에는 knowledge distillation을 위해 last Transformer layer의 모든 head를 활용하므로, speech SSL model에서 computational overhead가 발생함
      - 따라서 논문은 각 Transformer layer의 모든 head에 대한 Averaged Attention Map을 사용함
    • 그러면 각 Transformer layer에 대한 loss는 teacher $T$와 student $S$의 average attention map에 대한 Kullback-Leibler (KL) divergence로 얻어짐:
      (Eq. 1) $ \mathcal{L}_{avg\text{-}attn}=\sum_{\ell=1}^{L}\sum_{t=1}^{N}D_{KL}\left( \left.\left.\frac{1}{H^{T}}\sum_{h=1}^{H^{T}}\mathbf{A}_{h,t}^{\ell,T}\right|\right| \frac{1}{H^{S}}\sum_{h=1}^{H^{S}}\mathbf{A}_{h,t}^{\ell,S}\right)$
      - $H, L$ : total attention head 수, Transformer layer 수
      - $\mathbf{A}_{h,t}^{\ell}$ : time step $t$에서 Transformer layer $\ell$의 attention head $h$에 대한 attention distribution

- Temporal Gram Matrix Distillation

  • 논문은 attention map 외에도 stronger hint를 제공하기 위해, 각 Transformer layer output의 temporal relation에 대한 distillation objective를 고려함
    • 이때 sample 간의 correlation은 sample representation 간의 inner product로 정의되는 Gram matrix로 나타낼 수 있음
    • $\mathbf{F}\in\mathbb{R}^{d\times N}$을 channel width $d$의 representation이라고 하면, Gram matrix $\tilde{\mathbf{G}}$는:
      (Eq. 2) $\tilde{G}_{i,j}=\sum_{k}F_{ik}F_{jk},\,\,\,\tilde{\mathbf{G}}\in\mathbb{R}^{d\times d}$
      - $F_{\cdot k}$ : $k$-th time step의 representation
    • Gram matrix는 주로 data property를 represent 하기 위해 사용됨
      1. 대표적으로 style loss에서는 (Eq. 2)와 같은 sample 내의 positional information을 aggregate 하여 channel 간의 correlation을 minimize 함
      2. 그 외에도 Gram matrix는 Knowledge Distillation을 위한 objective로써 사용할 수 있음
        - 따라서 논문은 speech frame 간의 temporal relation을 반영하기 위해, 2개의 time step에서 channel information을 aggregate 하는 Temporal Gram Matrix (TGM)을 도입함
      3. 이때 TGM $\mathbf{G}$는:
        (Eq. 3) $G_{ij}=\sum_{k}F_{ki}F_{kj},\,\,\, \mathbf{G}\in\mathbb{R}^{N\times N}$
  • Layer-wise TGM Distillation
    • Layer-wise TGM distillation은 각 Transformer layer output에 대한 TGM을 사용함
      - 이때 front-end convolutional layer를 directly distill 하기 위해 first Transformer layer의 input을 포함함
    • 해당 approach는 SSL model의 각 layer 내에 encode 된 information을 student의 해당 layer로 transfer 함
    • First Transformer layer input을 zero-th output으로 두면, layer-wise TGM distillation loss는 모든 layer에서 teacher, student TGM 간의 Mean Squared Error (MSE)와 같음:
      (Eq. 4) $\mathcal{L}_{layer\text{-}wise}=\sum_{\ell=0}^{L}\left|\left| \mathbf{G}^{\ell,T}-\mathbf{G}^{\ell,S}\right|\right|_{2}^{2}$
  • Intra-layer TGM Distillation
    • Student model의 parameter size를 더욱 줄이기 위해 논문은 intra-layer TGM distillation을 도입함
      1. 이를 위해 solution procedure matrix를 따라 TGM을 single Transformer layer의 input, output 간의 temporal relation을 compute 하는 것으로 modify 함
      2. 이때 matrix는 각 individual layer 내에서 speech representation의 progression을 capture 하고, 2개의 서로 다른 representation을 기반으로 flexible objective를 제공함
    • 결과적으로 얻어지는 Transformer layer $\ell$의 modified TGM $\breve{\mathbf{G}}$과 intra-layer TGM distillation loss는
      (Eq. 5) $\breve{G}_{ij}^{\ell}=\sum_{k}F_{ki}^{\ell-1}F_{kj}^{\ell},\,\,\,\breve{\mathbf{G}}^{\ell} \in\mathbb{R}^{N\times N}$
      (Eq. 6) $\mathcal{L}_{intra\text{-}layer}=\sum_{\ell=1}^{L}\left|\left| \breve{\mathbf{G}}^{\ell,T}-\breve{\mathbf{G}}^{\ell,S}\right|\right|_{2}^{2}$
    • 해당 objective는 distillation을 위한 additional parameter를 요구하지 않음
      1. Student model의 channel width가 teacher model과 다르더라도 time length $N$이 동일하다면, 해당 loss를 additional linear projection 없이 formulate 할 수 있음
      2. 따라서 이를 통해 compact student를 얻을 수 있고, teacher의 temporal relation knowledge를 fully transfer 할 수 있음

3. Experiments

- Settings

- Results

  • 전체적으로 STaRHuBERT의 성능이 가장 우수함

Model 성능 비교

  • Selection of STaR Loss
    • $\mathcal{L}_{layer\text{-}wise}+\mathcal{L}_{intra\text{-}layer}$의 조합이 가장 적합함

Loss Selection

  • Examination on Universality

Universality

  • Parameter Size
    • STaR를 사용하면 작은 parameter에서도 성능을 안정적인 성능을 얻을 수 있음

Parameter 별 성능

 

반응형
댓글
최근에 올라온 글
최근에 달린 댓글
«   2025/12   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30 31
Total
Today
Yesterday