ConvNeXt实战之实现植物幼苗分类

目录
  • 前言
  • ConvNeXt残差模块
  • 数据增强Cutout和Mixup
  • 项目结构
  • 数据集
  • 导入模型文件
  • 安装库,并导入需要的库
  • 设置全局参数
  • 数据预处理
  • 设置模型
  • 定义训练和验证函数
  • 测试
    • 第一种写法
    • 第二种写法

前言

ConvNeXts 完全由标准 ConvNet 模块构建,在准确性和可扩展性方面与 Transformer 竞争,实现 87.8% ImageNet top-1 准确率,在 COCO 检测和 ADE20K 分割方面优于 Swin Transformers,同时保持标准 ConvNet 的简单性和效率。

论文链接:https://arxiv.org/pdf/2201.03545.pdf

代码链接:https://github.com/facebookresearch/ConvNeXt

如果github不能下载,可以使用下面的连接:

https://gitcode.net/hhhhhhhhhhwwwwwwwwww/ConvNeXt

ConvNexts的特点;

使用7×7的卷积核,在VGG、ResNet等经典的CNN模型中,使用的是小卷积核,但是ConvNexts证明了大卷积和的有效性。作者尝试了几种内核大小,包括 3、5、7、9 和 11。网络的性能从 79.9% (3×3) 提高到 80.6% (7×7),而网络的 FLOPs 大致保持不变, 内核大小的好处在 7×7 处达到饱和点。

使用GELU(高斯误差线性单元)激活函数。GELUs是 dropout、zoneout、Relus的综合,GELUs对于输入乘以一个0,1组成的mask,而该mask的生成则是依概率随机的依赖于输入。实验效果要比Relus与ELUs都要好。下图是实验数据:

使用LayerNorm而不是BatchNorm。

倒置瓶颈。图 3 (a) 至 (b) 说明了这些配置。尽管深度卷积层的 FLOPs 增加了,但由于下采样残差块的快捷 1×1 卷积层的 FLOPs 显着减少,这种变化将整个网络的 FLOPs 减少到 4.6G。成绩从 80.5% 提高到 80.6%。在 ResNet-200/Swin-B 方案中,这一步带来了更多的收益(81.9% 到 82.6%),同时也减少了 FLOP。

ConvNeXt残差模块

残差模块是整个模型的核心。如下图:

代码实现:

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        x = input + self.drop_path(x)
        return x

数据增强Cutout和Mixup

ConvNext使用了Cutout和Mixup,为了提高成绩我在我的代码中也加入这两种增强方式。官方使用timm,我没有采用官方的,而选择用torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout

# 数据预处理

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

Mixup实现,在train方法中。需要导入包:from torchtoolbox.tools import mixup_data, mixup_criterion

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
        optimizer.zero_grad()
        output = model(data)
        loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()

项目结构

使用tree命令,打印项目结构

数据集

数据集选用植物幼苗分类,总共12类。数据集连接如下:

链接  提取码:syng

在工程的根目录新建data文件夹,获取数据集后,将trian和test解压放到data文件夹下面,如下图:

导入模型文件

从官方的链接中找到convnext.py文件,将其放入Model文件夹中。如图:

安装库,并导入需要的库

模型用到了timm库,如果没有需要安装,执行命令:

pip install timm

新建train_connext.py文件,导入所需要的包:

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from Model.convnext import convnext_tiny
from torchtoolbox.tools import mixup_data, mixup_criterion
from torchtoolbox.transform import Cutout

设置全局参数

设置使用GPU,设置学习率、BatchSize、epoch等参数。

# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 8
EPOCHS = 300
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据预处理

数据处理比较简单,没有做复杂的尝试,有兴趣的可以加入一些处理。

# 数据预处理

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

数据读取

然后我们在dataset文件夹下面新建 init.py和dataset.py,在mydatasets.py文件夹写入下面的代码:

说一下代码的核心逻辑。

第一步 建立字典,定义类别对应的ID,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

代码如下:

# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split

Labels = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,
          'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,
          'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}

