pytorch实现图像识别(实战)

目录
  • 1. 代码讲解
    • 1.1 导库
    • 1.2 标准化、transform、设置GPU
    • 1.3 预处理数据
    • 1.4 建立模型
    • 1.5 训练模型
    • 1.6 测试模型
    • 1.7结果

1. 代码讲解

1.1 导库

import os.path
from os import listdir
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import AdaptiveAvgPool2d
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

1.2 标准化、transform、设置GPU

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize])  # 转换

1.3 预处理数据

class DogDataset(Dataset):
# 定义变量
    def __init__(self, img_paths, img_labels, size_of_images):  
        self.img_paths = img_paths
        self.img_labels = img_labels
        self.size_of_images = size_of_images

# 多少长图片
    def __len__(self):
        return len(self.img_paths)

# 打开每组图片并处理每张图片
    def __getitem__(self, index):
        PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images)
        TENSOR_IMAGE = transform(PIL_IMAGE)
        label = self.img_labels[index]
        return TENSOR_IMAGE, label

print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train')))
print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')))
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test')))

train_paths = []
test_paths = []
labels = []
# 训练集图片路径
train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'
for path in listdir(train_paths_lir):
    train_paths.append(os.path.join(train_paths_lir, path))  
# 测试集图片路径
labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')
labels_data = pd.DataFrame(labels_data)  
# 把字符标签离散化,因为数据有120种狗,不离散化后面把数据给模型时会报错:字符标签过多。把字符标签从0-119编号
size_mapping = {}
value = 0
size_mapping = dict(labels_data['breed'].value_counts())
for kay in size_mapping:
    size_mapping[kay] = value
    value += 1
# print(size_mapping)
labels = labels_data['breed'].map(size_mapping)
labels = list(labels)
# print(labels)
print(len(labels))
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2)

train_set = DogDataset(X_train, y_train, (32, 32))
test_set = DogDataset(X_test, y_test, (32, 32))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)

1.4 建立模型

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 120)
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.features(x)
        x = x.view(batch_size, -1)
        x = self.classifier(x)
        return x

model = LeNet().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters())
TRAIN_LOSS = []  # 损失
TRAIN_ACCURACY = []  # 准确率

1.5 训练模型

def train(epoch):
    model.train()
    epoch_loss = 0.0 # 损失
    correct = 0  # 精确率
    for batch_index, (Data, Label) in enumerate(train_loader):
    # 扔到GPU中
        Data = Data.to(device)
        Label = Label.to(device)
        output_train = model(Data)
    # 计算损失
        loss_train = criterion(output_train, Label)
        epoch_loss = epoch_loss + loss_train.item()
    # 计算精确率
        pred = torch.max(output_train, 1)[1]
        train_correct = (pred == Label).sum()
        correct = correct + train_correct.item()
    # 梯度归零、反向传播、更新参数
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
    print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))

1.6 测试模型

和训练集差不多。

def test():
    model.eval()
    correct = 0.0
    test_loss = 0.0
    with torch.no_grad():
        for Data, Label in test_loader:
            Data = Data.to(device)
            Label = Label.to(device)
            test_output = model(Data)
            loss = criterion(test_output, Label)
            pred = torch.max(test_output, 1)[1]
            test_correct = (pred == Label).sum()
            correct = correct + test_correct.item()
            test_loss = test_loss + loss.item()
    print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))

1.7结果

epoch = 10
for n_epoch in range(epoch):
    train(n_epoch)
test()

