利用pytorch实现对CIFAR-10数据集的分类

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np

###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算

#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch实现mnist分类的示例讲解

    torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具. torchvision.datasets中包含了以下数据集 MNIST COCO(用于图像标注和目标检测)(Captioning and Detection) LSUN Classification ImageFolder Imagenet-12 CIFAR10 and CIFAR100 STL10 torchvision.models torchvision.models模块的 子模块中包含以下模型结构. Ale

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • pytorch 数据处理:定义自己的数据集合实例

    数据处理 版本1 #数据处理 import os import torch from torch.utils import data from PIL import Image import numpy as np #定义自己的数据集合 class DogCat(data.Dataset): def __init__(self,root): #所有图片的绝对路径 imgs=os.listdir(root) self.imgs=[os.path.join(root,k) for k in imgs

  • 关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu18.04 显卡:GTX1080ti python版本:2.7(3.7) 网络架构 具有4层的CNN具有以下架构. 输入层:784个节点(MNIST图像大小) 第一卷积层:5x5x32 第一个最大池层 第二卷积层:5x5x64 第二个最大池层 第三个完全连接层:1024个节点 输出层:10个节点(M

  • pytorch 批次遍历数据集打印数据的例子

    我就废话不多说了,直接上代码吧! from os import listdir import os from time import time import torch.utils.data as data import torchvision.transforms as transforms from torch.utils.data import DataLoader def printProgressBar(iteration, total, prefix='', suffix='', d

  • PyTorch读取Cifar数据集并显示图片的实例讲解

    首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image

  • 利用pytorch实现对CIFAR-10数据集的分类

    步骤如下: 1.使用torchvision加载并预处理CIFAR-10数据集. 2.定义网络 3.定义损失函数和优化器 4.训练网络并更新网络参数 5.测试网络 运行环境: windows+python3.6.3+pycharm+pytorch0.3.0 import torchvision as tv import torchvision.transforms as transforms import torch as t from torchvision.transforms import

  • 在python中利用KNN实现对iris进行分类的方法

    如下所示: from sklearn.datasets import load_iris iris = load_iris() print iris.data.shape from sklearn.cross_validation import train_test_split X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size = 0.25, random_state = 3

  • 利用python实现对web服务器的目录探测的方法

    一.python Python是一种解释型.面向对象.动态数据类型的高级程序设计语言. python 是一门简单易学的语言,并且功能强大也很灵活,在渗透测试中的应用广泛,让我们一起打造属于自己的渗透测试工具 二.web服务器的目录探测脚本打造 1.在渗透时如果能发现web服务器中的webshell,渗透是不是就可以变的简单一点尼 通常情况下御剑深受大家的喜爱,但是今天在测试的时候webshell不知道为什么御剑扫描不到 仔细查看是webshell有防爬功能,是检测User-Agent头,如果没有

  • 利用JWT如何实现对API的授权访问详解

    什么是JWT JWT(JSON Web Token)是一个开放标准(RFC 7519),它定义了一种紧凑且独立的方式,可以在各个系统之间用JSON作为对象安全地传输信息,并且可以保证所传输的信息不会被篡改. 「JWT」由三部分构成 信息头:指定了使用的签名算法 声明部分:其中也可以包含超时时间 基于指定的算法生成的签名 通过这三部分信息,API 服务端可以根据「JWT」信息头和声明部分的信息重新生成签名.之所以可以这样做,是因为生成签名需要的秘钥存放在服务器端. jwtauth.New("HS2

  • java中利用List的subList方法实现对List分页(简单易学)

    以下是介绍利用List的subList方法实现对List分页,废话不多说了,直接看代码把 /** *//** * List分页 * 实现:利用List的获取子List方法,实现对List的分页 * @author 显武 * @date 2010-1-8 16:27:31 * */ import java.util.ArrayList; import java.util.List; public class PageModel { private int page = 1; // 当前页 publ

  • 详解利用Pytorch实现ResNet网络

    目录 正文 评估模型 训练 ResNet50 模型 正文 每个 batch 前清空梯度,否则会将不同 batch 的梯度累加在一块,导致模型参数错误. 然后我们将输入和目标张量都移动到所需的设备上,并将模型的梯度设置为零.我们调用model(inputs)来计算模型的输出,并使用损失函数(在此处为交叉熵)来计算输出和目标之间的误差.然后我们通过调用loss.backward()来计算梯度,最后调用optimizer.step()来更新模型的参数. 在训练过程中,我们还计算了准确率和平均损失.我们

  • 原生JavaScript来实现对dom元素class的操作方法(推荐)

    jQuery操作class的方式非常强大 写了一个利用原生js来实现对dom元素class的操作方法 1.addClass:为指定的dom元素添加样式 2.removeClass:删除指定dom元素的样式 3.toggleClass:如果存在(不存在),就删除(添加)一个样式 4.hasClass:判断样式是否存在 下面为一toggleClass的测试例子 [html] view plain copy <style type="text/css"> div.testClas

  • SpringBoot拦截器实现对404和500等错误的拦截

    今天给大家介绍一下SpringBoot中拦截器的用法,相比Struts2中的拦截器,SpringBoot的拦截器就显得更加方便简单了. 只需要写几个实现类就可以轻轻松松实现拦截器的功能了,而且不需要配置任何多余的信息,对程序员来说简直是一种福利啊. 废话不多说,下面开始介绍拦截器的实现过程: 第一步:创建我们自己的拦截器类并实现 HandlerInterceptor 接口. package example.Interceptor; import javax.servlet.http.HttpSe

  • 用ASP实现对ORACLE数据库的操作

    ASP(Active Server Pages)是微软公司为开发互联网应用程序所提出的工具之一,ASP与数据库的联接一般通过ADO(Activex Data Object)来实现的,就象<计算机世界>2000年3月20日的<用ASP对SQL Server数据库操作>文章介绍的一样,ADO可以完全支持Microsoft SQL Server ,但对应用更加广泛.机制更加复杂的ORACLE 数据库服务就有一些困难,如果想作一些简单的查询功能,ADO是足够的,如要想更好地发挥ORACLE

  • 详解jdbc实现对CLOB和BLOB数据类型的操作

    详解jdbc实现对CLOB和BLOB数据类型的操作 1. 读取操作 CLOB  //获得数据库连接 Connection con = ConnectionFactory.getConnection(); con.setAutoCommit(false); Statement st = con.createStatement(); //不需要"for update" ResultSet rs = st.executeQuery("select CLOBATTR from TES

随机推荐