티스토리 뷰
Paper/TTS
[Paper 리뷰] JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech
feVeRin 2024. 3. 24. 10:59반응형
JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech
- Text-to-Speech는 2-stage 방식이나 개별적으로 training 된 모델의 cascade로 학습됨
- BUT, training pipeline은 최적의 성능을 위해서 fine-tuning이나 speech-text alignment를 요구함
- JETS
- Simplified pipeline을 구성해 개별적으로 학습된 모델들보다 뛰어난 성능을 발휘하는 end-to-end 모델을 제시
- Alignment module을 사용하여 FastSpeech2와 HiFi-GAN을 jointly training
- Alignment learning objective를 채택하여 external alignment tool에 대한 의존성을 제거
- 논문 (INTERSPEECH 2022) : Paper Link
1. Introduction
- Text-to-Speech (TTS)는 일반적으로 acoustic feature generator와 vocoder로 나누어져 각각의 sub-task를 담당함
- Acoustic feature generator가 먼저 input text로부터 acoustic feature를 생성하고, vocoder가 acoustic feature로부터 raw waveform을 합성하는 방식
- 따라서 2-stage system의 각 모델은 개별적으로 training된 다음, 추론 시 결합됨 - BUT, 두 모델을 개별적으로 training하면 acoustic feature mismatch로 인해 합성 품질의 저하가 발생함
- 특히 vocoder는 training을 위해 ground-truth acoustic feature를 취하고 추론 시에는 acoustic feature generator에서 예측된 acoustic feature를 사용함
- 이때 성능 향상을 위해 fine-tuning을 수행하거나 처음부터 예측된 acoustic feature를 사용할 수 있음
- 결과적으로 두 방식 모두 training pipeline을 복잡하게 만드는 단점이 있음 - 한편으로 End-to-End TTS (E2E-TTS)는 acoustic feature generator와 vocoder를 구분하지 않고 single stage로 동작하는 방식임
- E2E-TTS는 acoustic feature mismatch 문제가 발생하지 않기 때문에 fine-tuning과 같은 추가적인 training 과정이 필요하지 않음
- 특히 external alignment tool에 대한 의존성을 제거하여 pipeline을 더욱 단순화할 수 있음
- Acoustic feature generator가 먼저 input text로부터 acoustic feature를 생성하고, vocoder가 acoustic feature로부터 raw waveform을 합성하는 방식
-> 그래서 단순화된 pipeline과 고품질 합성을 지원하는 E2E-TTS 모델인 JETS를 제안
- JETS
- Acoustic feature generator와 vocoder를 joint training 함
- Intermediate mel-spectrogram 없이 input text로부터 raw waveform을 직접적으로 합성
- External aligner에 의존하지 않고 sinlge stage로써 training 될 수 있도록 alignment learning objective를 incorporate 함
< Overall of JETS >
- FastSpeech2와 HiFi-GAN을 joint training 하여 고품질 E2E-TTS 모델을 구성
- Alignment learning framework를 활용하여 training 중에 token duration을 얻음
- 결과적으로 기존 모델들보다 우수한 성능을 달성
2. Method
- JETS는 alignment module을 사용하여 FastSpeech2와 HiFi-GAN을 jointly training 하는 E2E-TTS 모델
- FastSpeech2
- JETS의 component 중 하나로써 FastSpeech2 architecture를 채택
- FastSpeech2는 빠르고 고품질의 합성을 지원하는 non-autoregressive acoustic feature generator
- Duration predictor를 통해 token duration을 explicitly modeling 하여 phoneme repeat/skip과 같은 synthesis error에 대한 robustness를 향상
- 이전 버전인 FastSpeech와 달리 pitch, energy 등의 variance information을 사용하여 품질을 크게 향상함 - 구조적으로는 feed-forward Transformer encoder, decoder, 1D-convolutional variance adaptor로 구성됨
- 이때 encoder는 input text를 text embedding $\mathbf{h}$로 encoding 하고, variance adaptor는 text embedding에 variance information을 추가하고 decoder의 각 token duration에 따라 expand 함 - Variance adaptor는 pitch, energy, duration predictor로 구성되고, 기존 frame-wise 방식 대신 FastPitch와 같이 token-wise pitch, energy를 최소화하도록 training 함
- Training 중에 필요한 token-wise pitch, energy $\mathbf{p}, \mathbf{e}$는 token duration $\mathbf{d}$에 따라 frame-wise ground-truth pitch, energy를 평균하여 계산됨
- Token duration은 각 input text token에 할당된 mel-frame 수로 정의되고, alignment module로부터 얻어짐
- Text embedding에 pitch, energy가 추가된 다음, token duration에 따라 length regulator (LR)에 의해 expand 됨
- 여기서 vanilla upsamping 대신 fixed temperature Gaussian samping인 softmax aligner를 사용
- JETS가 intermediate mel-spectrogram 없이 input text에서 raw waveform을 직접 합성할 수 있도록, mel-spectrogram loss는 제외되고, $L_{2}$ loss로 각 variance를 최소화하는 variance loss는 유지함:
(Eq. 1) $L_{var}=|| \mathbf{d}-\hat{\mathbf{d}}||_{2}+|| \mathbf{p}-\hat{\mathbf{p}}||_{2}+|| \mathbf{e}-\hat{\mathbf{e}}||_{2}$
- $\mathbf{d}, \mathbf{p}, \mathbf{e}$ : ground-truth duration, pitch, energy feature sequence
- $\hat{\mathbf{d}}, \hat{\mathbf{p}},\hat{\mathbf{e}}$ : 예측된 feature seqeuence
- FastSpeech2는 빠르고 고품질의 합성을 지원하는 non-autoregressive acoustic feature generator
- HiFi-GAN
- HiFi-GAN은 빠르고 효율적인 병렬 합성이 가능한 Generative Adversarial Network (GAN) 기반의 neural vocoder
- GAN framework에서 generator는 discriminator를 속이고 discriminator는 generator의 sample을 discriminate 하는 adversarial feedback으로 training 됨
- HiFi-GAN의 discriminator는 fidelity를 향상하기 위해, Multi-Period Discriminator (MPD)와 Multi-Scale Discriminator (MSD)를 활용함
- MPD는 periodic pattern을 처리하고, MSD는 넓은 receptive field를 가지는 다양한 scale의 consecutive waveform에서 동작 - JETS는 decoder output에서 raw waveform을 합성하기 위해 HiFi-GAN generator를 채택함
- HiFi-GAN generator는 decoder output이 ground-truth waveform의 mel-spectrogram과 동일한 length를 가지는 raw waveform의 length와 일치하도록 transposed covolution을 통해 decoder output을 upsampling 함
- 이때 adversarial loss 외에 feature matching loss와 mel-spectrogram loss를 auxiliary loss로써 사용함
- 여기서 auxiliary mel-spectrogram loss는 ground-truth와 예측 mel-spectrogram 간의 $L_{1}$ loss로 앞선 FastSpeech2의 mel-spectrogram loss와는 다름
- 결과적으로 HiFi-GAN의 training objective는 LSGAN을 따르고, 이때 generator loss는:
(Eq. 2) $L_{g}=L_{g,adv}+\lambda_{fm}L_{fm}+\lambda_{mel}L_{mel}$
- $L_{g,adv}$ : least-square 기반 adversarial loss
- $\lambda_{fm}, \lambda_{mel}$ : feature matching, mel-spectrogram loss에 대한 scaling factor
- Alignment Learning Framework
- Speech-text alignment는 training을 위해 explicit duration이 필요한 duration informed network에서 중요한 요소
- JETS의 각 token duration $\mathbf{d}$는 duration predictor training, token-averaged pitch 계산, frame-wise energy 계산, text embedding upsampling에 사용됨
- 일반적으로 token duration은 montreal forced aligner (MFA), pre-trained autoregressive TTS 모델 등에서 얻어지지만, 이 과정을 training pipeline에 통합하여 더욱 단순화할 수 있음
- 따라서 JETS는 training 중에 필요한 token duration $\mathbf{d}$를 얻기 위해 alignment learning framework를 통합함
- 이때 alignment learning objective는 forward-sum algorithm을 통해 효율적으로 계산될 수 있음
- 먼저 aligmnet module은 text embedding $\mathbf{h}$와 mel-spectrogram $\mathbf{m}$을 각각 2개, 3개의 1D convolution layer를 사용하여 $\mathbf{h}^{enc}, \mathbf{m}^{enc}$로 encoding 함
- 이후 모든 text token과 mel-frame 간의 learned pairwise affinity를 기반으로 text domain 전체에 걸쳐 normalized softmax로 계산되는 soft alignment 분포 $\mathcal{A}_{soft}$를 얻음:
(Eq. 3) $D_{i,j}=dist_{L2}(\mathbf{h}_{i}^{enc},\mathbf{m}_{j}^{enc})$
(Eq. 4) $\mathcal{A}_{soft}=\mathrm{softmax}(-D,dim=0)$
- $\mathbf{h}_{i}^{enc}, \mathbf{m}_{j}^{enc}$ : time step $i,j$에서 encoding 된 text embedding과 mel-spectrogram - Soft alignment 분포 $\mathcal{A}_{soft}$에서 모든 valid monotonic alignment가 최대화되는 likelihood를 계산할 수 있음:
(Eq. 5) $P(S(\mathbf{h})|\mathbf{m})=\sum_{s\in S(\mathbf{h})}\prod_{t=1}^{T}P(s_{t}|m_{t})$
- $s$ : text와 mel-spectrogram 간의 specific alignment (e.g. $s_{1}=h_{1},s_{2}=h_{2},...,s_{T}=h_{T}$)
- $S(\mathbf{h})$ : 모든 valid monotonic alignment의 set
- $T, N$ : 각각 mel-spectrogram, text token length - Alignment learning objective를 계산하기 위해 forward-sum algorithm이 사용되고, 해당 algorithm의 negative를 forward-sum loss $L_{forward\textrm{_}sum}$으로 정의
- 이는 off-the-shelf CTC loss를 통해 효율적으로 training 될 수 있음
- 이후 모든 text token과 mel-frame 간의 learned pairwise affinity를 기반으로 text domain 전체에 걸쳐 normalized softmax로 계산되는 soft alignment 분포 $\mathcal{A}_{soft}$를 얻음:
- Token duration $\mathbf{d}$를 얻기 위해, Montonic Alignment Search (MAS)는 soft alignment $\mathcal{A}_{soft}$를 monotonic binarized hard alignment $\mathcal{A}_{hard}$로 변환함
- 이때 $\sum_{j=1}^{T}\mathcal{A}_{hard,i,j}$는 각 token duration을 나타냄
- 따라서 각 token duration은 input text token 각각에 할당된 mel-frame 수이고, duration의 합은 mel-spectrogram의 length와 같음
- 이를 위해 KL-divergence를 최소화하여 $\mathcal{A}_{soft}, \mathcal{A}_{hard}$를 match하는 additional binarization loss $L_{bin}$을 사용:
(Eq. 6) $L_{bin}=-\mathcal{A}_{hard}\odot \log \mathcal{A}_{soft}$ - 추가적으로 beta-binomial alignment prior를 사용하여 $\mathcal{A}_{soft}$ 이전에 2D static을 곱하여, near-diagonal path를 유도함으로써 alignment learning을 가속화함
(Eq. 7) $L_{align}=L_{forward \textrm{_} sum}+L_{bin}$
- $\odot$ : Hadamard product, $L_{align}$ : alignment에 대한 final loss
- 이를 위해 KL-divergence를 최소화하여 $\mathcal{A}_{soft}, \mathcal{A}_{hard}$를 match하는 additional binarization loss $L_{bin}$을 사용:
- Final Loss
- JETS는 encoder, variance adaptor, decoder, HiFi-GAN generator, alignment module로 구성됨
- 이때 alignment module은 training 시에만 사용되고, GAN training framework에서 intermediate mel-spectrogram loss 없이 input text로부터 raw waveform을 직접 합성하도록 training 됨
- 여기서 training을 위해 HiFi-GAN의 discriminator를 사용 - 결과적으로 JETS의 final loss는 variance loss와 alignment loss를 결합한 GAN training loss:
(Eq. 8) $L=L_{g}+\lambda_{var}L_{var}+\lambda_{align}L_{align}$
- 논문에서는 $\lambda_{var} =1, \lambda_{align}=2$로 설정
- 이때 alignment module은 training 시에만 사용되고, GAN training framework에서 intermediate mel-spectrogram loss 없이 input text로부터 raw waveform을 직접 합성하도록 training 됨
3. Experiments
- Settings
- Dataset : LJSpeech
- Comparisons : ESPNet2-TTS (CF2), VITS
- Results
- JETS는 다른 모델과 비교하여 더 높은 MOS와 정량적 metric 결과를 보임
- End-to-End 방식이 단순한 joint training이나 joint fine-tuning 보다 더 효과적이라고 볼 수 있음
반응형
'Paper > TTS' 카테고리의 다른 글
댓글