CN118313445A - A federated incremental learning method and system based on constrained gradient updating - Google Patents
A federated incremental learning method and system based on constrained gradient updating Download PDFInfo
- Publication number
- CN118313445A CN118313445A CN202410529143.5A CN202410529143A CN118313445A CN 118313445 A CN118313445 A CN 118313445A CN 202410529143 A CN202410529143 A CN 202410529143A CN 118313445 A CN118313445 A CN 118313445A
- Authority
- CN
- China
- Prior art keywords
- task
- training
- model parameters
- previous
- client
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 72
- 238000012549 training Methods 0.000 claims abstract description 271
- 239000013598 vector Substances 0.000 claims abstract description 89
- 239000011159 matrix material Substances 0.000 claims description 25
- 230000002776 aggregation Effects 0.000 claims description 18
- 238000004220 aggregation Methods 0.000 claims description 18
- 238000005457 optimization Methods 0.000 claims description 14
- 230000008569 process Effects 0.000 claims description 12
- 238000000354 decomposition reaction Methods 0.000 claims description 8
- 238000012935 Averaging Methods 0.000 claims description 7
- 230000015654 memory Effects 0.000 description 26
- 238000003860 storage Methods 0.000 description 14
- 230000006870 function Effects 0.000 description 8
- 101001121408 Homo sapiens L-amino-acid oxidase Proteins 0.000 description 7
- 102100026388 L-amino-acid oxidase Human genes 0.000 description 7
- 230000004931 aggregating effect Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 6
- 101100233916 Saccharomyces cerevisiae (strain ATCC 204508 / S288c) KAR5 gene Proteins 0.000 description 3
- 238000004821 distillation Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 230000002452 interceptive effect Effects 0.000 description 3
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 2
- 101000827703 Homo sapiens Polyphosphoinositide phosphatase Proteins 0.000 description 2
- 102100023591 Polyphosphoinositide phosphatase Human genes 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000013500 data storage Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000002360 preparation method Methods 0.000 description 2
- 230000001052 transient effect Effects 0.000 description 2
- 101100012902 Saccharomyces cerevisiae (strain ATCC 204508 / S288c) FIG2 gene Proteins 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/243—Classification techniques relating to the number of classes
- G06F18/2431—Multiple classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0499—Feedforward networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Probability & Statistics with Applications (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
Description
技术领域Technical Field
本发明涉及联邦学习技术领域,具体涉及一种基于受约束梯度更新的联邦类增量学习方法及系统。The present invention relates to the field of federated learning technology, and in particular to a federated incremental learning method and system based on constrained gradient updating.
背景技术Background technique
联邦学习已发展为一种著名的分布式学习范式。经典的联邦学习场景假设客户端在静态分布的本地数据上进行模型的训练和测试,即每个客户端的本地数据的任务和数量是预先设定且保持不变。实际上,客户端本地数据的任务分布会随着时间而动态变化,每个客户端都会不断产生新的任务的数据,因此全局模型需要不断更新知识以适应新的任务,这就是联邦类增量学习。Federated learning has developed into a well-known distributed learning paradigm. The classic federated learning scenario assumes that the client trains and tests the model on statically distributed local data, that is, the tasks and quantity of each client's local data are pre-set and remain unchanged. In fact, the task distribution of the client's local data will change dynamically over time, and each client will continuously generate new task data, so the global model needs to continuously update knowledge to adapt to new tasks. This is federated incremental learning.
在联邦学习中,客户端都是存储容量有限的轻量级移动设备,如果不断保存新数据会导致内存危机。但如果客户端仅保存新出现的数据并只在新任务的数据上训练模型就会导致全局模型无法很好区分新的任务和早期任务,导致早期任务上的灾难性遗忘问题。因此,联邦类增量学习中对早期任务的灾难性遗忘问题亟待解决。In federated learning, clients are lightweight mobile devices with limited storage capacity. If new data is continuously saved, it will lead to a memory crisis. However, if the client only saves the newly appeared data and trains the model only on the data of the new task, the global model will not be able to distinguish between the new task and the earlier task well, resulting in catastrophic forgetting of the earlier task. Therefore, the catastrophic forgetting of the earlier tasks in federated incremental learning needs to be solved urgently.
目前,通常是采用蒸馏学习方法,通过约束历史任务的软标签使之接近当前任务模型的输出来保留历史任务的知识。上述技术方案存在的问题是,当客户端的任务数量多或任务偏移量较大时,历史任务的软标签会发生错误降级,使得当前任务学习过程中的噪声增加,蒸馏效果差,引发对早期任务的灾难性遗忘问题。At present, the distillation learning method is usually used to retain the knowledge of historical tasks by constraining the soft labels of historical tasks to be close to the output of the current task model. The problem with the above technical solution is that when the number of tasks on the client is large or the task offset is large, the soft labels of historical tasks will be incorrectly degraded, which will increase the noise in the current task learning process, poor distillation effect, and cause catastrophic forgetting of early tasks.
发明内容Summary of the invention
有鉴于此,本发明提供了一种基于受约束梯度更新的联邦类增量学习方法及系统,以解决客户端不断产生新的任务的数据时,引发的对早期任务的灾难性遗忘的问题。In view of this, the present invention provides a federated incremental learning method and system based on constrained gradient updating to solve the problem of catastrophic forgetting of early tasks caused by the client continuously generating data for new tasks.
第一方面,本发明提供了一种基于受约束梯度更新的联邦类增量学习方法,应用于客户端,该方法包括:In a first aspect, the present invention provides a federated incremental learning method based on constrained gradient updating, which is applied to a client, and the method comprises:
接收服务器发送的上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务;Receive the personalized global model and cross-task collaborative loss of the previous task sent by the server. The personalized global model is obtained by the server personalizing and aggregating multiple model parameters and multiple class average prototypes obtained by all clients based on their local data training of the previous task. The cross-task collaborative loss indicates the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks indicate all tasks whose training sequence precedes the previous task.
当获取到新的任务的本地数据时,基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数;When the local data of the new task is obtained, based on the basis vectors of the gradient space of the previous task, the predicted label and true label of the new task, and the cross-task collaborative loss, the model parameters of the previous task are updated along the direction orthogonal to its gradient space to obtain the model parameters of the new task in this round of training;
对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型;Take a weighted average of the class prototypes of the new task to obtain the class average prototype of the new task in this round of training;
重复上述模型参数和类平均原型的更新,进行预设轮次训练,在每一轮训练完成后,向服务器发送新的任务的模型参数和类平均原型。Repeat the above updating of model parameters and class average prototypes, perform preset rounds of training, and after each round of training is completed, send the model parameters and class average prototypes of the new task to the server.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,由于上一任务训练得到的个性化全局模型能够实现对所有客户端的上一任务以及上一任务的历史任务进行分类,上一任务得到的跨任务协作损失能够提高个性化全局模型的性能,因此在获取到新的任务的本地数据时,可以基于服务器发送的上一任务的个性化全局模型和跨任务协作损失,从梯度空间的角度出发,将梯度更新方向限制在与上一任务的梯度空间正交的方向上,进行模型参数的更新,以保护先前学习到的知识。然后,客户端将每轮训练得到的模型参数和类平均原型发送给服务器,将新的任务学习到的知识用于下一轮训练。通过对上一任务的个性化全局模型和跨任务协作损失接收并更新,并且重复发送每轮训练的更新结果,使得在客户端不断产生新的任务的数据时,能够避免对早期任务的灾难性遗忘,同时避免客户端间的模型漂移问题。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient update. Since the personalized global model obtained by the previous task training can classify the previous tasks of all clients and the historical tasks of the previous tasks, and the cross-task collaborative loss obtained by the previous task can improve the performance of the personalized global model, when the local data of the new task is obtained, the personalized global model and cross-task collaborative loss of the previous task sent by the server can be used to limit the gradient update direction to the direction orthogonal to the gradient space of the previous task from the perspective of the gradient space, and the model parameters are updated to protect the previously learned knowledge. Then, the client sends the model parameters and class average prototype obtained in each round of training to the server, and uses the knowledge learned by the new task for the next round of training. By receiving and updating the personalized global model and cross-task collaborative loss of the previous task, and repeatedly sending the updated results of each round of training, when the client continuously generates data for new tasks, catastrophic forgetting of early tasks can be avoided, and the model drift problem between clients can be avoided.
在一种可选的实施方式中,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数之前,该方法还包括:In an optional implementation, before updating the model parameters of the previous task along a direction orthogonal to its gradient space to obtain the model parameters of the new task in the current round of training, the method further includes:
将新的任务的本地数据转化为高维向量;Convert local data of new tasks into high-dimensional vectors;
基于上一任务的模型参数和新的任务的高维向量,确定新的任务的预测标签。Based on the model parameters of the previous task and the high-dimensional vector of the new task, the predicted label of the new task is determined.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,通过确定新的任务的预测标签,能够基于新的任务的真实标签和预测标签之间的差异,为更新新的任务的模型参数做准备。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient updating. By determining the predicted label of the new task, it is possible to prepare for updating the model parameters of the new task based on the difference between the real label and the predicted label of the new task.
在一种可选的实施方式中,该方法还包括:In an optional embodiment, the method further includes:
在新的任务进行预设轮次训练后,基于新的任务在最后一轮训练中得到的梯度和历史任务基向量,得到新的任务的表示矩阵,历史任务基向量包含新的任务的所有历史任务的基向量;After the new task is trained for a preset number of rounds, the representation matrix of the new task is obtained based on the gradient of the new task obtained in the last round of training and the historical task basis vectors. The historical task basis vectors contain the basis vectors of all historical tasks of the new task.
对表示矩阵进行奇异值分解,得到新的任务的基向量,形成新的任务的正交梯度空间;Perform singular value decomposition on the representation matrix to obtain the basis vectors of the new task and form the orthogonal gradient space of the new task;
将新的任务的基向量与历史任务基向量进行拼接,得到包含当前所有任务的基向量,由包含当前所有任务的基向量形成新的正交梯度空间,用于下一任务的训练。The basis vector of the new task is concatenated with the basis vector of the historical tasks to obtain the basis vector containing all current tasks. The basis vector containing all current tasks forms a new orthogonal gradient space for training the next task.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,由于历史任务基向量是由上一任务的基向量和上一任务的所有历史任务的基向量拼接得到,从而基于该历史任务基向量和新的任务的基向量得到的包含当前所有任务的基向量,能够在学习新的任务的知识的同时不干扰早期任务的知识,从而避免对早期任务的灾难性遗忘。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient updating. Since the historical task basis vector is obtained by concatenating the basis vector of the previous task and the basis vectors of all historical tasks of the previous task, the basis vector containing all current tasks is obtained based on the historical task basis vector and the basis vector of the new task. This can learn the knowledge of new tasks without interfering with the knowledge of early tasks, thereby avoiding catastrophic forgetting of early tasks.
在一种可选的实施方式中,接收服务器发送的上一任务的个性化全局模型和跨任务协作损失之前,该方法还包括:In an optional implementation, before receiving the personalized global model of the previous task and the cross-task collaborative loss sent by the server, the method further includes:
对于每一客户端,接收服务器发送的初始模型;For each client, receive the initial model sent by the server;
基于客户端的初始任务的本地数据和初始模型,进行初始任务的第一轮训练,得到初始模型参数和初始预测标签;Based on the local data and initial model of the initial task of the client, the first round of training of the initial task is performed to obtain the initial model parameters and initial prediction labels;
基于客户端的初始任务和初始任务的类原型,得到客户端的初始任务在第一轮训练中的初始类平均原型;Based on the initial task of the client and the class prototype of the initial task, the initial class average prototype of the initial task of the client in the first round of training is obtained;
基于初始模型的预测标签、初始模型的真实标签和初始模型参数,确定初始任务对应的初始梯度;Determine the initial gradient corresponding to the initial task based on the predicted label of the initial model, the true label of the initial model, and the initial model parameters;
基于初始梯度,得到初始任务的梯度空间的正交初始基向量;Based on the initial gradient, the orthogonal initial basis vector of the gradient space of the initial task is obtained;
向服务器发送初始任务在第一轮训练中得到的初始模型参数和初始类平均原型。The initial model parameters and initial class average prototypes obtained in the first round of training for the initial task are sent to the server.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,通过接收并训练服务器发送的初始模型,并将训练得到的初始模型参数和初始类平均原型发送给服务器,使得在有新的任务出现时能够基于该初始模型进行更新,从而避免对早期任务的灾难性遗忘。An embodiment of the present invention provides a federated class incremental learning method based on constrained gradient updating, which receives and trains an initial model sent by a server, and sends the trained initial model parameters and initial class average prototype to the server, so that when a new task appears, it can be updated based on the initial model, thereby avoiding catastrophic forgetting of early tasks.
在一种可选的实施方式中,该方法还包括:In an optional embodiment, the method further includes:
在新的任务每一轮训练结束后,基于新的任务的上一轮训练的跨任务协作损失、新的任务的预测标签和真实标签,更新优化目标,优化目标用于指示客户端在训练过程中得到的训练损失的期望值;After each round of training for a new task, the optimization target is updated based on the cross-task collaborative loss of the previous round of training for the new task, the predicted label and the true label of the new task. The optimization target is used to indicate the expected value of the training loss obtained by the client during the training process;
基于更新后的优化目标,进行新的任务的下一轮训练。Based on the updated optimization objective, the next round of training for the new task is carried out.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,通过在优化目标中设计跨任务协作损失这一项,将与上一任务有益的知识整合进新的任务的训练过程中,以减轻客户端之间的模型漂移问题。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient update, which integrates the knowledge beneficial to the previous task into the training process of the new task by designing a cross-task collaborative loss in the optimization objective, so as to alleviate the model drift problem between clients.
第二方面,本发明提供了一种基于受约束梯度更新的联邦类增量学习方法,应用于服务器,该方法包括:In a second aspect, the present invention provides a federated incremental learning method based on constrained gradient updating, which is applied to a server, and the method comprises:
向每一客户端发送上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务;Send the personalized global model and cross-task collaborative loss of the previous task to each client. The personalized global model is obtained by the server personalizing and aggregating multiple model parameters and multiple class average prototypes obtained by all clients based on their local data training of the previous task. The cross-task collaborative loss indicates the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks refer to all tasks that are trained before the previous task.
接收客户端在获取到新的任务数据后,在每轮训练中对模型参数和类平均原型进行更新,并在完成每轮训练后发送的更新后的模型参数和类平均原型,模型参数基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新得到,类平均原型基于对新的任务的类原型进行加权平均得到。After obtaining new task data, the receiving client updates the model parameters and class average prototype in each round of training, and sends the updated model parameters and class average prototype after completing each round of training. The model parameters are based on the basis vectors of the gradient space of the previous task, the predicted labels and true labels of the new task, and the cross-task collaborative loss. The model parameters of the previous task are updated in a direction orthogonal to its gradient space. The class average prototype is obtained based on the weighted average of the class prototypes of the new task.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,通过向客户端发送上一任务的个性化全局模型和跨任务协作损失,使得客户端在新的任务下能够基于上一任务的个性化全局模型和跨任务协作损失进行训练,并且在客户端的每轮训练后接收客户端更新后的模型参数和类平均原型,以得到新的任务的个性化全局模型和跨任务协作损失,从而实现当客户端不断更新任务时,能够在保存且不干扰早期任务的基础上学习新的任务,从而避免对早期任务的灾难性遗忘。An embodiment of the present invention provides a federated class incremental learning method based on constrained gradient updating. By sending a personalized global model and cross-task collaborative loss of a previous task to a client, the client can perform training based on the personalized global model and cross-task collaborative loss of the previous task under a new task, and after each round of training of the client, the client receives the updated model parameters and class average prototype to obtain the personalized global model and cross-task collaborative loss of the new task, so that when the client continuously updates the task, it can learn the new task on the basis of preserving and not interfering with the early task, thereby avoiding catastrophic forgetting of the early task.
在一种可选的实施方式中,向每一客户端发送上一任务的个性化全局模型和跨任务协作损失之前,该方法还包括:In an optional implementation, before sending the personalized global model of the previous task and the cross-task collaborative loss to each client, the method further includes:
对于每一客户端,在客户端对上一任务进行预设轮次训练后,接收客户端的上一任务的模型参数和类平均原型,上一任务的模型参数和类平均原型通过上一任务在最后一轮训练得到;For each client, after the client performs a preset round of training on the previous task, receiving the model parameters and the class average prototype of the previous task from the client, where the model parameters and the class average prototype of the previous task are obtained through the last round of training on the previous task;
确定客户端与其他客户端之间的多个第一距离,第一距离表示客户端的上一任务的类平均原型与其他客户端的上一任务的类平均原型之间的相似程度;Determine a plurality of first distances between the client and other clients, the first distance representing the similarity between the class average prototype of the previous task of the client and the class average prototype of the previous task of other clients;
基于多个第一距离和上一任务的模型参数,进行个性化聚合,得到客户端的上一任务的个性化全局模型;Based on the multiple first distances and the model parameters of the previous task, personalized aggregation is performed to obtain a personalized global model of the previous task of the client;
从每个历史任务中,确定与客户端的第二距离大于距离阈值的多个跨任务客户端,第二距离表示客户端的上一任务的类平均原型与任一跨任务客户端的目标类平均原型之间的相似程度;From each historical task, determine a plurality of cross-task clients whose second distance to the client is greater than a distance threshold, where the second distance represents the degree of similarity between the class average prototype of the client's previous task and the target class average prototype of any cross-task client;
基于多个跨任务客户端的目标模型参数、多个跨任务客户端对应的第二距离以及上一任务的模型参数,确定上一任务的跨任务协作损失。A cross-task collaboration loss of the previous task is determined based on target model parameters of the multiple cross-task clients, second distances corresponding to the multiple cross-task clients, and a model parameter of the previous task.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,通过跨任务协作损失能够基于其他客户端的历史任务中对个性化全局模型有益的知识,提高个性化全局模型的性能,从而减轻客户端之间的模型漂移问题。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient updating. Through cross-task collaborative loss, it can improve the performance of a personalized global model based on knowledge that is beneficial to the personalized global model in historical tasks of other clients, thereby alleviating the problem of model drift between clients.
第三方面,本发明提供了一种基于受约束梯度更新的联邦类增量学习系统,该系统包括客户端和服务器,执行以下步骤:In a third aspect, the present invention provides a federated incremental learning system based on constrained gradient updating, the system comprising a client and a server, and performing the following steps:
服务器向每一客户端发送上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务;The server sends the personalized global model and cross-task collaborative loss of the previous task to each client. The personalized global model is obtained by the server personalizing and aggregating multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss indicates the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks refer to all tasks that were trained before the previous task.
客户端接收服务器发送的上一任务的个性化全局模型和跨任务协作损失;The client receives the personalized global model and cross-task collaborative loss of the previous task sent by the server;
当获取到新的任务的本地数据时,客户端基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数;When the local data of a new task is obtained, the client updates the model parameters of the previous task in a direction orthogonal to its gradient space based on the basis vectors of the gradient space of the previous task, the predicted label and true label of the new task, and the cross-task collaborative loss, and obtains the model parameters of the new task in this round of training;
客户端对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型;The client performs weighted averaging on the class prototypes of the new task to obtain the class average prototype of the new task in this round of training;
客户端重复上述模型参数和类平均原型的更新,进行预设轮次训练,在每一轮训练完成后,向服务器发送新的任务的模型参数和类平均原型;The client repeats the above model parameter and class average prototype update, performs preset rounds of training, and sends the new task model parameters and class average prototype to the server after each round of training.
服务器接收客户端在获取到新的任务数据后,在每轮训练中对模型参数和类平均原型进行更新,并在完成每轮训练后发送的更新后的模型参数和类平均原型。After receiving new task data, the server updates the model parameters and class average prototypes in each round of training, and sends the updated model parameters and class average prototypes after completing each round of training.
第四方面,本发明提供了一种基于受约束梯度更新的联邦类增量学习装置,应用于客户端,该装置包括:In a fourth aspect, the present invention provides a federated incremental learning device based on constrained gradient updating, which is applied to a client, and the device includes:
第一接收模块,用于接收服务器发送的上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务;The first receiving module is used to receive the personalized global model and cross-task collaborative loss of the previous task sent by the server. The personalized global model is obtained by the server performing personalized aggregation on multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss represents the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks represent all tasks whose training sequence is prior to the previous task.
更新模块,用于当获取到新的任务的本地数据时,基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数;The update module is used to update the model parameters of the previous task in a direction orthogonal to its gradient space based on the basis vectors of the gradient space of the previous task, the predicted label and true label of the new task, and the cross-task collaborative loss when the local data of the new task is obtained, so as to obtain the model parameters of the new task in this round of training;
获取模块,用于对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型;The acquisition module is used to perform weighted averaging on the class prototypes of the new task to obtain the class average prototype of the new task in this round of training;
第一训练模块,用于重复上述模型参数和类平均原型的更新,进行预设轮次训练,在每一轮训练完成后,向服务器发送新的任务的模型参数和类平均原型。The first training module is used to repeat the updating of the above model parameters and class average prototypes, perform preset rounds of training, and send the model parameters and class average prototypes of new tasks to the server after each round of training is completed.
第五方面,本发明提供了一种基于受约束梯度更新的联邦类增量学习装置,应用于服务器,该装置包括:In a fifth aspect, the present invention provides a federated incremental learning device based on constrained gradient updating, which is applied to a server, and the device includes:
发送模块,用于向每一客户端发送上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务;A sending module is used to send the personalized global model and cross-task collaborative loss of the previous task to each client. The personalized global model is obtained by the server performing personalized aggregation on multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss represents the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks represent all tasks whose training sequence is prior to the previous task.
第一接收模块,用于接收客户端在获取到新的任务数据后,在每轮训练中对模型参数和类平均原型进行更新,并在完成每轮训练后发送的更新后的模型参数和类平均原型,模型参数基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,类平均原型基于对新的任务的类原型进行加权平均得到。The first receiving module is used to receive the model parameters and class average prototypes updated by the client in each round of training after acquiring new task data, and send the updated model parameters and class average prototypes after completing each round of training. The model parameters are based on the basis vectors of the gradient space of the previous task, the predicted labels and true labels of the new task, and the cross-task collaborative loss. The model parameters of the previous task are updated in a direction orthogonal to its gradient space, and the class average prototype is obtained based on the weighted average of the class prototypes of the new task.
第六方面,本发明提供了一种计算机设备,包括:存储器和处理器,存储器和处理器之间互相通信连接,存储器中存储有计算机指令,处理器通过执行计算机指令,从而执行上述第一方面或其对应的任一实施方式的一种基于受约束梯度更新的联邦类增量学习方法。In a sixth aspect, the present invention provides a computer device, comprising: a memory and a processor, the memory and the processor being communicatively connected to each other, the memory storing computer instructions, and the processor executing the computer instructions to execute a federated incremental learning method based on constrained gradient update according to the first aspect or any corresponding embodiment thereof.
第七方面,本发明提供了一种计算机可读存储介质,该计算机可读存储介质上存储有计算机指令,计算机指令用于使计算机执行上述第一方面或其对应的任一实施方式的一种基于受约束梯度更新的联邦类增量学习方法。In the seventh aspect, the present invention provides a computer-readable storage medium having computer instructions stored thereon, the computer instructions being used to enable a computer to execute a federated incremental learning method based on constrained gradient updating according to the first aspect or any corresponding embodiment thereof.
附图说明BRIEF DESCRIPTION OF THE DRAWINGS
为了更清楚地说明本发明具体实施方式或现有技术中的技术方案,下面将对具体实施方式或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。In order to more clearly illustrate the specific implementation methods of the present invention or the technical solutions in the prior art, the drawings required for use in the specific implementation methods or the description of the prior art will be briefly introduced below. Obviously, the drawings described below are some implementation methods of the present invention. For ordinary technicians in this field, other drawings can be obtained based on these drawings without paying creative work.
图1是根据本发明实施例的应用于客户端的一种基于受约束梯度更新的联邦类增量学习方法的流程图;FIG1 is a flow chart of a federated incremental learning method based on constrained gradient updating applied to a client according to an embodiment of the present invention;
图2是根据本发明实施例的受约束梯度更新的示意图;FIG2 is a schematic diagram of a constrained gradient update according to an embodiment of the present invention;
图3是根据本发明实施例的应用于服务器的一种基于受约束梯度更新的联邦类增量学习方法的流程图;3 is a flow chart of a federated incremental learning method based on constrained gradient updating applied to a server according to an embodiment of the present invention;
图4是根据本发明实施例的一种基于受约束梯度更新的联邦类增量学习系统中客户端和服务器的交互流程图;4 is a flowchart of the interaction between a client and a server in a federated incremental learning system based on constrained gradient updating according to an embodiment of the present invention;
图5是根据本发明实施例的一种基于受约束梯度更新的联邦类增量学习方法的流程示意图;FIG5 is a flow chart of a federated incremental learning method based on constrained gradient updating according to an embodiment of the present invention;
图6是根据本发明实施例的应用于客户端的一种基于受约束梯度更新的联邦类增量学习装置的结构框图;6 is a structural block diagram of a federated incremental learning device based on constrained gradient updating applied to a client according to an embodiment of the present invention;
图7是根据本发明实施例的应用于服务器的一种基于受约束梯度更新的联邦类增量学习装置的结构框图;7 is a structural block diagram of a federated incremental learning device based on constrained gradient updating applied to a server according to an embodiment of the present invention;
图8是本发明实施例的计算机设备的硬件结构示意图。FIG. 8 is a schematic diagram of the hardware structure of a computer device according to an embodiment of the present invention.
具体实施方式Detailed ways
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。In order to make the purpose, technical solution and advantages of the embodiments of the present invention clearer, the technical solution in the embodiments of the present invention will be clearly and completely described below in conjunction with the drawings in the embodiments of the present invention. Obviously, the described embodiments are part of the embodiments of the present invention, not all of the embodiments. Based on the embodiments of the present invention, all other embodiments obtained by those skilled in the art without creative work are within the scope of protection of the present invention.
本发明实施例提供了一种基于受约束梯度更新的联邦类增量学习方法,应用于客户端和服务器,解决了相关技术中通过蒸馏学习进行联邦类增量学习,引发对早期任务的灾难性遗忘的问题,通过限制梯度更新方向和增加客户端之间的跨任务协作损失,能够在不干扰早期任务的知识的基础上学习新的任务的知识,从而避免对早期任务的灾难性遗忘,同时能够避免由于客户端之间差异引发的模型漂移问题。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient updating, which is applied to clients and servers. It solves the problem in related technologies that federated incremental learning through distillation learning causes catastrophic forgetting of early tasks. By limiting the gradient update direction and increasing the cross-task collaboration loss between clients, it is possible to learn knowledge of new tasks without interfering with the knowledge of early tasks, thereby avoiding catastrophic forgetting of early tasks and avoiding the problem of model drift caused by differences between clients.
根据本发明实施例,提供了一种基于受约束梯度更新的联邦类增量学习方法实施例,需要说明的是,在附图的流程图示出的步骤可以在诸如一组计算机可执行指令的计算机系统中执行,并且,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。According to an embodiment of the present invention, an embodiment of a federated incremental learning method based on constrained gradient updating is provided. It should be noted that the steps shown in the flowchart of the accompanying drawings can be executed in a computer system such as a set of computer executable instructions, and although a logical order is shown in the flowchart, in some cases, the steps shown or described may be executed in an order different from that shown here.
在本实施例中提供了一种基于受约束梯度更新的联邦类增量学习方法,应用于客户端,图1是根据本发明实施例的应用于客户端的一种基于受约束梯度更新的联邦类增量学习方法的流程图,如图1所示,该流程包括如下步骤:In this embodiment, a federated incremental learning method based on constrained gradient updating is provided, which is applied to a client. FIG1 is a flow chart of a federated incremental learning method based on constrained gradient updating applied to a client according to an embodiment of the present invention. As shown in FIG1 , the process includes the following steps:
步骤S101,接收服务器发送的上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务。具体地,该个性化全局模型为神经网络模型,例如LeNet-5、AlexNet(深度卷积神经网络)、ResNet(残差神经网络)、MobileNet或MLP(Multilayer Perceptron,多层感知机)等,本发明实施例对此不进行限制。以ALexNet为例,其结构包括一个输入层,五个卷积层,每个卷积层后跟一个最大池化层和局部响应归一化层,卷积层之后跟三个全连接层,卷积层和全连接层采用ReLU激活函数,最后是一个Softmax层用于输出分类结果,实现分类任务。本发明以轻量级移动设备为客户端,特定应用服务提供商为服务器为例进行说明,例如服务器提供各种面向图像、视频或文本等数据的分类任务。客户端作为数据收集方,会持续产生新的任务的数据,但同时受制于有限的内容,无法保存所有任务的数据,因此客户端需要持续在新增任务的数据上训练模型,同时还需要确保训练的模型能够在已被删除或不可访问的早期任务的数据上表现出良好的分类能力。如果采用传统的训练方式在新增任务的数据上训练模型就会很容易造成服务器的全局模型在早期任务的数据的分类准确率的快速下降,导致对早期任务的灾难性遗忘问题。本发明通过接收服务器发送的上一任务的个性化全局模型和跨任务协作损失,在上一任务的个性化全局模型的基础上进行更新,使得能够在对新的任务进行分类的同时,也能够对早期任务进行准确分类,避免对早期任务的灾难性遗忘问题。并且,由于跨任务协作损失是基于所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度得到的,能够避免由客户端之间差异带来的模型漂移问题。Step S101, receiving the personalized global model and cross-task collaborative loss of the previous task sent by the server, the personalized global model is obtained by the server personalizing and aggregating multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task, the cross-task collaborative loss represents the similarity between the model parameters of the previous task of all clients and the model parameters of the historical task, and the historical task represents all tasks whose training sequence precedes the previous task. Specifically, the personalized global model is a neural network model, such as LeNet-5, AlexNet (deep convolutional neural network), ResNet (residual neural network), MobileNet or MLP (Multilayer Perceptron), etc., and the embodiment of the present invention does not limit this. Taking ALexNet as an example, its structure includes an input layer, five convolutional layers, each convolutional layer is followed by a maximum pooling layer and a local response normalization layer, the convolutional layer is followed by three fully connected layers, the convolutional layer and the fully connected layer use ReLU activation function, and finally a Softmax layer is used to output the classification result to achieve the classification task. The present invention is described by taking a lightweight mobile device as a client and a specific application service provider as a server as an example, for example, the server provides various classification tasks for data such as images, videos or texts. As a data collector, the client will continuously generate data for new tasks, but at the same time, it is limited by the limited content and cannot save the data of all tasks. Therefore, the client needs to continuously train the model on the data of the newly added tasks, and at the same time, it is also necessary to ensure that the trained model can show good classification ability on the data of the earlier tasks that have been deleted or inaccessible. If the model is trained on the data of the newly added tasks using the traditional training method, it will easily cause the classification accuracy of the server's global model on the data of the earlier tasks to drop rapidly, resulting in the catastrophic forgetting problem of the earlier tasks. The present invention receives the personalized global model and cross-task collaborative loss of the previous task sent by the server, and updates it on the basis of the personalized global model of the previous task, so that the new tasks can be classified while the earlier tasks can also be accurately classified, avoiding the catastrophic forgetting problem of the earlier tasks. In addition, since the cross-task collaborative loss is obtained based on the similarity between the model parameters of the previous tasks of all clients and the model parameters of the historical tasks, the model drift problem caused by the differences between clients can be avoided.
步骤S102,当获取到新的任务的本地数据时,基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数。具体地,当客户端有新增任务时,也即获取到新的任务的本地数据,可以通过下述公式(1)对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务的梯度空间的梯度。Step S102, when the local data of the new task is obtained, based on the basis vectors of the gradient space of the previous task, the predicted label and the true label of the new task, and the cross-task collaborative loss, the model parameters of the previous task are updated along the direction orthogonal to its gradient space, and the model parameters of the new task in this round of training are obtained. Specifically, when the client has a new task, that is, the local data of the new task is obtained, the model parameters of the previous task can be updated along the direction orthogonal to its gradient space by the following formula (1) to obtain the gradient of the gradient space of the new task.
其中,t表示上一任务;t+1表示新的任务;▽表示梯度;θt+1表示新的任务的模型参数;k表示第k个客户端;表示第k个客户端在新的任务中的真实标签,表示第k个客户端在新的任务中的预测标签;表示新的任务的真实标签和预测标签之间的交叉熵损失;Ut表示上一任务的梯度空间的基向量。Where t represents the previous task; t+1 represents the new task; ▽ represents the gradient; θ t+1 represents the model parameters of the new task; k represents the kth client; represents the true label of the kth client in the new task, represents the predicted label of the kth client in the new task; Represents the cross entropy loss between the true label and the predicted label of the new task; U t represents the basis vector of the gradient space of the previous task.
可选的,可以通过下述公式(2)确定新的任务的真实标签和预测标签之间的交叉熵损失。Optionally, the cross entropy loss between the true label and the predicted label of the new task can be determined by the following formula (2).
在一些可选的实施方式中,可以基于上一任务的跨任务协作损失和新的任务的真实标签和预测标签之间的交叉熵损失,通过误差和反向传播算法,对上一任务的个性化全局模型中的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数,使得更新后的模型参数能够学习到新的任务的知识。In some optional embodiments, based on the cross-task collaboration loss of the previous task and the cross entropy loss between the true label and the predicted label of the new task, the model parameters in the personalized global model of the previous task can be updated in a direction orthogonal to its gradient space through the error and back propagation algorithm to obtain the model parameters of the new task in this round of training, so that the updated model parameters can learn the knowledge of the new task.
在一些可选的实施方式中,图2是根据本发明实施例的受约束梯度更新的示意图,如图2所示,表示上一任务的模型参数,表示新的任务的模型参数。当客户端每训练完一个任务之后,可以认为训练得到的个性化全局模型的梯度中保留了已学习过的所有历史任务的知识,当前任务和历史任务的梯度构成一个梯度空间。当客户端获取到新的任务的本地数据时,需要更新梯度。由于上一任务的梯度空间是由上一任务的梯度和上一任务的所有历史任务的梯度共同构成的,如果新的任务的梯度更新方向不在上一任务的梯度空间的正交方向,则会造成对上一任务的知识的扰动,即更新后的梯度空间遗忘了早期任务的知识。因此,通过将梯度的更新方向约束在上一任务的梯度空间的正交方向上来实现对历史任务的知识的扰动最小化,从而根据更新后的梯度来对模型参数进行更新,防止遗忘历史任务的知识,减轻早期遗忘问题。In some optional implementations, FIG. 2 is a schematic diagram of a constrained gradient update according to an embodiment of the present invention. As shown in FIG. 2 , represents the model parameters of the previous task, Represents the model parameters of the new task. After the client completes training of each task, it can be considered that the gradient of the personalized global model obtained through training retains the knowledge of all historical tasks that have been learned, and the gradients of the current task and the historical tasks constitute a gradient space. When the client obtains the local data of a new task, the gradient needs to be updated. Since the gradient space of the previous task is composed of the gradient of the previous task and the gradients of all historical tasks of the previous task, if the gradient update direction of the new task is not in the orthogonal direction of the gradient space of the previous task, it will cause disturbance to the knowledge of the previous task, that is, the updated gradient space forgets the knowledge of the earlier tasks. Therefore, by constraining the update direction of the gradient to the orthogonal direction of the gradient space of the previous task, the disturbance to the knowledge of the historical tasks is minimized, so that the model parameters are updated according to the updated gradient to prevent the knowledge of the historical tasks from being forgotten and alleviate the problem of early forgetting.
步骤S103,对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型。具体地,该类原型为新的任务的本地数据通过特征提取器输出的中间结果,也即新的任务的本地数据在经过最后一个全连接层的计算之后得到的结果。通过获取类平均原型,为后续确定新的任务的个性化全局模型做准备。Step S103, weighted average the class prototype of the new task to obtain the class average prototype of the new task in this round of training. Specifically, the class prototype is the intermediate result of the local data of the new task output by the feature extractor, that is, the result obtained after the local data of the new task is calculated by the last fully connected layer. By obtaining the class average prototype, preparation is made for the subsequent determination of the personalized global model of the new task.
可选的,可以通过下述公式(3)确定新的任务在本轮训练中的类平均原型。Optionally, the class average prototype of the new task in this round of training can be determined by the following formula (3).
其中,t表示上一任务;t+1表示新的任务;k表示第k个客户端,表示第k个客户端在新的任务中的类平均原型;N表示新的任务的本地数据的数量;e表示新的任务在本轮训练中的类原型;i表示新的任务的本地数据。Among them, t represents the previous task; t+1 represents the new task; k represents the kth client, represents the average class prototype of the kth client in the new task; N represents the number of local data of the new task; e represents the class prototype of the new task in this round of training; i represents the local data of the new task.
步骤S104,重复上述模型参数和类平均原型的更新,进行预设轮次训练,在每一轮训练完成后,向服务器发送新的任务的模型参数和类平均原型。具体地,每一客户端均有多个任务,每一任务需要进行预设轮次训练。步骤S101至步骤S104表示除了初始任务外的任一任务,该任务的任一轮训练的执行过程。在该轮训练中,接收该轮训练的上一轮训练得到的个性化全局模型和跨任务协作损失,然后基于上一轮训练得到的个性化全局模型和跨任务协作损失进行本轮训练,得到本轮训练的模型参数和类平均原型,并发送给服务器。其中,本发明对任一任务的训练轮次不进行限制。通过对新的任务进行预设轮次训练并将训练结果发送给服务器,保障了对新的任务的本地数据的学习效果。Step S104, repeat the update of the above model parameters and class average prototype, perform preset rounds of training, and after each round of training is completed, send the model parameters and class average prototype of the new task to the server. Specifically, each client has multiple tasks, and each task needs to be trained for a preset round. Steps S101 to S104 represent the execution process of any round of training for any task except the initial task. In this round of training, the personalized global model and cross-task collaborative loss obtained in the previous round of training are received, and then the current round of training is performed based on the personalized global model and cross-task collaborative loss obtained in the previous round of training, and the model parameters and class average prototype of this round of training are obtained and sent to the server. Among them, the present invention does not limit the training rounds of any task. By performing preset rounds of training on the new task and sending the training results to the server, the learning effect of the local data of the new task is guaranteed.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,由于上一任务训练得到的个性化全局模型能够实现对所有客户端的上一任务以及上一任务的历史任务进行分类,上一任务得到的跨任务协作损失能够提高个性化全局模型的性能,因此在获取到新的任务的本地数据时,可以基于服务器发送的上一任务的个性化全局模型和跨任务协作损失,从梯度空间的角度出发,将梯度更新方向限制在与上一任务的梯度空间正交的方向上,进行模型参数的更新,以保护先前学习到的知识。然后,客户端将每轮训练得到的模型参数和类平均原型发送给服务器,将新的任务学习到的知识用于下一轮训练。通过对上一任务的个性化全局模型和跨任务协作损失接收并更新,并且重复发送每轮训练的更新结果,使得在客户端不断产生新的任务的数据时,能够避免对早期任务的灾难性遗忘,同时避免客户端间的模型漂移问题。An embodiment of the present invention provides a federated incremental learning method based on constrained gradient update. Since the personalized global model obtained by the previous task training can classify the previous tasks of all clients and the historical tasks of the previous tasks, and the cross-task collaborative loss obtained by the previous task can improve the performance of the personalized global model, when the local data of the new task is obtained, the personalized global model and cross-task collaborative loss of the previous task sent by the server can be used to limit the gradient update direction to the direction orthogonal to the gradient space of the previous task from the perspective of the gradient space, and the model parameters are updated to protect the previously learned knowledge. Then, the client sends the model parameters and class average prototype obtained in each round of training to the server, and uses the knowledge learned by the new task for the next round of training. By receiving and updating the personalized global model and cross-task collaborative loss of the previous task, and repeatedly sending the updated results of each round of training, when the client continuously generates data for new tasks, catastrophic forgetting of early tasks can be avoided, and the model drift problem between clients can be avoided.
在本实施例中提供了一种基于受约束梯度更新的联邦类增量学习方法,应用于服务器,图3是根据本发明实施例的应用于服务器的一种基于受约束梯度更新的联邦类增量学习方法的流程图,如图3所示,该流程包括如下步骤:In this embodiment, a federated incremental learning method based on constrained gradient updating is provided, which is applied to a server. FIG3 is a flow chart of a federated incremental learning method based on constrained gradient updating applied to a server according to an embodiment of the present invention. As shown in FIG3 , the process includes the following steps:
步骤S301,向每一客户端发送上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务。具体地,对于上一任务,所有客户端均会执行,因此在上一任务的每一轮训练中,每一客户端均能接收到在该轮训练的上一轮训练中对应的个性化全局模型,从而能够基于上一轮训练得到的个性化全局模型进行下一轮训练。由于服务器基于所有客户端在上一任务中训练得到的多个模型参数和多个类平均原型,得到上一任务的个性化全局模型。因此,该个性化全局模型对所有客户端的上一任务具备分类能力。由于服务器基于所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度得到的跨任务协作损失,因此,该跨任务协作损失能够减小客户端之间的差异,以避免模型漂移问题。可选的,对于每一任务的第一轮的训练,服务器可以向客户端发送该任务的上一任务在最后一轮训练中得到的个性化全局模型和跨任务协作损失。而对于每一任务的第二轮训练直至预设轮次训练结束,服务器可以向客户端发送当前轮次的上一轮训练得到的个性化全局模型,也可以向客户端发送当前轮次的上一轮训练得到的模型参数,本发明实施例对此不进行限制。Step S301, sending the personalized global model and cross-task collaborative loss of the previous task to each client. The personalized global model is obtained by the server performing personalized aggregation on multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss indicates the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks indicate all tasks whose training sequence precedes the previous task. Specifically, all clients will execute the previous task. Therefore, in each round of training of the previous task, each client can receive the personalized global model corresponding to the previous round of training of the training, so that the next round of training can be performed based on the personalized global model obtained in the previous round of training. Since the server obtains the personalized global model of the previous task based on the multiple model parameters and multiple class average prototypes obtained by all clients in the previous task. Therefore, the personalized global model has the classification capability for the previous task of all clients. Since the server obtains the cross-task collaborative loss based on the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks, the cross-task collaborative loss can reduce the differences between clients to avoid the model drift problem. Optionally, for the first round of training of each task, the server may send to the client the personalized global model and cross-task collaborative loss obtained in the last round of training of the previous task of the task. For the second round of training of each task until the end of the preset round of training, the server may send to the client the personalized global model obtained in the previous round of training of the current round, or may send to the client the model parameters obtained in the previous round of training of the current round, which is not limited in the embodiment of the present invention.
步骤S302,接收客户端在获取到新的任务数据后,在每轮训练中对模型参数和类平均原型进行更新,并在完成每轮训练后发送的更新后的模型参数和类平均原型,模型参数基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的梯度方向进行更新得到,类平均原型基于对新的任务的类原型进行加权平均得到。具体地,由于存在多个客户端,并且每个客户端均有多个任务。对于每个任务需要进行预设轮次训练,客户端会发送每一轮训练得到的模型参数和类平均模型,服务器会接收所有客户端在该任务该轮次训练中得到的所有模型参数和所有类平均原型,为下一轮训练做准备。Step S302, after receiving the new task data, the client updates the model parameters and the class average prototype in each round of training, and sends the updated model parameters and class average prototype after completing each round of training. The model parameters are based on the basis vectors of the gradient space of the previous task, the predicted labels and true labels of the new task, and the cross-task collaborative loss. The model parameters of the previous task are updated along the gradient direction orthogonal to its gradient space, and the class average prototype is obtained based on the weighted average of the class prototype of the new task. Specifically, since there are multiple clients, and each client has multiple tasks. For each task, a preset round of training is required. The client will send the model parameters and class average model obtained in each round of training. The server will receive all model parameters and all class average prototypes obtained by all clients in this round of training for this task, and prepare for the next round of training.
在一些可选的实施方式中,图5是根据本发明实施例的一种基于受约束梯度更新的联邦类增量学习方法的流程示意图,如图5所示,以除了初始任务外任一任务的任一轮训练为例进行说明,表示第一个客户端在该任务的本轮训练的本地数据,表示第一个客户端在该任务的本轮训练的模型参数,表示第一个客户端在该任务的本轮训练的类平均原型,D表示可选距离度量函数,pt表示跨任务协作损失,表示该任务的本轮训练的个性化全局模型,k表示第k个客户端。首先,所有客户端对该任务执行本地训练,得到本轮训练的模型参数和类平均原型,并发送给服务器。然后,服务器通过类平均原型计算历史任务和当前任务之间的距离,确定跨任务协作损失。然后,服务器对本轮训练的模型参数进行个性化聚合得到本轮训练的个性化全局模型。然后,服务器将本轮训练得到的跨任务协作损失和个性化全局模型发送给客户端,进行下一轮训练。In some optional implementations, FIG5 is a flow chart of a federated incremental learning method based on constrained gradient updating according to an embodiment of the present invention. As shown in FIG5 , any round of training of any task except the initial task is taken as an example for explanation. Represents the local data of the first client in this round of training for this task. Indicates the model parameters of the first client in this round of training for this task. represents the average prototype of the class of the first client in this round of training for this task, D represents an optional distance metric function, p t represents the cross-task collaboration loss, represents the personalized global model of this round of training for this task, and k represents the kth client. First, all clients perform local training on this task, obtain the model parameters and class average prototype of this round of training, and send them to the server. Then, the server calculates the distance between the historical task and the current task through the class average prototype to determine the cross-task collaboration loss. Then, the server performs personalized aggregation on the model parameters of this round of training to obtain the personalized global model of this round of training. Then, the server sends the cross-task collaboration loss and personalized global model obtained in this round of training to the client for the next round of training.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,通过向客户端发送上一任务的个性化全局模型和跨任务协作损失,使得客户端在新的任务下能够基于上一任务的个性化全局模型和跨任务协作损失进行训练,并且在客户端的每轮训练后接收客户端更新后的模型参数和类平均原型,以得到新的任务的个性化全局模型和跨任务协作损失,从而实现当客户端不断更新任务时,能够在保存早期任务的基础上学习新的任务,从而避免对早期任务的灾难性遗忘。An embodiment of the present invention provides a federated class incremental learning method based on constrained gradient updating. By sending a personalized global model and cross-task collaborative loss of a previous task to a client, the client can perform training based on the personalized global model and cross-task collaborative loss of the previous task under a new task, and after each round of training of the client, the client receives the updated model parameters and class average prototype to obtain the personalized global model and cross-task collaborative loss of the new task, so that when the client continuously updates the task, it can learn the new task on the basis of saving the early task, thereby avoiding catastrophic forgetting of the early task.
在本实施例中提供了一种基于受约束梯度更新的联邦类增量学习系统,该系统包括客户端和服务器,图4是根据本发明实施例的一种基于受约束梯度更新的联邦类增量学习系统中客户端和服务器的交互流程图,如图4所示,服务器向每一客户端发送上一任务的个性化全局模型和跨任务协作损失。然后,客户端接收上一任务的个性化全局模型和跨任务协作损失,在有新增任务时,进行新增任务的训练。客户端基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的梯度方向进行更新,得到新的任务在本轮训练中的模型参数。然后,客户端对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型。然后,客户端重复上述模型参数和类平均原型的更新,在每轮训练完成后都向服务器发送更新后的模型参数和类平均原型。然后,服务器接收客户端在每轮训练完成后的更新后的模型参数和类平均原型。In this embodiment, a federated class incremental learning system based on constrained gradient update is provided, and the system includes a client and a server. FIG4 is an interaction flow chart of a client and a server in a federated class incremental learning system based on constrained gradient update according to an embodiment of the present invention. As shown in FIG4, the server sends a personalized global model and cross-task collaborative loss of the previous task to each client. Then, the client receives the personalized global model and cross-task collaborative loss of the previous task, and performs training on the newly added task when there is a new task. Based on the basis vectors of the gradient space of the previous task, the predicted label and the true label of the new task, and the cross-task collaborative loss, the client updates the model parameters of the previous task along the gradient direction orthogonal to its gradient space, and obtains the model parameters of the new task in this round of training. Then, the client performs weighted averaging on the class prototype of the new task to obtain the class average prototype of the new task in this round of training. Then, the client repeats the update of the above model parameters and class average prototype, and sends the updated model parameters and class average prototype to the server after each round of training. Then, the server receives the updated model parameters and class average prototype of the client after each round of training.
在一些可选的实施方式中,该基于受约束梯度更新的联邦类增量学习系统包括客户端和服务器,该系统的详细流程包括如下步骤:In some optional implementations, the federated incremental learning system based on constrained gradient update includes a client and a server, and the detailed process of the system includes the following steps:
步骤S401,对于每一客户端,接收服务器发送的初始模型。具体地,该初始模型为随机初始化的模型参数,可以用θinit表示。Step S401: For each client, receive the initial model sent by the server. Specifically, the initial model is a model parameter randomly initialized and can be represented by θ init .
步骤S402,客户端基于客户端的初始任务的本地数据和初始模型,进行初始任务的第一轮训练,得到初始模型参数和初始预测标签。具体地,所有客户端初始化各自的本地数据。假设每一客户端有T个任务并按顺序出现,第t个任务可以表示为其中,表示第t个任务的本地数据集;k表示第k个客户端;i表示第i条本地数据;t表示第t个任务;N表示每个任务的本地数据数量;T表示客户端的任务的总数量。其中,初始任务的本地数据可以表示为然后,初始任务的第一轮训练的过程如下:首先将初始任务的本地数据读取为由浮点数组成的高维向量通过数据和模型参数θinit进行运算,得到预测标签然后,可以通过上述公式(2)计算真实标签和预测标签之间的交叉熵损失。然后,通过误差和反向传播算法对模型参数θinit进行修改和更新,得到更新后的初始模型参数 Step S402: The client performs the first round of training for the initial task based on the local data and initial model of the initial task of the client to obtain the initial model parameters and initial prediction labels. Specifically, all clients initialize their local data. Assuming that each client has T tasks and they appear in sequence, the tth task can be expressed as in, represents the local data set of the t-th task; k represents the k-th client; i represents the ith local data; t represents the t-th task; N represents the number of local data for each task; T represents the total number of tasks for the client. The local data of the initial task can be expressed as Then, the first round of training for the initial task proceeds as follows: First, the local data of the initial task is read as a high-dimensional vector consisting of floating-point numbers By Data Calculate the model parameters θ init to get the predicted label Then, the true label can be calculated by the above formula (2) and predicted labels Then, the model parameters θ init are modified and updated through the error and back propagation algorithm to obtain the updated initial model parameters
步骤S403,基于客户端的初始任务和初始任务的类原型,得到客户端的初始任务在第一轮训练中的初始类平均原型。具体地,可以参照上述步骤S103,通过上述公式(3)得到初始类平均原型。Step S403, based on the initial task of the client and the class prototype of the initial task, obtain the initial class average prototype of the initial task of the client in the first round of training. Specifically, referring to the above step S103, the initial class average prototype can be obtained by the above formula (3).
步骤S404,基于初始模型的预测标签、初始模型的真实标签和初始模型参数,确定初始任务对应的初始梯度。具体地,可以通过下述公式(4)得到初始梯度。Step S404, based on the predicted label of the initial model, the true label of the initial model and the initial model parameters, determine the initial gradient corresponding to the initial task. Specifically, the initial gradient can be obtained by the following formula (4).
其中,▽表示梯度;θ1表示初始任务的第一轮训练的模型参数;表示初始任务在第一轮训练中的真实标签和预测标签之间的交叉熵损失,k表示第k个客户端,表示第k个客户端在初始任务的第一轮训练中的真实标签;表示第k个客户端在初始任务的本地数据的高维向量。Among them, ▽ represents the gradient; θ 1 represents the model parameters of the first round of training for the initial task; represents the cross entropy loss between the true label and the predicted label of the initial task in the first round of training, k represents the kth client, represents the true label of the kth client in the first round of training of the initial task; A high-dimensional vector representing the local data of the k-th client in the initial task.
步骤S405,客户端基于初始梯度,得到初始任务的梯度空间的正交初始基向量。具体地,可以将初始任务在第一轮训练中得到的初始梯度如下式(5)所示的表示矩阵进行表示。然后,对表示矩阵进行奇异值分解得到下式(6)。然后,将左奇异矩阵定义为如下式(7)所示的初始任务在第一轮训练中的梯度空间的正交初始基向量。可选的,还可以采用其它矩阵分解方法得到基向量,本申请实施例对此不进行限制。In step S405, the client obtains the orthogonal initial basis vectors of the gradient space of the initial task based on the initial gradient. Specifically, the initial gradient obtained by the initial task in the first round of training can be represented by the representation matrix shown in equation (5). Then, the representation matrix is subjected to singular value decomposition to obtain the following equation (6). Then, the left singular matrix is defined as the orthogonal initial basis vectors of the gradient space of the initial task in the first round of training as shown in equation (7). Optionally, other matrix decomposition methods can also be used to obtain basis vectors, which is not limited in the embodiments of the present application.
其中,Z表示初始任务的第一次训练的表示矩阵;L表示L个向量。Among them, Z represents the representation matrix of the first training of the initial task; L represents L vectors.
Z1=U1S1(V1)T(6)Z 1 =U 1 S 1 (V 1 ) T (6)
其中,Z表示初始任务的第一次训练的表示矩阵;U表示初始任务的第一次训练的左奇异矩阵;S表示初始任务的第一次训练的奇异值;V表示初始任务的第一次训练的右奇异矩阵。Among them, Z represents the representation matrix of the first training of the initial task; U represents the left singular matrix of the first training of the initial task; S represents the singular values of the first training of the initial task; V represents the right singular matrix of the first training of the initial task.
其中,U表示初始任务在第一轮训练中梯度空间的正交初始基向量;span()表示扩张空间函数;k表示第k个客户端。Among them, U represents the orthogonal initial basis vector of the gradient space of the initial task in the first round of training; span() represents the expansion space function; k represents the kth client.
步骤S406,客户端向服务器发送初始任务在第一轮训练中得到的初始模型参数和初始类平均原型。具体地,将初始任务在第一轮训练得到的初始模型参数和初始类平均原型发送给服务器,使得初始任务的第二轮训练能够在初始任务的第一轮训练学习到的知识的基础上进行更新。Step S406, the client sends the initial model parameters and the initial class average prototype obtained in the first round of training of the initial task to the server. Specifically, the initial model parameters and the initial class average prototype obtained in the first round of training of the initial task are sent to the server, so that the second round of training of the initial task can be updated based on the knowledge learned in the first round of training of the initial task.
需要说明的是,上述步骤S401至步骤S406为所有客户端的初始任务的第一轮训练的实现过程。每一客户端执行上述步骤S401至步骤S406后,服务器接收初始任务的第一轮训练得到的初始模型参数和初始类平均原型,确定初始任务的第一轮训练的个性化全局模型。然后,服务器将第一轮训练得到的个性化全局模型发送给客户端进行初始任务的第二轮训练。然后,从初始任务的第二轮训练开始直至完成初始任务的最后一轮训练,客户端参照上述步骤S401至步骤S406,接收服务器上一轮训练得到的个性化全局模型用于本轮训练,并且在每轮训练结束后将每轮训练得到的模型参数和类平均原型发送给服务器,服务器也在每轮训练后确定当前轮次的个性化全局模型并发送给客户端以进行下一轮训练。其中,初始任务的训练轮次可以为5次、10次或20次等,本发明实施例对此不进行限制。当完成初始任务的预设轮次训练后,若客户端有新增任务,客户端和服务器对新增的任务不再执行上述步骤S401至步骤S406,而执行下述步骤以实现对新增任务的训练和学习。It should be noted that the above steps S401 to S406 are the implementation process of the first round of training of the initial task of all clients. After each client executes the above steps S401 to S406, the server receives the initial model parameters and the initial class average prototype obtained in the first round of training of the initial task, and determines the personalized global model of the first round of training of the initial task. Then, the server sends the personalized global model obtained in the first round of training to the client for the second round of training of the initial task. Then, starting from the second round of training of the initial task until the last round of training of the initial task is completed, the client refers to the above steps S401 to S406, receives the personalized global model obtained in the previous round of training on the server for this round of training, and sends the model parameters and class average prototype obtained in each round of training to the server after each round of training. The server also determines the personalized global model of the current round after each round of training and sends it to the client for the next round of training. Among them, the training rounds of the initial task can be 5 times, 10 times or 20 times, etc., which are not limited in the embodiment of the present invention. After completing the preset rounds of training for the initial task, if the client has a new task, the client and the server no longer perform the above steps S401 to S406 for the new task, but perform the following steps to achieve training and learning for the new task.
在一些可选的实施方式中,以客户端结束上一任务的预设轮次的训练后,并且出现新增任务为例,通过下述步骤S407至步骤S424为例对本申请实施例提供的一种基于受约束梯度更新的联邦类增量学习方法进行说明。其中,服务器通过下述步骤S407至步骤S411,对完成训练的上一任务的训练结果进行处理,客户端和服务器从步骤S412开始进行新的任务的训练。In some optional implementations, after the client completes the preset round of training of the previous task and a new task appears, the federated incremental learning method based on constrained gradient update provided by the embodiment of the present application is described by taking the following steps S407 to S424 as an example. Among them, the server processes the training results of the previous task that has completed training through the following steps S407 to S411, and the client and the server start training for the new task from step S412.
步骤S407,对于每一客户端,在客户端对上一任务进行预设轮次训练后,服务器接收客户端的上一任务的模型参数和类平均原型,上一任务的模型参数和类平均原型通过上一任务在最后一轮训练得到。具体地,上一任务可以为除了初始任务外的任一任务。由于每一任务的每轮训练都会得到本轮训练对应的模型参数和类平均原型。因此,在客户端对上一任务进行预设轮次训练后,服务器接收到的模型参数和类平均原型为上一任务进行最后一轮训练得到的。服务器通过接收上一任务的模型参数和类平均原型,能够确定学习了上一任务的知识的个性化全局模型,为下一任务的训练做准备。Step S407, for each client, after the client performs a preset round of training on the previous task, the server receives the model parameters and class average prototype of the previous task of the client, and the model parameters and class average prototype of the previous task are obtained by the last round of training of the previous task. Specifically, the previous task can be any task except the initial task. Since each round of training of each task will obtain the model parameters and class average prototype corresponding to this round of training. Therefore, after the client performs a preset round of training on the previous task, the model parameters and class average prototype received by the server are obtained by the last round of training of the previous task. By receiving the model parameters and class average prototype of the previous task, the server can determine the personalized global model that has learned the knowledge of the previous task and prepare for the training of the next task.
步骤S408,服务器确定客户端与其他客户端之间的多个第一距离,第一距离表示客户端的上一任务的类平均原型与其他客户端的上一任务的类平均原型之间的相似程度。具体地,服务器可以通过下述公式(8)确定第一距离。Step S408, the server determines a plurality of first distances between the client and other clients, wherein the first distance represents the similarity between the class average prototype of the previous task of the client and the class average prototype of the previous task of other clients. Specifically, the server can determine the first distance using the following formula (8).
其中,t表示上一任务;k表示第k个客户端;j作为其他客户端的索引;表示第一距离;D表示可选距离度量函数;p表示类平均原型;K表示所有客户端。可选的,还可使用其他欧式空间内的距离度量函数确定类平均原型之间的距离,本申请实施例对此不进行限制。Where t represents the previous task; k represents the kth client; j is the index of other clients; represents the first distance; D represents an optional distance metric function; p represents the class average prototype; K represents all clients. Optionally, other distance metric functions in Euclidean space can also be used to determine the distance between class average prototypes, which is not limited in the present embodiment.
步骤S409,服务器基于多个第一距离和上一任务的模型参数,进行个性化聚合,得到客户端的上一任务的个性化全局模型。具体地,服务器可以通过下述公式(9)进行个性化聚合,得到上一任务的个性化全局模型。通过基于第一距离、目标模型参数和目标类平均原型进行个性化聚合,使得得到的个性化全局模型不仅能分类自己客户端的任务,还具备分类其他客户端的任务的能力,提高了模型的分类能力。In step S409, the server performs personalized aggregation based on the multiple first distances and the model parameters of the previous task to obtain a personalized global model of the previous task of the client. Specifically, the server can perform personalized aggregation through the following formula (9) to obtain a personalized global model of the previous task. By performing personalized aggregation based on the first distance, the target model parameters and the target class average prototype, the obtained personalized global model can not only classify the tasks of the client itself, but also has the ability to classify the tasks of other clients, thereby improving the classification ability of the model.
其中,t表示上一任务;k表示第k个客户端;表示上一任务的个性化全局模型;表示上一任务的模型参数;j作为其他客户端的索引;K表示所有客户端;表示第一距离;Among them, t represents the previous task; k represents the kth client; Represents a personalized global model of the previous task; represents the model parameters of the previous task; j is the index of other clients; K represents all clients; represents the first distance;
步骤S410,服务器从每个历史任务中,确定与客户端的第二距离大于距离阈值的多个跨任务客户端,第二距离表示客户端的上一任务的类平均原型与任一跨任务客户端的目标类平均原型之间的相似程度。具体地,服务器可以通过历史原型选择机制寻找对上一任务改进最大的历史任务,也即通过从每个历史任务中基于上述公式(8)得到的第二距离来确定跨任务客户端,以提高上一任务的个性化全局模型对历史任务的性能。例如,服务器在每个历史任务中为客户端k挑选2个距离最小其他客户端,当客户端执行到第t个任务时,被选中的其他客户端历史任务数量为α=2*(t-1)。可选的,还可通过其他方式确定跨任务客户端,本发明实施例对此不进行限制。In step S410, the server determines from each historical task a plurality of cross-task clients whose second distance to the client is greater than a distance threshold, wherein the second distance represents the degree of similarity between the class average prototype of the client's previous task and the target class average prototype of any cross-task client. Specifically, the server can use the historical prototype selection mechanism to find the historical task that has made the greatest improvement on the previous task, that is, determine the cross-task client by the second distance obtained from each historical task based on the above formula (8) to improve the performance of the personalized global model of the previous task on the historical task. For example, the server selects two other clients with the smallest distance for client k in each historical task. When the client executes the tth task, the number of selected other client historical tasks is α=2*(t-1). Optionally, the cross-task client can also be determined in other ways, which is not limited in the embodiments of the present invention.
步骤S411,服务器基于多个跨任务客户端的目标模型参数、多个跨任务客户端对应的第二距离以及上一任务的模型参数,确定上一任务的跨任务协作损失。具体地,由于跨任务客户端与该客户端的相似程度高,因此基于跨任务客户端的目标模型参数,通过下述公式(10)确定跨任务协作损失,能够减少因客户端之间的差异带来的模型漂移问题。In step S411, the server determines the cross-task collaboration loss of the previous task based on the target model parameters of the multiple cross-task clients, the second distances corresponding to the multiple cross-task clients, and the model parameters of the previous task. Specifically, since the cross-task client has a high degree of similarity with the client, the cross-task collaboration loss is determined based on the target model parameters of the cross-task client by the following formula (10), which can reduce the model drift problem caused by the differences between the clients.
其中,t表示上一任务;k表示第k个客户端;表示上一任务的模型参数;θold表示从历史任务中确定的跨任务客户端的目标模型参数;表示跨任务客户端对应的第二距离;α表示历史任务的数量;i表示第i个历史任务。Among them, t represents the previous task; k represents the kth client; represents the model parameters of the previous task; θ old represents the target model parameters of the cross-task client determined from the historical tasks; represents the second distance corresponding to the cross-task client; α represents the number of historical tasks; i represents the i-th historical task.
步骤S412,服务器向每一客户端发送上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务。详细请参见图3所示实施例的步骤S301,在此不再赘述。In step S412, the server sends the personalized global model and cross-task collaborative loss of the previous task to each client. The personalized global model is obtained by the server performing personalized aggregation on multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss indicates the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks indicate all tasks whose training sequence is prior to the previous task. For details, please refer to step S301 of the embodiment shown in FIG3 , which will not be described in detail here.
步骤S413,客户端接收服务器发送的上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务。详细请参见图1所示实施例的步骤S101,在此不再赘述。In step S413, the client receives the personalized global model and cross-task collaborative loss of the previous task sent by the server. The personalized global model is obtained by the server performing personalized aggregation on multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss indicates the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks indicate all tasks whose training sequence is prior to the previous task. For details, please refer to step S101 of the embodiment shown in FIG1 , which will not be described in detail here.
步骤S414,客户端将新的任务的本地数据转化为高维向量。具体地,该高维向量由浮点数组成。通过将新的任务的本地数据转化成高维向量,为确定新的任务的预测标签做准备。Step S414: The client converts the local data of the new task into a high-dimensional vector. Specifically, the high-dimensional vector is composed of floating-point numbers. By converting the local data of the new task into a high-dimensional vector, preparation is made for determining the prediction label of the new task.
步骤S415,客户端基于上一任务的模型参数和新的任务的高维向量,确定新的任务的预测标签。具体地,参照上述步骤S602确定新的任务的预测标签。In step S415, the client determines the predicted label of the new task based on the model parameters of the previous task and the high-dimensional vector of the new task. Specifically, the predicted label of the new task is determined with reference to the above step S602.
步骤S416,基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,客户端对上一任务的模型参数沿着与其梯度空间正交的梯度方向进行更新,得到新的任务在本轮训练中的模型参数。详细请参见图1所示实施例的步骤S102,在此不再赘述。Step S416: Based on the basis vectors of the gradient space of the previous task, the predicted labels and true labels of the new task, and the cross-task collaborative loss, the client updates the model parameters of the previous task along the gradient direction orthogonal to its gradient space to obtain the model parameters of the new task in this round of training. For details, please refer to step S102 of the embodiment shown in FIG1 , which will not be described in detail here.
步骤S417,客户端对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型。详细请参见图1所示实施例的步骤S103,在此不再赘述。Step S417: The client performs weighted averaging on the class prototypes of the new task to obtain the class average prototype of the new task in this round of training. For details, please refer to step S103 of the embodiment shown in FIG1 , which will not be described in detail here.
步骤S418,在新的任务每一轮训练结束后,基于新的任务的上一轮训练的跨任务协作损失、新的任务的预测标签和真实标签,更新优化目标,优化目标用于指示客户端在训练过程中得到的训练损失的期望值。具体地,该优化目标能够指示客户端对任务训练期望达到的最小损失值。可以通过下述公式(11)确定优化目标。Step S418, after each round of training of the new task is completed, the optimization target is updated based on the cross-task collaborative loss of the previous round of training of the new task, the predicted label and the true label of the new task, and the optimization target is used to indicate the expected value of the training loss obtained by the client during the training process. Specifically, the optimization target can indicate the minimum loss value that the client expects to achieve for task training. The optimization target can be determined by the following formula (11).
其中,t+1表示新的任务;k表示第k个客户端;表示新的任务的个性化全局模型;表示基于新的任务的预测标签和真实标签得到的交叉熵损失,可通过上述公式(2)计算得到;表示参照上述步骤S411,通过上述公式(10)得到的新的任务的上一轮训练得到的跨任务协作损失。Among them, t+1 represents a new task; k represents the kth client; Represent a personalized global model for new tasks; represents the cross entropy loss based on the predicted label and the true label of the new task, which can be calculated by the above formula (2); It represents the cross-task collaboration loss obtained by the previous round of training of the new task obtained by the above formula (10) with reference to the above step S411.
步骤S419,客户端基于更新后的优化目标,进行新的任务的下一轮训练。具体地,由于该优化目标能够反映训练损失期望达到的最小值,因此基于该优化目标指示下一轮训练,能够提高下一轮训练的训练效果。Step S419, the client performs the next round of training for the new task based on the updated optimization target. Specifically, since the optimization target can reflect the minimum value expected to be achieved by the training loss, indicating the next round of training based on the optimization target can improve the training effect of the next round of training.
步骤S420,客户端重复上述模型参数和类平均原型的更新,进行预设轮次训练,在每一轮训练完成后,向服务器发送新的任务的模型参数和类平均原型。详细请参见图1所示实施例的步骤S104,在此不再赘述。In step S420, the client repeats the updating of the above model parameters and the class average prototype, performs a preset round of training, and sends the model parameters and the class average prototype of the new task to the server after each round of training. For details, please refer to step S104 of the embodiment shown in FIG1, which will not be repeated here.
步骤S421,在新的任务进行预设轮次训练后,客户端基于新的任务在最后一轮训练中得到的梯度和历史任务基向量,得到新的任务的表示矩阵,历史任务基向量包含新的任务的所有历史任务的基向量。具体地,该历史任务基向量通过新的任务的所有历史任务的基向量进行拼接得到。因此基于客户端在新的任务的最后一轮训练中得到的梯度和历史任务基向量,可以通过下述公式(12)得到新的任务的表示矩阵,从而能够在学习新的任务的基础上,不干扰历史任务的知识,避免了对早期任务的灾难性遗忘。Step S421, after the new task is trained for a preset round, the client obtains the representation matrix of the new task based on the gradient of the new task obtained in the last round of training and the historical task basis vector, and the historical task basis vector contains the basis vectors of all historical tasks of the new task. Specifically, the historical task basis vector is obtained by splicing the basis vectors of all historical tasks of the new task. Therefore, based on the gradient and historical task basis vector obtained by the client in the last round of training of the new task, the representation matrix of the new task can be obtained by the following formula (12), so that the knowledge of the historical tasks can be not interfered with on the basis of learning the new task, avoiding catastrophic forgetting of the early tasks.
其中,t表示上一任务;t+1表示新的任务;Ut表示历史任务基向量;Zt+1表示参照上述步骤S405,基于新的任务在最后一轮训练中得到的梯度得到的表示矩阵;表示新的任务的历史任务的表示矩阵。Wherein, t represents the previous task; t+1 represents the new task; U t represents the historical task basis vector; Z t+1 represents the representation matrix obtained based on the gradient obtained in the last round of training of the new task with reference to the above step S405; The representation matrix of the historical tasks that represent the new task.
步骤S422,客户端对表示矩阵进行奇异值分解,得到新的任务的基向量,形成新的任务的正交梯度空间。具体地,可以对表示矩阵进行奇异值分解,如下式(13)所示,然后通过下述公式(14)得到新的任务的基向量。In step S422, the client performs singular value decomposition on the representation matrix to obtain the basis vector of the new task and form the orthogonal gradient space of the new task. Specifically, the representation matrix can be subjected to singular value decomposition as shown in the following formula (13), and then the basis vector of the new task is obtained by the following formula (14).
Zt+1=Ut+1St+1(Vt+1)T(13)Z t+1 =U t+1 S t+1 (V t+1 ) T (13)
其中,t+1表示新的任务;Zt+1表示新的任务在本轮训练的表示矩阵;Ut+1表示新的任务在本轮训练的左奇异矩阵;St+1表示新的任务在本轮训练的奇异值;Vt+1表示新的任务在本轮训练的右奇异矩阵;Among them, t+1 represents the new task; Z t+1 represents the representation matrix of the new task in this round of training; U t+1 represents the left singular matrix of the new task in this round of training; S t+1 represents the singular value of the new task in this round of training; V t+1 represents the right singular matrix of the new task in this round of training;
其中,t+1表示新的任务;Ut+1表示新的任务在本轮训练的左奇异矩阵,也即新的任务的基向量。Among them, t+1 represents the new task; U t+1 represents the left singular matrix of the new task in this round of training, that is, the basis vector of the new task.
步骤S423,客户端将新的任务的基向量与历史任务基向量进行拼接,得到包含当前所有任务的基向量,由包含当前所有任务的基向量形成新的正交梯度空间,用于下一任务的训练。具体地,该包含当前所有任务的基向量如下式(15)所示。通过将新的任务的基向量与历史任务基向量进行拼接,将新的任务的知识与历史任务的知识进行拼接得到包含当前所有任务的基向量,在新增任务时能够基于该包含当前所有任务的基向量进行训练,实现同时存有新旧任务的知识,避免对早期任务的灾难性遗忘问题。In step S423, the client concatenates the basis vector of the new task with the basis vector of the historical tasks to obtain the basis vector containing all current tasks, and forms a new orthogonal gradient space by the basis vector containing all current tasks for training the next task. Specifically, the basis vector containing all current tasks is shown in the following formula (15). By concatenating the basis vector of the new task with the basis vector of the historical tasks, the knowledge of the new task is concatenated with the knowledge of the historical tasks to obtain the basis vector containing all current tasks. When adding a new task, training can be performed based on the basis vector containing all current tasks, so that the knowledge of new and old tasks can be stored at the same time, avoiding the catastrophic forgetting problem of early tasks.
其中,t表示上一任务;t+1表示新的任务;Ut表示上一任务的基向量;Ut+1表示新的任务的基向量;Ut+1表示新的任务在本轮训练的左奇异矩阵;k表示第k个客户端。Among them, t represents the previous task; t+1 represents the new task; U t represents the basis vector of the previous task; U t+1 represents the basis vector of the new task; U t+1 represents the left singular matrix of the new task in this round of training; k represents the kth client.
步骤S424,服务器接收客户端在获取到新的任务数据后,在每轮训练中对模型参数和类平均原型进行更新,并在完成每轮训练后发送的更新后的模型参数和类平均原型,模型参数基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的梯度方向进行更新,类平均原型基于对新的任务的类原型进行加权平均得到。详细请参见图3所示实施例的步骤S302,在此不再赘述。Step S424, the server receives the updated model parameters and class average prototypes sent by the client after acquiring new task data in each round of training, and after completing each round of training, the model parameters are updated based on the basis vectors of the gradient space of the previous task, the predicted labels and true labels of the new task, and the cross-task collaborative loss, and the model parameters of the previous task are updated along the gradient direction orthogonal to its gradient space, and the class average prototype is obtained based on the weighted average of the class prototypes of the new task. For details, please refer to step S302 of the embodiment shown in Figure 3, which will not be repeated here.
需要说明的是,对于初始任务第一轮训练,执行上述步骤S401至步骤S406,对于初始任务的第二轮训练直至最后一轮训练,客户端参照上述步骤S401至步骤S406,接收服务器上一轮训练得到的个性化全局模型用于本轮训练,在每轮训练结束后将每轮训练得到的模型参数和类平均原型发送给服务器,并且服务器在初始任务的每轮训练后向客户端发送上一轮训练得到的个性化全局模型。对于除初始任务外的任一任务,服务器通过上述步骤S407至步骤S411确定该任务的上一任务的个性化全局模型和跨任务协作损失。然后,服务器执行步骤S412,客户端执行步骤S413至步骤S420进行新的任务的第一轮训练。然后,对于新的任务的第二轮训练直至新的任务的最后一轮训练,服务器参照上述步骤S407至步骤S412对当前轮次的上一轮训练得到的模型参数和类平均原型进行更新,将得到的个性化全局模型和跨任务协作损失发送给客户端,客户端接收到当前轮次的上一轮训练得到的个性化全局模型和跨任务协作损失,重复执行上述步骤S413至步骤S420,服务器重复执行上述步骤S424对当前轮次训练得到的模型参数和类平均原型进行更新。当客户端执行新的任务的最后一轮训练时,还需要额外执行步骤S421至步骤S423。It should be noted that for the first round of training of the initial task, the above steps S401 to S406 are executed. For the second round of training of the initial task until the last round of training, the client refers to the above steps S401 to S406, receives the personalized global model obtained by the previous round of training from the server for this round of training, and sends the model parameters and class average prototype obtained by each round of training to the server after each round of training. The server sends the personalized global model obtained by the previous round of training to the client after each round of training of the initial task. For any task other than the initial task, the server determines the personalized global model and cross-task collaboration loss of the previous task of the task through the above steps S407 to S411. Then, the server executes step S412, and the client executes steps S413 to S420 to perform the first round of training for the new task. Then, for the second round of training of the new task until the last round of training of the new task, the server updates the model parameters and class average prototype obtained in the previous round of training of the current round with reference to the above steps S407 to S412, and sends the obtained personalized global model and cross-task collaborative loss to the client. The client receives the personalized global model and cross-task collaborative loss obtained in the previous round of training of the current round, and repeats the above steps S413 to S420. The server repeats the above step S424 to update the model parameters and class average prototype obtained in the current round of training. When the client performs the last round of training of the new task, it is also necessary to additionally perform steps S421 to S423.
本发明实施例提供的一种基于受约束梯度更新的联邦类增量学习方法,由于上一任务训练得到的个性化全局模型能够实现对所有客户端的上一任务以及上一任务的历史任务进行分类,上一任务得到的跨任务协作损失能够提高个性化全局模型的性能,因此在获取到新的任务的本地数据时,可以基于服务器发送的上一任务的个性化全局模型和跨任务协作损失,从梯度空间的角度出发,将梯度更新方向限制在上一任务的梯度空间上,进行模型参数的更新,以保护先前学习到的知识。然后,客户端将每轮训练得到的模型参数和类平均原型发送给服务器,将新的任务学习到的知识用于下一轮训练。通过对上一任务的个性化全局模型和跨任务协作损失接收并更新,并且重复发送每轮训练的更新结果,使得在客户端不断产生新的任务的数据时,能够避免对早期任务的灾难性遗忘,同时避免客户端间的模型漂移问题。The embodiment of the present invention provides a federated incremental learning method based on constrained gradient update. Since the personalized global model obtained by the previous task training can classify the previous tasks of all clients and the historical tasks of the previous task, the cross-task collaborative loss obtained by the previous task can improve the performance of the personalized global model. Therefore, when the local data of the new task is obtained, the gradient update direction can be limited to the gradient space of the previous task from the perspective of the gradient space based on the personalized global model and cross-task collaborative loss of the previous task sent by the server, and the model parameters are updated to protect the previously learned knowledge. Then, the client sends the model parameters and class average prototype obtained in each round of training to the server, and uses the knowledge learned by the new task for the next round of training. By receiving and updating the personalized global model and cross-task collaborative loss of the previous task, and repeatedly sending the updated results of each round of training, when the client continuously generates data for new tasks, catastrophic forgetting of early tasks can be avoided, and the model drift problem between clients can be avoided.
在本实施例中还提供了一种基于受约束梯度更新的联邦类增量学习装置,该装置用于实现上述实施例及优选实施方式,已经进行过说明的不再赘述。如以下所使用的,术语“模块”可以实现预定功能的软件和/或硬件的组合。尽管以下实施例所描述的装置较佳地以软件来实现,但是硬件,或者软件和硬件的组合的实现也是可能并被构想的。In this embodiment, a federated incremental learning device based on constrained gradient update is also provided, which is used to implement the above embodiments and preferred implementation modes, and will not be repeated hereafter. As used below, the term "module" can implement a combination of software and/or hardware for a predetermined function. Although the device described in the following embodiments is preferably implemented in software, the implementation of hardware, or a combination of software and hardware, is also possible and conceivable.
本实施例提供一种基于受约束梯度更新的联邦类增量学习装置,应用于客户端,如图6所示,包括:This embodiment provides a federated incremental learning device based on constrained gradient updating, which is applied to a client, as shown in FIG6 , and includes:
第一接收模块601,用于接收服务器发送的上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务。The first receiving module 601 is used to receive the personalized global model of the previous task and the cross-task collaborative loss sent by the server. The personalized global model is obtained by the server personalizing and aggregating multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss represents the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks represent all tasks whose training sequence is prior to the previous task.
更新模块602,用于当获取到新的任务数据时,基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,得到新的任务在本轮训练中的模型参数。Update module 602 is used to update the model parameters of the previous task in a direction orthogonal to its gradient space based on the basis vectors of the gradient space of the previous task, the predicted label and true label of the new task, and the cross-task collaborative loss when new task data is acquired, so as to obtain the model parameters of the new task in this round of training.
获取模块603,用于对新的任务的类原型进行加权平均,得到新的任务在本轮训练中的类平均原型。The acquisition module 603 is used to perform weighted averaging on the class prototypes of the new task to obtain the class average prototype of the new task in this round of training.
第一训练模块604,用于重复上述模型参数和类平均原型的更新,进行预设轮次训练,在每一轮训练完成后,向服务器发送新的任务的模型参数和类平均原型。The first training module 604 is used to repeat the updating of the above model parameters and class average prototypes, perform preset rounds of training, and send the model parameters and class average prototypes of new tasks to the server after each round of training is completed.
在一些可选的实施方式中,更新模块602之前,该装置还包括:In some optional implementations, before updating module 602, the device further includes:
转化模块,用于将新的任务的本地数据转化为高维向量。The conversion module is used to convert the local data of the new task into a high-dimensional vector.
第一确定模块,用于基于上一任务的模型参数和新的任务的高维向量,确定新的任务的预测标签。The first determination module is used to determine the prediction label of the new task based on the model parameters of the previous task and the high-dimensional vector of the new task.
在一些可选的实施方式中,该装置还包括:In some optional embodiments, the device further comprises:
第二确定模块,用于在新的任务进行预设轮次训练后,基于新的任务在最后一轮训练中得到的梯度和历史任务基向量,得到新的任务的表示矩阵,历史任务基向量包含新的任务的所有历史任务的基向量。The second determination module is used to obtain the representation matrix of the new task after the new task has been trained for a preset round of times, based on the gradient of the new task obtained in the last round of training and the historical task basis vectors, where the historical task basis vectors contain the basis vectors of all historical tasks of the new task.
分解模块,用于对表示矩阵进行奇异值分解,得到新的任务的基向量,形成新的任务的正交梯度空间。The decomposition module is used to perform singular value decomposition on the representation matrix to obtain the basis vectors of the new task and form the orthogonal gradient space of the new task.
拼接模块,用于将新的任务的基向量与历史任务基向量进行拼接,得到包含当前所有任务的基向量,用于下一任务的训练。The splicing module is used to splice the basis vector of the new task with the basis vector of the historical tasks to obtain the basis vector containing all current tasks for training the next task.
在一些可选的实施方式中,第一接收模块601之前,该装置还包括:In some optional implementations, before the first receiving module 601, the device further includes:
第二接收模块,用于对于每一客户端,接收服务器发送的初始模型。The second receiving module is used to receive the initial model sent by the server for each client.
第二训练模块,用于基于客户端的初始任务的本地数据和初始模型,进行初始任务的第一轮训练,得到初始模型参数和初始预测标签。The second training module is used to perform the first round of training for the initial task based on the local data and initial model of the initial task of the client, and obtain the initial model parameters and initial prediction labels.
第三确定模块,用于基于客户端的初始任务和初始任务的类原型,得到客户端的初始任务在第一轮训练中的初始类平均原型。The third determination module is used to obtain an initial class average prototype of the initial task of the client in the first round of training based on the initial task of the client and the class prototype of the initial task.
第四确定模块,用于基于初始模型的预测标签、初始模型的真实标签和初始模型参数,确定初始任务对应的初始梯度。The fourth determination module is used to determine the initial gradient corresponding to the initial task based on the predicted label of the initial model, the true label of the initial model and the initial model parameters.
第五确定模块,用于基于初始梯度,得到初始任务的梯度空间的正交初始基向量。The fifth determination module is used to obtain an orthogonal initial basis vector of the gradient space of the initial task based on the initial gradient.
发送模块,用于向服务器发送初始任务在第一轮训练中得到的初始模型参数和初始类平均原型。The sending module is used to send the initial model parameters and initial class average prototype obtained by the initial task in the first round of training to the server.
本实施例提供一种基于受约束梯度更新的联邦类增量学习装置,应用于服务器,如图7所示,包括:This embodiment provides a federated incremental learning device based on constrained gradient updating, which is applied to a server, as shown in FIG7 , and includes:
发送模块701,用于向每一客户端发送上一任务的个性化全局模型和跨任务协作损失,个性化全局模型由服务器对所有客户端基于各自的上一任务的本地数据训练得到的多个模型参数和多个类平均原型进行个性化聚合得到,跨任务协作损失表示所有客户端各自的上一任务的模型参数与历史任务的模型参数之间的相似程度,历史任务表示所有训练顺序先于上一任务的任务。The sending module 701 is used to send the personalized global model and cross-task collaborative loss of the previous task to each client. The personalized global model is obtained by the server personalizing and aggregating multiple model parameters and multiple class average prototypes obtained by training all clients based on their local data of the previous task. The cross-task collaborative loss represents the similarity between the model parameters of the previous task of all clients and the model parameters of the historical tasks. The historical tasks represent all tasks whose training sequence is prior to the previous task.
第一接收模块702,用于接收客户端在获取到新的任务数据后,在每轮训练中对模型参数和类平均原型进行更新,并在完成每轮训练后发送的更新后的模型参数和类平均原型,模型参数基于上一任务的梯度空间的基向量、新的任务的预测标签和真实标签以及跨任务协作损失,对上一任务的模型参数沿着与其梯度空间正交的方向进行更新,类平均原型基于对新的任务的类原型进行加权平均得到。The first receiving module 702 is used to receive the updated model parameters and class average prototypes sent by the client after the client obtains new task data in each round of training, and after completing each round of training. The model parameters are based on the basis vectors of the gradient space of the previous task, the predicted labels and true labels of the new task, and the cross-task collaborative loss. The model parameters of the previous task are updated along a direction orthogonal to its gradient space, and the class average prototype is obtained based on the weighted average of the class prototypes of the new task.
在一些可选的实施方式中,发送模块701之前,该装置还包括:In some optional implementations, before sending module 701, the device further includes:
第二接收模块,用于对于每一客户端,在客户端对上一任务进行预设轮次训练后,接收客户端的上一任务的模型参数和类平均原型,上一任务的模型参数和类平均原型通过上一任务在最后一轮训练得到。The second receiving module is used to receive the model parameters and class average prototype of the previous task of each client after the client performs a preset round of training on the previous task, and the model parameters and class average prototype of the previous task are obtained through the last round of training of the previous task.
第三确定模块,用于确定客户端与其他客户端之间的多个第一距离,第一距离表示客户端的上一任务的类平均原型与其他客户端的上一任务的类平均原型之间的相似程度。The third determination module is used to determine a plurality of first distances between the client and other clients, wherein the first distance represents the similarity between the class average prototype of the previous task of the client and the class average prototype of the previous task of other clients.
聚合模块,用于基于多个第一距离和上一任务的模型参数,进行个性化聚合,得到客户端的上一任务的个性化全局模型。The aggregation module is used to perform personalized aggregation based on multiple first distances and model parameters of the previous task to obtain a personalized global model of the previous task of the client.
第四确定模块,用于从每个历史任务中,确定与客户端的第二距离大于距离阈值的多个跨任务客户端,第二距离表示客户端的上一任务的类平均原型与任一跨任务客户端的目标类平均原型之间的相似程度。The fourth determination module is used to determine, from each historical task, multiple cross-task clients whose second distance to the client is greater than a distance threshold, where the second distance represents the similarity between the class average prototype of the client's previous task and the target class average prototype of any cross-task client.
第五确定模块,用于基于多个跨任务客户端的目标模型参数、多个跨任务客户端对应的第二距离以及上一任务的模型参数,确定上一任务的跨任务协作损失。The fifth determination module is used to determine the cross-task collaboration loss of the previous task based on the target model parameters of the multiple cross-task clients, the second distances corresponding to the multiple cross-task clients, and the model parameters of the previous task.
上述各个模块和单元的更进一步的功能描述与上述对应实施例相同,在此不再赘述。The further functional description of each of the above modules and units is the same as that of the above corresponding embodiments and will not be repeated here.
本实施例中的一种基于受约束梯度更新的联邦类增量学习装置是以功能单元的形式来呈现,这里的单元是指ASIC(Application Specific Integrated Circuit,专用集成电路)电路,执行一个或多个软件或固定程序的处理器和存储器,和/或其他可以提供上述功能的器件。In this embodiment, a federated incremental learning device based on constrained gradient update is presented in the form of a functional unit, where the unit refers to an ASIC (Application Specific Integrated Circuit) circuit, a processor and memory that executes one or more software or fixed programs, and/or other devices that can provide the above functions.
本发明实施例还提供一种计算机设备,具有上述图6和图7所示的一种基于受约束梯度更新的联邦类增量学习装置。An embodiment of the present invention also provides a computer device having a federated incremental learning device based on constrained gradient updating as shown in FIG. 6 and FIG. 7 above.
请参阅图8,图8是本发明可选实施例提供的一种计算机设备的结构示意图,如图8所示,该计算机设备包括:一个或多个处理器10、存储器20,以及用于连接各部件的接口,包括高速接口和低速接口。各个部件利用不同的总线互相通信连接,并且可以被安装在公共主板上或者根据需要以其它方式安装。处理器可以对在计算机设备内执行的指令进行处理,包括存储在存储器中或者存储器上以在外部输入/输出装置(诸如,耦合至接口的显示设备)上显示GUI的图形信息的指令。在一些可选的实施方式中,若需要,可以将多个处理器和/或多条总线与多个存储器和多个存储器一起使用。同样,可以连接多个计算机设备,各个设备提供部分必要的操作(例如,作为服务器阵列、一组刀片式服务器、或者多处理器系统)。图8中以一个处理器10为例。Please refer to Figure 8, which is a schematic diagram of the structure of a computer device provided by an optional embodiment of the present invention. As shown in Figure 8, the computer device includes: one or more processors 10, a memory 20, and interfaces for connecting various components, including high-speed interfaces and low-speed interfaces. The various components are connected to each other using different buses for communication, and can be installed on a common motherboard or installed in other ways as needed. The processor can process instructions executed in the computer device, including instructions stored in or on the memory to display graphical information of the GUI on an external input/output device (such as a display device coupled to the interface). In some optional embodiments, if necessary, multiple processors and/or multiple buses can be used together with multiple memories and multiple memories. Similarly, multiple computer devices can be connected, and each device provides some necessary operations (for example, as a server array, a group of blade servers, or a multi-processor system). In Figure 8, a processor 10 is taken as an example.
处理器10可以是中央处理器,网络处理器或其组合。其中,处理器10还可以进一步包括硬件芯片。上述硬件芯片可以是专用集成电路,可编程逻辑器件或其组合。上述可编程逻辑器件可以是复杂可编程逻辑器件,现场可编程逻辑门阵列,通用阵列逻辑或其任意组合。The processor 10 may be a central processing unit, a network processor or a combination thereof. The processor 10 may further include a hardware chip. The hardware chip may be a dedicated integrated circuit, a programmable logic device or a combination thereof. The programmable logic device may be a complex programmable logic device, a field programmable gate array, a general purpose array logic or any combination thereof.
其中,所述存储器20存储有可由至少一个处理器10执行的指令,以使所述至少一个处理器10执行实现上述实施例示出的方法。The memory 20 stores instructions executable by at least one processor 10, so that the at least one processor 10 executes the method shown in the above embodiment.
存储器20可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储根据计算机设备的使用所创建的数据等。此外,存储器20可以包括高速随机存取存储器,还可以包括非瞬时存储器,例如至少一个磁盘存储器件、闪存器件、或其他非瞬时固态存储器件。在一些可选的实施方式中,存储器20可选包括相对于处理器10远程设置的存储器,这些远程存储器可以通过网络连接至该计算机设备。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。The memory 20 may include a program storage area and a data storage area, wherein the program storage area may store an operating system, an application required by at least one function; the data storage area may store data created according to the use of the computer device, etc. In addition, the memory 20 may include a high-speed random access memory, and may also include a non-transient memory, such as at least one disk storage device, a flash memory device, or other non-transient solid-state storage device. In some optional embodiments, the memory 20 may optionally include a memory remotely arranged relative to the processor 10, and these remote memories may be connected to the computer device via a network. Examples of the above-mentioned network include, but are not limited to, the Internet, an intranet, a local area network, a mobile communication network, and combinations thereof.
存储器20可以包括易失性存储器,例如,随机存取存储器;存储器也可以包括非易失性存储器,例如,快闪存储器,硬盘或固态硬盘;存储器20还可以包括上述种类的存储器的组合。The memory 20 may include a volatile memory, such as a random access memory; the memory may also include a non-volatile memory, such as a flash memory, a hard disk or a solid state drive; the memory 20 may also include a combination of the above types of memory.
本发明实施例还提供了一种计算机可读存储介质,上述根据本发明实施例的方法可在硬件、固件中实现,或者被实现为可记录在存储介质,或者被实现通过网络下载的原始存储在远程存储介质或非暂时机器可读存储介质中并将被存储在本地存储介质中的计算机代码,从而在此描述的方法可被存储在使用通用计算机、专用处理器或者可编程或专用硬件的存储介质上的这样的软件处理。其中,存储介质可为磁碟、光盘、只读存储记忆体、随机存储记忆体、快闪存储器、硬盘或固态硬盘等;进一步地,存储介质还可以包括上述种类的存储器的组合。可以理解,计算机、处理器、微处理器控制器或可编程硬件包括可存储或接收软件或计算机代码的存储组件,当软件或计算机代码被计算机、处理器或硬件访问且执行时,实现上述实施例示出的方法。The embodiment of the present invention also provides a computer-readable storage medium. The method according to the embodiment of the present invention can be implemented in hardware, firmware, or can be implemented as a computer code that can be recorded in a storage medium, or can be implemented as a computer code that is originally stored in a remote storage medium or a non-temporary machine-readable storage medium and will be stored in a local storage medium through a network download, so that the method described herein can be stored in such software processing on a storage medium using a general-purpose computer, a dedicated processor, or programmable or dedicated hardware. Among them, the storage medium can be a magnetic disk, an optical disk, a read-only storage memory, a random access memory, a flash memory, a hard disk or a solid-state hard disk, etc.; further, the storage medium can also include a combination of the above types of memories. It can be understood that a computer, a processor, a microprocessor controller, or programmable hardware includes a storage component that can store or receive software or computer code. When the software or computer code is accessed and executed by a computer, a processor, or hardware, the method shown in the above embodiment is implemented.
虽然结合附图描述了本发明的实施例,但是本领域技术人员可以在不脱离本发明的精神和范围的情况下做出各种修改和变型,这样的修改和变型均落入由所附权利要求所限定的范围之内。Although the embodiments of the present invention have been described in conjunction with the accompanying drawings, those skilled in the art may make various modifications and variations without departing from the spirit and scope of the present invention, and such modifications and variations are all within the scope defined by the appended claims.
Claims (10)
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410529143.5A CN118313445A (en) | 2024-04-29 | 2024-04-29 | A federated incremental learning method and system based on constrained gradient updating |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410529143.5A CN118313445A (en) | 2024-04-29 | 2024-04-29 | A federated incremental learning method and system based on constrained gradient updating |
Publications (1)
Publication Number | Publication Date |
---|---|
CN118313445A true CN118313445A (en) | 2024-07-09 |
Family
ID=91722053
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410529143.5A Pending CN118313445A (en) | 2024-04-29 | 2024-04-29 | A federated incremental learning method and system based on constrained gradient updating |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN118313445A (en) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118586475A (en) * | 2024-08-05 | 2024-09-03 | 浙江浙能电力股份有限公司萧山发电厂 | A federated category incremental learning modeling method based on stable feature prototypes |
-
2024
- 2024-04-29 CN CN202410529143.5A patent/CN118313445A/en active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118586475A (en) * | 2024-08-05 | 2024-09-03 | 浙江浙能电力股份有限公司萧山发电厂 | A federated category incremental learning modeling method based on stable feature prototypes |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20220318307A1 (en) | Generating Neighborhood Convolutions Within a Large Network | |
US10963783B2 (en) | Technologies for optimized machine learning training | |
CN107545889B (en) | Model optimization method and device suitable for pattern recognition and terminal equipment | |
CN110210513B (en) | Data classification method and device and terminal equipment | |
CN110197195B (en) | Novel deep network system and method for behavior recognition | |
WO2017145960A1 (en) | Learning device, learning method, and recording medium | |
US20190138899A1 (en) | Processing apparatus, processing method, and nonvolatile recording medium | |
US11126359B2 (en) | Partitioning graph data for large scale graph processing | |
US20200167660A1 (en) | Automated heuristic deep learning-based modelling | |
CN118313445A (en) | A federated incremental learning method and system based on constrained gradient updating | |
US11295229B1 (en) | Scalable generation of multidimensional features for machine learning | |
CN114255381A (en) | Training method of image recognition model, image recognition method, device and medium | |
JP7621480B2 (en) | Domain Generalization Margin Using Meta-Learning for Deep Face Recognition | |
CN111368991A (en) | Deep learning model training method and device and electronic equipment | |
JP2022512211A (en) | Image processing methods, equipment, in-vehicle computing platforms, electronic devices and systems | |
CN115511104A (en) | Method, apparatus, device and medium for training a contrast learning model | |
JP2020191017A (en) | Information processing equipment, information processing methods and information processing programs | |
JP2024173962A (en) | Machine learning device, inference device, machine learning method, and machine learning program | |
Li et al. | Cyclic annealing training convolutional neural networks for image classification with noisy labels | |
CN114187465A (en) | Method and device for training classification model, electronic equipment and storage medium | |
CN118037580A (en) | Image denoising method, system, storage medium and electronic device | |
CN115905702B (en) | Data recommendation method and system based on user demand analysis | |
CN110728359A (en) | Method, device, equipment and storage medium for searching model structure | |
Truong et al. | Unsupervised learning for maximum consensus robust fitting: A reinforcement learning approach | |
CN117474108A (en) | Quantum system selection via coupling graph comparison |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |