[go: up one dir, main page]
More Web Proxy on the site http://driver.im/

CN116363445A - 图像角度分类模型训练方法、装置、设备及存储介质 - Google Patents

图像角度分类模型训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN116363445A
CN116363445A CN202211608633.1A CN202211608633A CN116363445A CN 116363445 A CN116363445 A CN 116363445A CN 202211608633 A CN202211608633 A CN 202211608633A CN 116363445 A CN116363445 A CN 116363445A
Authority
CN
China
Prior art keywords
classification
image
angle
regression
classification model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211608633.1A
Other languages
English (en)
Inventor
张洪
肖嵘
王孝宇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Intellifusion Technologies Co Ltd
Original Assignee
Shenzhen Intellifusion Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Intellifusion Technologies Co Ltd filed Critical Shenzhen Intellifusion Technologies Co Ltd
Priority to CN202211608633.1A priority Critical patent/CN116363445A/zh
Publication of CN116363445A publication Critical patent/CN116363445A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/766Arrangements for image or video recognition or understanding using pattern recognition or machine learning using regression, e.g. by projecting features on hyperplanes
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computing Systems (AREA)
  • Databases & Information Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种图像角度分类模型训练方法、装置、设备以及存储介质以及一种图像角度判定方法,所述方法包括:将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;基于所述样本图像的真实角度和所述回归预测结果,计算得到回归损失值;基于所述样本图像的真实标签向量和分类预测结果,计算得到分类损失值;基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件;在满足训练停止条件的情况下,确定得到训练好的所述图像角度分类模型。

Description

