Pytorch四维Tensor转图片并保存方式(维度顺序调整)

目录
  • Pytorch四维Tensor转图片并保存
    • 1.维度顺序转换
    • 2.转为numpy数组
    • 3.根据第一维度batch_size逐个读取中间结果,并存储到磁盘中
  • Pytorch中Tensor介绍
    • torch.Tensor或torch.tensor注意事项
    • 创建tensor的四种主要方法
  • 总结

Pytorch四维Tensor转图片并保存

最近在复现一篇论文代码的过程中,想要输出中间图片的结果图,通过debug发现在pytorch网络中是用Tensor存储的四维张量。

1.维度顺序转换

第一维代表的是batch_size,然后是通道数和图像尺寸,首先要进行维度顺序的转换

通过permute函数实现

outputRs = outputR.permute(0,2,3,1)

shape转为96 * 128 * 3

2.转为numpy数组

#由于代码中的中间结果是带有梯度的要进行detach()操作
k = outputRs.cpu().detach().numpy()

3.根据第一维度batch_size逐个读取中间结果,并存储到磁盘中

Image需导入from PIL import Image

		for i in range(10):
			res = k[i] #得到batch中其中一步的图片
			image = Image.fromarray(np.uint8(res)).convert('RGB')
			#image.show()
			#通过时间命名存储结果
			timestamp = datetime.datetime.now().strftime("%M-%S")
			savepath = timestamp + '_r.jpg'
			image.save(savepath)

Pytorch中Tensor介绍

PyTorch中的张量(Tensor)如同数组和矩阵一样,是一种特殊的数据结构。在PyTorch中,神经网络的输入、输出以及网络的参数等数据,都是使用张量来进行描述。

torch包中定义了10种具有CPU和GPU变体的tensor类型。

torch.Tensor或torch.tensor是一种包含单一数据类型元素的多维矩阵。

torch.Tensor或torch.tensor注意事项

(1). torch.Tensor是默认tensor类型torch.FloatTensor的别名。

(2). torch.tensor总是拷贝数据。

(3).每一个tensor都有一个关联的torch.Storage,它保存着它的数据。

(4).改变tensor的方法是使用下划线后缀标记,如torch.FloatTensor.abs_()就地(in-place)计算绝对值并返回修改后的tensor,而torch.FloatTensor.abs()在新tensor中计算结果。

(5).有几百种tensor相关的运算操作,包括各种数学运算、线性代数、随机采样等。

创建tensor的四种主要方法

(1).要使用预先存在的数据创建tensor,使用torch.tensor()。

(2).要创建具有特定大小的tensor,使用torch.*,如torch.rand()。

(3).要创建与另一个tensor具有相同大小(和相似类型)的tensor,使用torch.*_like,如torch.rand_like()。

(4).要创建与另一个tensor类型相似但大小不同的tensor,使用tensor.new_*,如tensor.new_ones()。

以上内容及以下测试代码主要参考:

1. torch.Tensor — PyTorch 1.10.0 documentation

2. https://pytorch.apachecn.org/#/docs/1.7/03

tensor具体用法见以下test_tensor.py测试代码:

import torch
import numpy as np

var = 2

