




版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報(bào)或認(rèn)領(lǐng)
文檔簡介
第PyTorch搭建雙向LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測目錄I.前言II.原理InputsOutputsbatch_first輸出提取III.訓(xùn)練和預(yù)測IV.源碼及數(shù)據(jù)
I.前言
前面幾篇文章中介紹的都是單向LSTM,這篇文章講一下雙向LSTM。
系列文章:
PyTorch搭建LSTM實(shí)現(xiàn)多變量多步長時(shí)序負(fù)荷預(yù)測
PyTorch搭建LSTM實(shí)現(xiàn)多變量時(shí)序負(fù)荷預(yù)測
PyTorch深度學(xué)習(xí)LSTM從input輸入到Linear輸出
PyTorch搭建LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測
II.原理
關(guān)于LSTM的輸入輸出在深入理解PyTorch中LSTM的輸入和輸出(從input輸入到Linear輸出)中已經(jīng)有過詳細(xì)敘述。
關(guān)于nn.LSTM的參數(shù),官方文檔給出的解釋為:
總共有七個(gè)參數(shù),其中只有前三個(gè)是必須的。由于大家普遍使用PyTorch的DataLoader來形成批量數(shù)據(jù),因此batch_first也比較重要。LSTM的兩個(gè)常見的應(yīng)用場景為文本處理和時(shí)序預(yù)測,因此下面對(duì)每個(gè)參數(shù)我都會(huì)從這兩個(gè)方面來進(jìn)行具體解釋。
input_size:在文本處理中,由于一個(gè)單詞沒法參與運(yùn)算,因此我們得通過Word2Vec來對(duì)單詞進(jìn)行嵌入表示,將每一個(gè)單詞表示成一個(gè)向量,此時(shí)input_size=embedding_size。比如每個(gè)句子中有五個(gè)單詞,每個(gè)單詞用一個(gè)100維向量來表示,那么這里input_size=100;在時(shí)間序列預(yù)測中,比如需要預(yù)測負(fù)荷,每一個(gè)負(fù)荷都是一個(gè)單獨(dú)的值,都可以直接參與運(yùn)算,因此并不需要將每一個(gè)負(fù)荷表示成一個(gè)向量,此時(shí)input_size=1。但如果我們使用多變量進(jìn)行預(yù)測,比如我們利用前24小時(shí)每一時(shí)刻的[負(fù)荷、風(fēng)速、溫度、壓強(qiáng)、濕度、天氣、節(jié)假日信息]來預(yù)測下一時(shí)刻的負(fù)荷,那么此時(shí)input_size=7。hidden_size:隱藏層節(jié)點(diǎn)個(gè)數(shù)??梢噪S意設(shè)置。num_layers:層數(shù)。nn.LSTMCell與nn.LSTM相比,num_layers默認(rèn)為1。batch_first:默認(rèn)為False,意義見后文。
Inputs
關(guān)于LSTM的輸入,官方文檔給出的定義為:
可以看到,輸入由兩部分組成:input、(初始的隱狀態(tài)h_0,初始的單元狀態(tài)c_0)?
其中input:
input(seq_len,batch_size,input_size)
seq_len:在文本處理中,如果一個(gè)句子有7個(gè)單詞,則seq_len=7;在時(shí)間序列預(yù)測中,假設(shè)我們用前24個(gè)小時(shí)的負(fù)荷來預(yù)測下一時(shí)刻負(fù)荷,則seq_len=24。batch_size:一次性輸入LSTM中的樣本個(gè)數(shù)。在文本處理中,可以一次性輸入很多個(gè)句子;在時(shí)間序列預(yù)測中,也可以一次性輸入很多條數(shù)據(jù)。input_size:見前文。
(h_0,c_0):
h_0(num_directions*num_layers,batch_size,hidden_size)
c_0(num_directions*num_layers,batch_size,hidden_size)
h_0和c_0的shape一致。
num_directions:如果是雙向LSTM,則num_directions=2;否則num_directions=1。num_layers:見前文。batch_size:見前文。hidden_size:見前文。
Outputs
關(guān)于LSTM的輸出,官方文檔給出的定義為:
可以看到,輸出也由兩部分組成:otput、(隱狀態(tài)h_n,單元狀態(tài)c_n)
其中output的shape為:
output(seq_len,batch_size,num_directions*hidden_size)
h_n和c_n的shape保持不變,參數(shù)解釋見前文。
batch_first
如果在初始化LSTM時(shí)令batch_first=True,那么input和output的shape將由:
input(seq_len,batch_size,input_size)
output(seq_len,batch_size,num_directions*hidden_size)
變?yōu)椋?/p>
input(batch_size,seq_len,input_size)
output(batch_size,seq_len,num_directions*hidden_size)
即batch_size提前。
輸出提取
假設(shè)最后我們得到了output(batch_size,seq_len,2*hidden_size),我們需要將其輸入到線性層,有以下兩種方法可以參考:
(1)直接輸入
和單向一樣,我們可以將output直接輸入到Linear。在單向LSTM中:
self.linear=nn.Linear(self.hidden_size,self.output_size)
而在雙向LSTM中:
self.linear=nn.Linear(2*self.hidden_size,self.output_size)
模型:
classBiLSTM(nn.Module):
def__init__(self,input_size,hidden_size,num_layers,output_size,batch_size):
super().__init__()
self.input_size=input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.output_size=output_size
self.num_directions=2
self.batch_size=batch_size
self.lstm=nn.LSTM(self.input_size,self.hidden_size,self.num_layers,batch_first=True,bidirectional=True)
self.linear=nn.Linear(self.num_directions*self.hidden_size,self.output_size)
defforward(self,input_seq):
h_0=torch.randn(self.num_directions*self.num_layers,self.batch_size,self.hidden_size).to(device)
c_0=torch.randn(self.num_directions*self.num_layers,self.batch_size,self.hidden_size).to(device)
#print(input_seq.size())
seq_len=input_seq.shape[1]
#input(batch_size,seq_len,input_size)
input_seq=input_seq.view(self.batch_size,seq_len,self.input_size)
#output(batch_size,seq_len,num_directions*hidden_size)
output,_=self.lstm(input_seq,(h_0,c_0))
#print(self.batch_size*seq_len,self.hidden_size)
output=output.contiguous().view(self.batch_size*seq_len,self.num_directions*self.hidden_size)#(5*30,64)
pred=self.linear(output)#pred()
pred=pred.view(self.batch_size,seq_len,-1)
pred=pred[:,-1,:]
returnpred
(2)處理后再輸入
在LSTM中,經(jīng)過線性層后的output的shape為(batch_size,seq_len,output_size)。假設(shè)我們用前24個(gè)小時(shí)(1to24)預(yù)測后2個(gè)小時(shí)的負(fù)荷(25to26),那么seq_len=24,output_size=2。根據(jù)LSTM的原理,最終的輸出中包含了所有位置的預(yù)測值,也就是((23),(34),(45)(2526))。很顯然我們只需要最后一個(gè)預(yù)測值,即output[:,-1,:]。
而在雙向LSTM中,一開始o(jì)utput(batch_size,seq_len,2*hidden_size),這里面包含了所有位置的兩個(gè)方向的輸出。簡單來說,output[0]為序列從左往右第一個(gè)隱藏層狀態(tài)輸出和序列從右往左最后一個(gè)隱藏層狀態(tài)輸出的拼接;output[-1]為序列從左往右最后一個(gè)隱藏層狀態(tài)輸出和序列從右往左第一個(gè)隱藏層狀態(tài)輸出的拼接。
如果我們想要同時(shí)利用前向和后向的輸出,我們可以將它們從中間切割,然后求平均。比如output的shape為(30,24,2*64),我們將其變成(30,24,2,64),然后在dim=2上求平均,得到一個(gè)shape為(30,24,64)的輸出,此時(shí)就與單向LSTM的輸出一致了。
具體處理方法:
output=output.contiguous().view(self.batch_size,seq_len,self.num_directions,self.hidden_size)
output=torch.mean(output,dim=2)
模型代碼:
classBiLSTM(nn.Module):
def__init__(self,input_size,hidden_size,num_layers,output_size,batch_size):
super().__init__()
self.input_size=input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.output_size=output_size
self.num_directions=2
self.batch_size=batch_size
self.lstm=nn.LSTM(self.input_size,self.hidden_size,self.num_layers,batch_first=True,bidirectional=True)
self.linear=nn.Linear(self.hidden_size,self.output_size)
defforward(self,input_seq):
h_0=torch.randn(self.num_directions*self.num_layers,self.batch_size,self.hidden_size).to(device)
c_0=torch.randn(self.num_directions*self.num_layers,self.batch_size,self.hidden_size).to(device)
#print(input_seq.size())
seq_len=input_seq.shape[1]
#input(batch_size,seq_len,input_size)
input_seq=input_seq.view(self.batch_size,seq_len,self.input_size)
#output(batch_size,seq_len,num_directions*hidden_size)
output,_=self.lstm(input_seq,(h_0,c_0))
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。
最新文檔
- 車輛銷售代理合同協(xié)議
- 輪胎銷售賣方合同協(xié)議
- 未獲征地協(xié)議書
- 車位防撞欄施工合同協(xié)議
- 邁克爾杰克遜協(xié)議合同
- 鄭州企業(yè)聘任合同協(xié)議
- 消防設(shè)施安全使用注意事項(xiàng)試題及答案
- 護(hù)理基礎(chǔ)知識(shí)分析試題及答案
- 2025年中級(jí)會(huì)計(jì)試題及答案全方位分析
- 安全管理人員的外語能力測試試題及答案
- 2025年入團(tuán)考試一覽無遺試題及答案
- 公司檔案及文件管理制度
- 2025年四川筠連縣國有資本投資運(yùn)營有限公司招聘筆試參考題庫含答案解析
- 2024年貴州遵義公開招聘社區(qū)工作者考試試題答案解析
- 2025年全國低壓電工證(復(fù)審)考試筆試試題(300題)含答案
- 2025至2030中國注射用重組人腦利鈉肽行業(yè)運(yùn)行態(tài)勢及未來趨勢研究報(bào)告
- 文言常識(shí)測試題及答案
- 入團(tuán)考試測試題及答案
- 中班早期閱讀《跑跑鎮(zhèn)》課件
- 【語文試卷+答案 】上海市崇明區(qū)2025屆高三第二學(xué)期第二次模擬考試(崇明二模)
- Unit 4 第5課時(shí) B learn學(xué)習(xí)任務(wù)單
評(píng)論
0/150
提交評(píng)論