本发明涉及机器学习,特别是一种基于改进聚合算法的异构数据联邦学习方法。
背景技术:
1、随着终端设备的大规模普及,如手机、平板电脑、家用电器等设备的广泛联网,网络数据流量呈爆炸式增长。如何从海量数据中提取关键特征成为一个重大挑战。机器学习特别是深度学习能够有效处理和分析大规模数据,为解决大数据特征提取难题带来了曙光。然而,利用大规模数据进行机器学习时面临“数据孤岛”问题,即数据被少数产业巨头垄断,行业内外的数据共享困难。这种现象在银行、政府部门及供应链管理等领域尤为突出。
2、为解决这一问题,mcmahan等人在2016年提出了联邦学习的概念。联邦学习是一种基于分布式数据集的机器学习框架,允许客户端在中央服务器的协调下协作训练模型。模型的相关信息可以在各参与方之间交换,但本地训练数据不会离开本地,从而在保护隐私的同时进行有效的机器学习建模。
3、尽管联邦学习解决了数据隐私问题,但仍面临数据非独立同分布(non-iid)性带来的挑战。由于用户使用习惯和行为偏好不同,各设备上的数据种类和分布差异很大,导致训练数据的非独立同分布性。在这种情况下,本地模型只能很好地拟合本地数据,但在全局数据集合上表现不佳,影响全局模型的精度。因此,缓解非独立同分布数据对联邦学习的影响成为重要议题。
4、针对这一问题,研究者提出了如fedprox和scaffold等解决方案。fedprox通过添加正则化项来解决数据异构性问题,但需要精细调节且在极端情况下效果有限。scaffold通过控制变量减少训练偏差,但其维护成本高且通信开销大。因此,如何在保持低通信开销的情况下解决极端非独立同分布问题仍是一个挑战。
技术实现思路
1、本发明要解决的技术问题是针对上述现有技术的不足,而提供一种基于改进聚合算法的异构数据联邦学习方法。
2、为解决上述技术问题,本发明采用的技术方案是:
3、一种基于改进聚合算法的异构数据联邦学习方法,包括以下步骤:
4、s1、构建联邦学习网络;单个服务器和k个客户端协作构建一个联邦学习网络,服务器和客户端分别拥有本地数据集,服务器初始化一个神经网络模型θglobal作为全局网络模型,并为每一个客户端分配一个唯一标识符uk;
5、s2、选择客户端;服务器随机选择r个客户端,r≤k,并将当前全局网络模型θglobal的参数wgobal下发给被选中的客户端;
6、s3、训练客户端;被选中的客户端uk接收来自服务器的全局网络模型,使用本地私有数据集dk对该全局网络模型θglobal进行训练得到本地网络模型θk,将训练后的本地网络模型θk的参数wk、训练数据集dk的大小nk和客户端标识符uk一并上传给服务器;
7、s4、服务器聚合:服务器收到各客户端的模型参数wk后,加权聚合各个客户端的模型参数,得到新的全局网络模型参数wgobal,据此更新得到新的全局模型θglobal;
8、s5、测试模型的准确率:服务器使用本地数据集测试当前全局模型的准确率,并与阈值进行对比,若准确率小于阈值,则返回s2,继续训练;若准确率大于等于阈值,则学习结束,并将最终的全局模型θglobal广播给参与学习的全部客户端。
9、作为本发明的进一步优选,所述s3包括s31、计算模型训练过程中的损失值;客户端接收到来自服务器的全局网络模型θglobal后,将其作为当前的本地网络模型θk,
10、客户端uk的本地私有数据集dk表示为:
11、
12、其中,xm表示数据集dk中第m条数据的特征向量,ym表示数据集dk中第m条数据对应的真实标签;nk表示数据集dk的长度;
13、将数据集dk的第m条数据的特征向量xm输入到本地模型θk中,计算得到标签预测值表示为:
14、
15、其中,表示客户端uk的第m条数据的预测输出值,σ表示激活函数,wk表示客户端uk的网络模型参数向量,表示客户端uk的第m条数据的特征向量;
16、使用交叉熵函数作为损失函数,计算客户端uk第m条数据的标签预测损失值表示为:
17、
18、其中,表示给定网络模型参数wk和数据样本下的标签预测损失值,表示客户端uk第m条数据的真实标签值,表示客户端uk第m条数据的标签预测值。
19、作为本发明的进一步优选,所述s3还包括s32、更新本地网络模型参数;客户端uk通过批量将训练数据输入到本地模型中,利用梯度下降算法迭代优化本地模型参数wk,本地模型参数的单次更新公式为:
20、
21、其中,η表示学习率,表示损失函数l关于模型参数wk在数据样本上的梯度,n表示每次训练中的数据批量大小。
22、作为本发明的进一步优选,所述s4中采用联邦聚合算法fedpw加权聚合各个客户端的模型参数;所述联邦聚合算法fedpw包括如下步骤:
23、s41、计算所选客户端重合度;
24、s42、计算客户端模型参数差异并归一化;
25、s43、计算各客户端模型参数的聚合权值;
26、s44、加权聚合各客户端的模型参数。
27、作为本发明的进一步优选,所述s41的具体步骤为:
28、计算当前第t个学习轮次所选客户端与第t-1个学习轮次所选客户端的重合度α,表示为:
29、
30、其中,r为当前轮次所选客户端与上一轮所选客户端相同的个数,r≤r,k为参与联邦学习的客户端总数。
31、作为本发明的进一步优选,所述s42的具体步骤为:
32、计算当前轮次参与联邦学习的客户端uk的网络模型参数差异向量δwk为:
33、δwk=wgoball-wk(6)
34、其中,表示当前学习轮次中客户端uk上传给服务器的网络模型参数向量,表示客户端uk的网络模型中输入层第j个神经元与输出层第i个神经元的连接权;表示当前学习轮次服务器本地的网络模型参数向量,表示全局网络模型中输入层第j个神经元与输出层第i个神经元的连接权;
35、根据客户端的网络模型参数差异向量,计算出客户端网络模型中各参数对应的重要性得分表示为:
36、
37、其中,表示客户端uk的网络模型中第j输入、i输出链路的参数差异值,nk表示为第k个客户端的训练数据集的大小。
38、作为本发明的进一步优选,所述s43的具体步骤为:遍历k得到所有客户端uk的网络模型参数重要性得分计算各个客户端网络模型参数聚合权值为:
39、
40、作为本发明的进一步优选,所述s44的具体步骤为:服务器使用客户端网络模型参数聚合权值根据式(9)对各客户端模型参数进行加权,得到聚合后的全局网络模型参数其中表示全局网络模型中输入层第j个神经元与输出层第i个神经元的连接权重,计算为:
41、
42、服务器基于聚合后的网络模型参数wglobal'更新全局神经网络θglobal;r是随机选择的客户端数、e是自然数、α表示重合度。
43、本发明具有如下有益效果:
44、本发明提出一种基于改进聚合算法的异构数据联邦学习方案,适用于在数据非独立同分布情况较为严重的场景下进行联邦学习。该方案建立在传统联邦平均算法(fedavg)的基础上,首先评估参与本轮次学习的客户端相较于上一轮次参与客户端的重合度;然后根据各客户端上传模型的更新差异计算各参数的重要性得分,再通过归一化各客户端之间相同参数位置上的得分得到加权系数;最后根据加权系数和模型重合度对各客户端模型进行加权得到更新后的全局网络模型,在保证模型收敛的同时降低客户端本地的计算复杂度,能够有效缓解训练数据分布不平衡导致的全局模型训练不足和通信开销大的问题。
1.一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:包括以下步骤:
2.根据权利要求1所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s3包括s31、计算模型训练过程中的损失值;客户端接收到来自服务器的全局网络模型θglobal后,将其作为当前的本地网络模型θk,
3.根据权利要求2所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s3还包括s32、更新本地网络模型参数;客户端uk通过批量将训练数据输入到本地模型中,利用梯度下降算法迭代优化本地模型参数wk,本地模型参数的单次更新公式为:
4.根据权利要求1所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s4中采用联邦聚合算法fedpw加权聚合各个客户端的模型参数;所述联邦聚合算法fedpw包括如下步骤:
5.根据权利要求4所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s41的具体步骤为:
6.根据权利要求5所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s42的具体步骤为:
7.根据权利要求6所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s43的具体步骤为:遍历k得到所有客户端uk的网络模型参数重要性得分计算各个客户端网络模型参数聚合权值为:
8.根据权利要求7所述的一种基于改进聚合算法的异构数据联邦学习方法,其特征在于:所述s44的具体步骤为:服务器使用客户端网络模型参数聚合权值根据式(9)对各客户端模型参数进行加权,得到聚合后的全局网络模型参数其中表示全局网络模型中输入层第j个神经元与输出层第i个神经元的连接权重,计算为: