Python利用 SVM 算法实现识别手写数字

目录
  • 前言
  • 使用 SVM 进行手写数字识别
  • 参数 C 和 γ 对识别手写数字精确度的影响
  • 完整代码

前言

支持向量机 (Support Vector Machine, SVM) 是一种监督学习技术,它通过根据指定的类对训练数据进行最佳分离,从而在高维空间中构建一个或一组超平面。在博文《OpenCV-Python实战(13)——OpenCV与机器学习的碰撞》中,我们已经学习了如何在 OpenCV 中实现和训练 SVM 算法,同时通过简单的示例了解了如何使用 SVM 算法。在本文中,我们将学习如何使用 SVM 分类器执行手写数字识别,同时也将探索不同的参数对于模型性能的影响,以获取具有最佳性能的 SVM 分类器。

使用 SVM 进行手写数字识别

我们已经在《利用 KNN 算法识别手写数字》中介绍了 MNIST 手写数字数据集,以及如何利用 KNN 算法识别手写数字。并通过对数字图像进行预处理( desew() 函数)并使用高级描述符( HOG 描述符)作为用于描述每个数字的特征向量来获得最佳分类准确率。因此,对于相同的内容不再赘述,接下来将直接使用在《利用 KNN 算法识别手写数字》中介绍预处理和 HOG 特征,利用 SVM 算法对数字图像进行分类。

首先加载数据,并将其划分为训练集和测试集:

# 加载数据
(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)
# 预处理函数
def deskew(img):
    m = cv2.moments(img)
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11'] / m['mu02']
    M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
    img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)

    return img
# HOG 高级描述符
def get_hog():
    hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)

    print("hog descriptor size: {}".format(hog.getDescriptorSize()))

    return hog
# 数据打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]

hog = get_hog()

hog_descriptors = []
for img in train_dataset:
    hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)

results = defaultdict(list)
# 数据划分
split_values = np.arange(0.1, 1, 0.1)

接下来,初始化 SVM,并进行训练:

# 模型初始化函数
def svm_init(C=12.5, gamma=0.50625):
    model = cv2.ml.SVM_create()
    model.setGamma(gamma)
    model.setC(C)
    model.setKernel(cv2.ml.SVM_RBF)
    model.setType(cv2.ml.SVM_C_SVC)
    model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))

    return model
# 模型训练函数
def svm_train(model, samples, responses):
    model.train(samples, cv2.ml.ROW_SAMPLE, responses)
    return model
# 模型预测函数
def svm_predict(model, samples):
    return model.predict(samples)[1].ravel()
# 模型评估函数
def svm_evaluate(model, samples, labels):
    predictions = svm_predict(model, samples)
    acc = (labels == predictions).mean()
    print('Percentage Accuracy: %.2f %%' % (acc * 100))
    return acc *100
# 使用不同训练集、测试集划分方法进行训练和测试
for split_value in split_values:
    partition = int(split_value * len(hog_descriptors))
    hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
    labels_train, labels_test = np.split(train_labels, [partition])

    print('Training SVM model ...')
    model = svm_init(C=12.5, gamma=0.50625)
    svm_train(model, hog_descriptors_train, labels_train)

    print('Evaluating model ... ')
    acc = svm_evaluate(model, hog_descriptors_test, labels_test)
    results['svm'].append(acc)

从上图所示,使用默认参数的 SVM 模型在使用 70% 的数字图像训练算法时准确率可以达到 98.60%,接下来我们通过修改 SVM 模型的参数 C 和 γ 来测试模型是否还有提升空间。

参数 C 和 γ 对识别手写数字精确度的影响

SVM 模型在使用 RBF 核时,有两个重要参数——C 和 γ,上例中我们使用 C=12.5 和 γ=0.50625 作为参数值,C 和 γ 的设定依赖于特定的数据集。因此,必须使用某种方法进行参数搜索,本例中使用网格搜索合适的参数 C 和 γ。

for C in [1, 10, 100, 1000]:
    for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
        model = svm_init(C, gamma)
        svm_train(model, hog_descriptors_train, labels_train)
        acc = svm_evaluate(model, hog_descriptors_test, labels_test)
        print(" {}".format("%.2f" % acc))
        results[C].append(acc)

最后,可视化结果:

fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]

for key in results:
    ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))

plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show()

程序的运行结果如下所示:

如图所示,通过使用不同参数,准确率可以达到 99.25% 左右。通过比较 KNN 分类器和 SVM 分类器在手写数字识别任务中的表现,我们可以得出在手写数字识别任务中 SVM 优于 KNN 分类器的结论。

完整代码

程序的完整代码如下所示:

import cv2
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import keras

(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)

def deskew(img):
    m = cv2.moments(img)
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11'] / m['mu02']
    M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
    img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)

    return img

def get_hog():
    hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)

    print("hog descriptor size: {}".format(hog.getDescriptorSize()))

    return hog

def svm_init(C=12.5, gamma=0.50625):
    model = cv2.ml.SVM_create()
    model.setGamma(gamma)
    model.setC(C)
    model.setKernel(cv2.ml.SVM_RBF)
    model.setType(cv2.ml.SVM_C_SVC)
    model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))

    return model

