我們先來回顧一下LSTM模型: LSTM的基本結(jié)構(gòu)LSTM的關(guān)鍵在于它的“記憶單元”,能夠選擇性地記住或者忘記信息。其核心組件包括三個(gè)門和一個(gè)記憶單元: 1. 遺忘門(Forget Gate):決定應(yīng)該丟棄哪些信息。 2. 輸入門(Input Gate):決定更新哪些新的信息。 3. 輸出門(Output Gate):決定當(dāng)前狀態(tài)如何影響輸出。 
數(shù)學(xué)公式解釋遺忘門: ft=σ(Wf?[ht?1,xt]+bf) 遺忘門決定了上一時(shí)刻的狀態(tài) Ct?1中,哪些信息需要保留,哪些需要丟棄。值域?yàn)?[0, 1],1表示完全保留,0表示完全丟棄。 輸入門: it=σ(Wi?[ht?1,xt]+bi) 輸入門決定了當(dāng)前時(shí)刻的輸入 xt 需要多少加入到記憶單元。 候選記憶單元: C~t=tanh?(WC?[ht?1,xt]+bC) 這是當(dāng)前時(shí)刻的候選記憶內(nèi)容。 更新記憶單元: Ct=ft?Ct?1+it?C~t 記憶單元通過遺忘門和輸入門結(jié)合,更新當(dāng)前的記憶狀態(tài)。 輸出門: ot=σ(Wo?[ht?1,xt]+bo) 輸出門決定了當(dāng)前記憶單元 Ct 對輸出的影響。 隱藏狀態(tài)更新: ht=ot?tanh?(Ct) 隱藏狀態(tài)通過輸出門和當(dāng)前記憶單元來更新。
有同學(xué)比較疑惑LSTM中的sigmoid函數(shù)和tanh的作用是什么?下來我來為大家解惑:Sigmoid函數(shù)(遺忘門、輸入門、輸出門):用于控制信息流。由于sigmoid的輸出值在 0 和 1 之間,表示“選擇”的強(qiáng)度。0表示完全不通過,1表示完全通過。它起到類似開關(guān)的作用。 Tanh函數(shù)(候選記憶單元、隱藏狀態(tài)):用于將輸入值縮放到 -1 到 1 之間,確保信息在網(wǎng)絡(luò)中不會增長過大或過小,幫助模型捕捉數(shù)據(jù)中的正負(fù)變化。同時(shí),tanh可以引入非線性特征,增加網(wǎng)絡(luò)表達(dá)能力。
也就是說Sigmoid函數(shù)作為**“開關(guān)”**,在LSTM的各個(gè)門(遺忘門、輸入門、輸出門)中使用,決定信息流的多少。Tanh函數(shù)用于將數(shù)值范圍縮放到-1到1之間,幫助控制記憶單元的值,確保信息的平衡和穩(wěn)定性,并用于生成隱藏狀態(tài)。 下面讓我們來用一個(gè)例子來輔助大家對模型的理解: 為了詳細(xì)講解LSTM如何工作,我們通過一個(gè)具體的例子一步步剖析每個(gè)步驟。 例子:預(yù)測序列中的下一個(gè)數(shù)我們有一個(gè)簡單的序列數(shù)據(jù): 1,2,3,4,5,6,7,8,9,10 我們的目標(biāo)是訓(xùn)練一個(gè)LSTM模型,讓它能夠根據(jù)之前的數(shù)字預(yù)測下一個(gè)數(shù)字。例如,輸入[1, 2, 3]時(shí),模型應(yīng)該輸出4。 1. 輸入表示LSTM處理的是時(shí)間序列數(shù)據(jù),我們可以將每個(gè)數(shù)字視為一個(gè)時(shí)間步(time step)。對于輸入序列1,2,3,我們需要在每個(gè)時(shí)間步都輸入一個(gè)數(shù)字: 時(shí)間步1:輸入1 時(shí)間步2:輸入2 時(shí)間步3:輸入3
在每一個(gè)時(shí)間步,LSTM會使用前一步的隱藏狀態(tài)以及當(dāng)前輸入來更新它的記憶單元C和隱藏狀態(tài)h。 2. LSTM的核心機(jī)制LSTM有三個(gè)關(guān)鍵的門:遺忘門、輸入門和輸出門。這三個(gè)門控制了信息如何在LSTM單元中流動(dòng)。讓我們看一下當(dāng)我們輸入1,2,3時(shí),LSTM內(nèi)部發(fā)生了什么。 時(shí)間步1:輸入1遺忘門 遺忘門的作用是決定上一個(gè)時(shí)間步的信息要保留多少。因?yàn)檫@是第一個(gè)時(shí)間步,之前沒有信息,所以LSTM的記憶單元初始為0。假設(shè)此時(shí)遺忘門計(jì)算出的值為0.8,這意味著LSTM會保留80%的之前的記憶狀態(tài)(雖然此時(shí)沒有實(shí)際的歷史狀態(tài))。 輸入門 輸入門決定新信息要多少被寫入記憶單元。假設(shè)輸入門給出的值為0.9,意味著我們會把當(dāng)前輸入的信息的90%加入到記憶單元。 更新記憶單元 LSTM單元會計(jì)算候選記憶內(nèi)容。假設(shè)此時(shí)的候選內(nèi)容為1(通過激活函數(shù)計(jì)算得到),結(jié)合遺忘門和輸入門,更新記憶單元: C1=0.8?0+0.9?1=0.9
輸出門 輸出門決定記憶單元如何影響隱藏狀態(tài)。假設(shè)輸出門給出的值為0.7,隱藏狀態(tài)通過以下公式計(jì)算: h1=0.7?tanh?(0.9)≈0.63 這個(gè)隱藏狀態(tài)會作為下一時(shí)間步的輸入。
時(shí)間步2:輸入2遺忘門 遺忘門決定如何處理前一個(gè)時(shí)間步的記憶單元。假設(shè)遺忘門的值為0.7,意味著70%的上一步記憶將被保留。 輸入門 假設(shè)此時(shí)輸入門的值為0.8,意味著會將當(dāng)前輸入2的80%加入到記憶單元。 更新記憶單元 候選內(nèi)容通過激活函數(shù)計(jì)算,假設(shè)此時(shí)候選內(nèi)容為2。結(jié)合遺忘門和輸入門: C2=0.7?0.9+0.8?2=2.03
輸出門 假設(shè)輸出門的值為0.6,隱藏狀態(tài)為: h2=0.6?tanh?(2.03)≈0.56
時(shí)間步3:輸入3遺忘門 假設(shè)遺忘門的值為0.75,保留75%的上一個(gè)記憶。 輸入門 假設(shè)輸入門的值為0.85,意味著會將當(dāng)前輸入3的85%加入到記憶單元。 更新記憶單元 假設(shè)候選內(nèi)容為3,更新記憶單元: C3=0.75?2.03+0.85?3=3.5
輸出門 假設(shè)輸出門的值為0.65,隱藏狀態(tài)為: h3=0.65?tanh?(3.5)≈0.58
3. LSTM如何預(yù)測經(jīng)過這三個(gè)時(shí)間步,LSTM得到了一個(gè)隱藏狀態(tài)h3,代表了模型對序列1,2,3的理解。接下來,隱藏狀態(tài)會通過一個(gè)全連接層或線性層,輸出預(yù)測值。 假設(shè)線性層的輸出是4,這意味著LSTM模型根據(jù)序列1,2,3,預(yù)測接下來是4。 總結(jié)通過這些門的組合,LSTM可以有效地在長序列數(shù)據(jù)中學(xué)習(xí)到哪些信息是重要的,哪些是可以忽略的,從而解決傳統(tǒng)RNN的長依賴問題。 接下來會繼續(xù)深入講解LSTM模型,比如如果輸入一串中文,LSTM該如何處理怎么預(yù)測,比如各個(gè)門的權(quán)重如何更新,如何決定遺忘多少,輸入多少。
|