Pytorch從0實現(xiàn)Transformer的實踐_第1頁
Pytorch從0實現(xiàn)Transformer的實踐_第2頁
Pytorch從0實現(xiàn)Transformer的實踐_第3頁
Pytorch從0實現(xiàn)Transformer的實踐_第4頁
Pytorch從0實現(xiàn)Transformer的實踐_第5頁
已閱讀5頁,還剩2頁未讀, 繼續(xù)免費閱讀

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)

文檔簡介

第Pytorch從0實現(xiàn)Transformer的實踐目錄摘要一、構(gòu)造數(shù)據(jù)1.1句子長度1.2生成句子1.3生成字典1.4得到向量化的句子二、位置編碼2.1計算括號內(nèi)的值2.2得到位置編碼三、多頭注意力3.1selfmask

摘要

Withthecontinuousdevelopmentoftimeseriesprediction,Transformer-likemodelshavegraduallyreplacedtraditionalmodelsinthefieldsofCVandNLPbyvirtueoftheirpowerfuladvantages.Amongthem,theInformerisfarsuperiortothetraditionalRNNmodelinlong-termprediction,andtheSwinTransformerissignificantlystrongerthanthetraditionalCNNmodelinimagerecognition.AdeepgraspofTransformerhasbecomeaninevitablerequirementinthefieldofartificialintelligence.ThisarticlewillusethePytorchframeworktoimplementthepositionencoding,multi-headattentionmechanism,self-mask,causalmaskandotherfunctionsinTransformer,andbuildaTransformernetworkfrom0.

隨著時序預(yù)測的不斷發(fā)展,Transformer類模型憑借強大的優(yōu)勢,在CV、NLP領(lǐng)域逐漸取代傳統(tǒng)模型。其中Informer在長時序預(yù)測上遠超傳統(tǒng)的RNN模型,SwinTransformer在圖像識別上明顯強于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領(lǐng)域的必然要求。本文將用Pytorch框架,實現(xiàn)Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網(wǎng)絡(luò)。

一、構(gòu)造數(shù)據(jù)

1.1句子長度

#關(guān)于wordembedding,以序列建模為例

#輸入句子有兩個,第一個長度為2,第二個長度為4

src_len=torch.tensor([2,4]).to(32)

#目標句子有兩個。第一個長度為4,第二個長度為3

tgt_len=torch.tensor([4,3]).to(32)

print(src_len)

print(tgt_len)

輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4

目標句子(tgt_len)有兩個。第一個長度為4,第二個長度為3

1.2生成句子

用隨機數(shù)生成句子,用0填充空白位置,保持所有句子長度一致

src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])

tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])

print(src_seq)

print(tgt_seq)

src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。

為什么句子是數(shù)字?在做中英文翻譯時,每個中文或英文對應(yīng)的也是一個數(shù)字,只有這樣才便于處理。

1.3生成字典

在該字典中,總共有8個字(行),每個字對應(yīng)8維向量(做了簡化了的)。注意在實際應(yīng)用中,應(yīng)當有幾十萬個字,每個字可能有512個維度。

#構(gòu)造wordembedding

src_embedding_table=nn.Embedding(9,model_dim)

tgt_embedding_table=nn.Embedding(9,model_dim)

#輸入單詞的字典

print(src_embedding_table)

#目標單詞的字典

print(tgt_embedding_table)

字典中,需要留一個維度給classtoken,故是9行。

1.4得到向量化的句子

通過字典取出1.2中得到的句子

#得到向量化的句子

src_embedding=src_embedding_table(src_seq)

tgt_embedding=tgt_embedding_table(tgt_seq)

print(src_embedding)

print(tgt_embedding)

該階段總程序

importtorch

#句子長度

src_len=torch.tensor([2,4]).to(32)

tgt_len=torch.tensor([4,3]).to(32)

#構(gòu)造句子,用0填充空白處

src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])

tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])

#構(gòu)造字典

src_embedding_table=nn.Embedding(9,8)

tgt_embedding_table=nn.Embedding(9,8)

#得到向量化的句子

src_embedding=src_embedding_table(src_seq)

tgt_embedding=tgt_embedding_table(tgt_seq)

print(src_embedding)

print(tgt_embedding)

二、位置編碼

位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統(tǒng)RNN的時序信息,增強了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)

2.1計算括號內(nèi)的值

#得到分子pos的值

pos_mat=torch.arange(4).reshape((-1,1))

#得到分母值

i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/8)

print(pos_mat)

print(i_mat)

2.2得到位置編碼

#初始化位置編碼矩陣

pe_embedding_table=torch.zeros(4,8)

#得到偶數(shù)行位置編碼

pe_embedding_table[:,0::2]=torch.sin(pos_mat/i_mat)

#得到奇數(shù)行位置編碼

pe_embedding_table[:,1::2]=torch.cos(pos_mat/i_mat)

pe_embedding=nn.Embedding(4,8)

#設(shè)置位置編碼不可更新參數(shù)

pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False)

print(pe_embedding.weight)

三、多頭注意力

3.1selfmask

有些位置是空白用0填充的,訓(xùn)練時不希望被這些位置所影響,那么就需要用到selfmask。selfmask的原理是令這些位置的值為無窮小,經(jīng)過softmax后,這些值會變?yōu)?,不會再影響結(jié)果。

3.1.1得到有效位置矩陣

#得到有效位置矩陣

vaild_encoder_pos=torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0)forLinsrc_len]),2)

valid_encoder_pos_matrix=torch.bmm(vaild_encoder_pos,vaild_encoder_pos.transpose(1,2))

print(valid_encoder_pos_matrix)

3.1.2得到無效位置矩陣

invalid_encoder_pos_matrix=1-valid_encoder_pos_matrix

mask_encoder_self_attention=invalid_encoder_pos_matrix.to(torch.bool)

print(mask_encoder_self_attention)

True代表需要對該位置mask

3.1.3得到mask矩陣

用極小數(shù)填充需要被mask的位置

#初始化mask矩陣

score=torch.randn(2,max(

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責。
  • 6. 下載文件中如有侵權(quán)或不適當內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論