# reference: https://pytorch.apachecn.org/#/docs/1.7/03
if var == 1: # 张量初始化
    # 1.直接生成张量, 注意: torch.tensor与torch.Tensor的区别: torch.Tensor是torch.FloatTensor的别名;而torch.tensor则根据输入数据推断数据类型
    data = [[1, 2], [3, 4]]
    x_data = torch.tensor(data); print(f"x_data: {x_data}, type: {x_data.type()}") # type: torch.LongTensor
    y_data = torch.Tensor(data); print(f"y_data: {y_data}, type: {y_data.type()}") # type: torch.FloatTensor
    z_data = torch.IntTensor(data); print(f"z_data: {z_data}, type: {z_data.type()}") # type: torch.IntTensor

    # 2.通过Numpy数组来生成张量,反过来也可以由张量生成Numpy数组
    np_array = np.array(data)
    x_np = torch.from_numpy(np_array); print("x_np:\n", x_np)
    y_np = torch.tensor(np_array); print("y_np:\n", y_np) # torch.tensor总是拷贝数据
    z_np = torch.as_tensor(np_array); print("z_np:\n", z_np) # 使用torch.as_tensor可避免拷贝数据

    # 3.通过已有的张量来生成新的张量: 新的张量将继承已有张量的属性(结构、类型),也可以重新指定新的数据类型
    x_ones = torch.ones_like(x_data); print(f"x_ones: {x_ones}, type: {x_ones.type()}") # 保留x_data的属性
    x_rand = torch.rand_like(x_data, dtype=torch.float); print(f"x_rand: {x_rand}, type: {x_rand.type()}") # 重写x_data的数据类型: long -> float

    tensor = torch.tensor((), dtype=torch.int32); print(f"shape of tensor: {tensor.shape}, type: {tensor.type()}")
    new_tensor = tensor.new_ones((2, 3)); print(f"shape of new_tensor: {new_tensor.shape}, type: {new_tensor.type()}")

    # 4.通过指定数据维度来生成张量
    shape = (2, 3) # shape是元组类型,用来描述张量的维数
    rand_tensor = torch.rand(shape); print(f"rand_tensor: {rand_tensor}, type: {rand_tensor.type()}")
    ones_tensor = torch.ones(shape, dtype=torch.int); print(f"ones_tensor: {ones_tensor}, type: {ones_tensor.type()}")
    zeros_tensor = torch.zeros(shape, device=torch.device("cpu")); print("zeros_tensor:", zeros_tensor)

    # 5.可以使用requires_grad=True创建张量,以便torch.autograd记录对它们的操作以进行自动微分
    x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
    out = x.pow(2).sum(); print(f"out: {out}")
    # out.backward(); print(f"x: {x}\nx.grad: {x.grad}")
elif var == 2: # 张量属性: 从张量属性我们可以得到张量的维数、数据类型以及它们所存储的设备(CPU或GPU)
    tensor = torch.rand(3, 4)
    print(f"shape of tensor: {tensor.shape}")
    print(f"datatype of tensor: {tensor.dtype}") # torch.float32
    print(f"device tensor is stored on: {tensor.device}") # cpu或cuda
    print(f"tensor layout: {tensor.layout}") # tensor如何在内存中存储
    print(f"tensor dim: {tensor.ndim}") # tensor维度
elif var == 3: # 张量运算: 有超过100种张量相关的运算操作,例如转置、索引、切片、数学运算、线性代数、随机采样等
    # 所有这些运算都可以在GPU上运行(相对于CPU来说可以达到更高的运算速度)
    tensor = torch.rand((4, 4), dtype=torch.float); print(f"src: {tensor}")

    # 判断当前环境GPU是否可用,然后将tensor导入GPU内运行
    if torch.cuda.is_available():
        tensor = tensor.to("cuda")

    # 1.张量的索引和切片
    tensor[:, 1] = 0; print(f"index: {tensor}") # 将第1列(从0开始)的数据全部赋值为0

    # 2.张量的拼接: 可以通过torch.cat方法将一组张量按照指定的维度进行拼接,也可以参考torch.stack方法,但与torch.cat稍微有点不同
    cat = torch.cat([tensor, tensor], dim=1); print(f"cat:\n {cat}")

    # 3.张量的乘积和矩阵乘法
    print(f"tensor.mul(tensor):\n {tensor.mul(tensor)}") # 逐个元素相乘结果
    print(f"tensor * tensor:\n {tensor * tensor}") # 等价写法

    print(f"tensor.matmul(tensor.T):\n {tensor.matmul(tensor.T)}") # 张量与张量的矩阵乘法
    print(f"tensor @ tensor.T:\n {tensor @ tensor.T}") # 等价写法

    # 4.自动赋值运算: 通常在方法后有"_"作为后缀,例如:x.copy_(y), x.t_()操作会改变x的取值(in-place)
    print(f"tensor:\n {tensor}")
    print(f"tensor:\n {tensor.add_(5)}")
