必威体育Betway必威体育官网
当前位置:首页 > IT技术

【解决办法】torch交叉熵使用时遇到 Dimension out of range

时间:2019-10-01 23:10:00来源:IT技术作者:seo实验室小编阅读:61次「手机版」
 

out of range

简述

其实这个问题我很久以前用pytorch写程序的时候就遇到过这个问题,当时纠结了很久之后最后解决了。当时本来就想来写个东西来记录下避免其他人也遇到这样的问题。但后面我个菜鸡就完全忘记了emmmm

展示说明

不同于我以往直接给出结论。这次我会通过展示的方式较为详细的讲解下。

  • 导入包
import torch
import torch.nn as nn
loss  = nn.CrossEntropyLoss()
  • 随机生成下这个数据A
A = torch.randn(3, 5, requires_grad=True)

内容如下:

A
tensor([[ 1.0483, -1.4251,  1.0502,  0.2437,  0.1477],
        [-0.1300, -1.0798, -1.0835, -1.0473,  1.2076],
        [-0.1647, -1.1366,  0.7088,  1.3719,  1.9135]], requires_grad=True)
  • 随机生成3个label
target = torch.empty(3, dtype=torch.long).random_(5)

内容如下:

target
tensor([2, 2, 3])

没有问题

output = loss(A, target)
  • 这时候就是没有问题的
output
tensor(1.6934, grad_fn=<NllLossBackward>)

有问题的情况演示

  • 将前面创建的A在一维上取argmax。得到对应的label
A_arg = torch.argmax(A, dim=1)

内容如下:

A_arg
tensor([2, 4, 4])
  • 调用,出现题目所示的bug
output = loss(A_arg, target)

报错非常一大串emmm。关键就是下面这句。

runtimeERROR: Dimension out of range (expected to be in range of [-1, 0], but got 1)

解释

其实就是因为torch的交叉熵的输入第一个位置的输入应该是在每个label下的概率, 而不是对应的label。

所以直接写成label的你,就出现上面所说的错误了。

相关阅读

Orange Country橘子郡 欧美女装名店

淘宝名店中就不缺女装,女生选择的网店也很多,今天介绍个欧美女装网店:橘子郡。小编很喜欢的风格,穿起来很有感觉的衣服,评分也不错,介

详解多维标度法(MDS,Multidimensional scaling)

流形学习(Manifold Learning)是机器学习中一大类算法的统称,而MDS就是其中非常经典的一种方法。多维标度法(Multidimensional Scaling

HDFS加密存储(HDP、Ranger、Ranger KMS实现)

HDFS加密存储,在CSDN上可以看到很多的前辈整理的博客,但是按照https://blog.csdn.net/linlinv3/article/details/44963429所介绍的

java.lang.StringIndexOutOfBoundsException: String

字符串截取下标越界   出错代码 @GetMapping("/edit") //@RequiresPermissions("erp:enquirySheet:edit") public String

漫谈格兰杰因果关系(Granger Causality)

#目录文章目录#简介格兰杰因果关系作为一种可以衡量时间序列之间相互影响关系的方法,最近十几年备受青睐。无论是经济学[1],气象科

分享到:

栏目导航

推荐阅读

热门阅读