文獻(xiàn)標(biāo)識(shí)碼: A
DOI:10.16157/j.issn.0258-7998.181958
中文引用格式: 馬治楠,韓云杰,彭琳鈺,等. 基于深層卷積神經(jīng)網(wǎng)絡(luò)的剪枝優(yōu)化[J].電子技術(shù)應(yīng)用,2018,44(12):119-122,126.
英文引用格式: Ma Zhinan,Han Yunjie,Peng Linyu,et al. Pruning optimization based on deep convolution neural network[J]. Application of Electronic Technique,2018,44(12):119-122,126.
0 引言
深度學(xué)習(xí)起源于人工神經(jīng)網(wǎng)絡(luò),后來(lái)LECUN Y提出了卷積神經(jīng)網(wǎng)絡(luò)LeNet-5[1],用于手寫(xiě)數(shù)字識(shí)別,并取得了較好的成績(jī),但當(dāng)時(shí)并沒(méi)有引起人們足夠的注意。隨后BP算法被指出梯度消失的問(wèn)題,當(dāng)網(wǎng)絡(luò)反向傳播時(shí),誤差梯度傳遞到前面的網(wǎng)絡(luò)層基本接近于0,導(dǎo)致無(wú)法進(jìn)行有效的學(xué)習(xí)。2006年HINTON G E提出多隱層的網(wǎng)絡(luò)可以通過(guò)逐層預(yù)訓(xùn)練來(lái)克服深層神經(jīng)網(wǎng)絡(luò)在訓(xùn)練上的困難[2],隨后深度學(xué)習(xí)迎來(lái)了高速發(fā)展期。一些新型的網(wǎng)絡(luò)結(jié)構(gòu)不斷被提出(如AlexNet、VGGNet、GoogleNet、ResNet等),網(wǎng)絡(luò)結(jié)構(gòu)不斷被優(yōu)化,性能不斷提升,用于圖像識(shí)別可以達(dá)到很好的效果。然而這些網(wǎng)絡(luò)大都具有更多的網(wǎng)絡(luò)層,對(duì)計(jì)算機(jī)處理圖像的能力要求很高,需要更多的計(jì)算資源,一般使用較好的GPU來(lái)提高訓(xùn)練速度,不利于在硬件資源(內(nèi)存、處理器、存儲(chǔ))較低的設(shè)備運(yùn)行,具有局限性。
深度學(xué)習(xí)發(fā)展到目前階段,其研究大體可以分為兩個(gè)方向:(1)設(shè)計(jì)復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu),提高網(wǎng)絡(luò)的性能;(2)對(duì)網(wǎng)絡(luò)模型進(jìn)行壓縮,減少計(jì)算復(fù)雜度。在本文將討論第二種情況,去除模型中冗余的參數(shù),減少計(jì)算量,提高程序運(yùn)行速度。
目前很多網(wǎng)絡(luò)都具有更復(fù)雜的架構(gòu)設(shè)計(jì),這就造成網(wǎng)絡(luò)模型中存在很多的參數(shù)冗余,增加了計(jì)算復(fù)雜度,造成不必要的計(jì)算資源浪費(fèi)。模型壓縮大體有以下幾個(gè)研究方向:(1)設(shè)計(jì)更為精細(xì)的網(wǎng)絡(luò)結(jié)構(gòu),讓網(wǎng)絡(luò)的性能更為簡(jiǎn)潔高效,如MobileNet網(wǎng)絡(luò)[3];(2)對(duì)模型進(jìn)行裁剪,越是結(jié)構(gòu)復(fù)雜的網(wǎng)絡(luò)越存在大量參數(shù)冗余,因此可以尋找一種有效的評(píng)判方法,對(duì)訓(xùn)練好的模型進(jìn)行裁剪;(3)為了保持?jǐn)?shù)據(jù)的精度,一般常見(jiàn)的網(wǎng)絡(luò)模型的權(quán)重,通常將其保存為32 bit長(zhǎng)度的浮點(diǎn)類(lèi)型,這就大大增加了數(shù)據(jù)的存儲(chǔ)和計(jì)算復(fù)雜度。因此,可以將數(shù)據(jù)進(jìn)行量化,或者對(duì)數(shù)據(jù)二值化,通過(guò)數(shù)據(jù)的量化或二值化從而大大降低數(shù)據(jù)的存儲(chǔ)。除此之外,還可以對(duì)卷積核進(jìn)行核的稀疏化,將卷積核的一部分誘導(dǎo)為0,從而減少計(jì)算量[4]。
本文著重討論第二種方法,對(duì)模型的剪枝,通過(guò)對(duì)無(wú)用權(quán)重參數(shù)的裁剪,減少計(jì)算量。
1 CNN卷積神經(jīng)網(wǎng)絡(luò)
卷積神經(jīng)網(wǎng)絡(luò)是一種前饋式網(wǎng)絡(luò),網(wǎng)絡(luò)結(jié)構(gòu)由卷積層、池化層、全連接層組成[5]。卷積層的作用是從輸入層提取特征圖,給定訓(xùn)練集:
在卷積層后面一般會(huì)加一個(gè)池化層,池化又稱(chēng)為降采樣,池化層可以用來(lái)降低輸入矩陣的緯度,而保存顯著的特征,池化分為最大池化和平均池化,最大池化即給出相鄰矩陣區(qū)域的最大值。池化層具有減小網(wǎng)絡(luò)規(guī)模和參數(shù)冗余的作用。
2 卷積神經(jīng)網(wǎng)絡(luò)剪枝
2.1 模型壓縮的方法
本文用以下方法修剪模型:(1)首先使用遷移學(xué)習(xí)的方法對(duì)網(wǎng)絡(luò)訓(xùn)練,然后對(duì)網(wǎng)絡(luò)進(jìn)行微調(diào),使網(wǎng)絡(luò)收斂并達(dá)到最優(yōu),保存模型;(2)對(duì)保存的模型進(jìn)行修剪,并再次訓(xùn)練,對(duì)修剪后的模型參數(shù)通過(guò)訓(xùn)練進(jìn)行微調(diào),如此反復(fù)進(jìn)行,直到檢測(cè)不到可供裁剪的卷積核;(3)對(duì)上一步裁剪后的模型再次訓(xùn)練,直到訓(xùn)練的次數(shù)達(dá)到設(shè)定的標(biāo)準(zhǔn)為止。具體的流程如圖2所示。
上述的處理流程比較簡(jiǎn)單,重點(diǎn)是如何評(píng)判網(wǎng)絡(luò)模型中神經(jīng)元的重要性。本文用價(jià)值函數(shù)C(W)作為評(píng)判重要性的工具。對(duì)于數(shù)據(jù)集D,經(jīng)訓(xùn)練后得到網(wǎng)絡(luò)模型Model,其中的權(quán)重參數(shù)為:
2.2 參數(shù)評(píng)估
網(wǎng)絡(luò)參數(shù)的評(píng)估在模型壓縮中有著非常重要的作用。一般采用下面的這種方法,通過(guò)比較權(quán)重參數(shù)的l2范數(shù)的大小,刪除l2范數(shù)較小的卷積核[8]。除此之外,還可以通過(guò)激活驗(yàn)證的方法對(duì)參數(shù)進(jìn)行評(píng)判,將數(shù)據(jù)集通過(guò)網(wǎng)絡(luò)前向傳播,對(duì)于某個(gè)網(wǎng)絡(luò)節(jié)點(diǎn),若有大量通過(guò)激活函數(shù)后的數(shù)值為0或者小于一定的閾值,則將其舍去。
2.2.1 最小化l2范數(shù)
3 實(shí)驗(yàn)結(jié)果
3.1 訓(xùn)練和剪枝結(jié)果
本設(shè)計(jì)在Ubuntu16.04系統(tǒng),搭載1080Ti顯卡的高性能服務(wù)器上進(jìn)行實(shí)驗(yàn),使用Pytorch深度學(xué)習(xí)框架進(jìn)行訓(xùn)練和測(cè)試。本設(shè)計(jì)使用VGG16網(wǎng)絡(luò),對(duì)16類(lèi)常見(jiàn)的路面障礙物圖片進(jìn)行訓(xùn)練,其中數(shù)據(jù)集中的訓(xùn)練集有24 000張圖片,訓(xùn)練集12 000張圖片。在VGG16網(wǎng)絡(luò)中有16個(gè)卷積網(wǎng)絡(luò)層,共4 224個(gè)卷積核。采用遷移學(xué)習(xí)的方法對(duì)其進(jìn)行訓(xùn)練,設(shè)置epoch為30,訓(xùn)練的結(jié)果如圖3所示。
圖3縱軸表示訓(xùn)練的準(zhǔn)確率,橫軸表示迭代次數(shù),最后的訓(xùn)練準(zhǔn)確率為97.97%。
將上面的訓(xùn)練參數(shù)保存為模型,對(duì)其進(jìn)行剪枝,分5次對(duì)其修剪,首先會(huì)根據(jù)l2范數(shù)最小值篩選出要修剪的網(wǎng)絡(luò)層中的卷積核,每次去除512個(gè)卷積核,修剪后模型中剩余的卷積核數(shù)量如圖4所示。
圖4中縱軸表示模型中保留的卷積核的數(shù)量,從最初的4 224降到1 664,裁剪率達(dá)到60.6%。5次迭代修剪后的準(zhǔn)確率如圖5所示。
對(duì)修剪后的網(wǎng)絡(luò)重新訓(xùn)練得到最終的修剪模型,訓(xùn)練過(guò)程如圖6所示。
最后達(dá)到98.7%的準(zhǔn)確率。剪枝前模型大小為512 MB,剪枝后模型可以縮小到162 MB,將模型的內(nèi)存占用降低了68.35%。
3.2 嵌入式平臺(tái)下的移植測(cè)試
在嵌入式平臺(tái)樹(shù)莓派3代B型上移植Pytorch框架,樹(shù)莓派3b擁有1.2 GHz的四核BCM2837 64位ARM A53處理器,1 GB運(yùn)行內(nèi)存,板載BCM43143WiFi。由于樹(shù)莓派運(yùn)行內(nèi)存有限,故增加2 GB的swap虛擬內(nèi)存,用于編譯Pytorch框架源碼。將在GPU服務(wù)器上訓(xùn)練好的網(wǎng)絡(luò)模型移植到嵌入式平臺(tái),對(duì)其進(jìn)行測(cè)試。對(duì)123張測(cè)試圖片進(jìn)行檢測(cè)分類(lèi),載入裁剪前的原始模型,用時(shí)109.47 s,準(zhǔn)確率為95.08%。載入剪枝后的模型,同樣對(duì)123張圖片進(jìn)行測(cè)試,用時(shí)41.85 s,準(zhǔn)確率達(dá)到96.72%。結(jié)果如圖7所示,可以看到對(duì)模型裁剪后時(shí)間上減少了61%,速度有了很大提升。
4 結(jié)論
目前深度學(xué)習(xí)是一個(gè)熱門(mén)的研究方向,在圖像檢測(cè)、分類(lèi)、語(yǔ)音識(shí)別等方面取得了前所未有的成功,但這些依賴(lài)于高性能高配置的計(jì)算機(jī),也出現(xiàn)了各種深度學(xué)習(xí)框架以及網(wǎng)絡(luò)模型,但是可以預(yù)見(jiàn)深度學(xué)習(xí)即將邁入一個(gè)發(fā)展平緩期,如果不能有一個(gè)寬闊的應(yīng)用領(lǐng)域,深度學(xué)習(xí)的發(fā)展將很快被擱淺。誠(chéng)然,將其應(yīng)用于嵌入式平臺(tái)將會(huì)是一個(gè)非常好的發(fā)展方向。相信未來(lái)深度學(xué)習(xí)在嵌入式領(lǐng)域會(huì)有一個(gè)更大的突破,部署于移動(dòng)平臺(tái)將不再是一個(gè)難題。
參考文獻(xiàn)
[1] LECUN Y,BOTTOU L,BENGIO Y,et al.Gradient-based learning applied to document recognition[C].Proceedings of the IEEE,1998,86(11):2278-2324.
[2] HINTON G E,SALAKHUTDINOV R R.Reducing the dimensionality of data with neural networks[J].Science,2006,313(5786):504-507.
[3] HOWARD A G,ZHU M,CHEN B,et al.MobileNets:efficient convolutional neural networks for mobile vision applications[Z].arXiv preprint arXiv:1704.04861,2017.
[4] HAN S,MAO H,DALLY W J.Deep compression:compressing deep neural networks with pruning, trained quantization and Huffman coding[J].Fiber,2015,56(4):3-7.
[5] 周飛燕,金林鵬,董軍.卷積神經(jīng)網(wǎng)絡(luò)研究綜述[J].計(jì)算機(jī)學(xué)報(bào),2017,40(6):1229-1251.
[6] ANWAR S,HWANG K,SUNG W,et al.Structured pruning of deep convolutional neural networks[J].JETC,2017,13(3):1-18.
[7] AYINDE B O,ZURADA J M.Building efficient ConvNets using redundant feature pruning[Z].arXiv preprint arXiv:1802.07653,2018.
[8] LI H,KADAV A,DURDANOVIC I,et al.Pruning filters for efficient ConvNets[C].ICLR 2017,2017.
作者信息:
馬治楠1,韓云杰2,彭琳鈺1,周進(jìn)凡1,林付春1,劉宇紅1
(1.貴州大學(xué) 大數(shù)據(jù)與信息工程學(xué)院,貴州 貴陽(yáng)550025;2.貴陽(yáng)信息技術(shù)研究院,貴州 貴陽(yáng)550081)