elif var == 4: # Tensor与Numpy的转化: 张量和Numpy array数组在CPU上可以共用一块内存区域,改变其中一个另一个也会随之改变
    # 1.由张量变换为Numpy array数组
    t = torch.ones(5); print(f"t: {t}")
    n = t.numpy(); print(f"n: {n}")

    t.add_(1) # 修改张量的值,则Numpy array数组值也会随之改变
    print(f"t: {t}")
    print(f"n: {n}")

    # 2.由Numpy array数组转为张量
    n = np.ones(5); print(f"n: {n}")
    t = torch.from_numpy(n); print(f"t: {t}")

    np.add(n, 1, out=n) # 修改Numpy array数组的值,则张量值也会随之改变
    print(f"n: {n}")
    print(f"t: {t}")

print("test finish")

GitHub:GitHub - fengbingchun/PyTorch_Test: PyTorch's usage

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch 和 Tensorflow v1 兼容的环境搭建方法

    Github 上很多大牛的代码都是Tensorflow v1 写的,比较新的文章则喜欢用Pytorch,这导致我们复现实验或者对比实验的时候需要花费大量的时间在搭建不同的环境上.这篇文章是我经过反复实践总结出来的环境配置教程,亲测有效! 首先最基本的Python 环境配置如下: conda create -n py37 python=3.7 python版本不要设置得太高也不要太低,3.6~3.7最佳,适用绝大部分代码库.(Tensorflow v1 最高支持的python 版本也只有3.7)

  • Python tensorflow与pytorch的浮点运算数如何计算

    目录 1. 引言 2. 模型结构 3. 计算模型的 FLOPs 3.1. tensorflow 1.12.0 3.2. tensorflow 2.3.1 3.3. pytorch 1.10.1+cu102 3.4. 结果对比 4. 总结 1. 引言 FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度.本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对

  • pytorch使用-tensor的基本操作解读

    目录 一.tensor加减乘除 二.tensor矩阵运算 四.tensor切片操作 五.tensor改变形状 六.tensor 和 numpy.array相互转换 七.tensor 转到GPU上 总结 一.tensor加减乘除 加法操作 import torch x = torch.randn(2, 3) y = torch.randn(2, 3) z = x + y print(z) z = torch.add(x, y) print(z) y.add_(x) print(y) 其他操作类似:

  • Pytorch如何把Tensor转化成图像可视化

    目录 Pytorch把Tensor转化成图像可视化 pytorch标准化的Tensor转图像问题 总结 Pytorch把Tensor转化成图像可视化 在调试程序的时候经常想把tensor可视化成来看看,可以这样操作: from torchvision import transforms unloader = transforms.ToPILImage() image = original_tensor.cpu().clone()  # clone the tensor image = image

  • 如何计算 tensorflow 和 pytorch 模型的浮点运算数

    目录 1. 引言 2. 模型结构 3. 计算模型的 FLOPs 3.1. tensorflow 1.12.0 3.2. tensorflow 2.3.1 3.3. pytorch 1.10.1+cu102 3.4. 结果对比 4. 总结 本文主要讨论如何计算 tensorflow 和 pytorch 模型的 FLOPs.如有表述不当之处欢迎批评指正.欢迎任何形式的转载,但请务必注明出处. 1. 引言 FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用

  • Pytorch实现List Tensor转Tensor,reshape拼接等操作

    目录 一.List Tensor转Tensor (torch.cat) 高维tensor 二.List Tensor转Tensor (torch.stack) 持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作. 其它Tensor操作如 einsum等见:待更新. 用到两个函数: torch.cat torch.stack 一.List Tensor转Tensor (torch.cat) // An highlighted

  • Pytorch四维Tensor转图片并保存方式(维度顺序调整)

    目录 Pytorch四维Tensor转图片并保存 1.维度顺序转换 2.转为numpy数组 3.根据第一维度batch_size逐个读取中间结果,并存储到磁盘中 Pytorch中Tensor介绍 torch.Tensor或torch.tensor注意事项 创建tensor的四种主要方法 总结 Pytorch四维Tensor转图片并保存 最近在复现一篇论文代码的过程中,想要输出中间图片的结果图,通过debug发现在pytorch网络中是用Tensor存储的四维张量. 1.维度顺序转换 第一维代表的

  • pytorch中tensor张量数据类型的转化方式

    1.tensor张量与numpy相互转换 tensor ----->numpy import torch a=torch.ones([2,5]) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]) # ********************************** b=a.numpy() array([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], dtype=float32) numpy --

  • pytorch下大型数据集(大型图片)的导入方式

    使用torch.utils.data.Dataset类 处理图片数据时, 1. 我们需要定义三个基本的函数,以下是基本流程 class our_datasets(Data.Dataset): def __init__(self,root,is_resize=False,is_transfrom=False): #这里只是个参考.按自己需求写. self.root=root self.is_resize=is_resize self.is_transfrom=is_transfrom self.i

  • 详解Python下载图片并保存本地的两种方式

    一:使用Python中的urllib类中的urlretrieve()函数,直接从网上下载资源到本地,具体代码: import os,stat import urllib.request img_url="https://timgsa.baidu.com/timg?image&quality=80&size=b9999_10000&sec=1516371301&di=d99af0828bb301fea27c2149a7070" \ "d44&am

  • 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth或.pkl.用相同的torch.save()语句保存出来的模型文件没有什么不同. 在pytorch官方的文档/代码里,有用.pt的,也有用.pth的.一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也

  • Pytorch中Tensor与各种图像格式的相互转化详解

    前言 在pytorch中经常会遇到图像格式的转化,例如将PIL库读取出来的图片转化为Tensor,亦或者将Tensor转化为numpy格式的图片.而且使用不同图像处理库读取出来的图片格式也不相同,因此,如何在pytorch中正确转化各种图片格式(PIL.numpy.Tensor)是一个在调试中比较重要的问题. 本文主要说明在pytorch中如何正确将图片格式在各种图像库读取格式以及tensor向量之间转化的问题.以下代码经过测试都可以在Pytorch-0.4.0或0.3.0版本直接使用. 对py

  • Java从网络读取图片并保存至本地实例

    本文实例为大家分享了Java从网络读取图片并保存至本地的具体代码,供大家参考,具体内容如下 package getUrlPic; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; public cl

  • 分享PHP源码批量抓取远程网页图片并保存到本地的实现方法

    做为一个仿站工作者,当遇到网站有版权时甚至加密的时候,WEBZIP也熄火,怎么扣取网页上的图片和背景图片呢.有时候,可能会想到用火狐,这款浏览器好像一个强大的BUG,文章有版权,屏蔽右键,火狐丝毫也不会被影响. 但是作为一个热爱php的开发者来说,更多的是喜欢自己动手.所以,我就写出了下面的一个源码,php远程抓取图片小程序.可以读取css文件并抓取css代码中的背景图片,下面这段代码也是针对抓取css中图片而编写的. <?php header("Content-Type: text/ht

  • Android异步加载数据和图片的保存思路详解

    把从网络获取的图片数据保存在SD卡上, 先把权限都加上 网络权限 android.permission.INTERNET SD卡读写权限 android.permission.MOUNT_UNMOUNT_FILESYSTEMS android.permission.WRITE_EXTERNAL_STORAGE 总体布局 写界面,使用ListView,创建条目的布局文件,水平摆放的ImageView TextView 在activity中获取到ListView对象,调用setAdapter()方法

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

随机推荐