def svm_train(model, samples, responses):
    model.train(samples, cv2.ml.ROW_SAMPLE, responses)
    return model

def svm_predict(model, samples):
    return model.predict(samples)[1].ravel()

def svm_evaluate(model, samples, labels):
    predictions = svm_predict(model, samples)
    acc = (labels == predictions).mean()
    return acc * 100
# 数据打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]
# 使用 HOG 描述符
hog = get_hog()
hog_descriptors = []
for img in train_dataset:
    hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)

# 训练数据与测试数据划分
partition = int(0.9 * len(hog_descriptors))
hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
labels_train, labels_test = np.split(train_labels, [partition])

print('Training SVM model ...')
results = defaultdict(list)

for C in [1, 10, 100, 1000]:
    for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
        model = svm_init(C, gamma)
        svm_train(model, hog_descriptors_train, labels_train)
        acc = svm_evaluate(model, hog_descriptors_test, labels_test)
        print(" {}".format("%.2f" % acc))
        results[C].append(acc)

fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]
for key in results:
    ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))
plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show() 

以上就是Python利用 SVM 算法实现识别手写数字的详细内容,更多关于Python SVM算法识别手写数字的资料请关注我们其它相关文章!

(0)

相关推荐

  • 使用python svm实现直接可用的手写数字识别

    目录 python svm实现手写数字识别--直接可用 1.训练 1.1.训练数据集下载--已转化成csv文件 1.2 .训练源码 2.预测单张图片 2.1.待预测图像 2.2.预测源码 2.3.预测结果 python svm实现手写数字识别--直接可用 最近在做个围棋识别的项目,需要识别下面的数字,如下图: 我发现现在网上很多代码是良莠不齐,-真是一言难尽,于是记录一下,能够运行成功并识别成功的一个源码. 1.训练 1.1.训练数据集下载--已转化成csv文件 下载地址 1.2 .训练源码 t

  • Python-OpenCV实战:利用 KNN 算法识别手写数字

    目录 前言 手写数字数据集 MNIST 介绍 基准模型--利用 KNN 算法识别手写数字 改进模型1--参数 K 对识别手写数字精确度的影响 改进模型2--训练数据量对识别手写数字精确度的影响 改进模型3--预处理对识别手写数字精确度的影响 改进模型4--使用高级描述符作为图像特征提高 KNN 算法准确率 完整代码 相关链接 前言 K-最近邻 (k-nearest neighbours, KNN) 是监督学习中最简单的算法之一,KNN 可用于分类和回归问题,在博文<Python OpenCV实战

  • python实现基于SVM手写数字识别功能

    本文实例为大家分享了SVM手写数字识别功能的具体代码,供大家参考,具体内容如下 1.SVM手写数字识别 识别步骤: (1)样本图像的准备. (2)图像尺寸标准化:将图像大小都标准化为8*8大小. (3)读取未知样本图像,提取图像特征,生成图像特征组. (4)将未知测试样本图像特征组送入SVM进行测试,将测试的结果输出. 识别代码: #!/usr/bin/env python import numpy as np import mlpy import cv2 print 'loading ...'

  • 怎么用Python识别手势数字

    前言 谷歌出了一个开源的.跨平台的.可定制化的机器学习解决方案工具包,给在线流媒体(当然也可以用于普通的视频.图像等)提供了机器学习解决方案.感兴趣的同学可以打开这个网址了解详情:mediapipe.dev/ 它提供了手势.人体姿势.人脸.物品等识别和追踪功能,并提供了C++.Python.JavaScript等编程语言的工具包以及iOS.Android平台的解决方案,今天我们就来看一下如何使用MediaPipe提供的手势识别来写一个Python代码识别手势中的数字:0-5 准备工作 电脑需要安

  • 详解Python OpenCV数字识别案例

    前言 实践是检验真理的唯一标准. 因为觉得一板一眼地学习OpenCV太过枯燥,于是在网上找了一个以项目为导向的教程学习.话不多说,动手做起来. 一.案例介绍 提供信用卡上的数字模板: 要求:识别出信用卡上的数字,并将其直接打印在原图片上.虽然看起来很蠢,但既然可以将数字打印在图片上,说明已经成功识别数字,因此也可以将其转换为数字文本保存.车牌号识别等项目的思路与此案例类似. 示例: 原图 处理后的图 二.步骤 大致分为如下几个步骤: 1.模板读入 2.模板预处理,将模板数字分开,并排序 3.输入

  • Python实战小项目之Mnist手写数字识别

    目录 程序流程分析图: 传播过程: 代码展示: 创建环境 准备数据集 下载数据集 下载测试集 绘制图像 搭建神经网络 训练模型 测试模型 保存训练模型 运行结果展示: 程序流程分析图: 传播过程: 代码展示: 创建环境 使用<pip install+包名>来下载torch,torchvision包 准备数据集 设置一次训练所选取的样本数Batch_Sized的值为512,训练此时Epochs的值为8 BATCH_SIZE = 512 EPOCHS = 8 device = torch.devi

  • Python利用 SVM 算法实现识别手写数字

    目录 前言 使用 SVM 进行手写数字识别 参数 C 和 γ 对识别手写数字精确度的影响 完整代码 前言 支持向量机 (Support Vector Machine, SVM) 是一种监督学习技术,它通过根据指定的类对训练数据进行最佳分离,从而在高维空间中构建一个或一组超平面.在博文<OpenCV-Python实战(13)--OpenCV与机器学习的碰撞>中,我们已经学习了如何在 OpenCV 中实现和训练 SVM 算法,同时通过简单的示例了解了如何使用 SVM 算法.在本文中,我们将学习如何

  • Python Opencv使用ann神经网络识别手写数字功能

    opencv中也提供了一种类似于Keras的神经网络,即为ann,这种神经网络的使用方法与Keras的很接近.关于mnist数据的解析,读者可以自己从网上下载相应压缩文件,用python自己编写解析代码,由于这里主要研究knn算法,为了图简单,直接使用Keras的mnist手写数字解析模块.本次代码运行环境为:python 3.6.8opencv-python 4.4.0.46opencv-contrib-python 4.4.0.46 下面的代码为使用ann进行模型的训练: from kera

  • python实现识别手写数字 python图像识别算法

    写在前面 这一段的内容可以说是最难的一部分之一了,因为是识别图像,所以涉及到的算法会相比之前的来说比较困难,所以我尽量会讲得清楚一点. 而且因为在编写的过程中,把前面的一些逻辑也修改了一些,将其变得更完善了,所以一切以本篇的为准.当然,如果想要直接看代码,代码全部放在我的GitHub中,所以这篇文章主要负责讲解,如需代码请自行前往GitHub. 本次大纲 上一次写到了数据库的建立,我们能够实时的将更新的训练图片存入CSV文件中.所以这次继续往下走,该轮到识别图片的内容了. 首先我们需要从文件夹中

  • python使用KNN算法识别手写数字

    本文实例为大家分享了python使用KNN算法识别手写数字的具体代码,供大家参考,具体内容如下 # -*- coding: utf-8 -*- #pip install numpy import os import os.path from numpy import * import operator import time from os import listdir """ 描述: KNN算法实现分类器 参数: inputPoint:测试集 dataSet:训练集 lab

  • Python实现识别手写数字 Python图片读入与处理

    写在前面 在上一篇文章Python徒手实现手写数字识别-大纲中,我们已经讲过了我们想要写的全部思路,所以我们不再说全部的思路. 我这一次将图片的读入与处理的代码写了一下,和大纲写的过程一样,这一段代码分为以下几个部分: 读入图片: 将图片读取为灰度值矩阵: 图片背景去噪: 切割图片,得到手写数字的最小矩阵: 拉伸/压缩图片,得到标准大小为100x100大小矩阵: 将图片拉为1x10000大小向量,存入训练矩阵中. 所以下面将会对这几个函数进行详解. 代码分析 基础内容 首先我们现在最前面定义基础

  • Python实现识别手写数字大纲

    写在前面 其实我之前写过一个简单的识别手写数字的程序,但是因为逻辑比较简单,而且要求比较严苛,是在50x50大小像素的白底图上手写黑色数字,并且给的训练材料也不够多,导致准确率只能五五开.所以这一次准备写一个加强升级版的,借此来提升我对Python处理文件与图片的能力. 这次准备加强难度: 被识别图片可以是任意大小: 不一定是白底图,只要数字颜色是黑色,周围环境是浅色就行: 加强识别手写数字的逻辑,提升准确率. 因为我还没开始正式写,并且最近专业课程学习也比较紧迫,所以可能更新的比较慢.不过放心

  • Python实现识别手写数字 简易图片存储管理系统

    写在前面 上一篇文章Python实现识别手写数字-图像的处理中我们讲了图片的处理,将图片经过剪裁,拉伸等操作以后将每一个图片变成了1x10000大小的向量.但是如果只是这样的话,我们每一次运行的时候都需要将他们计算一遍,当图片特别多的时候会消耗大量的时间. 所以我们需要将这些向量存入一个文件当中,每次先看看图库中有没有新增的图片,如果有新增的图片,那么就将新增的图片变成1x10000向量再存入文件之中,然后从文件中读取全部图片向量即可.当图库中没有新增图片的时候,那么就直接调用文件中的图片向量进

  • Python神经网络TensorFlow基于CNN卷积识别手写数字

    目录 基础理论 一.训练CNN卷积神经网络 1.载入数据 2.改变数据维度 3.归一化 4.独热编码 5.搭建CNN卷积神经网络 5-1.第一层:第一个卷积层 5-2.第二层:第二个卷积层 5-3.扁平化 5-4.第三层:第一个全连接层 5-5.第四层:第二个全连接层(输出层) 6.编译 7.训练 8.保存模型 代码 二.识别自己的手写数字(图像) 1.载入数据 2.载入训练好的模型 3.载入自己写的数字图片并设置大小 4.转灰度图 5.转黑底白字.数据归一化 6.转四维数据 7.预测 8.显示

随机推荐