class SeedlingData(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        """
        self.test = test
        self.transforms = transforms

        if self.test:
            imgs = [os.path.join(root, img) for img in os.listdir(root)]
            self.imgs = imgs
        else:
            imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]
            imgs = []
            for imglable in imgs_labels:
                for imgname in os.listdir(imglable):
                    imgpath = os.path.join(imglable, imgname)
                    imgs.append(imgpath)
            trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
            if train:
                self.imgs = trainval_files
            else:
                self.imgs = val_files

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        img_path = self.imgs[index]
        img_path = img_path.replace("\\", '/')
        if self.test:
            label = -1
        else:
            labelname = img_path.split('/')[-2]
            label = Labels[labelname]
        data = Image.open(img_path).convert('RGB')
        data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

然后我们在train.py调用SeedlingData读取数据 ,记着导入刚才写的dataset.py(from mydatasets import SeedlingData)

# 读取数据
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型

设置loss函数为nn.CrossEntropyLoss()。

  • 设置模型为coatnet_0,修改最后一层全连接输出改为12(数据集的类别)。
  • 优化器设置为adam。
  • 学习率调整策略改为余弦退火
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
#criterion = SoftTargetCrossEntropy()
model_ft = convnext_tiny(pretrained=True)
num_ftrs = model_ft.head.in_features
model_ft.fc = nn.Linear(num_ftrs, 12)
model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)

定义训练和验证函数

alpha=0.2 Mixup所需的参数。

# 定义训练过程
alpha=0.2
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
        optimizer.zero_grad()
        output = model(data)
        loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

ACC=0
# 验证过程
def val(model, device, test_loader):
    global ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        if acc > ACC:
            torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
            ACC = acc

# 训练

for epoch in range(1, EPOCHS + 1):
    train(model_ft, DEVICE, train_loader, optimizer, epoch)
    cosine_schedule.step()
    val(model_ft, DEVICE, test_loader)

然后就可以开始训练了

训练10个epoch就能得到不错的结果:

测试

第一种写法

测试集存放的目录如下图:

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat', 'Fat Hen', 'Loose Silky-bent',
           'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

transform_test = transforms.Compose([
         transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

第三步 加载model,并将模型放在DEVICE里。

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model_8_0.971.pth")
model.eval()
model.to(DEVICE)

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

path = 'data/test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = model(img)
    # Predict
    _, pred = torch.max(out.data, 1)
    print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

测试完整代码:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat', 'Fat Hen', 'Loose Silky-bent',
           'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model_8_0.971.pth")
model.eval()
model.to(DEVICE)

path = 'data/test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = model(img)
    # Predict
    _, pred = torch.max(out.data, 1)
    print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

运行结果:

第二种写法

第二种,使用自定义的Dataset读取图片。前三步同上,差别主要在第四步。读取数据的时候,使用Dataset的SeedlingData读取。

dataset_test =SeedlingData('data/test/', transform_test,test=True)
print(len(dataset_test))
# 对应文件夹的label

for index in range(len(dataset_test)):
    item = dataset_test[index]
    img, label = item
    img.unsqueeze_(0)
    data = Variable(img).to(DEVICE)
    output = model(data)
    _, pred = torch.max(output.data, 1)
    print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()]))
    index += 1

运行结果:

以上就是ConvNeXt实战之实现植物幼苗分类的详细内容,更多关于ConvNeXt植物幼苗分类的资料请关注我们其它相关文章!

(0)

相关推荐

  • Unity实现植物识别示例详解

    接口介绍: 可识别超过2万种常见植物和近8千种花卉,接口返回植物的名称,并支持获取识别结果对应的百科信息:还可使用EasyDL定制训练平台,定制识别植物种类.适用于拍照识图.幼教科普.图像内容分析等场景. 创建应用: 在产品服务中搜索图像识别,创建应用,获取AppID.APIKey.SecretKey信息: 查阅官方文档,以下是植物识别接口返回数据参数详情: 定义数据结构: using System; /// <summary> /// 植物识别 /// </summary> [S

  • 基于Python实现简易的植物识别小系统

    导语 "  花草树木 皆有呈名 热爱自然,从认识自然开始 " 现在的植物爱好者,遇到不认得的植物.怎么办呢? 前几天去逛商场,一进商城一一一一门口的花店吸引了我的注意:摆放在店门口的各色鲜花植物花卉真的特别好看! 忍不住进门逛了一圈,发现我真的不认识,种类太多,对花卉的品种了解颇少. 回来之后找到了2款简单好用的植物识别APP一一一伴侣跟形色蛮好用的! 闲着也是闲着:默默用Python编写了一款简单的植物识别系统给大家正好la~ 正文 1)环境安装 本文用到的环境:Python3.7 

  • CoAtNet实战之对植物幼苗图像进行分类(pytorch)

    目录 前言 项目结构 数据集 安装库,并导入需要的库 设置全局参数 数据预处理 数据读取 设置模型 测试 前言 虽然Transformer在CV任务上有非常强的学习建模能力,但是由于缺少了像CNN那样的归纳偏置,所以相比于CNN,Transformer的泛化能力就比较差.因此,如果只有Transformer进行全局信息的建模,在没有预训练(JFT-300M)的情况下,Transformer在性能上很难超过CNN(VOLO在没有预训练的情况下,一定程度上也是因为VOLO的Outlook Atten

  • ConvNeXt实战之实现植物幼苗分类

    目录 前言 ConvNeXt残差模块 数据增强Cutout和Mixup 项目结构 数据集 导入模型文件 安装库,并导入需要的库 设置全局参数 数据预处理 设置模型 定义训练和验证函数 测试 第一种写法 第二种写法 前言 ConvNeXts 完全由标准 ConvNet 模块构建,在准确性和可扩展性方面与 Transformer 竞争,实现 87.8% ImageNet top-1 准确率,在 COCO 检测和 ADE20K 分割方面优于 Swin Transformers,同时保持标准 ConvN

  • 基于Matlab LBP实现植物叶片识别功能

    目录 一.LBP简介 1.1 课题的提出与研究意义 1.2 国内外相关研究情况 1.3 论文的主要研究工作 1.4 论文结构 二.部分源代码 三.运行结果 一.LBP简介 第一章 引言 植物在我们的身边随处可见,它们从产生发展进化到现在,其间经历了漫长的岁月.地球上的植物种类繁多.数量浩瀚,它们是生物圈的重要组成部分,在维持整个生物界的平衡方面发挥着巨大的作用:它们同时也是构成人类生存环境的重要组成部分,是人类社会延续和发展不可或缺的重要因素.由于植物对于地球和人类都具有如此重要的意义,对它们的

  • 详解vue项目构建与实战

    前言 由于vue相对来说比较平缓的学习过程和新颖的技术思路,使其受到了广大前后端开发者的青睐,同时其通俗易懂的API和数据绑定的功能也为其揽获了不少用户.本文主要讲解vue项目的构建与实战,因此不会太多涉及其API和语法部分,旨在帮助vue的入门级用户了解从零开始构建vue项目的步骤和方法. vue项目分类 首先,在构建一个vue项目之前我们需要了解vue项目的分类,这里我主要将其分为两类:(1)直接引入vue.js文件 (2)使用vue单文件组件 按以上两类来看,直接引入vue.js文件就像页

  • python机器学习理论与实战(四)逻辑回归

    从这节算是开始进入"正规"的机器学习了吧,之所以"正规"因为它开始要建立价值函数(cost function),接着优化价值函数求出权重,然后测试验证.这整套的流程是机器学习必经环节.今天要学习的话题是逻辑回归,逻辑回归也是一种有监督学习方法(supervised machine learning).逻辑回归一般用来做预测,也可以用来做分类,预测是某个类别^.^!线性回归想比大家都不陌生了,y=kx+b,给定一堆数据点,拟合出k和b的值就行了,下次给定X时,就可以计

  • python机器学习理论与实战(二)决策树

    决策树也是有监督机器学习方法. 电影<无耻混蛋>里有一幕游戏,在德军小酒馆里有几个人在玩20问题游戏,游戏规则是一个设迷者在纸牌中抽出一个目标(可以是人,也可以是物),而猜谜者可以提问题,设迷者只能回答是或者不是,在几个问题(最多二十个问题)之后,猜谜者通过逐步缩小范围就准确的找到了答案.这就类似于决策树的工作原理.(图一)是一个判断邮件类别的工作方式,可以看出判别方法很简单,基本都是阈值判断,关键是如何构建决策树,也就是如何训练一个决策树. (图一) 构建决策树的伪代码如下: Check i

  • python机器学习理论与实战(六)支持向量机

    上节基本完成了SVM的理论推倒,寻找最大化间隔的目标最终转换成求解拉格朗日乘子变量alpha的求解问题,求出了alpha即可求解出SVM的权重W,有了权重也就有了最大间隔距离,但是其实上节我们有个假设:就是训练集是线性可分的,这样求出的alpha在[0,infinite].但是如果数据不是线性可分的呢?此时我们就要允许部分的样本可以越过分类器,这样优化的目标函数就可以不变,只要引入松弛变量即可,它表示错分类样本点的代价,分类正确时它等于0,当分类错误时,其中Tn表示样本的真实标签-1或者1,回顾

  • Spring Boot集成Swagger2项目实战

    一.Swagger简介 上一篇文章中我们介绍了Spring Boot对Restful的支持,这篇文章我们继续讨论这个话题,不过,我们这里不再讨论Restful API如何实现,而是讨论Restful API文档的维护问题. 在日常的工作中,我们往往需要给前端(WEB端.IOS.Android)或者第三方提供接口,这个时候我们就需要给他们提供一份详细的API说明文档.但维护一份详细的文档可不是一件简单的事情.首先,编写一份详细的文档本身就是一件很费时费力的事情,另一方面,由于代码和文档是分离的,所

  • vue项目实战总结篇

    这篇文章把小编前段时间做的vue项目,做个完整的总结,具体内容请参考本文. 这次算是详细总结,会从项目的搭建,一直到最后的服务器上部署. 废话不多说了.干货直接上. 一. 必须node环境, 这次就不写node环境的安装了.过两天我会写个node环境的安装随笔. 二. node环境配好后.开整vue. 1. 安装vue脚手架. npm install -g vue-cli 2. 用脚手架搭项目(只是一行命令) vue init webpack-simple (项目名字) 或 vue init w

  • 利用TensorFlow训练简单的二分类神经网络模型的方法

    利用TensorFlow实现<神经网络与机器学习>一书中4.7模式分类练习 具体问题是将如下图所示双月牙数据集分类. 使用到的工具: python3.5    tensorflow1.2.1   numpy   matplotlib 1.产生双月环数据集 def produceData(r,w,d,num): r1 = r-w/2 r2 = r+w/2 #上半圆 theta1 = np.random.uniform(0, np.pi ,num) X_Col1 = np.random.unifo

随机推荐