到此这篇关于pytorch实现图像识别(实战)的文章就介绍到这了,更多相关pytorch实现图像识别内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch实现图像识别之数字识别(附详细注释)

    使用了两个卷积层加上两个全连接层实现 本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份 详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上 代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释 Python实现代码: import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transf

  • pytorch实现图像识别(实战)

    目录 1. 代码讲解 1.1 导库 1.2 标准化.transform.设置GPU 1.3 预处理数据 1.4 建立模型 1.5 训练模型 1.6 测试模型 1.7结果 1. 代码讲解 1.1 导库 import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.

  • PyTorch实现图像识别实战指南

    目录 概述 预处理 导包 数据读取与预处理 数据可视化 主体 加载参数 建立模型 设置哪些层需要训练 优化器设置 训练模块 开始训练 测试 测试网络效果 测试训练好的模型 测试数据预处理 展示预测结果 总结 概述 今天我们要来做一个进阶的花分类问题. 不同于之前做过的鸢尾花, 这次我们会分析 102 中不同的花. 是不是很上头呀. 预处理 导包 常规操作, 没什么好解释的. 缺模块的同学自行pip -install. import numpy as np import time from mat

  • PyTorch一小时掌握之图像识别实战篇

    目录 概述 预处理 导包 数据读取与预处理 数据可视化 主体 加载参数 建立模型 设置哪些层需要训练 优化器设置 训练模块 开始训练 测试 测试网络效果 测试训练好的模型 测试数据预处理 展示预测结果 概述 今天我们要来做一个进阶的花分类问题. 不同于之前做过的鸢尾花, 这次我们会分析 102 中不同的花. 是不是很上头呀. 预处理 导包 常规操作, 没什么好解释的. 缺模块的同学自行pip -install. import numpy as np import time from matplo

  • PyTorch CNN实战之MNIST手写数字识别示例

    简介 卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的. 卷积神经网络CNN的结构一般包含这几个层: 输入层:用于数据的输入 卷积层:使用卷积核进行特征提取和特征映射 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射 池化层:进行下采样,对特征图稀疏处理,减少数据运算量. 全连接层:通常在CNN的尾部进行重新拟合,减

  • 简单易懂Pytorch实战实例VGG深度网络

    简单易懂Pytorch实战实例VGG深度网络 模型VGG,数据集cifar.对照这份代码走一遍,大概就知道整个pytorch的运行机制. 来源 定义模型: '''VGG11/13/16/19 in Pytorch.''' import torch import torch.nn as nn from torch.autograd import Variable cfg = {     'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M

  • PyTorch线性回归和逻辑回归实战示例

    线性回归实战 使用PyTorch定义线性回归模型一般分以下几步: 1.设计网络架构 2.构建损失函数(loss)和优化器(optimizer) 3.训练(包括前馈(forward).反向传播(backward).更新模型参数(update)) #author:yuquanle #data:2018.2.5 #Study of LinearRegression use PyTorch import torch from torch.autograd import Variable # train

  • 使用pytorch完成kaggle猫狗图像识别方式

    kaggle是一个为开发商和数据科学家提供举办机器学习竞赛.托管数据库.编写和分享代码的平台,在这上面有非常多的好项目.好资源可供机器学习.深度学习爱好者学习之用. 碰巧最近入门了一门非常的深度学习框架:pytorch,所以今天我和大家一起用pytorch实现一个图像识别领域的入门项目:猫狗图像识别. 深度学习的基础就是数据,咱们先从数据谈起.此次使用的猫狗分类图像一共25000张,猫狗分别有12500张,我们先来简单的瞅瞅都是一些什么图片. 我们从下载文件里可以看到有两个文件夹:train和t

  • CoAtNet实战之对植物幼苗图像进行分类(pytorch)

    目录 前言 项目结构 数据集 安装库,并导入需要的库 设置全局参数 数据预处理 数据读取 设置模型 测试 前言 虽然Transformer在CV任务上有非常强的学习建模能力,但是由于缺少了像CNN那样的归纳偏置,所以相比于CNN,Transformer的泛化能力就比较差.因此,如果只有Transformer进行全局信息的建模,在没有预训练(JFT-300M)的情况下,Transformer在性能上很难超过CNN(VOLO在没有预训练的情况下,一定程度上也是因为VOLO的Outlook Atten

  • python pytorch图像识别基础介绍

    目录 一.数据集爬取 二.数据处理 三.开始识别 四.模型测试 总结 一.数据集爬取 现在的深度学习对数据集量的需求越来越大了,也有了许多现成的数据集可供大家查找下载,但是如果你只是想要做一下深度学习的实例以此熟练一下或者找不到好的数据集,那么你也可以尝试自己制作数据集——自己从网上爬取图片,下面是通过百度图片爬取数据的示例. import os import time import requests import re def imgdata_set(save_path,word,epoch)

随机推荐