0%

Wav2Vec2 + CTC = 語音辨識? 簡單易懂的 CTC loss 介紹

語音辨識到底在做什麼,這個模型的loss會怎麼樣設計?這裡會嘗試解答這些問題>

transformers==4.2.0版本裡面,新增了一個Wav2Vec2ForCTC模型,短短的幾行code就可以做到語音辨識:

import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
import soundfile as sf

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

def map_to_array(batch):
    speech, _ = sf.read(batch["file"])
    batch["speech"] = speech
    return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.map(map_to_array)

input_values = processor(ds["speech"][0], return_tensors="pt").input_values  # Batch size 1
logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])

結果
'A MAN SAID TO THE UNIVERSE SIR I EXIST'

語音辨識到底在做什麼,這個模型的loss會怎麼樣設計?接下來會嘗試解答這些問題>

語音辨識在做甚麼?

語音辨識的輸入會是一段錄音,經過處理之後變成一張Spectrogram(時頻譜),來表示每一個時刻,不同頻段的能量。然後將Spectrogram輸入到模型中。
Spectrogram

模型接收到Spectrogram之後,每一段時刻都會做一次預測(例如100ms一個結果),然後根據預測的結果得到對應的文本。

但在模型訓練的時候,由於輸入只有錄音和錄音文本,並沒有提供某一段時刻應該輸出哪一個段文本,我們就不會知道預測的結果應該跟哪一個字算loss,模型也就不能訓練了。

為了解決長輸入對應到短輸出loss不能計算的問題,就有Connectionist Temporal Classification Loss(CTC Loss)的出現。

CTC Loss

首先,我們需要找一個機制,讓允許長的序列可以找到對應的結果。

  1. 允許重複輸出
    在例子上來看,模型在不確定哪一個時刻應該輸出g的時候,允許模型在多於一個時刻的預測輸出how。換句話說,模型不需要煩惱how輸出過了沒,更多關注按照當前時刻的特徵應該輸出甚麼就好。
  2. 合併輸出
    下一步呢就需要合併這些重複的輸出,也加入空格符號,讓詞之間可以分隔開。

模型的訓練便是讓它的輸出的結果能輸出符合這個規則的sequence
用一個例子來看這個過程,輸入一段音頻,要預測g這個字。假設模型會decode出三個狀態,每一個狀態都會給我們所有token的機率,然後我們選取最高機率的結果,得到文本以後,再按照上述文本合併輸出:

所以呢,訓練的目標就是希望模型可以輸出這些組合的其中一個,就可以輸出到正確的文本。一個直觀的方法是窮舉所有的組合,對這些組合都算loss,模型就會漸漸收斂到其中的一個結果。
在之前的例子上,一樣是會輸出三個state,然後希望decode出g, 我們修改一下機率,模擬訓練中的狀況,然後我們把所有可能的path列出來,算loss:

在Pytorch CTC Loss上驗證結果有沒有一樣:

import torch

input = torch.log(torch.tensor([
    [[ 0.4, 0.6,]],
    [[ 0.3, 0.7,]],
    [[ 0.2, 0.8,]],
], dtype=torch.float, requires_grad=True))
target = torch.tensor([[1]])
input_lengths = torch.tensor([3])
target_lengths = torch.tensor([1])

ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)
# tensor(0.1839, grad_fn=<MeanBackward0>)

不過呢,這個窮舉的方法還有一個缺陷,組合的數量會隨著輸入的長度增加而呈指數增長,效率太低,難以在大量的資料上訓練。為了增加訓練的效率,其實並不需要將所有結果窮舉出來,token與token之間移動的方向和範圍其實都是有限制的,我們可能根據他們的限制高效找尋所有合理的結果。

因此接下來會改為用Dynamic Program的方式得出最終的loss。它的主要想法是將問題一層層分解,每一層的結果是根據上一層的輸出得到。

首先,我們可以將所有可能出現的結果列出來,做成一個表格,嘗試將所有可行的結果在這個表格裡面表示出來,表格裡面的值則會是當前的機率,這個機率會根據模型輸出的結果(左邊的表格)和上一個時刻的輸出得到的。

  • T1
    在T1的時候,給他們對應的機率(這裡會有兩個_,一個是表示在g之前出現,另一個則表示在g之後出現):

    開頭只會是_g,因此也只有前兩個我們需要在模型輸出那邊,找到對應的機率,而其他的機率則都會是0。

  • T2
    然後,根據T1的結果推算T2。T2裡面就有可能出現兩個_,我們用圓形表示g之前的_,用三角形表示g之後的_。根據之前窮舉的結果,(T2,_,圓形)只會從(T1,_,圓形)過來。(T2,g)的結果則可能從(T1,_,圓形)和(T1,g)過來。(T2,_,三角形)則會是(T1,g)之後的結果,然後我們將T1的結果和T2的乘對應的機率再相加。

  • T3
    也是一樣按照上面的規則,算出T3的結果。

  • 最終的機率
    T3的兩個情況都可能發生,相加取log就是對應的loss。也可以看到算出來的結果跟之前的結果是一致。

在上面的過程,會發現有跡可循的地方:

  • 開頭一定會是左上角的前兩格,結束一定會是右下方的後兩格,也就是開頭一定是_或第一個字,結束一定是最後一個字或_
  • T轉到T+1的時候只會是往右方移動,而且只會平移(同一個字)或者往下一格或者兩格走,表示移到下一個字或者空格符號(例子只有一個字所以沒有來移兩格到情況,實在抱歉)

所以按照這個規則,我們也可以從後往前把機率算出來,結果也是一樣的:

其實將 從前往後(Forward) 跟 從後往前(Backward) 並不是算好玩的,這個有助於我們推算出每一個T的時候的每一個token的loss應該是甚麼,然後反向傳播回去算loss。 例如 (T2,g) 這個token的機率就是根據Forward跟backward得到的:

以上就是CTC Loss的大致想法,解決了長sequence對應到短sequence沒有alignment不能訓練這一問題,使得語音辨識的訓練能用更大量的資料訓練,得到更加好的結果。侷限就是輸入sequence一定要長於輸出,且輸入sequence越長越不好訓練。

接下來會介紹wav2vec2的模型,以及其多語言版本XLSR~