pytorch分類模型繪制混淆矩陣以及可視化詳解_第1頁
pytorch分類模型繪制混淆矩陣以及可視化詳解_第2頁
pytorch分類模型繪制混淆矩陣以及可視化詳解_第3頁
pytorch分類模型繪制混淆矩陣以及可視化詳解_第4頁
全文預(yù)覽已結(jié)束

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡(jiǎn)介

第pytorch分類模型繪制混淆矩陣以及可視化詳解目錄Step1.獲取混淆矩陣Step2.混淆矩陣可視化其它分類指標(biāo)的獲取總結(jié)

Step1.獲取混淆矩陣

#首先定義一個(gè)分類數(shù)*分類數(shù)的空混淆矩陣

conf_matrix=torch.zeros(Emotion_kinds,Emotion_kinds)

#使用torch.no_grad()可以顯著降低測(cè)試用例的GPU占用

withtorch.no_grad():

forstep,(imgs,targets)inenumerate(test_loader):

#imgs:torch.Size([50,3,200,200])torch.FloatTensor

#targets:torch.Size([50,1]),torch.LongTensor多了一維,所以我們要把其去掉

targets=targets.squeeze()#[50,1]-----[50]

#將變量轉(zhuǎn)為gpu

targets=targets.cuda()

imgs=imgs.cuda()

#print(step,imgs.shape,imgs.type(),targets.shape,targets.type())

out=model(imgs)

#記錄混淆矩陣參數(shù)

conf_matrix=confusion_matrix(out,targets,conf_matrix)

conf_matrix=conf_matrix.cpu()

混淆矩陣的求取用到了confusion_matrix函數(shù),其定義如下:

defconfusion_matrix(preds,labels,conf_matrix):

preds=torch.argmax(preds,1)

forp,tinzip(preds,labels):

conf_matrix[p,t]+=1

returnconf_matrix

在當(dāng)我們的程序執(zhí)行結(jié)束test_loader后,我們可以得到本次數(shù)據(jù)的混淆矩陣,接下來就要計(jì)算其識(shí)別正確的個(gè)數(shù)以及混淆矩陣可視化:

conf_matrix=np.array(conf_matrix.cpu())#將混淆矩陣從gpu轉(zhuǎn)到cpu再轉(zhuǎn)到np

corrects=conf_matrix.diagonal(offset=0)#抽取對(duì)角線的每種分類的識(shí)別正確個(gè)數(shù)

per_kinds=conf_matrix.sum(axis=1)#抽取每個(gè)分類數(shù)據(jù)總的測(cè)試條數(shù)

print("混淆矩陣總元素個(gè)數(shù):{0},測(cè)試集總個(gè)數(shù):{1}".format(int(np.sum(conf_matrix)),test_num))

print(conf_matrix)

#獲取每種Emotion的識(shí)別準(zhǔn)確率

print("每種情感總個(gè)數(shù):",per_kinds)

print("每種情感預(yù)測(cè)正確的個(gè)數(shù):",corrects)

print("每種情感的識(shí)別準(zhǔn)確率為:{0}".format([rate*100forrateincorrects/per_kinds]))

執(zhí)行此步的輸出結(jié)果如下所示:

Step2.混淆矩陣可視化

對(duì)上邊求得的混淆矩陣可視化

#繪制混淆矩陣

Emotion=8#這個(gè)數(shù)值是具體的分類數(shù),大家可以自行修改

labels=['neutral','calm','happy','sad','angry','fearful','disgust','surprised']#每種類別的標(biāo)簽

#顯示數(shù)據(jù)

plt.imshow(conf_matrix,cmap=plt.cm.Blues)

#在圖中標(biāo)注數(shù)量/概率信息

thresh=conf_matrix.max()/2#數(shù)值顏色閾值,如果數(shù)值超過這個(gè),就顏色加深。

forxinrange(Emotion_kinds):

foryinrange(Emotion_kinds):

#注意這里的matrix[y,x]不是matrix[x,y]

info=int(conf_matrix[y,x])

plt.text(x,y,info,

verticalalignment='center',

horizontalalignment='center',

color="white"ifinfothreshelse"black")

plt.tight_layout()#保證圖不重疊

plt.yticks(range(Emotion_kinds),labels)

plt.xticks(range(Emotion_kinds),labels,rotation=45)#X軸字體傾斜45°

plt.show()

plt.clo

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。

評(píng)論

0/150

提交評(píng)論