CN114723998B - 基于大边界贝叶斯原型学习的小样本图像分类方法及装置 - Google Patents
基于大边界贝叶斯原型学习的小样本图像分类方法及装置 Download PDFInfo
- Publication number
- CN114723998B CN114723998B CN202210482490.8A CN202210482490A CN114723998B CN 114723998 B CN114723998 B CN 114723998B CN 202210482490 A CN202210482490 A CN 202210482490A CN 114723998 B CN114723998 B CN 114723998B
- Authority
- CN
- China
- Prior art keywords
- sample
- class
- model
- distribution
- formula
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
- G06F18/24155—Bayesian classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于大边界贝叶斯原型学习的小样本图像分类方法及装置,方法包括以下步骤:S1、对数据进行预处理,其中数据包括训练集和测试集;S2、基于卷积神经网络和高斯分布模型,构建贝叶斯原型分布模型;S3、基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;S4、利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能。本发明引入了大边界分类和变分推理的思想,基于分布估计建模的方式进行研究,建立了一种基于大边界贝叶斯原型学习的小样本图像分类方法及装置,解决了小样本图像分类中存在的原型偏差问题,改善了图像的分类效果,具有很高的实用价值。
Description
技术领域
本发明涉及图像分类技术领域,尤其涉及一种基于大边界贝叶斯原型学习的小样本图像分类方法及装置。
背景技术
近年来,随着计算机技术的发展,人们浏览的信息日益丰富,每天都有大量图片被上传到网络,由于数量巨大,人工已经无法对此进行分类。在很多大样本图像分类任务上,机器的识别性能已经超越人类。然而,当样本量比较少时,机器的识别水平仍与人类存在较大差距。因此,研究高效可靠的图片分类算法有很迫切的社会需求。
人类具体通过极少量样本识别一个新物体的能力,例如小朋友只需要看过书中的个别图片,就可以准确的判断什么是“香蕉”或者是“草莓”。小样本学习指的是研究人员希望机器学习模型在学习一定类别的大量数据后,遇到新的类别,只需要少量的数据就可以快速的学习,实现“小样本学习”。
小样本分类属于小样本学习范畴,往往包含类别空间不相交的两类数据,即基类数据和新类数据。小样本分类旨在利用基类数据学习的知识和新类数据的少量标记样本来学习分类规则,准确预测新类任务中未标记样本的类别。
原先在度量学习中是利用固定的点,一般利用中心点,使用点估计的方法预测类,但是单点的度量学习方法有固有的噪声脆弱性和容易产生偏差。例如当有限的支撑点在嵌入空间中分布不均匀时,估计这种特定的点是很困难的,因为很容易受到噪声的影响。此外,这些点也缺乏可解释性,因为单个嵌入不足以表示一个类,我们认为同一个类的每个点在嵌入空间中不是孤立的,而是从高维分布中采样的,每个类的分布比几个点有更好的描述能力。所以我们基于点估计方法引起的问题,提出一种基于变分推理的分布估计框架,虽然精确的推理在计算上很困难,但是变分推理是一种具有理论吸引力且适合计算的方法。
但是变分推理的方法需要通过一个具有完整数据集的模型,这就导致低效地嵌入所有元素,因此分布不好生成,类原型存在的偏差大,MAML (与模型无关的元学习算法)优化方式会对梯度产生影响。因此我们基于前人的工作在支持集和目标集之间共享生成机制,然后将其合并到大型的特定类的分布中,这样通过每个类的分布,可以直观地计算目标点所在类的置信度,我们就可以利用置信度进行分类,就有利于避免样本受到偏差的影响。
发明内容
本发明针对上述小样本图像分类中的原型偏差问题,提出一种基于大边界贝叶斯原型学习的小样本图像分类方法及装置,引入了大边界分类和变分推理的思想,利用分布建模的方式,改善MAML(与模型无关的元学习算法)的优化方式,增加类内相似度,减少类间相似度,发明基于变分推理的方法,度量学习中以类的估计代替点估计以消除单点度量学习方法固有的噪声脆弱和偏差,引入少量的随机变量用于表示类原型,在此基础上,利用变分推理的方式学习样本的后验分布,在支持集和目标集之间共享生成机制,将其合并到大型的特定的类方法中,通过估计类的分布,直观地计算目标点所在类的置信度,再利用置信度进行分类,获取预测结果,有利于避免样本受到偏差的影响,增强分类精度和鲁棒性的同时也减少成本。
为了实现上述目的,本发明提供如下技术方案:
一方面,本发明提供了一种基于大边界贝叶斯原型学习的小样本图像分类方法,包括以下步骤:
S1、对数据进行预处理,其中数据包括训练集和测试集;
S2、基于卷积神经网络和高斯分布模型,构建贝叶斯原型分布模型,模型由嵌入模块fθ和分布映射模块组成;其中,嵌入模块用于提取样本特征,包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;分布映射模块由一层全连接层和转换层组成,用于将样本的深层特征转换为多维高斯分布;
S3、基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;
S4、利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能。
进一步地,步骤S1的预处理方法为:
S12,对于C-way K-shot分类任务,从Dtrain中随机选出C个类别,每个类别中随机选出M个样本,其中K个样本作为支持样本Si,其余M-K 个样本作为查询样本Qi,Si和Qi构成一个任务Ti;同样地,对于Dtest有任务
进一步地,步骤S11中遵循四层卷积架构来形成特征提取器F,其中每个卷积块包含一个带有64个滤波器的3×3的卷积,一个批量归一化,一个relu非线性层,一个2×2最大池化层,裁剪了最后两个块的最大池化层,特征提取器的输出大小为64×5×5=1600;利用128维全连接层构建原始生成器,最后通过聚合规划获得支持集和目标集的64维分布。
进一步地,步骤S3具体包括:
S33、支持样本和查询样本经过全连接层之后获取样本的全局特征;
S35、利用公式(1)和(2)获取的支持样本的类原型;
式(1)中xi代表样本,Sc代表支持样本中第C类别,μ(xi)代表样本xi的均值,μc代表类别C的均值即第c个类的分布,式(1)整体代表求类别C的加权调和平均数,代表模型中不同类原型的位置,式(2)代表求类别C的方差,表示样本波动的程度,式(1)和(2)结合在一起表示样本的类原型;
S36、利用公式(3)计算所有支持样本的预测概率Pr(μ(xi)|q(z|Sc)):
式(3)中q(z|Sc)为第C类支持集样本的分布,Pr(X)这里指代预测概率;式(3)代表所有支持样本c的预测概率由高斯分布得到,其中均值是μc,代表了所有维度中向量的均值,同时也代表了样本的位置,集中了从样本中提取的类信息,方差是代表不同维度的显著性;
S37,利用预测概率可以得到整个任务T的分类损失LR(S)为:
式(4)代表的是整个任务T的分类损失,S代表支持集样本,c代表第c个样本类,Pr代表预测概率,Sc代表支持样本c,等式左侧代表了支持样本损失函数,等式右侧代表交叉熵损失函数;
S38,利用式(5)计算不同任务之间的损失,即判断预测样本和真实样本之间的相似程度,即不同类别之间分布的KL散度:
S39,利用公式(6)计算总的损失L:
L=LR+Linter (6)
S310,采用MAML的优化方式,利用公式(6)中支持样本的损失函数,进行多步随机梯度下降,得到新的网络参数;
S311,利用公式(4)计算查询样本在该参数下的分类预测损失,并利用Adam优化器更新模型。
进一步地,步骤S4的具体步骤为:
式(1)中xi代表样本,Sc代表支持样本中第C类别,μ(xi)代表样本xi的均值,μc代表类别C的均值即第c个类的分布,式(1)整体代表求类别C 的加权调和平均数,代表模型中不同类原型的位置,式(2)代表求类别C 的方差,表示样本波动的程度,式(1)和(2)结合在一起表示样本的类原型;
S42,将查询样本输入训练好的嵌入模块fθ中;然后,将嵌入模块输出的矩阵特征经过贝叶斯原型分布模型,获得查询样本下的每个类的分布,直观地计算目标点所在类的置信度,再利用置信度进行分类,获取预测结果。
另一方面,本发明还提供了一种基于大边界贝叶斯原型学习的小样本图像分类装置,用以实现上述的任一项方法,包括以下模块:
数据预处理模块:用于对数据进行预处理,将数据划分成为训练集和测试集并且确定模型的训练方式;
网络模型构建模块:用于引入卷积神经网络模型和多维高斯分布,构建面贝叶斯原型分布模型,模型由嵌入模块fθ和分布映射模块组成;其中,嵌入模块用于提取样本特征,包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;分布映射模块由一层全连接层和转换层组成,用于将样本的深层特征转换为多维高斯分布;
训练模型参数模块:基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;
测试模型性能模块:利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能。
与现有技术相比,本发明的有益效果为:
本发明引入了大边界分类和变分推理的思想,基于分布估计建模的方式进行研究,建立一种基于大边界贝叶斯原型学习的小样本图像分类方法及装置,引入少量的随机变量用于表示类原型,在此基础上,利用变分推理的方式学习样本的后验分布,在支持集和目标集之间共享生成机制,将其合并到大型的特定的类方法中,通过估计类的分布,直观计算所在类的置信度,直观地计算目标点所在类的置信度,再利用置信度进行分类,获取预测结果,有利于避免样本受到偏差的影响,增强分类精度和鲁棒性的同时也减少成本。总的来说,本发明这种允许在模型中使用完整的贝叶斯分析的方法,相比较前人的工作更容易解释,消除了点估计的偏差,提高了预测性能。这种基于大边界贝叶斯原型学习的小样本图像方法,巧妙地应用了变分推理地思想,利用MAML的优化方式,强化类固有地特征,抑制不相关的特征,解决了小样本图像分类中存在的原型偏差问题,改善了图像的分类效果,具有很高的实用价值。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的计算类原型的变分贝叶斯原型学习网络。
图2为本发明实施例提供的基于大边界贝叶斯原型学习的小样本图像分类装置模块流程图。
图3为本发明实施例提供的卷积神经网络流程图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。本发明中的实施例,本领域技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
根据本文公开的一个方面,提供了一种基于大边界贝叶斯原型学习的小样本图像分类方法,如图1所示,包括以下阶段步骤:
S1、对数据进行预处理,其中数据包括训练集和测试集;
具体地,步骤S1的预处理方法为:
S11,将数据分为训练集/> 和测试集/>两部分,且这两部分的类别空间互斥,将Dtrain作为基类数据训练模型,Dtest作为新类数据测评模型性能,遵循四层卷积架构来形成特征提取器F,卷积神经网络中每个卷积块包含一个带有64个滤波器的3×3的卷积,一个批量归一化,一个relu 非线性层,一个2×2最大池化层,裁剪了最后两个块的最大池化层。特征提取器的输出大小为64×5×5=1600。利用128维全连接层构建原始生成器,最后通过聚合规划获得支持集和目标集的64维分布。
S12,对于C-way K-shot分类任务,从Dtrain中随机选出C个类别,每个类别中随机选出M个样本,其中K个样本作为支持样本Si,其余M-K 个样本作为查询样本Qi,Si和Qi构成一个任务Ti;同样地,对于Dtest有任务全连接层将样本深层特征映射为一个向量,该向量一部分作为样本均值μ,另一部分作为样本原始方差向量/>转换层将原始的方差向量转换为样本的方差向量σ2。
S2、基于卷积神经网络和高斯分布模型,构建贝叶斯原型分布模型。
具体地,模型由嵌入模块fθ和分布映射模块组成;其中,嵌入模块用于提取样本特征,包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;分布映射模块一层全连接层和转换层组成,用于将样本的深层特征转换为多维高斯分布。
S3、基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;
具体地,步骤S3包括:
S33、支持样本和查询样本经过全连接层之后获取样本的全局特征;
S35、利用公式(1)和(2)获取的支持样本的类原型;
式(1)中xi代表样本,Sc代表支持样本中第C类别,μ(xi)代表样本xi的均值,μc代表类别C的均值即第c个类的分布,式(1)整体代表求类别C的加权调和平均数,代表模型中不同类原型的位置,式(2)代表求类别C的方差,表示样本的波动程度,式(1)和(2)结合在一起表示样本的类原型;
S36、利用公式(3)计算所有支持样本的预测概率Pr(μ(xi)|q(z|Sc)):
式(3)中q(z|Sc)为第C类支持集样本的分布,Pr(X)这里指代预测概率;式(3)代表所有支持样本c的预测概率由高斯分布得到,其中均值是μc,代表了所有维度中向量的均值,同时也代表了样本的位置,集中了从样本中提取的类信息,方差是代表不同维度的显著性;
S37,利用预测概率可以得到整个任务T的分类损失LR(S)为:
式(4)代表的是整个任务T的分类损失,S代表支持集样本,c代表第c个样本类,Pr代表预测概率,Sc代表支持样本c,等式左侧代表了支持样本损失函数,等式右侧代表交叉熵损失函数;
S38,利用式(5)计算不同任务之间的损失,即判断预测样本和真实样本之间的相似程度,即不同类别之间分布的KL散度:
S39,利用公式(6)计算总的损失L:
L=LR+Linter (6)
S310,采用MAML的优化方式,利用公式(6)中支持样本的损失函数,进行多步随机梯度下降,得到新的网络参数;
S311,利用公式(4)计算查询样本在该参数下的分类预测损失,并利用Adam优化器更新模型。
S4、利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能。
具体地,步骤S4的步骤为:
式(1)中xi代表样本,Sc代表支持样本中第C类别,μ(xi)代表样本xi的均值,μc代表类别C的均值即第c个类的分布,式(1)整体代表求类别C 的加权调和平均数,代表模型中不同类原型的位置,式(2)代表求类别C 的方差,表示样本的波动程度,式(1)和(2)结合在一起表示样本的类原型;
S42,将查询样本输入训练好的嵌入模块fθ中;然后,将嵌入模块输出的矩阵特征经过贝叶斯原型分布模型,获得查询样本下的每个类的分布,直观地计算目标点所在类的置信度,再利用置信度进行分类,获取预测结果。
另一方面,本发明还提供了一种基于大边界贝叶斯原型学习的小样本图像分类装置,用以实现上述的任一项方法,包括以下模块:
数据预处理模块:用于对数据进行预处理,将数据划分成为训练集和测试集并且确定模型的训练方式;
网络模型构建模块:用于引入卷积神经网络模型和多维高斯分布,构建贝叶斯原型分布模型,模型由嵌入模块fθ和分布映射模块组成;其中,嵌入模块用于提取样本特征,包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;分布映射模块一层全连接层和转换层组成,用于将样本的深层特征转换为多维高斯分布;
训练模型参数模块:基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;
测试模型性能模块:利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能。
本发明方法和装置具有以下优点:
1、针对单点的度量学习方法产生的固有的噪声脆弱和容易产生偏差的问题,利用随机变分推理消除偏离的类特定分布,利用新样本的分布统计量,实现无分类器预测,增强小样本学习分类框架的分类精度和鲁棒性;
2、基于有限的支撑点在嵌入空间中分布不均匀时,估计特定的点很困难,单个的点也缺乏可解释性,因为其不足以表示一个类,所以由高维分布中采样,采取基于变分推理的分布估计框架;
3、基于前人在变分推理思想中的支持集和数据集之间生成共享生成机制,将其结合到大型的特定类的分布中,结合贝叶斯分析方法,提出变分贝叶斯框架,灵活地使用MAML的优化方式对模型进行优化,接着通过直观观察所在类的置信度且利用置信度进行分类,强化了类固有的特征,抑制不相关的特征,从而减小偏差。
以上结合附图对所提出的基于大边界贝叶斯原型学习的小样本图像分类方法及模型的具体实施方式进行了阐述。通过以上实施方式的描述,所属领域的技术人员可以清楚的了解该方法以及装置的实施。
需要说明的是,在附图和说明书正文中,未描述的实现方式,均为所属技术领域中普通技术人员所知的形式,未进行详细说明。此外,上述对各元件和方法的定义并不仅限于实例中提到的各种具体结构、形状或方式,本领域普通技术人员可对其进行简单地更改或替换。
此外,除非特别描述或必须依序发生地步骤,上述步骤地顺序并无限制于以上所列,且可根据所需设计而变化或重新安排。并且上述实例可基于设计及可靠度地考虑,彼此混合搭配使用或与其他实例混合搭配使用,即不同实施中的技术特征可以自由组合形成更多地实施例子。在此提供的算法和显示不与任何特定计算机、虚拟系统或者其他设备固有相关。各种通用系统也可以与基于在此地启示一起使用。根据上面的描述,构造这类系统所要求的结构是显而易见的。此外,本文公开的也不针对任何特定的编程语言。但是应当了解,可以利用各种编程语言实现在此描述的本文公开的内容,并且上面对特定语言所做的描述是为了披露本文公开的最佳实施方式。
类似的,应当理解,为了使本文尽量精简并且帮助理解各个公开方面中的一个或多个,在上面对本文公开的示例性实施例的描述中,本文公开的各个特征有时被一起分组到单个实施例、图、或者对其的描述中。然而,并不应将该公开的方法解释成反映如下示意图:即要求所保护的本文公开的要求比在每个权利要求中所明确记载的特征具有更多的特征。更确切地说,如下面的权利要求书所反映的那样,公开方面在于少于前面公开的单个实施例的所有特征。因此,遵循具体实施方式的权利要求书由此明确地并入该具体实施方式,其中每个权利要求本身都作为本发明公开的单独实施例子。
以上所述实施例,仅为本申请的具体实施方式,用以说明本申请的技术方案,而非对其限制,本申请的保护范围并不局限于此,尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特殊进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本申请实施例技术方案的精神和范围。都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应所述以权利要求的保护范围为准。
Claims (5)
1.一种基于大边界贝叶斯原型学习的小样本图像分类方法,其特征在于,包括以下步骤:
S1、对数据进行预处理,其中数据包括训练集和测试集;
S2、基于卷积神经网络和高斯分布模型,构建贝叶斯原型分布模型,模型由嵌入模块fθ和分布映射模块组成;其中,嵌入模块用于提取样本特征,包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;分布映射模块由一层全连接层和转换层组成,用于将样本的深层特征转换为多维高斯分布;
S3、基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;
步骤S3具体包括:
S33、支持样本和查询样本经过全连接层之后获取样本的全局特征;
S35、利用公式(1)和(2)获取的支持样本的类原型;
式(1)中xi代表样本,Sc代表支持样本中第C类别,μ(xi)代表样本xi的均值,μc代表类别C的均值即第c个类的分布,式(1)整体代表求类别C的加权调和平均数,代表模型中不同类原型的位置,式(2)代表求类别C的方差,表示样本的波动幅度,式(1)和(2)结合在一起表示样本的类原型;
S36、利用公式(3)计算所有支持样本的预测概率Pr(μ(xi)|q(z|Sc)):
式(3)中q(z|Sc)为第C类支持集样本的分布,Pr(X)这里指代预测概率;式(3)代表所有支持样本c的预测概率由高斯分布得到,其中均值是μc,代表了所有维度中向量的均值,同时也代表了样本的位置,集中了从样本中提取的类信息,方差是代表不同维度的显著性;
S37,利用预测概率得到整个任务T的分类损失LR(S)为:
式(4)代表的是整个任务T的分类损失,S代表支持集样本,c代表第c个样本类,Pr代表预测概率,Sc代表支持样本c,等式左侧代表了支持样本损失函数,等式右侧代表交叉熵损失函数;
S38,利用式(5)计算不同任务之间的损失,即判断预测样本和真实样本之间的相似程度,即不同类别之间分布的KL散度:
S39,利用公式(6)计算总的损失L:
L=LR+Linter (6)
S310,采用MAML的优化方式,利用公式(6)中支持样本的损失函数,进行多步随机梯度下降,得到新的网络参数;
S311,利用公式(4)计算查询样本在该参数下的分类预测损失,并利用Adam优化器更新模型;
S4、利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能;
步骤S4的具体步骤为:
式(1)中xi代表样本,Sc代表支持样本中第C类别,μ(xi)代表样本xi的均值,μc代表类别C的均值即第c个类的分布,式(1)整体代表求类别C的加权调和平均数,代表模型中不同类原型的位置,式(2)代表求类别C的方差,表样本的波动程度,式(1)和(2)结合在一起表示样本的类原型;
S42,将查询样本输入训练好的嵌入模块fθ中;然后,将嵌入模块输出的矩阵特征经过贝叶斯原型分布模型,获得查询样本下的每个类的分布,直观地计算目标点所在类的置信度,再利用置信度进行分类,获取预测结果;具体过程为:
利用公式(3)计算所有支持样本的预测概率Pr(μ(xi)|q(z|Sc)):
式(3)中q(z|Sc)为第C类支持集样本的分布,Pr(X)这里指代预测概率;式(3)代表所有支持样本c的预测概率由高斯分布得到,其中均值是μc,代表了所有维度中向量的均值,同时也代表了样本的位置,集中了从样本中提取的类信息,方差是代表不同维度的显著性;
利用预测概率可以得到整个任务T的分类损失LR(S)为:
式(4)代表的是整个任务T的分类损失,S代表支持集样本,c代表第c个样本类,Pr代表预测概率,Sc代表支持样本c,等式左侧代表了支持样本损失函数,等式右侧代表交叉熵损失函数;
利用式(5)计算不同任务之间的损失,即判断预测样本和真实样本之间的相似程度,即不同类别之间分布的KL散度:
利用公式(6)计算总的损失L:
L=LR+Linter (6)
采用MAML的优化方式,利用公式(6)中支持样本的损失函数,进行多步随机梯度下降,得到新的网络参数;
利用公式(3)计算查询样本在该参数下的分类预测结果。
3.根据权利要求2所述的基于大边界贝叶斯原型学习的小样本图像分类方法,其特征在于,步骤S11中遵循四层卷积架构来形成特征提取器F,其中每个卷积块包含一个带有64个滤波器的3×3的卷积,一个批量归一化,一个relu非线性层,一个2×2最大池化层,裁剪了最后两个块的最大池化层,特征提取器的输出大小为64×5×5=1600;利用128维全连接层构建原始生成器,最后通过聚合规划获得支持集和目标集的64维分布。
5.一种基于大边界贝叶斯原型学习的小样本图像分类装置,其特征在于,用以实现权利要求1-4所述的任一项方法,包括以下模块:
数据预处理模块:用于对数据进行预处理,将数据划分成为训练集和测试集并且确定模型的训练方式;
网络模型构建模块:用于引入卷积神经网络模型和多维高斯分布,构建基于大边界贝叶斯原型学习的小样本图像分类模型,模型由嵌入模块fθ和分布映射模块组成;其中,嵌入模块用于提取样本特征,包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;分布映射模块一层全连接层和转换层组成,用于将样本的深层特征转换为多维高斯分布;
训练模型参数模块:基于优化后的贝叶斯原型分布模型的目标函数并求解模型参数;
测试模型性能模块:利用优化后的贝叶斯原型分布模型对测试集图像进行分类,测评模型的性能。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210482490.8A CN114723998B (zh) | 2022-05-05 | 2022-05-05 | 基于大边界贝叶斯原型学习的小样本图像分类方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210482490.8A CN114723998B (zh) | 2022-05-05 | 2022-05-05 | 基于大边界贝叶斯原型学习的小样本图像分类方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114723998A CN114723998A (zh) | 2022-07-08 |
CN114723998B true CN114723998B (zh) | 2023-06-20 |
Family
ID=82231221
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210482490.8A Active CN114723998B (zh) | 2022-05-05 | 2022-05-05 | 基于大边界贝叶斯原型学习的小样本图像分类方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114723998B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116500575B (zh) * | 2023-05-11 | 2023-12-22 | 兰州理工大学 | 一种基于变分贝叶斯理论的扩展目标跟踪方法和装置 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110414600A (zh) * | 2019-07-27 | 2019-11-05 | 西安电子科技大学 | 一种基于迁移学习的空间目标小样本识别方法 |
WO2021051987A1 (zh) * | 2019-09-18 | 2021-03-25 | 华为技术有限公司 | 神经网络模型训练的方法和装置 |
Family Cites Families (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111476292B (zh) * | 2020-04-03 | 2021-02-19 | 北京全景德康医学影像诊断中心有限公司 | 医学图像分类处理人工智能的小样本元学习训练方法 |
KR102564285B1 (ko) * | 2020-06-19 | 2023-08-08 | 한국전자통신연구원 | 온라인 베이지안 퓨샷 학습 방법 및 장치 |
CN112633382B (zh) * | 2020-12-25 | 2024-02-13 | 浙江大学 | 一种基于互近邻的少样本图像分类方法及系统 |
CN114124437B (zh) * | 2021-09-28 | 2022-09-23 | 西安电子科技大学 | 基于原型卷积网络的加密流量识别方法 |
-
2022
- 2022-05-05 CN CN202210482490.8A patent/CN114723998B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110414600A (zh) * | 2019-07-27 | 2019-11-05 | 西安电子科技大学 | 一种基于迁移学习的空间目标小样本识别方法 |
WO2021051987A1 (zh) * | 2019-09-18 | 2021-03-25 | 华为技术有限公司 | 神经网络模型训练的方法和装置 |
Non-Patent Citations (2)
Title |
---|
Variational Feature Disentangling for Fine-Grained Few-Shot Classification;Jingyi Xu 等;《Proceedings of the IEEE/CVF International Conference on Computer Vision 》;8812-8821 * |
一种改进的小批量手写体字符识别算法;李远沐;王展青;;《小型微型计算机系统》;第41卷(第07期);1541-1546 * |
Also Published As
Publication number | Publication date |
---|---|
CN114723998A (zh) | 2022-07-08 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Candanedo et al. | Machine learning predictive model for industry 4.0 | |
CN109947086A (zh) | 基于对抗学习的机械故障迁移诊断方法及系统 | |
Zhou et al. | A belief-rule-based model for information fusion with insufficient multi-sensor data and domain knowledge using evolutionary algorithms with operator recommendations | |
Li et al. | Two-class 3D-CNN classifiers combination for video copy detection | |
Mulongo et al. | Anomaly detection in power generation plants using machine learning and neural networks | |
CN113761259A (zh) | 一种图像处理方法、装置以及计算机设备 | |
CN113469470A (zh) | 基于电力大脑中枢的用能数据与碳排放量关联分析方法 | |
CN115631396A (zh) | 一种基于知识蒸馏的YOLOv5目标检测方法 | |
Mai et al. | Optimization of interval type-2 fuzzy system using the PSO technique for predictive problems | |
CN114723998B (zh) | 基于大边界贝叶斯原型学习的小样本图像分类方法及装置 | |
Berghout et al. | Auto-NAHL: A neural network approach for condition-based maintenance of complex industrial systems | |
Hariri et al. | Tipburn disorder detection in strawberry leaves using convolutional neural networks and particle swarm optimization | |
Fakhrmoosavy et al. | A modified brain emotional learning model for earthquake magnitude and fear prediction | |
CN114943859B (zh) | 面向小样本图像分类的任务相关度量学习方法及装置 | |
Mikhaylenko et al. | Analysis of the predicting neural network person recognition system by picture image | |
Zhang et al. | Zero-small sample classification method with model structure self-optimization and its application in capability evaluation | |
Goyal et al. | Hierarchical class-based curriculum loss | |
Gao et al. | A fault diagnosis method for rolling bearings based on graph neural network with one-shot learning | |
Liu et al. | Missing data imputation and classification of small sample missing time series data based on gradient penalized adversarial multi-task learning | |
Huang et al. | An intelligent fault identification method of rolling bearings based on SVM optimized by improved GWO | |
CN114818945A (zh) | 融入类别自适应度量学习的小样本图像分类方法及装置 | |
Chander et al. | Auto-encoder—lstm-based outlier detection method for wsns | |
Shen et al. | A Fault Diagnosis Method under Data Imbalance Based on Generative Adversarial Network and Long Short-Term Memory Algorithms for Aircraft Hydraulic System | |
Kim et al. | A sensory control system for adjusting group emotion using Bayesian networks and reinforcement learning | |
CN108960406A (zh) | 一种基于bfo小波神经网络的mems陀螺随机误差预测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |