PyTorch平方根报错的处理方案

问题描述

初步使用PyTorch进行平方根计算,通过range()创建一个张量,然后对其求平方根。

a = torch.tensor(list(range(9)))
b = torch.sqrt(a)

报出以下错误:

RuntimeError: sqrt_vml_cpu not implemented for 'Long'

原因

Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后, 数据类型变成了Long。

print(a.dtype)

torch.int64

解决方法

提前将数据类型指定为浮点型, 重新执行:

b = torch.sqrt(a.to(torch.double))
print(b)

tensor([0.0000, 1.0000, 1.4142, 1.7321, 2.0000, 2.2361, 2.4495, 2.6458, 2.8284], dtype=torch.float64)

补充:pytorch10 pytorch常见运算详解

矩阵与标量

这个是矩阵(张量)每一个元素与标量进行操作。

import torch
a = torch.tensor([1,2])
print(a+1)
>>> tensor([2, 3])

哈达玛积

这个就是两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积,也成为element wise。

a = torch.tensor([1,2])
b = torch.tensor([2,3])
print(a*b)
print(torch.mul(a,b))
>>> tensor([2, 6])
>>> tensor([2, 6])

这个torch.mul()和*是等价的。

当然,除法也是类似的:

a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
print(a/b)
print(torch.div(a/b))
>>> tensor([0.5000, 0.6667])
>>> tensor([0.5000, 0.6667])

我们可以发现的torch.div()其实就是/, 类似的:torch.add就是+,torch.sub()就是-,不过符号的运算更简单常用。

矩阵乘法

如果我们想实现线性代数中的矩阵相乘怎么办呢?

这样的操作有三个写法:

torch.mm()

torch.matmul()

@,这个需要记忆,不然遇到这个可能会挺蒙蔽的

a = torch.tensor([[1.],[2.]])
b = torch.tensor([2.,3.]).view(1,2)
print(torch.mm(a, b))
print(torch.matmul(a, b))
print(a @ b)

这是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul()可以使用。等等,多维张量怎么进行矩阵的乘法?在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:

a = torch.rand((1,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([1, 2, 64, 64])

a = torch.rand((3,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([3, 2, 64, 64])

这样也是可以相乘的,因为这里涉及一个自动传播Broadcasting机制,这个在后面会讲,这里就知道,如果这种情况下,会把b的第一维度复制3次 ,然后变成和a一样的尺寸,进行矩阵相乘。

幂与开方

print('幂运算')
a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
c1 = a ** b
c2 = torch.pow(a, b)
print(c1,c2)
>>> tensor([1., 8.]) tensor([1., 8.])

和上面一样,不多说了。开方运算可以用torch.sqrt(),当然也可以用a**(0.5)。

对数运算

在上学的时候,我们知道ln是以e为底的,但是在pytorch中,并不是这样。

pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。

import numpy as np
print('对数运算')
a = torch.tensor([2,10,np.e])
print(torch.log(a))
print(torch.log2(a))
print(torch.log10(a))
>>> tensor([0.6931, 2.3026, 1.0000])
>>> tensor([1.0000, 3.3219, 1.4427])
>>> tensor([0.3010, 1.0000, 0.4343])

近似值运算

.ceil() 向上取整

.floor()向下取整

.trunc()取整数

.frac()取小数

.round()四舍五入

.ceil() 向上取整.floor()向下取整.trunc()取整数.frac()取小数.round()四舍五入

a = torch.tensor(1.2345)
print(a.ceil())
>>>tensor(2.)
print(a.floor())
>>> tensor(1.)
print(a.trunc())
>>> tensor(1.)
print(a.frac())
>>> tensor(0.2345)
print(a.round())
>>> tensor(1.)

剪裁运算

这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。

a = torch.rand(5)
print(a)
print(a.clamp(0.3,0.7))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch masked_fill报错的解决

    如下所示: import torch.nn.functional as F import numpy as np a = torch.Tensor([1,2,3,4]) a = a.masked_fill(mask = torch.ByteTensor([1,1,0,0]), value=-np.inf) print(a) b = F.softmax(a) print(b) tensor([-inf, -inf, 3., 4.]) d:/pycharmdaima/star-transformer

  • 解决pytorch 数据类型报错的问题

    pytorch报错: RuntimeError: Expected object of type Variable[torch.LongTensor] but found type Variable[torch.cuda.ByteTensor] for argument #1 'argument1' 解决方法: pytorch框架在存储labels时,采用LongTensor来存储,所以在一开始dataset返回label时,就要返回与LongTensor对应的数据类型,即numpy.int64

  • Pytorch Tensor基本数学运算详解

    1. 加法运算 示例代码: import torch # 这两个Tensor加减乘除会对b自动进行Broadcasting a = torch.rand(3, 4) b = torch.rand(4) c1 = a + b c2 = torch.add(a, b) print(c1.shape, c2.shape) print(torch.all(torch.eq(c1, c2))) 输出结果: torch.Size([3, 4]) torch.Size([3, 4]) tensor(1, dt

  • PyTorch平方根报错的处理方案

    问题描述 初步使用PyTorch进行平方根计算,通过range()创建一个张量,然后对其求平方根. a = torch.tensor(list(range(9))) b = torch.sqrt(a) 报出以下错误: RuntimeError: sqrt_vml_cpu not implemented for 'Long' 原因 Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后,

  • Android为textView设置setText的时候报错的讲解方案

    在对中TextView setText 覆值int 时报错,网上查下原因是setText整型表明是设值R.id.xxx,当然找不到. 解决方法是将int转化为string,用String.valueOf(xxx) 一.我的代码如下:就是我textView设置值 if (list != null) { for (Student stu : list) { //如果一下子赋值的话是不正确的 tv_name.setText(stu.getName()); tv_sex.setText(stu.getS

  • mybatis like模糊查询特殊字符报错转义处理方式

    目录 like模糊查询特殊字符报错转义处理 方案1 方案2 like模糊查询中包含有特殊字符(_.\.%) 处理 注意 like模糊查询特殊字符报错转义处理 方案1     <if test="projectName!=null and projectName!=''">             <bind name="projectName_" value="'%'+projectName+'%'"/>        

  • Spring Boot 2.6.x整合Swagger启动失败报错问题的完美解决办法

    目录 问题 原因 解决方案 方案一(治标) 方案二(治本) 总结 问题 Spring Boot 2.6.x版本引入依赖 springfox-boot-starter (Swagger 3.0) 后,启动容器会报错: Failed to start bean ‘ documentationPluginsBootstrapper ‘ ; nested exception… 原因 Springfox 假设 Spring MVC 的路径匹配策略是 ant-path-matcher,而 Spring Bo

  • Vue router/Element重复点击导航路由报错问题及解决

    目录 Vue router/Element重复点击导航路由报错 解决方法如下 Vue使用element-UI路由报错问题 报错代码 修改方案 Vue router/Element重复点击导航路由报错 虽然此报错并不会影响项目运行,但是作为一个强迫症的码农的确受不了error 解决方法如下 方法1:在项目目录下运行 npm i vue-router@3.0 -S 将vue-router改为3.0版本即可: 方法2:若不想更换版本解决方法 在router.js中加入以下代码就可以 记住插入的位置 c

  • PyTorch训练LSTM时loss.backward()报错的解决方案

    训练用PyTorch编写的LSTM或RNN时,在loss.backward()上报错: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. 千万别改成loss.backward(retain_graph=Tru

  • 解决pytorch报错:AssertionError: Invalid device id的问题

    在服务器上训练的网络放到本地台式机进行infer,结果出现报错: AssertionError: Invalid device id 仔细检查后发现原来服务器有多个GPU,当时开启了两个进行加速运算. net1 = nn.DataParallel(net1, device_ids=[0, 1]) 而本地台式机只有一个GPU,调用数量超出所以报错. 改为 net1 = nn.DataParallel(net1, device_ids=[0]) 问题解决. 以上这篇解决pytorch报错:Asser

  • 安装pytorch报错torch.cuda.is_available()=false问题的解决过程

    问题介绍 在安装torch之后,命令行(Anaconda Powershell Prompt)运行这三行代码: python # python import torch torch.cuda.is_available() 返回结果始终为False. 出错原因 原因有多个,可以参考文章最后的链接[1] 他的很清晰,如果按我的没有解决可以看一下. 主要就是以下两个: CUDA.cudnn.torch版本不对应.(解决方法参考链接[1]) 一个坑:是通过清华源下载的!检查是不是清华源下载导致的问题:

随机推荐