nn.CrossEntropyLoss
pytorch
本文字数:1.2k 字 | 阅读时长 ≈ 5 min

nn.CrossEntropyLoss

pytorch
本文字数:1.2k 字 | 阅读时长 ≈ 5 min

如果想直接看 CrossEntropyLoss 的作用可以直接看第三节

1. torch.nn.LogSoftmax

torch.nn.LogSoftmax(dim=None)

此函数计算输入向量的 Softmax 并取对数
$$
LogSoftmax(x_{i})=log(\frac{exp(x_{i})}{\sum_{j}exp(x_{j})})
$$

举例:

# 定义tensor向量
batch_size = 2
class_num = 3
inputs = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        inputs[i][j] = (i + 1) * (j + 1)
print("inputs:", inputs)
'''
inputs: tensor([[1., 2., 3.],
                [2., 4., 6.]])
'''


# 使用softmax函数求解
softmax = nn.Softmax(dim=1) # 对行求softmax
probs = softmax(inputs)
print("probs:\n", probs)
'''
取第一个值进行验证
print(exp(1)/(exp(1)+exp(2)+exp(3))):0.09003057317038046
probs:
 tensor([[0.0900, 0.2447, 0.6652],
         [0.0159, 0.1173, 0.8668]])
'''


# 使用LogSoftmax函数求解
LogSoftmax = nn.LogSoftmax(dim=1) # 对行求softmax之后求log
log_probs = LogSoftmax(inputs)
print("log_probs:\n", log_probs)
'''
log_probs:
  tensor([[-2.4076, -1.4076, -0.4076],
          [-4.1429, -2.1429, -0.1429]])
'''


# 下面我们将Softmax的结果取log看看与上述结果是否相同
print(torch.log(probs))
'''
tensor([[-2.4076, -1.4076, -0.4076],
        [-4.1429, -2.1429, -0.1429]])
'''

nn.Softmax()函数和nn.LogSoftmax()函数的唯一区别是nn.LogSoftmax函数在求出Sofmax值之后会取自然对数e

2. torch.nn.NLLLOSS

torch.nn.NLLLOSS(reduction='mean')

直接举例:

# 首先定义tensor
batch_size = 2
class_num = 3
inputs = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        inputs[i][j] = (i + 1) * (j + 1)
print("inputs:", inputs)
'''
inputs: tensor([[1., 2., 3.],
                [2., 4., 6.]])
'''

nlloss = nn.NLLLoss()
target = torch.empty(2, dtype=torch.long).random_(3)
output = nlloss(inputs, target)
print(target)
print(output)
'''
tensor([2, 2])
tensor(-4.5000)
'''

NLLLOSS取了对应target的数值,然后相加求平均取负,比如上述例子,target分别为2和2,所以tensor的第一行和第二行分别取3和6,然后相加取平均最后加负号就是最终结果-4.5

3. torch.nn.CrossEntropyLoss

torch.nn.CrossEntropyLoss(reduction='mean')

明白上述两个函数LogSoftmax以及NLLLoss之后,CrossEntropyLoss其实就是将上述两个函数进行了结合

3.1 首先讲解交叉熵的原理

在分类问题中,假设我们有两个样本,这两个样本有三个类别可以分。比如我们有两只小动物,现在有猫狗猪三个类别,经过网络计算结果如下

predict real
(0.1, 0.2, 0.7) (0, 0, 1) 正确
(0.5, 0.2, 0.3) (0, 1, 0) 错误

使用交叉熵损失公式来计算我们的结果

二元分类:
$$
L = \frac{1}{N}\sum_{i}L_{i} = -\frac{1}{N}\sum_{i}[y_{i}log(p_{i})+(1-y_{i})log(1-p_{i})]
$$

多元分类:
$$
L = \frac{1}{N}\sum_{i}L_{i} = -\frac{1}{N}\sum_{i}\sum_{c=1}^{M}y_{ic}log(p_{ic})
$$

上述损失:
$$
\begin{aligned}
& sample~1:loss = -[0\times log(0.1)+0\times log(0.2)+1\times log(0.7)] = 0.36 \
& sample~2:loss = -[0\times log(0.5)+1\times log(0.2)+0\times log(0.3)] = 1.61 \
& loss_{ave} = \frac{0.36+1.61}{2} = 0.96
\end{aligned}
$$

3.2 代码解释

明白原理之后我们来看一下具体的代码公式

# 定义input和target
batch_size = 2
class_num = 3
input = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        input[i][j] = (i + 1) * (j + 1)
print("inputs:", input)
target = torch.empty(2, dtype=torch.long).random_(3)
print("targets:", target)
'''
inputs: tensor([[1., 2., 3.],
                [2., 4., 6.]])
targets: tensor([2, 1])
'''


# 计算CrossEntropyLoss
loss = nn.CrossEntropyLoss()
output = loss(input, target)
print(output)
'''
tensor(1.2753)
'''

下面我们用LogSoftmax以及NLLLoss拆分来看看CrossEntropyLoss具体作了什么

  1. LogSoftmax函数
LogSoftmax = nn.LogSoftmax(dim=1)
log_probs = LogSoftmax(input)
print("log_probs:\n", log_probs)
'''
log_probs:
 tensor([[-2.4076, -1.4076, -0.4076],
         [-4.1429, -2.1429, -0.1429]])
'''

这一步主要进行了两个操作,一个是取Softmax,然后进行log

求解之后和上述LogSoftmax函数结果一致

Softmax = nn.Softmax(dim=1)
probs = Softmax(input)
print(torch.log(probs))
'''
tensor([[-2.4076, -1.4076, -0.4076],
        [-4.1429, -2.1429, -0.1429]])
'''
  1. NLLLoss函数

此函数在第二节介绍过,这里我们在第一步LogSoftmax函数的基础上直接用此函数

nllloss = nn.NLLLoss()
output2 = nllloss(log_probs, target)
print(output2)
'''
tensor(1.2753)
'''

发现这两步合起来的结果和只用CrossEntropyLoss的结果一致,我们看上述结果,1.2753是怎么的出来的呢?注意看

# 经过LogSoftmax处理后的结果
tensor([[-2.4076, -1.4076, -0.4076],
        [-4.1429, -2.1429, -0.1429]])
# 目标target
targets: tensor([2, 1])

$L=\frac{1}{N}\sum_{i}L_{i}=-\frac{1}{N}\sum_{i}\sum_{c=1}^{M}y_{ic}log(p_{ic})$ 由于正确样本target为1,错误的为0,因此只需要计算正确的损失即可
即$L=\frac{1}{N}\sum_{i}L_{i}=-\frac{1}{N}\sum_{i}y_{target}log(p_{target})~~Where i is the number of samples$

$$
\begin{aligned}
loss = -\frac{-0.4076-2.1429}{2} = 1.2753 \
= -[1\times (-0.4076) + 1\times (-2.1429)]
\end{aligned}
$$

4月 06, 2025
3月 10, 2025
12月 31, 2024