当前位置: 首页 > news >正文

分类模型为什么使用交叉熵作为损失函数

推导过程

让推理更有体感,进行下面假设:

  1. 假设要对进行图片识别分类
  2. 假设模型输出 y y y,是一个几率,表示是猫的概率

训练资料如下:

x n x^n xn类别 y ^ n \widehat{y}^n y n
x 1 x^1 x11
x 2 x^2 x21
x 3 x^3 x30

注: x 1 x^1 x1是第一组训练资料它是属于猫,因为我们使用one-hot来表示目标类别,所以 y ^ i n \widehat{y}^n_i y in要么等于0,要么等于1

损失函数怎么定义比较好?我们优先想到的是判断结果是否和真实值相等

L o s s = [ f ( x 1 ) ≠ y ^ 1 ] + [ f ( x 2 ) ≠ y ^ 2 ] + [ f ( x 3 ) ≠ y ^ 3 ] Loss=[f(x^1) \neq \widehat{y}^1]+[f(x^2) \neq \widehat{y}^2]+[f(x^3) \neq \widehat{y}^3] Loss=[f(x1)=y 1]+[f(x2)=y 2]+[f(x3)=y 3]
f ( x n ) = { 1 , y n > 0.5 0 , y n < = 0.5 f(x^n)= \begin{dcases} 1 , y^n > 0.5 \\ 0 ,y^n <=0.5 \end{dcases} f(xn)={1yn>0.50yn<=0.5
只需要找到 w ∗ , b ∗ = arg ⁡ m i n w , b L ( w , b ) w^*,b^*=\arg\underset{w,b}{min}L(w,b) w,b=argw,bminL(w,b) 使得Loss最小即可
但是上面的 f ( x n ) f(x^n) f(xn)无法进行微分,不能计算梯度

所以重新寻找Loss函数:
L o s s = f ( x 1 ) + f ( x 2 ) + ( 1 − f ( x 3 ) ) = y 1 + y 2 + ( 1 − y 3 ) Loss=f(x^1)+f(x^2)+(1-f(x^3))= y^1+y^2+(1-y^3) Loss=f(x1)+f(x2)+(1f(x3))=y1+y2+(1y3)
Loss越大说明和训练集越相似,效果越好,我们希望找到一个 w ∗ , b ∗ = arg ⁡ m a x w , b L ( w , b ) w^*,b^*=\arg\underset{w,b}{max}L(w,b) w,b=argw,bmaxL(w,b) 使得Loss最大

但是Loss多大算大?我们还是希望找到一个最小Loss,最好趋近于0
所以对Loss再次变形,对Loss加一个 负 l n ln ln,我们就可以求Loss的最小值了
w ∗ , b ∗ = arg ⁡ m a x w , b L ( w , b ) = arg ⁡ m i n w , b − l n L ( w , b ) w^*,b^*=\arg\underset{w,b}{max}L(w,b)=\arg\underset{w,b}{min}-lnL(w,b) w,b=argw,bmaxL(w,b)=argw,bminlnL(w,b)

推导:
L o s s = − [ l n f ( x 1 ) + l n f ( x 2 ) + ( 1 − l n f ( x 3 ) ) ] Loss=-[lnf(x^1)+lnf(x^2)+(1-lnf(x^3))] Loss=[lnf(x1)+lnf(x2)+(1lnf(x3))]

因为 y ^ i n \widehat{y}^n_i y in要么等于0,要么等于1,所以可得: { l n f ( x n ) = y ^ n l n f ( x n ) + ( 1 − y ^ n ) l n ( 1 − l n f ( x n ) ) 1 − l n f ( x n ) = y ^ n l n f ( x n ) + ( 1 − y ^ n ) l n ( 1 − l n f ( x n ) ) \begin{dcases} lnf(x^n)=\widehat{y}^nlnf(x^n)+(1-\widehat{y}^n)ln(1-lnf(x^n)) \\ 1-lnf(x^n)=\widehat{y}^nlnf(x^n)+(1-\widehat{y}^n)ln(1-lnf(x^n)) \end{dcases} {lnf(xn)=y nlnf(xn)+(1y n)ln(1lnf(xn))1lnf(xn)=y nlnf(xn)+(1y n)ln(1lnf(xn))

= − [ y ^ 1 l n f ( x 1 ) + ( 1 − y ^ 1 ) l n ( 1 − l n f ( x 1 ) ) + y ^ 2 l n f ( x 2 ) + ( 1 − y ^ 2 ) l n ( 1 − l n f ( x 2 ) ) + y ^ 3 l n f ( x 3 ) + ( 1 − y ^ 3 ) l n ( 1 − l n f ( x 3 ) ) ] =-[\widehat{y}^1lnf(x^1)+(1-\widehat{y}^1)ln(1-lnf(x^1))+\widehat{y}^2lnf(x^2)+(1-\widehat{y}^2)ln(1-lnf(x^2))+\widehat{y}^3lnf(x^3)+(1-\widehat{y}^3)ln(1-lnf(x^3))] =[y 1lnf(x1)+(1y 1)ln(1lnf(x1))+y 2lnf(x2)+(1y 2)ln(1lnf(x2))+y 3lnf(x3)+(1y 3)ln(1lnf(x3))]

= − ∑ [ y ^ n l n f ( x n ) + ( 1 − y ^ n ) l n ( 1 − l n f ( x n ) ) ] =-\sum[\widehat{y}^nlnf(x^n)+(1-\widehat{y}^n)ln(1-lnf(x^n))] =[y nlnf(xn)+(1y n)ln(1lnf(xn))]

设: p ( x ) p(x) p(x)为二项分布,其中 p ( 1 ) = y ^ n p(1)=\widehat{y}^n p(1)=y n p ( 0 ) = 1 − y ^ n p(0)=1-\widehat{y}^n p(0)=1y n
设: q ( x ) q(x) q(x)为二项分布,其中 q ( 1 ) = f ( x n ) q(1)=f(x^n) q(1)=f(xn) q ( 0 ) = 1 − f ( x n ) q(0)=1-f(x^n) q(0)=1f(xn)
= − ∑ i = 1 N p ( x ) l n ( q ( x ) ) =-\displaystyle\sum_{i=1}^Np(x)ln(q(x)) =i=1Np(x)ln(q(x))
= ∑ i = 1 N p ( x ) l n ( 1 q ( x ) ) =\displaystyle\sum_{i=1}^Np(x)ln(\frac{1}{q(x)}) =i=1Np(x)ln(q(x)1)

**交叉熵(cross entropy)**的数学公式如下:
![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/9dcfc80e77e349bdb2f256cbb4097b49.png

  • 如果我们把模型输出和真实值看做是一个二项分布的话,那么Loss的最终定义就是这两个二项分布越接近越好

交叉熵可在神经网络中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量p与q的相似性。交叉熵作为损失函数还有一个好处是使用sigmoid函数在梯度下降时能避免均方误差损失函数学习速率降低的问题,因为学习速率可以被输出的误差所控制。

问题

为什么不使用平均方差作为Loss函数呢?

  1. 假设 L o s s = 1 2 ∑ ( f ( x n ) − y ^ n ) 2 Loss=\frac{1}{2}\sum(f(x^n)-\widehat{y}^n)^2 Loss=21(f(xn)y n)2
  2. 假设用的是sigmoid函数

求导之后:
在这里插入图片描述
注: f ( x ) f(x) f(x)的导数是 f ( x ) ( 1 − f ( x ) ) f(x)(1-f(x)) f(x)(1f(x))
对于红色字体中的式子来说:
当y=1,f(x) = 1的时候,gradient=0,那么暂停训练是合理的
但y=1,f(x) = 0,这个时候和实际值有差距,应该继续训练,但gradient=0了
在这里插入图片描述

分类如果超过2维怎么办?

如果还是使用one-hot来表示目标类别,那么输出变为了多个,经过softmax函数之后都是0~1之间的数,把它看做是概率,就是N个二项式概率分布的相似度求和。


http://www.mrgr.cn/news/83523.html

相关文章:

  • [QCustomPlot] 交互示例 Interaction Example
  • 操作手册:集成钉钉审批实例消息监听配置
  • 《HeadFirst设计模式》笔记(下)
  • Leetcode 967 Numbers With Same Consecutive Differences
  • wordpress 房产网站筛选功能
  • C++实现设计模式---单例模式 (Singleton)
  • Spring——几个常用注解
  • mybatis分页插件:PageHelper、mybatis-plus-jsqlparser(解决SQL_SERVER2005连接分页查询OFFSET问题)
  • 【leetcode刷题】:双指针篇(有效三角形的个数、和为s的两个数)
  • 文献阅读分享:XSimGCL - 极简图对比学习在推荐系统中的应用
  • 【大数据】Apache Superset:可视化开源架构
  • PatchTST:通道独立的、切片的 时序 Transformer
  • 【JVM-2.3】深入解析JVisualVM:Java性能监控与调优利器
  • 25/1/12 嵌入式笔记 学习esp32
  • Elasticsearch快速入门
  • 浅谈云计算03 | 云计算的技术支撑(云使能技术)
  • 现代 CPU 的高性能架构与并发安全问题
  • AWS简介
  • 【Excel/WPS】根据平均值,随机输出三个范围在80到100的随机值。
  • 从预训练的BERT中提取Embedding
  • 机械燃油车知识图谱、知识大纲、知识结构(持续更新...)
  • 【Rust自学】11.9. 单元测试
  • qt QPainter setViewport setWindow viewport window
  • 如何在Jupyter中快速切换Anaconda里不同的虚拟环境
  • 【数通】MPLS
  • 【Bluedroid】HFP连接流程源码分析(一)