




版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)
文檔簡介
第Pytorch從0實現(xiàn)Transformer的實踐目錄摘要一、構(gòu)造數(shù)據(jù)1.1句子長度1.2生成句子1.3生成字典1.4得到向量化的句子二、位置編碼2.1計算括號內(nèi)的值2.2得到位置編碼三、多頭注意力3.1selfmask
摘要
Withthecontinuousdevelopmentoftimeseriesprediction,Transformer-likemodelshavegraduallyreplacedtraditionalmodelsinthefieldsofCVandNLPbyvirtueoftheirpowerfuladvantages.Amongthem,theInformerisfarsuperiortothetraditionalRNNmodelinlong-termprediction,andtheSwinTransformerissignificantlystrongerthanthetraditionalCNNmodelinimagerecognition.AdeepgraspofTransformerhasbecomeaninevitablerequirementinthefieldofartificialintelligence.ThisarticlewillusethePytorchframeworktoimplementthepositionencoding,multi-headattentionmechanism,self-mask,causalmaskandotherfunctionsinTransformer,andbuildaTransformernetworkfrom0.
隨著時序預(yù)測的不斷發(fā)展,Transformer類模型憑借強大的優(yōu)勢,在CV、NLP領(lǐng)域逐漸取代傳統(tǒng)模型。其中Informer在長時序預(yù)測上遠超傳統(tǒng)的RNN模型,SwinTransformer在圖像識別上明顯強于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領(lǐng)域的必然要求。本文將用Pytorch框架,實現(xiàn)Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網(wǎng)絡(luò)。
一、構(gòu)造數(shù)據(jù)
1.1句子長度
#關(guān)于wordembedding,以序列建模為例
#輸入句子有兩個,第一個長度為2,第二個長度為4
src_len=torch.tensor([2,4]).to(32)
#目標句子有兩個。第一個長度為4,第二個長度為3
tgt_len=torch.tensor([4,3]).to(32)
print(src_len)
print(tgt_len)
輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標句子(tgt_len)有兩個。第一個長度為4,第二個長度為3
1.2生成句子
用隨機數(shù)生成句子,用0填充空白位置,保持所有句子長度一致
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
print(src_seq)
print(tgt_seq)
src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數(shù)字?在做中英文翻譯時,每個中文或英文對應(yīng)的也是一個數(shù)字,只有這樣才便于處理。
1.3生成字典
在該字典中,總共有8個字(行),每個字對應(yīng)8維向量(做了簡化了的)。注意在實際應(yīng)用中,應(yīng)當有幾十萬個字,每個字可能有512個維度。
#構(gòu)造wordembedding
src_embedding_table=nn.Embedding(9,model_dim)
tgt_embedding_table=nn.Embedding(9,model_dim)
#輸入單詞的字典
print(src_embedding_table)
#目標單詞的字典
print(tgt_embedding_table)
字典中,需要留一個維度給classtoken,故是9行。
1.4得到向量化的句子
通過字典取出1.2中得到的句子
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
該階段總程序
importtorch
#句子長度
src_len=torch.tensor([2,4]).to(32)
tgt_len=torch.tensor([4,3]).to(32)
#構(gòu)造句子,用0填充空白處
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
#構(gòu)造字典
src_embedding_table=nn.Embedding(9,8)
tgt_embedding_table=nn.Embedding(9,8)
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統(tǒng)RNN的時序信息,增強了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)
2.1計算括號內(nèi)的值
#得到分子pos的值
pos_mat=torch.arange(4).reshape((-1,1))
#得到分母值
i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/8)
print(pos_mat)
print(i_mat)
2.2得到位置編碼
#初始化位置編碼矩陣
pe_embedding_table=torch.zeros(4,8)
#得到偶數(shù)行位置編碼
pe_embedding_table[:,0::2]=torch.sin(pos_mat/i_mat)
#得到奇數(shù)行位置編碼
pe_embedding_table[:,1::2]=torch.cos(pos_mat/i_mat)
pe_embedding=nn.Embedding(4,8)
#設(shè)置位置編碼不可更新參數(shù)
pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)
三、多頭注意力
3.1selfmask
有些位置是空白用0填充的,訓(xùn)練時不希望被這些位置所影響,那么就需要用到selfmask。selfmask的原理是令這些位置的值為無窮小,經(jīng)過softmax后,這些值會變?yōu)?,不會再影響結(jié)果。
3.1.1得到有效位置矩陣
#得到有效位置矩陣
vaild_encoder_pos=torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0)forLinsrc_len]),2)
valid_encoder_pos_matrix=torch.bmm(vaild_encoder_pos,vaild_encoder_pos.transpose(1,2))
print(valid_encoder_pos_matrix)
3.1.2得到無效位置矩陣
invalid_encoder_pos_matrix=1-valid_encoder_pos_matrix
mask_encoder_self_attention=invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)
True代表需要對該位置mask
3.1.3得到mask矩陣
用極小數(shù)填充需要被mask的位置
#初始化mask矩陣
score=torch.randn(2,max(
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責。
- 6. 下載文件中如有侵權(quán)或不適當內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 醫(yī)療科技助力精準醫(yī)療的突破與挑戰(zhàn)
- 企業(yè)健康保險與醫(yī)療保險的結(jié)合管理
- 2025年互聯(lián)網(wǎng)個人工作總結(jié)模版
- 醫(yī)療技術(shù)轉(zhuǎn)移與商業(yè)化中的知識產(chǎn)權(quán)挑戰(zhàn)與對策
- 醫(yī)療器械行業(yè)中的項目管理挑戰(zhàn)與機遇
- 嘉善物業(yè)公司今冬明春火災(zāi)防控工作總結(jié)模版
- AI技術(shù)在商業(yè)決策分析中的應(yīng)用價值
- 絲綢加工合同范例
- 公司電腦轉(zhuǎn)讓合同范例
- 倉庫保潔合同范例
- 2025直播帶貨主播簽約合作合同(范本)
- 人事檔案管理系統(tǒng)驗收報告文檔
- 《刑事訴訟法學(xué)教學(xué)》課件
- 2025年高考物理復(fù)習(xí)之小題狂練600題(解答題):機械波(10題)
- 首都經(jīng)濟貿(mào)易大學(xué)《中級微觀經(jīng)濟學(xué)》2023-2024學(xué)年第一學(xué)期期末試卷
- 2018年高考英語全國一卷(精校+答案+聽力原文)
- 工程決算書(結(jié)算書)模板
- 零星工程維修 投標方案(技術(shù)方案)
- 統(tǒng)編版 高中語文 必修下冊 第六單元《促織》
- 2024年房屋代持協(xié)議書范本
- 2024廚房改造合同范本
評論
0/150
提交評論