李沐Softmax回归从零开始实现代码中的关于y和y_hat
原视频:李沐Softmax回归从零开始实现
其中,这段代码令人迷惑。
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat[[0, 1], y]
视频文字上的注释是:
创建一个数据
y_hat
,其中包含2个样本在3个类别上的预测概率,使用y
作为y_hat
中概率的索引。
为什么介绍这段代码?因为为了介绍交叉熵。
在之前的课程中提到,对真实 y
进行独热编码。
比如,共有 3 类,则真实输出 y = [ 0 , 0 , 1 ] \bold y = [0, 0, 1] y=[0,0,1],即表示:真实的类别是第3类。
最后发现,交叉熵损失等于 − l o g ( y y ^ ) -log(\hat{y_y}) −log(yy^),就是 i = y 真实类别的预测概率 y ^ \hat{y} y^。
但是,这里的 y 不表示这个含义。这里的 y 表示 2 个样本的真实类别分别是 0 和 2(类别有 [0, 1, 2])
而之前的独热编码 y 表示为 1 个样本的真实类别:[0, 0, 1]。第 2 个是1,则表示第 2 个为真实类别。所以独热编码的y要写成上述代码的y,可以写成:y = [2]
当把 y 写成独热编码,是为了方便解释:交叉熵损失的预测概率只需要真实类别的预测概率,并对其求-log。
那么,既然如此,代码中的 y 就表示 index,就告诉你哪一个是真实类别的预测概率,那么要计算交叉熵损失就直接根据 index 在 y_hat 里面取就行。