本发明涉及神经网络,特别是涉及一种网络剪枝方法、装置、终端及计算机可读存储介质。
背景技术:
1、深度神经网络已应用到生活和工业的各个领域,在使用深度神经网络模型时,通常更大的模型可以获得更优的算法指标,但往往需要更高的推理成本,或对推理硬件平台要求较高。为了获得更适于应用部署的模型,通常使用较小模型结构,如果直接训练小模型,算法指标会有较大幅度下降,为了在小模型上获得相对优异的算法指标,目前主流的处理方式之一是基于大参数模型进行剪枝微调,但是当前的网络剪枝方法剪枝后的深度神经网络的性能较差。
技术实现思路
1、本发明主要解决的技术问题是提供一种网络剪枝方法、装置、终端及计算机可读存储介质,解决现有技术中剪枝后的模型性能差的问题。
2、为解决上述技术问题,本发明采用的第一个技术方案是:提供一种网络剪枝方法,网络剪枝方法适用于图像处理、文本处理、语音处理中的任一应用场景,网络剪枝方法包括:
3、获取数据信息和初始神经网络;初始神经网络包括多个处理模块,处理模块具有节点;
4、将初始神经网络中的各处理模块依次作为当前处理模块,将当前处理模块中的各节点分别作为操作节点,断开操作节点与其他处理模块中各节点之间的连接得到操作节点对应的预处理神经网络;
5、将数据信息分别输入初始神经网络和操作节点对应的预处理神经网络,确定操作节点的损失值;
6、基于当前处理模块中各节点的损失值,对当前处理模块中的节点进行剪枝处理得到目标神经网络。
7、其中,网络剪枝方法包括:
8、从初始神经网络的第一个处理模块开始,按从前往后的顺序依次将每个处理模块作为当前处理模块,并对处理模块通过基于贪心策略的通道选择方法进行剪枝,直至最后一个处理模块中的节点剪枝完毕。
9、其中,网络剪枝方法包括:
10、利用训练样本集训练剪枝后的目标神经网络,以优化剪枝后的目标神经网络的参数。
11、其中,断开操作节点与其他处理模块中各节点之间的连接得到操作节点对应的预处理神经网络,包括:
12、将当前处理模块中的操作节点与相邻连接的处理模块的各节点之间的连接参数均置为零,得到操作节点对应的预处理神经网络;预处理神经网络中的当前处理模块之前的处理模块已完成剪枝处理。
13、其中,将数据信息分别输入初始神经网络和操作节点对应的预处理神经网络,确定操作节点的损失值,包括:
14、通过初始神经网络对数据信息进行处理,得到数据信息的第一处理结果;
15、通过预处理神经网络对数据信息进行处理,得到数据信息的第二处理结果;
16、将数据信息对应的第一处理结果和第二处理结果之间的差值作为操作节点的损失值。
17、其中,初始神经网络具有预设剪枝率;
18、基于当前处理模块中各节点的损失值,对当前处理模块中的节点进行剪枝处理得到目标神经网络,包括:
19、将当前处理模块中包含的节点根据损失值从小到大依次进行排序;
20、根据初始神经网络的预设剪枝率,选取排序靠前的预设数量的节点作为待剪枝节点;
21、将当前处理模块中待剪枝节点与其他节点之间的连接断开,并对待剪枝节点进行剪枝处理得到目标神经网络。
22、其中,基于当前处理模块中各节点的损失值,对当前处理模块中的节点进行剪枝处理得到目标神经网络,包括:
23、将当前处理模块中最小损失值对应的节点与其他处理模块中的节点之间的连接断开,并对最小损失值对应的节点进行剪枝处理得到目标神经网络。
24、其中,网络剪枝方法还包括:
25、对初始神经网络预先进行训练;在训练过程中初始神经网络中的各节点连接有dropout层;待初始神经网络训练完成后,将dropout层去除。
26、为解决上述技术问题,本发明采用的第二个技术方案是:提供一种网络剪枝装置,网络剪枝装置适用于图像处理、文本处理、语音处理中的任一应用场景,网络剪枝装置包括:
27、获取模块,用于获取数据信息和初始神经网络;初始神经网络包括多个处理模块,处理模块具有节点;
28、处理模块,用于将初始神经网络中的各处理模块依次作为当前处理模块,将当前处理模块中的各节点分别作为操作节点,断开操作节点与其他处理模块中各节点之间的连接得到操作节点对应的预处理神经网络;
29、计算模块,用于将数据信息分别输入初始神经网络和操作节点对应的预处理神经网络,确定操作节点的损失值;
30、剪枝模块,用于基于当前处理模块中各节点的损失值,对当前处理模块中的节点进行剪枝处理得到目标神经网络。
31、为解决上述技术问题,本发明采用的第三个技术方案是:提供一种终端,终端包括存储器、处理器以及存储于存储器中并在处理器上运行的计算机程序,处理器用于执行程序数据以实现如上述的网络剪枝方法中的步骤。
32、为解决上述技术问题,本发明采用的第四个技术方案是:提供一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,计算机程序被处理器执行时实现如上述的网络剪枝方法中的步骤。
33、本发明的有益效果是:区别于现有技术的情况,提供的一种网络剪枝方法、装置、终端及计算机可读存储介质,网络剪枝方法包括:获取数据信息和初始神经网络;初始神经网络包括多个处理模块,处理模块具有节点;将初始神经网络中的各处理模块依次作为当前处理模块,将当前处理模块中的各节点分别作为操作节点,断开操作节点与其他处理模块中各节点之间的连接得到操作节点对应的预处理神经网络;将数据信息分别输入初始神经网络和操作节点对应的预处理神经网络,确定操作节点的损失值;基于当前处理模块中各节点的损失值,对当前处理模块中的节点进行剪枝处理得到目标神经网络。本申请通过对初始神经网络进行预处理生成操作节点对应的预处理神经网络,再基于预处理神经网络和初始神经网络分别对同一数据信息进行处理,根据两个神经网络对应的处理结果之间的差值确定操作节点对初始神经网络的贡献度,以便于根据操作节点对初始神经网络的贡献度对初始神经网络进行剪枝,减小剪枝处理对目标神经网络性能的影响。
1.一种网络剪枝方法,其特征在于,所述网络剪枝方法适用于图像处理、文本处理、语音处理中的任一应用场景,所述网络剪枝方法包括:
2.根据权利要求1所述的网络剪枝方法,其特征在于,所述网络剪枝方法包括:
3.根据权利要求1所述的网络剪枝方法,其特征在于,所述网络剪枝方法包括:
4.根据权利要求1~3中任一项所述的网络剪枝方法,其特征在于,
5.根据权利要求4所述的网络剪枝方法,其特征在于,
6.根据权利要求1所述的网络剪枝方法,其特征在于,所述初始神经网络具有预设剪枝率;
7.根据权利要求1所述的网络剪枝方法,其特征在于,
8.根据权利要求1所述的网络剪枝方法,其特征在于,所述网络剪枝方法还包括:
9.一种网络剪枝装置,其特征在于,所述网络剪枝装置适用于图像处理、文本处理、语音处理中的任一应用场景,所述网络剪枝装置包括:
10.一种终端,其特征在于,所述终端包括存储器、处理器以及存储于所述存储器中并在所述处理器上运行的计算机程序,所述处理器用于执行程序数据以实现如权利要求1~8任一项所述的网络剪枝方法中的步骤。
11.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1~8任一项所述的网络剪枝方法中的步骤。