上一講薈薈講了一下張量切片、逐元素運(yùn)算、廣播,這一講再說說張量點(diǎn)積和張量變形,集齊這五種裝備,基本可以開始愉快地玩耍了。 張量點(diǎn)積前面我們說過逐元素運(yùn)算,要將兩個(gè)張量進(jìn)行逐元素相乘,我們使用“*“運(yùn)算符,比如將兩個(gè)矩陣A和B進(jìn)行逐元素相乘 而在神經(jīng)網(wǎng)絡(luò)中,經(jīng)常需要類似圖1的乘法方式,我們將輸入向量和權(quán)值向量的對(duì)應(yīng)元素進(jìn)行相乘后再求和,這種方式我們稱之為張量點(diǎn)積。 圖1 如圖2所示,如果有多個(gè)神經(jīng)元需要輸出,則可以把輸入X寫成一個(gè)列向量,它的轉(zhuǎn)置為 輸入權(quán)值矩陣可寫為 該矩陣的每一列,代表輸入到中間層某一個(gè)神經(jīng)元的權(quán)值向量,如下圖中輸入到z2的權(quán)值正是上述矩陣的第二列。從圖中可以看出,我們得到的中間層神經(jīng)元的輸出(這里為了簡單起見省略了非線性激活函數(shù),如ReLU)也是一個(gè)列向量,如下: 其中的第n個(gè)元素正是輸入的每個(gè)元素和W的第n列對(duì)應(yīng)元素一一相乘并相加的結(jié)果。如果向更高維度擴(kuò)展,輸入不是一個(gè)向量,而是一個(gè)矩陣,比如我們將100個(gè)輸入樣本向量輸入剛才討論的神經(jīng)網(wǎng)絡(luò),那么XT的形狀就是(100,3),ZT的形狀就是(100,4)。 圖2 看到這里,學(xué)過線性代數(shù)的童鞋可能已經(jīng)看出,這不就是我們學(xué)過的矩陣乘法嘛,沒錯(cuò),這里說的張量點(diǎn)積,如果是2d張量(矩陣),就是我們最常見的矩陣乘法,語法如下。 只不過它還可以繼續(xù)拓展,比如要在神經(jīng)網(wǎng)絡(luò)里處理視頻數(shù)據(jù),形狀通常是(幀,長,寬,顏色),如果需要進(jìn)行點(diǎn)積(對(duì)應(yīng)元素相乘相加),這時(shí)候如果有一個(gè)底層優(yōu)化過的張量點(diǎn)積運(yùn)算,我們就可以保持其他維度不變,直接用一次dot把所有幀和顏色的張量都進(jìn)行相同的點(diǎn)積運(yùn)算,而按照傳統(tǒng)方法我們則需要進(jìn)行多重循環(huán)。 在Numpy中,如果A為M維張量,B為N為張量,兩個(gè)張量的點(diǎn)積就是將A張量的最后一個(gè)軸中的所有元素,與B張量中倒數(shù)第二個(gè)軸的所有元素對(duì)應(yīng)相乘后相加的結(jié)果,也就是dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])。有點(diǎn)暈菜對(duì)不對(duì),下面用矩陣來解釋一下,如圖3所示,矩陣x的第一個(gè)軸是行,第二個(gè)軸是列,矩陣一共就兩個(gè)軸,所以x的最后一個(gè)軸就是列,同理y的倒數(shù)第二個(gè)軸是行??梢娋仃噚和y的點(diǎn)積就是將x第i行中的所有列(x的最后一個(gè)軸)和y第j列中的所有行(y的倒數(shù)第二個(gè)軸)對(duì)應(yīng)相乘并相加的結(jié)果返回給z(i,j)。從這個(gè)例子中也可以發(fā)現(xiàn),這就要求x的列數(shù)必須等于y的行數(shù),也就是A張量最后一個(gè)軸的元素?cái)?shù)量必須等于B張量倒數(shù)第二個(gè)軸中的元素?cái)?shù)量。 圖3 那么張量點(diǎn)積后的形狀和輸入張量形狀的關(guān)系又是什么呢,以一個(gè)例子來說明一下,如果A的形狀為(a,b,c,d),B的形狀為(d,e),那么numpy.dot(A,B)的形狀為(a,b,c,e)。如果A的形狀為(a,b,c,d), B的形狀為(c,d),那么就呵呵了,因?yàn)锳的最后一個(gè)軸元素?cái)?shù)量為d,而B的倒數(shù)第二個(gè)軸元素?cái)?shù)量為c,不匹配,會(huì)報(bào)錯(cuò)。 下面給出了一個(gè)矩陣乘法的實(shí)際案例,小伙伴們?nèi)プ屑?xì)研究一下吧。 此外,Numpy還提供了一種更為靈活的張量點(diǎn)乘函數(shù)numpy.tensordot(a,b,axes),可以指定需要和并的軸,但是不太容易理解,這里就不寫出來混淆試聽了,有興趣的朋友們可以看看Chenxiao Ma的博客,寫的還挺清楚的https://www./blog/tensordot 對(duì)于初學(xué)者來說知道numpy.dot就是求矩陣乘法也基本夠用了。 張量變形
張量變形其實(shí)咱們?cè)诘诙v中已經(jīng)見過,當(dāng)時(shí)我們需要把一個(gè)形狀為(60000,28,28)的手寫字符圖像張量輸入一個(gè)中間層為784個(gè)神經(jīng)元的密集連接型神經(jīng)網(wǎng)絡(luò)。所以需要將這個(gè)張量變形為(60000,28*28)的形狀,才好和后面的全連接神經(jīng)網(wǎng)絡(luò)進(jìn)行加權(quán)點(diǎn)乘。當(dāng)時(shí)我們的預(yù)處理代碼是: train_images =train_images.reshape((60000, 28 * 28)) reshape函數(shù)很簡單,輸入?yún)?shù)就是你想要的新形狀,只要新張量中的元素個(gè)數(shù)和原始張量中的元素個(gè)數(shù)是一樣的即可。元素的安排順序默認(rèn)和c語言是一致的,即依次進(jìn)行排列。下面兩個(gè)例子可以讓你清楚地理解這一操作。
另一種常見的張量變形就是轉(zhuǎn)置了,即把矩陣的行換成列,列換成行,也就是x[i,:]變?yōu)閤[:,i],語法為np.transpose(x)。
|
|