图像角度分类模型训练方法、装置、设备及存储介质
技术领域
本发明涉及深度学习技术领域,尤其涉及一种图像角度分类模型训练方法、装置、设备及存储介质
背景技术
图像中物体的角度估计的在深度学习的可以看作多分类问题来实现,将不同的角度作为不同的分类类别,然后按照多分类任务来进行模型训练。但是,在训练过程中,当图像角度分类模型预测错误时,针对同一张图像得到除正确预测结果之外的其他错误预测结果,其他错误预测结果对应的损失值是相同的,比如,将图像角度为0度的样本图像错误预测为45度和180度时,其对应的损失值是相同的,这样,在模型训练过程中连续多次出现预测错误,且对应的损失值不再改变的情况下,存在模型训练提前结束的风险,导致图像角度分类模型得不到充分训练,存在准确率较低的问题。
发明内容
本发明实施例公开了一种图像角度分类模型训练方法、装置、计算机设备及存储介质,以解决现有的图像角度分类模型存在准确率较低的问题。
一种图像角度分类模型训练方法,所述方法包括:
将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
基于所述样本图像的真实角度和所述回归预测结果,计算得到回归损失值;
基于所述样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件;
在满足训练停止条件的情况下,确定得到训练好的所述图像角度分类模型。
上述方法,可选的,所述基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件,包括:
基于所述分类损失值和所述回归损失值,得到总损失值;
在预设的训练次数内,所述总损失值不再下降,或者,达到所述预设的训练次数,确定所述图像角度分类模型满足所述训练停止条件。
上述方法,可选的,所述基于所述真实角度值和所述回归预测结果,计算得到回归损失值,包括:
将所述真实角度值和所述回归预测结果输入到回归损失函数,得到回归损失值。
上述方法,可选的,所述回归损失函数包括第一回归损失函数和第二回归损失函数;
其中,将所述真实角度值和所述回归预测结果输入到回归损失函数,得到回归损失值,包括:
将所述真实角度值和所述回归预测结果,分别输入到所述第一回归损失函数和所述第二回归损失函数,得到第一回归损失值和第二回归损失值。
上述方法,可选的,所述分类预测结果通过以下方式得到:
将所述样本图像输入到主干网络,获取到所述样本图像中的特征图像;
对所述特征图像进行池化处理,得到所述样本图像的角度特征;
将所述角度特征输入到分类分支,得到所述样本图像的分类预测结果。
上述方法,可选的,所述将所述角度特征分别输入到分类分支,得到样本图像的分类预测结果,包括:
将所述角度特征输入到分类分支,得到每一个标签向量的预测概率值;每个所述标签向量对应一个角度值,或者一个角度值范围;
输出最大的所述预测概率对应的所述标签向量,作为所述分类预测结果。
一种图像角度判定方法,所述方法包括:获取到目标图像,将所述目标图像输入到上述任一所述的图像角度分类模型训练方法所得到的图像角度分类模型,获得所述分类分支输出的所述目标图像的预测结果。
一种图像角度分类模型训练装置,包括:
预测结果输出单元,用于将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
回归损失值计算单元,用于基于所述样本图像的真实角度和所述回归预测结果,计算得到回归损失值;
分类损失值计算单元,用于基于所述样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
训练停止判断单元,用于基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件;在满足训练停止条件的情况下,确定得到训练好的所述图像角度分类模型。
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上述任一项所述图像角度分类模型训练方法。
一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上述任一项所述图像角度分类模型训练方法。
上述图像角度分类模型训练方法、装置、计算机设备及存储介质,通过在图像角度分类模型的训练过程中,在计算分类分支的分类损失值得基础上,增加回归分支的回归损失值,以分类损失值和回归损失值来综合判断图像角度分类模型训练是否满足训练停止条件,由此,得到训练好的图像角度分类模型。可见,本发明中通过同时基于分类损失值和回归损失值来判定图像角度分类模型的训练效果,相较于仅基于分类损失值来判定图像角度分类模型的训练效果,可以达到提高图像角度分类模型的准确率的目的。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例一公开的一种图像角度分类模型训练方法的流程示意图;
图2是本发明实施例一公开的一种图像角度分类模型训练方法的部分流程示意图;
图3是本发明实施例一公开的一种图像角度分类模型训练方法的部分流程示意图;
图4是本发明实施例一公开的一种图像角度分类模型训练方法的模型结构示意图;
图5是本发明实施例一公开的一种图像角度分类模型训练方法的部分流程示意图;
图6是本发明实施例二公开的一种图像角度分类模型训练装置的结构示意图;
图7是本发明实施例四公开的一种计算机设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例公开的一种图像角度分类模型训练方法、装置、计算机设备及存储介质,通过在图像角度分类模型的训练过程中,在计算分类分支的分类损失值得基础上,增加回归分支的回归损失值,以分类损失值和回归损失值来综合判断图像角度分类模型训练是否满足训练停止条件,由此,得到训练好的图像角度分类模型。可以理解的,在图像角度分类模型训练过程中,由于常用的分类损失函数,在图像角度分类模型预测得到除正确预测结果之外的其他错误预测结果的损失值是相同的,因此,本发明中增加回归分支,以根据回归分支计算得到的回归预测结果,来计算得到回归损失值,由此,基于回归损失值和分类损失值,来判断图像角度分类模型是否满足训练停止条件,进而得到训练好的图像角度分类模型。可见,本发明中通过同时基于分类损失值和回归损失值来判定图像角度分类模型的训练效果,相较于仅基于分类损失值来判定图像角度分类模型的训练效果,可以达到提高图像角度分类模型的预测准确率的目的。
实施例一
如图1所示,为本发明实施例一公开的一种图像角度分类模型训练方法的实现流程图,该方法在图像角度分类模型训练过程中,在基于分类分支的分类损失值判定图像角度分类模型是否满足训练停止条件的基础上,增加回归分支的回归损失值,由此,通过回归损失值和分类损失值综合判定图像角度分类模型是否满足训练停止条件,由此,可以达到提高图像角度分类模型的预测准确率的目的。
S101:将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果。
在具体实现中,本实施例中的图像角度分类模型主要由三部分组成,包括:主干网络、池化层以及两个全连接层组成。其中,分类分支为两个全连接层中一个全连接层,回归分支为两个全连接层中的另一个全连接层。样本图像在经过主干网络和池化层处理后,得到样本图像的角度特征,将角度特征分别输入到分类分支和回归分支,得到回归分支输出的回归预测结果和分类分支输出的分类预测结果。据此,可以得到分类分支输出的分类预测结果和回归分支输出的回归预测结果。
S102:基于样本图像的真实角度和回归预测结果,计算得到回归损失值。
在具体实现中,本实施例中的样本图像的真实角度,为提前给定样本图像的真实角度,或者,提前标注好的样本。例如,通过三维建模技术,基于给定的角度值,来构建对应角度值的样本图像,以使图像中物体的角度值与给定的角度值相同,由此,给定的角度值即为样本图像的真实角度。再如,预先收集一些样本图像,然后通过人工标注的方法,标注出样本图像的角度值,由此,以人工标注的角度值作为样本图像的真实角度。将样本图像的真实角度和回归分支输出的回归预测结果,输入到回归损失函数中,进而计算出回归分支的回归损失值。据此,可以得到回归损失值,以用于图像角度分类模型训练的后续步骤。
需要注意的是,回归损失函数需要满足以下特征:
当回归分支输出的回归预测结果与样本图像的真实角度的差值越大时,在将回归分支输出的回归预测结果与样本图像的真实角度输入到回归损失函数中后,得到回归损失值也将越大。
例如,以样本图像的真实角度为0度为例,将该样本图像输入到图像角度分类模型中,以得到回归分支输出的预测结果。当回归分支输出错误预测结果为45度时,回归分支对应的分类损失值,小于,回归分支输出错误预测结果为180度时,回归分支对应的分类损失值。
S103:基于样本图像的真实标签向量和分类预测结果,计算得到分类损失值。
其中,样本图像的真实标签向量对应一个分类类别,该分类类别对应样本图像的真实角度。样本图像的真实角度为提前给定样本图像的真实角度,或者,提前标注好的样本。例如,通过三维建模技术,基于给定的角度值,来构建对应角度值的样本图像,以使图像中物体的角度值与给定的角度值相同,由此,给定的角度值即为样本图像的真实角度。再如,预先收集一些样本图像,然后通过人工标注的方法,标注出样本图像的角度值,由此,以人工标注的角度值作为样本图像的真实角度。
在具体实现中,本实施例中可以将将样本图像的真实标签向量和分类分支输出的分类预测结果,输入到分类损失函数中,进而计算出分类分支的回归损失值。
示例性的,分类损失函数可以如下所示:
Figure BDA0003999600670000061
其中,softmax()是归一化指数函数,ln()是以e为底数的对数函数,
Figure BDA0003999600670000062
是标签向量,/>
Figure BDA0003999600670000063
是分类分支的输出向量。
S104:基于分类损失值和回归损失值,判断图像角度分类模型是否满足训练停止条件。
如果图像角度分类模型满足训练停止条件,执行步骤S105,即确定得到训练好的图像角度分类模型,如果图像角度分类模型不满足训练停止条件,则更新图像角度分类模型模型参数,继续对图像角度分类模型进行训练。
在具体实现中,本实施例中基于分类损失值和回归损失值计算出总的损失值,根据该损失值得变化,来判断图像角度分类模型是否满足训练停止条件。可以理解的,当分类分支连续多次输出错误的预测结果时,分类分支的分类损失值是不变的,而回归分支在连续多次输出错误的预测结果时,回归分支的回归损失值,是不断变化的,因此,当总损失值不断变化时,图像角度分类模型是不满足训练停止条件,因此,仍需继续对图像角度分类模型进行训练,以此,避免图像角度分类模型提前停止训练,可以达到提高图像角度分类模型的准确率的目的。
S105:确定得到训练好的图像角度分类模型。
在得到训练好的图像角度分类模型后,可以丢弃图像角度分类模型的回归分支,仅是以分类分支输出的分类预测结果,作为图像角度分类模型预测结果,或者,保留图像角度分类模型的回归分支,以回归分支输出的回归预测结果来评价分类分支的分类预测结果,在回归分支输出的回归预测结果与分类分支的分类预测结果不同时,对图像角度分类模型的输入的回归预测结果和分类预测结果进行再一次确认,例如,需要进行角度分类的图像再次输入图像角度分类模型,重新得到回归分支输出的回归预测结果与分类分支的分类预测结果。
综上,本发明实施例一公开的一种图像角度分类模型训练方法,通过在图像角度分类模型的训练过程中,在计算分类分支的分类损失值得基础上,增加回归分支的回归损失值,以分类损失值和回归损失值来综合判断图像角度分类模型训练是否满足训练停止条件,由此,得到训练好的图像角度分类模型。可见,本发明中同时基于分类损失值和回归损失值来判定图像角度分类模型是否满足训练停止条件,相较于仅基于分类损失值来判定图像角度分类模型是否满足训练停止条件,可以提高图像角度分类模型的训练效果,以达到提高图像角度分类模型的准确率的目的。
在基于图1的具体实现中,步骤S104具体可以通过以下步骤实现,如图2所示:
S201:基于分类损失值和回归损失值,得到总损失值。
在具体实现中,本实施例中可以将分类损失值和回归损失值输入到总损失计算公式,以得到总损失值,其中,分类损失值和回归损失值在总损失计算公式中分别具备对应的权重系数,该权重系数可以根据实际需要进行调整,本实施例中对分类损失值和回归损失值分别对应的权重系数不做限定。
示例性的,本实施例中的总损失计算公式可以如下所示:
loss=w0×lossCE+w1×lossReg1
其中,lossCE为分类损失值,w0为分类损失值对应的权重系数,lossReg1为回归损失值,w1为回归损失值对应的权重系数,loss为总损失值。
S202:在预设的训练次数内,总损失值不再下降,或者,达到预设的训练次数,确定图像角度分类模型满足训练停止条件。
在具体实现中,本实施例的中训练停止条件可以包括以下几种:
在一种实现方式中,图像角度分类模型的训练停止条件:在预设的训练次数内,总损失值不再下降,确定图像角度分类模型满足训练停止条件。
具体的,本实施例中的图像角度分类模型在预设的训练次数内,无论如何在更新图像角度分类模型的参数,总损失值都不在发生变换,认为得到了训练好的图像角度分类模型。
在另一种实现方式中,图像角度分类模型的训练停止条件:达到预设的训练次数。
也就是说,本实施例中的图像角度分类模型在达到预设的训练次数时,就认为得到了训练好的图像角度分类模型。
在基于图1的具体实现中,步骤S102具体可以通过以下步骤实现,具体如下所示:
将样本图像的真实角度值和图像角度分类模型的回归分支输出的回归预测结果,输入到回归损失函数,得到回归损失值。
示例性的,本实施例中的回归损失函数可以如下所示:
lossReg1=abs(sin(2π×sigmoid(p)-sin(q))+
abs(cos(2π×sigmoid(p))-cos(q))
其中,q是样本图像的真实角度,p是图像角度分类模型的回归分支输出的回归预测结果,sin()为正弦函数、cos()为余弦函数,sigmoid()是S型函数,abs()为绝对值函数,lossReg1为回归损失值。
在具体实现中,回归分支输出的回归预测结果为一个数值,需要对回归预测结果进行后处理,才能得到回归预测结果对应的角度,例如,将回归预测结果利用sigmoid()函数进行处理,得到一个基于0~1之间的数值,之后,将该数值乘以2π,即可得到对应的度数。
可见,在上述的第一回归损失函数中,首先计算出样本图像的回归预测结果的正弦值与样本图像的真实角度的正弦值的正弦差值,然后对该正弦差值取绝对值,即
abs(sin(2π×sigmoid(p)-sin(q))
然后计算出样本图像的回归预测结果的余弦值与样本图像的真实角度的余弦值的余弦差值,然后对该余弦差值取绝对值,即
abs(cos(2π×sigmoid(p))-cos(q))
然后求出该正弦差值的绝对值和余弦差值的绝对值和,所得结果即为回归损失值。
可以理解的是,当存在两个回归损失函数时,总损失计算公式也需要进行对应的调整。
示例性的,总损失计算公式可以如下所示
loss=w0×lossCE+w1×lossReg1+w2×lossReg2
其中,其中,lossCE为分类损失值,w0为分类损失值对应的权重系数,lossReg1为回归损失值,w1为回归损失值对应的权重系数,lossReg2为回归损失值,w2为回归损失值对应的权重系数,loss为总损失值。
在基于图1的具体实现中,步骤S102也可以通过以下步骤实现,具体如下所示:
将样本图像的真实角度值和图像角度分类模型的回归分支输出的回归预测结果,分别输入到第一回归损失函数和第二回归损失函数,得到第一回归损失值和第二回归损失值。
示例性的,本实施例中的第一回归损失函数可以如下所示:
lossReg1=abs(sin(2π×sigmoid(p)-sin(q))+
abs(cos(2π×sigmoid(p))-cos(q))
其中,q是样本图像的真实角度,p是图像角度分类模型的回归分支输出的回归预测结果,sin()为正弦函数、cos()为余弦函数,sigmoid()是S型函数,abs()为绝对值函数,lossReg1为第一回归损失值。
在具体实现中,回归分支输出的回归预测结果为一个数值,需要对回归预测结果进行后处理,才能得到回归预测结果对应的角度,例如,将回归预测结果利用sigmoid()函数进行处理,得到一个基于0~1之间的数值,之后,将该数值乘以2π,即可得到对应的度数。
可见,在上述的第一回归损失函数中,首先计算出样本图像的回归预测结果的正弦值与样本图像的真实角度的正弦值的正弦差值,然后对该正弦差值取绝对值,即
abs(sin(2π×sigmoid(p)-sin(q))
然后计算出样本图像的回归预测结果的余弦值与样本图像的真实角度的余弦值的余弦差值,然后对该余弦差值取绝对值,即
abs(cos(2π×sigmoid(p))-cos(q))
然后求出该正弦差值的绝对值和余弦差值的绝对值和,所得结果即为第一回归损失值。
示例性的,第二回归损失函数可以如下所示:
Figure BDA0003999600670000101
其中,q是样本图像的真实角度,p是回归分支的回归预测结果,sigmoid()是S型函数,abs()是绝对值函数,min()表示取较小的值,lossReg2为第二回归损失函数。
可见,在上述的第二回归损失函数中,通过sigmoid()将回归分支的回归预测结果转换为0~1的数值,然后将样本图像的真实角度通过除以2π得到一个0~1的数值,然后求出这两个数值的差值的绝对值,即
Figure BDA0003999600670000102
之后,基于上述
Figure BDA0003999600670000103
得到的绝对值,用整数1减去该绝对值,即
Figure BDA0003999600670000104
然后通过min()函数在该绝对值和用整数1减去该绝对值得到数值中,选择较小的数值作为第二回归损失值。
综上所述,本实施例中可以通过一个回归损失函数来对回归分支进行监督,也可以通过两个回归损失函数来对回归分支进行监督,以两个回归损失函数得到的损失值的差异,使得图像角度分类模型的训练效果更好。也就是说,根据两个回归损失函数得到的损失值是不同,两个回归损失函数在模型训练过程中可以起到互补的作用,基于这两个回归损失函数进行模型训练,可以得到训练效果更好的模型图像角度分类模型。
在基于图1的具体实现中,分类预测结果可以通过方式获取到,如图3所示:
本实施例中的图像角度分类模型主要包含四个部分:主干网络、池化层以及两个全连接层,其中,分类分支为这两个全连接层中的一个全连接层,回归分支为这两个全连接层中的另一个全连接层,如图4所示。
S301:将样本图像输入到主干网络,获取到样本图像中的特征图像。
将样本图像输入到图像角度分类模型的主干网络,通过主干网络在样本图像中提取出包含角度特征的图像,即特征图像,然后基于这些特征图像,执行对图像角度分类模型进行训练的后续步骤。
在具体实现中,本实施例中的主干网络可以具备图像特征提取功能的模型,例如ResNet等等,将样本图像输入到ResNet模型中,由ResNet模型对样本图像进行处理,以输出样本图像中具备角度特征的特征图像。
需要注意的是,本实施例中的图像角度分类模型的主干网络,并不限于ResNet模型一种,也可以使用其他的模型来作为图像角度分类模型的主干网络,如MobileNet模型,因此,本实施例中对图像角度分类模型的主干网络所使用的模型结构不做具体限定。
S302:对特征图像进行池化处理,得到样本图像的角度特征。
将特征图像输入到池化层进行池化处理,进而将特征图像转化为对应的角度特征,据此,得到样本图像的角度特征。
在具体实现中,本实施例中的池化层可以对特征图像进行全局平均池化,从特征图像中提取到样本图像的角度特征,或者对特征图像进行全局最大池化,从特征图像中提取到样本图像的角度特征,本实施例中对池化层对特征图像的池化处理方式不做具体限定。
S303:将角度特征输入到分类分支,得到样本图像的分类预测结果。
将样本图像的角度特征输入到分类分支的全连接层,进而根据角度特征进行角度预测,得到样本图像的分类预测结果。
在基于图3的具体实现中,步骤S303可以通过以下步骤实现,如图5所示:
S501:将角度特征输入到分类分支,得到每一个标签向量的预测概率值。
其中,每个标签向量对应一个角度值,或者一个角度值范围。
将样本图像的角度特征输入到图像角度分类模型的分类分支后,会根据预设的标签向量,输出每一个标签向量对应的预测概率值。其中,每一个标签向量对应一个分类类比,该分类类别可以为一个角度值,或者,也可以是一个角度值范围。
例如,以每一个度数为一个分类类别为例,也就是说角度为1度是一个分类类别,角度为2度是一个分类类别,角度为3度是一个分类类别,以此类推,共有360个分类类别,也就是存在360个元素的标签向量,在将样本图像的角度特征输入到图像角度分类模型的分类分支后,输出这360个标签向量中每一个标签向量对应的预测概率值,也就是说,存在360个预测概率值。据此,可以得到每一个标签向量的预测概率值。
再如,以10度的角度值范围为一个分类类别为例,也就是说角度为0~10度是一个分类类别,角度为11~20度是一个分类类别,角度为21~30度是一个分类类别,以此类推,共有36个分类类别,也就是存在36个元素的标签向量,在将样本图像的角度特征输入到图像角度分类模型的分类分支后,输出这360个标签向量中每一个标签向量对应的预测概率值,也就是说,存在36个预测概率值。据此,可以得到每一个标签向量的预测概率值。
需要注意的是,上述划分分类类别的方法,仅是本实施例中部分实现方式,也可以通过其他的方法来划分分类类别,如以5度的角度范围为一个分类类别,划分为72的类别等等,本实施例中不做具体限定。
S502:输出最大的预测概率对应的标签向量,作为分类预测结果。
根据图像角度分类模型的分类分支输出的标签向量及对应的预测概率,从中选择预测概率最大的标签向量作为分类预测结果,由此,可以得到样本图像的分类预测结果。可以理解的是,本实施例中得到的分类预测结果仅是一个标签向量,并不是具体的度数,需要根据标签向量匹配到对应的角度值,或者角度值范围,即可得到样本图像的对应的角度值,或者角度值范围。
例如,以每一个度数为一个分类类别为例,也就是说,存在360个分类类别及对应的标签向量,即存在360个标签向量。将样本图像输入到图像角度分类模型中后,图像角度分类模型的分类分支会输出360个标签向量及每个标签向量对应的预测概率值,通过softmax()函数对着360个标签向量及对应的预测概率值进行归一化处理,选取预测概率值最大的标签向量作为分类预测结果。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
实施例二
本发明实施例二公开的一种图像角度判定方法,该方法基于上述任一实施例公开的图像角度分类模型训练方法训练得到的图像角度分类模型实现,其模型结构参考图4所示。
将目标图像输入到图像角度分类模型,以实现如下步骤:
将目标图像输入到主干网络,获取到目标图像中的特征图像。
在具体实现中,将目标图像输入到图像角度分类模型的主干网络,通过主干网络在样本图像中提取出包含角度特征的图像,即特征图像,然后将这些特征图像,输入到池化层中,执行图像角度判定的后续步骤。
在具体实现中,本实施例中的主干网络可以具备图像特征提取功能的模型,例如ResNet等等,将样本图像输入到ResNet模型中,由ResNet模型对目标图像进行处理,以输出目标图像中具备角度特征的特征图像。
需要注意的是,本实施例中的图像角度判定方法所使用的图像角度分类模型的主干网络,并不限于ResNet模型一种,也可以使用其他的模型来作为图像角度分类模型的主干网络,如MobileNet模型,因此,本实施例中对图像角度分类模型的主干网络所使用的模型结构不做具体限定。
对特征图像进行池化处理,得到目标图像的角度特征。
在具体实现中,将特征图像输入到池化层进行池化处理,进而将特征图像转化为对应的角度特征,之后,将得到的角度特征输入到分类分支,执行图像角度判定的后续步骤。
在具体实现中,本实施例中的池化层可以对特征图像进行全局平均池化,从特征图像中提取到样本图像的角度特征,或者对特征图像进行全局最大池化,从特征图像中提取到样本图像的角度特征,本实施例中对池化层对特征图像的池化处理方式不做具体限定。
将角度特征输入到分类分支,得到目标图像的分类预测结果。
将目标图像的角度特征输入到分类分支的全连接层,进而根据角度特征进行角度预测,得到目标图像的分类预测结果。
将目标图像的角度特征输入到图像角度分类模型的分类分支后,会根据预设的标签向量,输出每一个标签向量对应的预测概率值,对每个所述标签向量及对应的预测概率值进行归一化处理,选择预测概率值最大的标签向量作为分类预测结果,将分类预测结果对应的角度值作为目标图像角度值。其中,每一个标签向量对应一个分类类比,即一个角度值,或者一个角度值范围。
例如,以每一个度数为一个分类类别为例,也就是说角度为1度是一个分类类别,角度为2度是一个分类类别,角度为3度是一个分类类别,以此类推,共有360个分类类别,也就是存在360个标签向量,在将目标图像的角度特征输入到图像角度分类模型的分类分支后,输出这360个标签向量中每一个标签向量对应的预测概率值,也就是说,存在360个预测概率值,然后对这360个标签向量及对应的预测概率值进行归一化处理,得到预测概率值最大的标签向量作为分类预测结果,将分类预测结果对应的角度值作为目标图像角度值。据此,可以得到每一个标签向量的预测概率值。
需要注意的是,上述图像角度判定方法所使用的图像角度分类模型在训练好之后,可以舍弃图像角度分类模型的回归分支,仅使用分类分支来进行图像角度判定。也就是说,仅以图像角度分类模型的分类分支输出的分类预测结果作为图像角度分类模型的预测结果。
实施例三
如图6所示,为本发明实施例三公开的一种图像角度分类模型训练装置的实现流程图,该装置在图像角度分类模型训练过程中,在基于分类分支的分类损失值判定图像角度分类模型是否满足训练停止条件的基础上,增加回归分支的回归损失值,由此,通过回归损失值和分类损失值综合判定图像角度分类模型是否满足训练停止条件,由此,可以达到提高图像角度分类模型的预测准确率的目的。
本实施例中的图像角度分类模型训练装置,具体可以包括以下单元:
预测结果输出单元601,用于将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
回归损失值计算单元602,用于基于样本图像的真实角度和回归预测结果,计算得到回归损失值;
分类损失值计算单元603,用于基于样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
训练停止判断单元604,用于基于分类损失值和回归损失值,判断图像角度分类模型是否满足训练停止条件;在满足训练停止条件的情况下,确定得到训练好的图像角度分类模型。
综上,本发明实施例三公开的一种图像角度分类模型训练装置,通过在图像角度分类模型的训练过程中,在计算分类分支的分类损失值得基础上,增加回归分支的回归损失值,以分类损失值和回归损失值来综合判断图像角度分类模型训练是否满足训练停止条件,由此,得到训练好的图像角度分类模型。可见,本发明中同时基于分类损失值和回归损失值来判定图像角度分类模型是否满足训练停止条件,相较于仅基于分类损失值来判定图像角度分类模型是否满足训练停止条件,可以提高图像角度分类模型的训练效果,以达到提高图像角度分类模型的准确率的目的。
在一种实现方式中,训练停止判断单元604具体用于:
基于分类损失值和回归损失值,得到总损失值;
在预设的训练次数内,总损失值不再下降,或者,达到预设的训练次数,确定图像角度分类模型满足训练停止条件。
在一种实现方式中,回归损失值计算单元602具体用于:
将真实角度值和回归预测结果输入到回归损失函数,得到回归损失值。
在一种实现方式中,回归损失函数包括第一回归损失函数和第二回归损失函数;
回归损失值计算单元602也可以用于:
将真实角度值和回归预测结果,分别输入到第一回归损失函数和第二回归损失函数,得到第一回归损失值和第二回归损失值。
在一种实现方式中,分类预测结果通过以下方式得到:
将样本图像输入到主干网络,获取到样本图像中的特征图像;
对特征图像进行池化处理,得到样本图像的角度特征;
将角度特征输入到分类分支,得到样本图像的分类预测结果。
在一种实现方式中,样本图像的分类预测结果可以通过以下方式得到:
将角度特征输入到分类分支,得到每一个标签向量的预测概率值;每个标签向量对应一个角度值,或者一个角度值范围;
输出最大的预测概率对应的标签向量,作为分类预测结果。
关于图像角度分类模型训练装置的具体限定,可以参见上文中对于图像角度分类模型训练方法的有关限定,在此不再赘述。上述图像角度分类模型训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
实施例四
本申请实施例四公开了一种计算机设备,该计算机设备可以是服务端,其内部结构图可以如图7所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种图像角度分类模型训练方法。
在一个实施例中,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现以下步骤:
将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
基于样本图像的真实角度和回归预测结果,计算得到回归损失值;
基于样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
基于分类损失值和回归损失值,判断图像角度分类模型是否满足训练停止条件;
在满足训练停止条件的情况下,确定得到训练好的图像角度分类模型。
实施例五
本申请实施例五公开了一种计算机可读存储介质,当计算机可读存储介质中的指令由计算机设备中的处理器执行时,使得计算机设备能够执行如本发明公开的一种图像角度分类模型训练方法的任一实施例的各个步骤。所述计算机可读存储介质可以是非易失性,也可以是易失性。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现以下步骤:
将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
基于样本图像的真实角度和回归预测结果,计算得到回归损失值;
基于样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
基于分类损失值和回归损失值,判断图像角度分类模型是否满足训练停止条件;
在满足训练停止条件的情况下,确定得到训练好的图像角度分类模型。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,该计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。

Claims (10)

1.一种图像角度分类模型训练方法,其特征在于,所述方法包括:
将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
基于所述样本图像的真实角度和所述回归预测结果,计算得到回归损失值;
基于所述样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件;
在满足训练停止条件的情况下,确定得到训练好的所述图像角度分类模型。
2.如权利要求1所述的图像角度分类模型训练方法,其特征在于,所述基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件,包括:
基于所述分类损失值和所述回归损失值,得到总损失值;
在预设的训练次数内,所述总损失值不再下降,或者,达到所述预设的训练次数,确定所述图像角度分类模型满足所述训练停止条件。
3.如权利要求1所述的图像角度分类模型训练方法,其特征在于,所述基于所述真实角度值和所述回归预测结果,计算得到回归损失值,包括:
将所述真实角度值和所述回归预测结果输入到回归损失函数,得到回归损失值。
4.如权利要求3所述的图像角度分类模型训练方法,其特征在于,所述回归损失函数包括第一回归损失函数和第二回归损失函数;
其中,将所述真实角度值和所述回归预测结果输入到回归损失函数,得到回归损失值,包括:
将所述真实角度值和所述回归预测结果,分别输入到所述第一回归损失函数和所述第二回归损失函数,得到第一回归损失值和第二回归损失值。
5.如权利要求1所述的图像角度分类模型训练方法,其特征在于,所述分类预测结果通过以下方式得到:
将所述样本图像输入到主干网络,获取到所述样本图像中的特征图像;
对所述特征图像进行池化处理,得到所述样本图像的角度特征;
将所述角度特征输入到分类分支,得到所述样本图像的分类预测结果。
6.如权利要求5所述的图像角度分类模型训练方法,其特征在于,所述将所述角度特征分别输入到分类分支,得到样本图像的分类预测结果,包括:
将所述角度特征输入到分类分支,得到每一个标签向量的预测概率值;每个所述标签向量对应一个角度值,或者一个角度值范围;
输出最大的所述预测概率对应的所述标签向量,作为所述分类预测结果。
7.一种图像角度判定方法,其特征在于,所述方法包括:获取到目标图像,将所述目标图像输入到权利要求1-6任一所述的图像角度分类模型训练方法所得到的图像角度分类模型,获得所述分类分支输出的所述目标图像的预测结果。
8.一种图像角度分类模型训练装置,其特征在于,包括:
预测结果输出单元,用于将样本图像输入到图像角度分类模型进行训练,得到分类分支输出的分类预测结果和回归分支输出的回归预测结果;
回归损失值计算单元,用于基于所述样本图像的真实角度和所述回归预测结果,计算得到回归损失值;
分类损失值计算单元,用于基于所述样本图像的真实标签向量和分类预测结果,计算得到分类损失值;
训练停止判断单元,用于基于所述分类损失值和所述回归损失值,判断所述图像角度分类模型是否满足训练停止条件;在满足训练停止条件的情况下,确定得到训练好的所述图像角度分类模型。
9.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至6任一项所述图像角度分类模型训练方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述图像角度分类模型训练方法。
CN202211608633.1A 2022-12-14 2022-12-14 图像角度分类模型训练方法、装置、设备及存储介质 Pending CN116363445A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211608633.1A CN116363445A (zh) 2022-12-14 2022-12-14 图像角度分类模型训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211608633.1A CN116363445A (zh) 2022-12-14 2022-12-14 图像角度分类模型训练方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN116363445A true CN116363445A (zh) 2023-06-30

Family

ID=86930175

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211608633.1A Pending CN116363445A (zh) 2022-12-14 2022-12-14 图像角度分类模型训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN116363445A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117333539A (zh) * 2023-10-09 2024-01-02 南京华麦机器人技术有限公司 一种面向移动机器人的充电桩定位方法及装置

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117333539A (zh) * 2023-10-09 2024-01-02 南京华麦机器人技术有限公司 一种面向移动机器人的充电桩定位方法及装置

Similar Documents

Publication Publication Date Title
JP6483667B2 (ja) ベイズの最適化を実施するためのシステムおよび方法
CN111950638B (zh) 基于模型蒸馏的图像分类方法、装置和电子设备
CN110347971B (zh) 基于tsk模糊模型的粒子滤波方法、装置及存储介质
JP6610278B2 (ja) 機械学習装置、機械学習方法及び機械学習プログラム
US20210012183A1 (en) Method and device for ascertaining a network configuration of a neural network
CN110824587B (zh) 图像预测方法、装置、计算机设备和存储介质
CN110705718A (zh) 基于合作博弈的模型解释方法、装置、电子设备
CN111767400A (zh) 文本分类模型的训练方法、装置、计算机设备和存储介质
CN109271958A (zh) 人脸年龄识别方法及装置
CN113328908B (zh) 异常数据的检测方法、装置、计算机设备和存储介质
CN112232426A (zh) 目标检测模型的训练方法、装置、设备及可读存储介质
CN114078195A (zh) 分类模型的训练方法、超参数的搜索方法以及装置
CN114360520A (zh) 语音分类模型的训练方法、装置、设备及存储介质
CN110909784A (zh) 一种图像识别模型的训练方法、装置及电子设备
CN116363445A (zh) 图像角度分类模型训练方法、装置、设备及存储介质
CN107292323B (zh) 用于训练混合模型的方法和设备
CN111159481B (zh) 图数据的边预测方法、装置及终端设备
CN114861859A (zh) 神经网络模型的训练方法、数据处理方法及装置
CN115186062A (zh) 多模态预测方法、装置、设备及存储介质
CN115797735A (zh) 目标检测方法、装置、设备和存储介质
CN113762005B (zh) 特征选择模型的训练、对象分类方法、装置、设备及介质
CN114782758B (zh) 图像处理模型训练方法、系统、计算机设备及存储介质
CN116468967B (zh) 样本图像筛选方法、装置、电子设备及存储介质
CN117115551A (zh) 一种图像检测模型训练方法、装置、电子设备及存储介质
CN113744060A (zh) 基于融合模型的数据预测方法、装置、设备以及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination