Python利用DNN实现宝石识别

目录
  • 任务描述
  • 深度神经网络(DNN)
  • 数据集介绍
    • 1.数据准备
    • 2.定义模型
    • 3.训练模型
    • 4.模型评估
    • 5.模型预测

任务描述

本次实践是一个多分类任务,需要将照片中的宝石分别进行识别,完成宝石的识别

实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图

深度神经网络(DNN)

深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input、hidden(可有多层)、output,每层均为全连接。

数据集介绍

  • 数据集文件名为archive_train.zip,archive_test.zip。
  • 该数据集包含25个类别不同宝石的图像。
  • 这些类别已经分为训练和测试数据。
  • 图像大小不一,格式为.jpeg。

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. This directory will be recovered automatically after resetting environment.
!ls /home/aistudio/data

data55032 dataset

#导入需要的包
import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import matplotlib.pyplot as plt

1.数据准备

'''
参数配置
'''
train_parameters = {
    "input_size": [3, 64, 64],                           #输入图片的shape
    "class_dim": -1,                                     #分类数
    'augment_path' : '/home/aistudio/augment',           #数据增强图片目录
    "src_path":"data/data55032/archive_train.zip",       #原始数据集路径
    "target_path":"/home/aistudio/data/dataset",        #要解压的路径
    "train_list_path": "./train_data.txt",              #train_data.txt路径
    "eval_list_path": "./val_data.txt",                  #eval_data.txt路径
    "label_dict":{},                                    #标签字典
    "readme_path": "/home/aistudio/data/readme.json",   #readme.json路径
    "num_epochs": 20,                                    #训练轮数
    "train_batch_size": 64,                             #批次的大小
    "learning_strategy": {                              #优化函数相关的配置
        "lr": 0.001                                     #超参数学习率
    }
}
def unzip_data(src_path,target_path):

    '''
    解压原始数据集,将src_path路径下的zip包解压至data/dataset目录下
    '''

    if(not os.path.isdir(target_path)):
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()
    else:
        print("文件已解压")
def get_data_list(target_path,train_list_path,eval_list_path, augment_path):
    '''
    生成数据列表
    '''
    #存放所有类别的信息
    class_detail = []
    #获取所有类别保存的文件夹名称
    data_list_path=target_path
    class_dirs = os.listdir(data_list_path)
    if '__MACOSX' in class_dirs:
        class_dirs.remove('__MACOSX')
    # #总的图像数量
    all_class_images = 0
    # #存放类别标签
    class_label=0
    # #存放类别数目
    class_dim = 0
    # #存储要写进eval.txt和train.txt中的内容
    trainer_list=[]
    eval_list=[]
    #读取每个类别
    for class_dir in class_dirs:
        if class_dir != ".DS_Store":
            class_dim += 1
            #每个类别的信息
            class_detail_list = {}
            eval_sum = 0
            trainer_sum = 0
            #统计每个类别有多少张图片
            class_sum = 0
            #获取类别路径
            path = os.path.join(data_list_path,class_dir)
            # print(path)
            # 获取所有图片
            img_paths = os.listdir(path)
            for img_path in img_paths:                                  # 遍历文件夹下的每个图片
                if img_path =='.DS_Store':
                    continue
                name_path = os.path.join(path,img_path)                       # 每张图片的路径
                if class_sum % 15 == 0:                                 # 每10张图片取一个做验证数据
                    eval_sum += 1                                       # eval_sum为测试数据的数目
                    eval_list.append(name_path + "\t%d" % class_label + "\n")
                else:
                    trainer_sum += 1
                    trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                class_sum += 1                                          #每类图片的数目
                all_class_images += 1                                   #所有类图片的数目
            # ----------------------------------数据增强----------------------------------
            aug_path = os.path.join(augment_path, class_dir)
            for img_path in os.listdir(aug_path):                                  # 遍历文件夹下的每个图片
                name_path = os.path.join(aug_path,img_path)                       # 每张图片的路径
                trainer_sum += 1
                trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                all_class_images += 1                                   #所有类图片的数目
            # ----------------------------------------------------------------------------
            # 说明的json文件的class_detail数据
            class_detail_list['class_name'] = class_dir             #类别名称
            class_detail_list['class_label'] = class_label          #类别标签
            class_detail_list['class_eval_images'] = eval_sum       #该类数据的测试集数目
            class_detail_list['class_trainer_images'] = trainer_sum #该类数据的训练集数目
            class_detail.append(class_detail_list)
            #初始化标签列表
            train_parameters['label_dict'][str(class_label)] = class_dir
            class_label += 1

    #初始化分类数
    train_parameters['class_dim'] = class_dim
    print(train_parameters)
    #乱序
    random.shuffle(eval_list)
    with open(eval_list_path, 'a') as f:
        for eval_image in eval_list:
            f.write(eval_image)
    #乱序
    random.shuffle(trainer_list)
    with open(train_list_path, 'a') as f2:
        for train_image in trainer_list:
            f2.write(train_image) 

    # 说明的json文件信息
    readjson = {}
    readjson['all_class_name'] = data_list_path                  #文件父目录
    readjson['all_class_images'] = all_class_images
    readjson['class_detail'] = class_detail
    jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': '))
    with open(train_parameters['readme_path'],'w') as f:
        f.write(jsons)
    print ('生成数据列表完成!')
def data_reader(file_list):
    '''
    自定义data_reader
    '''
    def reader():
        with open(file_list, 'r') as f:
            lines = [line.strip() for line in f]
            for line in lines:
                img_path, lab = line.strip().split('\t')
                img = Image.open(img_path)
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                img = img.resize((64, 64), Image.BILINEAR)
                img = np.array(img).astype('float32')
                img = img.transpose((2, 0, 1))  # HWC to CHW
                img = img/255                   # 像素值归一化
                yield img, int(lab)
    return reader
!pip install Augmentor
Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already satisfied: Augmentor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.2.8)
Requirement already satisfied: tqdm>=4.9.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (4.36.1)
Requirement already satisfied: future>=0.16.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (0.18.0)
Requirement already satisfied: numpy>=1.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (1.16.4)
Requirement already satisfied: Pillow>=5.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (7.1.2)

'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']
batch_size=train_parameters['train_batch_size']
augment_path = train_parameters['augment_path']
'''
解压原始数据到指定路径
'''
unzip_data(src_path,target_path)

文件已解压

def proc_img(src):
    for root, dirs, files in os.walk(src):
        if '__MACOSX' in root:continue
        for file in files:
            src=os.path.join(root,file)
            img=Image.open(src)
            if img.mode != 'RGB':
                    img = img.convert('RGB')
                    img.save(src)            

if __name__=='__main__':
    proc_img(r"data/dataset")
import os, Augmentor
import shutil, glob

if not os.path.exists(augment_path): # 控制不重复增强数据
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            path_ = os.path.join(root, name)
            if '__MACOSX' in path_:continue
            print('数据增强:',os.path.join(root, name))
            print('image:',os.path.join(root, name))
            p = Augmentor.Pipeline(os.path.join(root, name),output_directory='output')
            p.rotate(probability=0.6, max_left_rotation=2, max_right_rotation=2)
            p.zoom(probability=0.6, min_factor=0.9, max_factor=1.1)
            p.random_distortion(probability=0.4, grid_height=2, grid_width=2, magnitude=1)

            count = 1000 - len(glob.glob(pathname=path_+'/*.jpg'))
            p.sample(count, multi_threaded=False)
            p.process()

    print('将生成的图片拷贝到正确的目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in files:
            path_ = os.path.join(root, name)
            if path_.rsplit('/',3)[2] == 'output':
                type_ = path_.rsplit('/',3)[1]
                dest_dir = os.path.join(augment_path ,type_)
                if not os.path.exists(dest_dir):os.makedirs(dest_dir)
                dest_path_ = os.path.join(augment_path ,type_, name)
                shutil.move(path_, dest_path_)
    print('删除所有output目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            if name == 'output':
                path_ = os.path.join(root, name)
                shutil.rmtree(path_)
    print('完成数据增强')
Processing kunzite_20.jpg:   1%|          | 11/968 [00:00<00:14, 65.61 Samples/s]

数据增强: data/dataset/Kunzite
image: data/dataset/Kunzite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Kunzite/output.

Processing kunzite_14.jpg:   2%|▏         | 24/968 [00:00<00:17, 54.43 Samples/s]Processing kunzite_15.jpg: 100%|██████████| 968/968 [00:15<00:00, 61.57 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=350x366 at 0x7F7060EB06D0>: 100%|██████████| 32/32 [00:00<00:00, 269.33 Samples/s]
Processing almandine_5.jpg:   1%|          | 6/969 [00:00<00:20, 45.91 Samples/s] 

数据增强: data/dataset/Almandine
image: data/dataset/Almandine
Initialised with 31 image(s) found.
Output directory set to data/dataset/Almandine/output.

Processing almandine_2.jpg:   1%|▏         | 14/969 [00:00<00:27, 34.12 Samples/s] Processing almandine_25.jpg: 100%|██████████| 969/969 [00:22<00:00, 42.25 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E020C90>: 100%|██████████| 31/31 [00:00<00:00, 173.21 Samples/s]
Processing emerald_2.jpg:   1%|          | 10/964 [00:00<00:16, 58.72 Samples/s]

数据增强: data/dataset/Emerald
image: data/dataset/Emerald
Initialised with 36 image(s) found.
Output directory set to data/dataset/Emerald/output.

Processing emerald_36.jpg:   2%|▏         | 20/964 [00:00<00:17, 54.08 Samples/s]Processing emerald_15.jpg: 100%|██████████| 964/964 [00:26<00:00, 36.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=460x460 at 0x7F705DED0110>: 100%|██████████| 36/36 [00:00<00:00, 149.48 Samples/s]
Processing sapphire blue_9.jpg:   1%|          | 10/966 [00:00<00:13, 68.91 Samples/s]

数据增强: data/dataset/Sapphire Blue
image: data/dataset/Sapphire Blue
Initialised with 34 image(s) found.
Output directory set to data/dataset/Sapphire Blue/output.

Processing sapphire blue_16.jpg:   2%|▏         | 22/966 [00:00<00:16, 56.52 Samples/s]Processing sapphire blue_30.jpg: 100%|██████████| 966/966 [00:18<00:00, 53.08 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=450x450 at 0x7F706885B810>: 100%|██████████| 34/34 [00:00<00:00, 177.29 Samples/s]
Processing malachite_2.jpg:   1%|          | 10/972 [00:00<00:20, 47.64 Samples/s]

数据增强: data/dataset/Malachite
image: data/dataset/Malachite
Initialised with 28 image(s) found.
Output directory set to data/dataset/Malachite/output.

Processing malachite_16.jpg:   2%|▏         | 18/972 [00:00<00:20, 47.14 Samples/s]Processing malachite_22.jpg: 100%|██████████| 972/972 [00:18<00:00, 52.32 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=376x262 at 0x7F7060E93D10>: 100%|██████████| 28/28 [00:00<00:00, 173.34 Samples/s]
Processing alexandrite_0.jpg:   1%|          | 6/966 [00:00<00:24, 39.61 Samples/s] 

数据增强: data/dataset/Alexandrite
image: data/dataset/Alexandrite
Initialised with 34 image(s) found.
Output directory set to data/dataset/Alexandrite/output.

Processing alexandrite_23.jpg:   2%|▏         | 18/966 [00:00<00:21, 44.52 Samples/s]Processing alexandrite_20.jpg: 100%|██████████| 966/966 [00:20<00:00, 48.06 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705E025B10>: 100%|██████████| 34/34 [00:00<00:00, 129.49 Samples/s]
Processing zircon_8.jpg:   1%|          | 5/967 [00:00<00:33, 28.43 Samples/s] 

数据增强: data/dataset/Zircon
image: data/dataset/Zircon
Initialised with 33 image(s) found.
Output directory set to data/dataset/Zircon/output.

Processing zircon_23.jpg:   1%|          | 6/967 [00:00<00:33, 28.43 Samples/s]Processing zircon_24.jpg: 100%|██████████| 967/967 [00:24<00:00, 38.88 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705DEAC3D0>: 100%|██████████| 33/33 [00:00<00:00, 134.76 Samples/s]
Processing onyx black_16.jpg:   1%|          | 8/972 [00:00<00:13, 69.17 Samples/s]

数据增强: data/dataset/Onyx Black
image: data/dataset/Onyx Black
Initialised with 28 image(s) found.
Output directory set to data/dataset/Onyx Black/output.

Processing onyx black_6.jpg:   2%|▏         | 18/972 [00:00<00:18, 51.84 Samples/s] Processing onyx black_2.jpg: 100%|██████████| 972/972 [00:18<00:00, 53.19 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DEE1910>: 100%|██████████| 28/28 [00:00<00:00, 131.50 Samples/s]
Processing rhodochrosite_29.jpg:   1%|          | 10/971 [00:00<00:18, 53.20 Samples/s]

数据增强: data/dataset/Rhodochrosite
image: data/dataset/Rhodochrosite
Initialised with 29 image(s) found.
Output directory set to data/dataset/Rhodochrosite/output.

Processing rhodochrosite_21.jpg:   2%|▏         | 21/971 [00:00<00:16, 58.01 Samples/s]Processing rhodochrosite_15.jpg: 100%|██████████| 971/971 [00:20<00:00, 46.42 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=373x356 at 0x7F705E011910>: 100%|██████████| 29/29 [00:00<00:00, 243.76 Samples/s]
Processing diamond_16.jpg:   1%|          | 5/969 [00:00<00:28, 34.31 Samples/s]

数据增强: data/dataset/Diamond
image: data/dataset/Diamond
Initialised with 31 image(s) found.
Output directory set to data/dataset/Diamond/output.

Processing diamond_6.jpg:   1%|          | 11/969 [00:00<00:26, 35.79 Samples/s] Processing diamond_20.jpg: 100%|██████████| 969/969 [00:24<00:00, 40.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE6CCD0>: 100%|██████████| 31/31 [00:00<00:00, 150.83 Samples/s]
Processing benitoite_29.jpg:   1%|          | 7/969 [00:00<00:15, 63.04 Samples/s]

数据增强: data/dataset/Benitoite
image: data/dataset/Benitoite
Initialised with 31 image(s) found.
Output directory set to data/dataset/Benitoite/output.

Processing benitoite_2.jpg:   2%|▏         | 24/969 [00:00<00:16, 57.15 Samples/s] Processing benitoite_12.jpg: 100%|██████████| 969/969 [00:17<00:00, 55.09 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=472x433 at 0x7F705DFE9290>: 100%|██████████| 31/31 [00:00<00:00, 178.70 Samples/s]
Processing pearl_0.jpg:   1%|          | 6/967 [00:00<00:25, 38.13 Samples/s] 

数据增强: data/dataset/Pearl
image: data/dataset/Pearl
Initialised with 33 image(s) found.
Output directory set to data/dataset/Pearl/output.

Processing pearl_32.jpg:   2%|▏         | 21/967 [00:00<00:20, 47.09 Samples/s]Processing pearl_12.jpg: 100%|██████████| 967/967 [00:17<00:00, 54.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020A50>: 100%|██████████| 33/33 [00:00<00:00, 205.47 Samples/s]
Processing beryl golden_39.jpg:   1%|          | 11/964 [00:00<00:12, 79.36 Samples/s]

数据增强: data/dataset/Beryl Golden
image: data/dataset/Beryl Golden
Initialised with 36 image(s) found.
Output directory set to data/dataset/Beryl Golden/output.

Processing beryl golden_29.jpg:   2%|▏         | 22/964 [00:00<00:14, 63.92 Samples/s]Processing beryl golden_2.jpg: 100%|██████████| 964/964 [00:16<00:00, 58.61 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE6F910>: 100%|██████████| 36/36 [00:00<00:00, 273.71 Samples/s]
Processing labradorite_16.jpg:   1%|          | 9/960 [00:00<00:17, 55.49 Samples/s]

数据增强: data/dataset/Labradorite
image: data/dataset/Labradorite
Initialised with 40 image(s) found.
Output directory set to data/dataset/Labradorite/output.

Processing labradorite_17.jpg:   2%|▏         | 20/960 [00:00<00:18, 52.03 Samples/s]Processing labradorite_11.jpg: 100%|██████████| 960/960 [00:21<00:00, 45.63 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE70F10>: 100%|██████████| 40/40 [00:00<00:00, 117.40 Samples/s]
Processing fluorite_23.jpg:   1%|          | 11/968 [00:00<00:14, 65.24 Samples/s]

数据增强: data/dataset/Fluorite
image: data/dataset/Fluorite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Fluorite/output.

Processing fluorite_4.jpg:   1%|▏         | 14/968 [00:00<00:19, 49.03 Samples/s] Processing fluorite_4.jpg: 100%|██████████| 968/968 [00:21<00:00, 44.39 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x442 at 0x7F705DE87CD0>: 100%|██████████| 32/32 [00:00<00:00, 169.43 Samples/s]
Processing iolite_2.jpg:   1%|          | 7/968 [00:00<00:24, 39.15 Samples/s] 

数据增强: data/dataset/Iolite
image: data/dataset/Iolite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Iolite/output.

Processing iolite_35.jpg:   2%|▏         | 23/968 [00:00<00:18, 51.39 Samples/s]Processing iolite_23.jpg: 100%|██████████| 968/968 [00:16<00:00, 57.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE764D0>: 100%|██████████| 32/32 [00:00<00:00, 373.16 Samples/s]
Processing quartz beer_24.jpg:   1%|          | 12/965 [00:00<00:16, 57.87 Samples/s]

数据增强: data/dataset/Quartz Beer
image: data/dataset/Quartz Beer
Initialised with 35 image(s) found.
Output directory set to data/dataset/Quartz Beer/output.

Processing quartz beer_28.jpg:   2%|▏         | 24/965 [00:00<00:14, 65.30 Samples/s]Processing quartz beer_30.jpg: 100%|██████████| 965/965 [00:16<00:00, 59.48 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=300x300 at 0x7F705DE82DD0>: 100%|██████████| 35/35 [00:00<00:00, 173.58 Samples/s]
Processing garnet red_21.jpg:   1%|          | 7/964 [00:00<00:34, 27.76 Samples/s]

数据增强: data/dataset/Garnet Red
image: data/dataset/Garnet Red
Initialised with 36 image(s) found.
Output directory set to data/dataset/Garnet Red/output.

Processing garnet red_2.jpg:   2%|▏         | 17/964 [00:00<00:28, 33.50 Samples/s] Processing garnet red_2.jpg: 100%|██████████| 964/964 [00:20<00:00, 46.97 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020090>: 100%|██████████| 36/36 [00:00<00:00, 197.00 Samples/s]
Processing danburite_35.jpg:   1%|          | 8/968 [00:00<00:16, 58.65 Samples/s]

数据增强: data/dataset/Danburite
image: data/dataset/Danburite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Danburite/output.

Processing danburite_32.jpg:   2%|▏         | 17/968 [00:00<00:19, 49.88 Samples/s]Processing danburite_23.jpg: 100%|██████████| 968/968 [00:19<00:00, 50.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE78390>: 100%|██████████| 32/32 [00:00<00:00, 144.25 Samples/s]
Processing cats eye_7.jpg:   1%|          | 8/969 [00:00<00:24, 39.01 Samples/s] 

数据增强: data/dataset/Cats Eye
image: data/dataset/Cats Eye
Initialised with 31 image(s) found.
Output directory set to data/dataset/Cats Eye/output.

Processing cats eye_26.jpg:   2%|▏         | 15/969 [00:00<00:23, 41.33 Samples/s]Processing cats eye_33.jpg: 100%|██████████| 969/969 [00:25<00:00, 38.19 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=401x401 at 0x7F706AF09510>: 100%|██████████| 31/31 [00:00<00:00, 214.03 Samples/s]
Processing hessonite_1.jpg:   0%|          | 3/970 [00:00<00:33, 28.84 Samples/s] 

数据增强: data/dataset/Hessonite
image: data/dataset/Hessonite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Hessonite/output.

Processing hessonite_19.jpg:   1%|▏         | 13/970 [00:00<00:31, 30.34 Samples/s]Processing hessonite_33.jpg: 100%|██████████| 970/970 [00:20<00:00, 47.73 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020610>: 100%|██████████| 30/30 [00:00<00:00, 162.33 Samples/s]
Processing carnelian_12.jpg:   1%|          | 5/967 [00:00<00:28, 34.19 Samples/s]

数据增强: data/dataset/Carnelian
image: data/dataset/Carnelian
Initialised with 33 image(s) found.
Output directory set to data/dataset/Carnelian/output.

Processing carnelian_32.jpg:   1%|          | 12/967 [00:00<00:29, 32.65 Samples/s]Processing carnelian_31.jpg: 100%|██████████| 967/967 [00:24<00:00, 39.93 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=425x425 at 0x7F705DE840D0>: 100%|██████████| 33/33 [00:00<00:00, 147.85 Samples/s]
Processing jade_26.jpg:   1%|          | 9/972 [00:00<00:25, 38.24 Samples/s]

数据增强: data/dataset/Jade
image: data/dataset/Jade
Initialised with 28 image(s) found.
Output directory set to data/dataset/Jade/output.

Processing jade_20.jpg:   2%|▏         | 22/972 [00:00<00:19, 47.93 Samples/s]Processing jade_18.jpg: 100%|██████████| 972/972 [00:18<00:00, 51.18 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE8B050>: 100%|██████████| 28/28 [00:00<00:00, 331.02 Samples/s]
Processing variscite_22.jpg:   1%|          | 5/970 [00:00<00:25, 37.31 Samples/s]

数据增强: data/dataset/Variscite
image: data/dataset/Variscite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Variscite/output.

Processing variscite_10.jpg:   1%|▏         | 13/970 [00:00<00:26, 35.70 Samples/s]Processing variscite_31.jpg: 100%|██████████| 970/970 [00:21<00:00, 45.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE7BE50>: 100%|██████████| 30/30 [00:00<00:00, 157.22 Samples/s]
Processing tanzanite_2.jpg:   1%|          | 5/964 [00:00<00:31, 30.52 Samples/s] 

数据增强: data/dataset/Tanzanite
image: data/dataset/Tanzanite
Initialised with 36 image(s) found.
Output directory set to data/dataset/Tanzanite/output.

Processing tanzanite_15.jpg:   2%|▏         | 15/964 [00:00<00:25, 36.60 Samples/s]Processing tanzanite_37.jpg: 100%|██████████| 964/964 [00:25<00:00, 38.41 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E00E4D0>: 100%|██████████| 36/36 [00:00<00:00, 144.18 Samples/s]                 

将生成的图片拷贝到正确的目录
删除所有output目录
完成数据增强

#每次生成数据列表前,首先清空train.txt和eval.txt
with open(train_list_path, 'w') as f:
    f.seek(0)
    f.truncate()
with open(eval_list_path, 'w') as f:
    f.seek(0)
    f.truncate() 

#生成数据列表
get_data_list(target_path,train_list_path,eval_list_path,augment_path)

'''
构造数据提供器
'''
train_reader = paddle.batch(data_reader(train_list_path),
                            batch_size=batch_size,
                            drop_last=True)
eval_reader = paddle.batch(data_reader(eval_list_path),
                            batch_size=batch_size,
                            drop_last=True)
{'input_size': [3, 64, 64], 'class_dim': 25, 'augment_path': '/home/aistudio/augment', 'src_path': 'data/data55032/archive_train.zip', 'target_path': '/home/aistudio/data/dataset', 'train_list_path': './train_data.txt', 'eval_list_path': './val_data.txt', 'label_dict': {'0': 'Kunzite', '1': 'Almandine', '2': 'Emerald', '3': 'Sapphire Blue', '4': 'Malachite', '5': 'Alexandrite', '6': 'Zircon', '7': 'Onyx Black', '8': 'Rhodochrosite', '9': 'Diamond', '10': 'Benitoite', '11': 'Pearl', '12': 'Beryl Golden', '13': 'Labradorite', '14': 'Fluorite', '15': 'Iolite', '16': 'Quartz Beer', '17': 'Garnet Red', '18': 'Danburite', '19': 'Cats Eye', '20': 'Hessonite', '21': 'Carnelian', '22': 'Jade', '23': 'Variscite', '24': 'Tanzanite'}, 'readme_path': '/home/aistudio/data/readme.json', 'num_epochs': 20, 'train_batch_size': 64, 'learning_strategy': {'lr': 0.001}}
生成数据列表完成!
Batch=0
Batchs=[]
all_train_accs=[]
def draw_train_acc(Batchs, train_accs):
    title="training accs"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("acc", fontsize=14)
    plt.plot(Batchs, train_accs, color='green', label='training accs')
    plt.legend()
    plt.grid()
    plt.show()

all_train_loss=[]
def draw_train_loss(Batchs, train_loss):
    title="training loss"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("loss", fontsize=14)
    plt.plot(Batchs, train_loss, color='red', label='training loss')
    plt.legend()
    plt.grid()
    plt.show()

2.定义模型

###在以下cell中完成DNN网络的定义###

#定义网络
class MyDNN(fluid.dygraph.Layer):
    '''
    卷积神经网络
    '''
    def __init__(self):
        super(MyDNN,self).__init__()
        self.hidden1=fluid.dygraph.Linear(3*64*64,1000, act='relu')
        self.hidden2=fluid.dygraph.Linear(1000,500, act='relu')
        self.hidden3=fluid.dygraph.Linear(500,100, act='relu')
        self.out = fluid.dygraph.Linear(input_dim=100, output_dim=25, act='softmax')

    def forward(self,input):
        x = fluid.layers.reshape(input,shape=[-1,3*64*64])
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)
        return x

3.训练模型

with fluid.dygraph.guard(place = fluid.CUDAPlace(0)):
    print(train_parameters['class_dim'])
    print(train_parameters['label_dict'])
    model=MyDNN() #模型实例化
    model.train() #训练模式
    opt=fluid.optimizer.SGDOptimizer(learning_rate=train_parameters['learning_strategy']['lr'], parameter_list=model.parameters())#优化器选用SGD随机梯度下降,学习率为0.001.
    epochs_num=train_parameters['num_epochs'] #迭代次数

    for pass_num in range(epochs_num):
        for batch_id,data in enumerate(train_reader()):
            images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
            labels = np.array([x[1] for x in data]).astype('int64')
            labels = labels[:, np.newaxis]

            image=fluid.dygraph.to_variable(images)
            label=fluid.dygraph.to_variable(labels)

            predict=model(image) #数据传入model

            loss=fluid.layers.cross_entropy(predict,label)
            avg_loss=fluid.layers.mean(loss)#获取loss值

            acc=fluid.layers.accuracy(predict,label)#计算精度

            if batch_id!=0 and batch_id%5==0:
                Batch = Batch+5
                Batchs.append(Batch)
                all_train_loss.append(avg_loss.numpy()[0])
                all_train_accs.append(acc.numpy()[0])

                print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num,batch_id,avg_loss.numpy(),acc.numpy()))

            avg_loss.backward()
            opt.minimize(avg_loss)    #优化器对象的minimize方法对参数进行更新
            model.clear_gradients()   #model.clear_gradients()来重置梯度
    fluid.save_dygraph(model.state_dict(),'MyDNN')#保存模型

draw_train_acc(Batchs,all_train_accs)
draw_train_loss(Batchs,all_train_loss)

train_pass:19,batch_id:400,train_loss:[0.24890603],train_acc:[0.96875]

4.模型评估

#模型评估
with fluid.dygraph.guard():
    accs = []
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    for batch_id,data in enumerate(eval_reader()):#测试集
        images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
        labels = np.array([x[1] for x in data]).astype('int64')
        labels = labels[:, np.newaxis]
        image=fluid.dygraph.to_variable(images)
        label=fluid.dygraph.to_variable(labels)
        predict=model(image)
        acc=fluid.layers.accuracy(predict,label)
        accs.append(acc.numpy()[0])
        avg_acc = np.mean(accs)
    print(avg_acc)

0.96875

5.模型预测

import os
import zipfile

def unzip_infer_data(src_path,target_path):
    '''
    解压预测数据集
    '''
    if(not os.path.isdir(target_path)):
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()

def load_image(img_path):
    '''
    预测图片预处理
    '''
    img = Image.open(img_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    img = img.resize((64, 64), Image.BILINEAR)
    img = np.array(img).astype('float32')
    img = img.transpose((2, 0, 1))  # HWC to CHW
    img = img/255                # 像素值归一化
    return img

infer_src_path = '/home/aistudio/data/data55032/archive_test.zip'
infer_dst_path = '/home/aistudio/data/archive_test'
unzip_infer_data(infer_src_path,infer_dst_path)
label_dic = train_parameters['label_dict']

'''
模型预测
'''
with fluid.dygraph.guard():
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式

    #展示预测图片
    infer_path='data/archive_test/alexandrite_3.jpg'
    img = Image.open(infer_path)
    plt.imshow(img)          #根据数组绘制图像
    plt.show()               #显示图像

    #对预测图片进行预处理
    infer_imgs = []
    infer_imgs.append(load_image(infer_path))
    infer_imgs = np.array(infer_imgs)

    for i in range(len(infer_imgs)):
        data = infer_imgs[i]
        dy_x_data = np.array(data).astype('float32')
        dy_x_data=dy_x_data[np.newaxis,:, : ,:]
        img = fluid.dygraph.to_variable(dy_x_data)
        out = model(img)
        lab = np.argmax(out.numpy())  #argmax():返回最大数的索引

        print("第{}个样本,被预测为:{},真实标签为:{}".format(i+1,label_dic[str(lab)],infer_path.split('/')[-1].split("_")[0]))

print("结束")

第1个样本,被预测为:Malachite,真实标签为:alexandrite 结束

以上就是Python利用DNN实现宝石识别的详细内容,更多关于Python DNN宝石识别的资料请关注我们其它相关文章!

(0)

相关推荐

  • Python神经网络TensorFlow基于CNN卷积识别手写数字

    目录 基础理论 一.训练CNN卷积神经网络 1.载入数据 2.改变数据维度 3.归一化 4.独热编码 5.搭建CNN卷积神经网络 5-1.第一层:第一个卷积层 5-2.第二层:第二个卷积层 5-3.扁平化 5-4.第三层:第一个全连接层 5-5.第四层:第二个全连接层(输出层) 6.编译 7.训练 8.保存模型 代码 二.识别自己的手写数字(图像) 1.载入数据 2.载入训练好的模型 3.载入自己写的数字图片并设置大小 4.转灰度图 5.转黑底白字.数据归一化 6.转四维数据 7.预测 8.显示

  • Python通过TensorFlow卷积神经网络实现猫狗识别

    这份数据集来源于Kaggle,数据集有12500只猫和12500只狗.在这里简单介绍下整体思路 处理数据 设计神经网络 进行训练测试 1. 数据处理 将图片数据处理为 tf 能够识别的数据格式,并将数据设计批次. 第一步get_files() 方法读取图片,然后根据图片名,添加猫狗 label,然后再将 image和label 放到 数组中,打乱顺序返回 将第一步处理好的图片 和label 数组 转化为 tensorflow 能够识别的格式,然后将图片裁剪和补充进行标准化处理,分批次返回. 新建

  • Python与人工神经网络:使用神经网络识别手写图像介绍

    人体的视觉系统是一个相当神奇的存在,对于下面的一串手写图像,可以毫不费力的识别出他们是504192,轻松到让人都忘记了其实这是一个复杂的工作. 实际上在我们的大脑的左脑和右脑的皮层都有一个第一视觉区域,叫做V1,里面有14亿视觉神经元.而且,在我们识别上面的图像的时候,工作的不止有V1,还有V2.V3.V4.V5,所以这么一看,我们确实威武. 但是让计算机进行模式识别,就比较复杂了,主要困难在于我们如何给计算机描述一个数字9在图像上应该是怎样的,比如我们跟计算机说,9的上面是一个圈,下右边是1竖

  • python神经网络编程之手写数字识别

    写在之前 首先是写在之前的一些建议: 首先是关于这本书,我真的认为他是将神经网络里非常棒的一本书,但你也需要注意,如果你真的想自己动手去实现,那么你一定需要有一定的python基础,并且还需要有一些python数据科学处理能力 然后希望大家在看这边博客的时候对于神经网络已经有一些了解了,知道什么是输入层,什么是输出层,并且明白他们的一些理论,在这篇博客中我们仅仅是展开一下代码: 然后介绍一下本篇博客的环境等: 语言:Python3.8.5 环境:jupyter 库文件: numpy | matp

  • python神经网络编程实现手写数字识别

    本文实例为大家分享了python实现手写数字识别的具体代码,供大家参考,具体内容如下 import numpy import scipy.special #import matplotlib.pyplot class neuralNetwork: def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate): self.inodes=inputnodes self.hnodes=hiddennodes self.onodes

  • Python利用DNN实现宝石识别

    目录 任务描述 深度神经网络(DNN) 数据集介绍 1.数据准备 2.定义模型 3.训练模型 4.模型评估 5.模型预测 任务描述 本次实践是一个多分类任务,需要将照片中的宝石分别进行识别,完成宝石的识别 实践平台:百度AI实训平台-AI Studio.PaddlePaddle1.8.0 动态图 深度神经网络(DNN) 深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input.hidden(可有多层).output,每层均为全连接. 数据集介绍 数

  • python利用Opencv实现人脸识别功能

    本文实例为大家分享了python利用Opencv实现人脸识别功能的具体代码,供大家参考,具体内容如下 首先:需要在在自己本地安装opencv具体步骤可以问度娘 如果从事于开发中的话建议用第三方的人脸识别(推荐阿里) 1.视频流中进行人脸识别 # -*- coding: utf-8 -*- import cv2 import sys from PIL import Image def CatchUsbVideo(window_name, camera_idx): cv2.namedWindow(w

  • python利用pytesseract 实现本地识别图片文字

    #!/usr/bin/env python3 # -*- coding: utf-8 -*- import glob from os import path import os import pytesseract from PIL import Image from queue import Queue import threading import datetime import cv2 def convertimg(picfile, outdir): '''调整图片大小,对于过大的图片进行

  • Python利用 SVM 算法实现识别手写数字

    目录 前言 使用 SVM 进行手写数字识别 参数 C 和 γ 对识别手写数字精确度的影响 完整代码 前言 支持向量机 (Support Vector Machine, SVM) 是一种监督学习技术,它通过根据指定的类对训练数据进行最佳分离,从而在高维空间中构建一个或一组超平面.在博文<OpenCV-Python实战(13)--OpenCV与机器学习的碰撞>中,我们已经学习了如何在 OpenCV 中实现和训练 SVM 算法,同时通过简单的示例了解了如何使用 SVM 算法.在本文中,我们将学习如何

  • Python利用逻辑回归模型解决MNIST手写数字识别问题详解

    本文实例讲述了Python利用逻辑回归模型解决MNIST手写数字识别问题.分享给大家供大家参考,具体如下: 1.MNIST手写识别问题 MNIST手写数字识别问题:输入黑白的手写阿拉伯数字,通过机器学习判断输入的是几.可以通过TensorFLow下载MNIST手写数据集,通过import引入MNIST数据集并进行读取,会自动从网上下载所需文件. %matplotlib inline import tensorflow as tf import tensorflow.examples.tutori

  • python利用百度云接口实现车牌识别的示例

    一个小需求---实现车牌识别. 目前有两个想法 1. 调云在线的接口或者使用SDK做开发(配置环境和编译第三方库很麻烦,当然使用python可以避免这些问题) 2. 自己实现车牌识别算法(复杂) 一开始准备使用百度云文字识别C++ SDK来做,发现需要准备curl.jsoncpp和OpenCV,并且curl和jsoncpp需要自己编译,很麻烦,所以换用了python来做,真的是顺畅简单. 1. 安装python环境(我用python3.7) python官网下载地址:https://www.py

  • python 利用百度API识别图片文字(多线程版)

    #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Jun 12 09:37:38 2018 利用百度api实现图片文本识别 @author: XnCSD """ import glob from os import path import os from aip import AipOcr from PIL import Image from queue impor

  • Python利用机器学习算法实现垃圾邮件的识别

    开发工具 **Python版本:**3.6.4 相关模块: scikit-learn模块: jieba模块: numpy模块: 以及一些Python自带的模块. 环境搭建 安装Python并添加到环境变量,pip安装需要的相关模块即可. 逐步实现 (1)划分数据集 网上用于垃圾邮件识别的数据集大多是英文邮件,所以为了表示诚意,我花了点时间找了一份中文邮件的数据集.数据集划分如下: 训练数据集: 7063封正常邮件(data/normal文件夹下): 7775封垃圾邮件(data/spam文件夹下

  • Python摸鱼神器之利用树莓派opencv人脸识别自动控制电脑显示桌面

    前言 老早就看到新闻员工通过人脸识别监控老板来摸鱼. 有时候摸鱼太入迷了,经常在上班时间玩其他的东西被老板看到.自从在咸鱼上淘了一个树莓派3b,尝试做了一下内网穿透,搭建网站就吃灰了,接下来突发奇想就买了一个摄像头和延长线 接下来就是敲代码了 环境 树莓派3+ python3.7 win7 python3.6 过程 首先树莓派和电脑要在一个内网下面,就是一个路由器下面吧.要在树莓派设置里面开启摄像头,然后安装cv2,cv2有很多依赖库需要手动安装,很是费脑筋.原理介绍一下,人脸识别主要是依赖op

  • Python opencv实现人眼/人脸识别以及实时打码处理

    利用Python+opencv实现从摄像头捕获图像,识别其中的人眼/人脸,并打上马赛克. 系统环境:Windows 7 + Python 3.6.3 + opencv 3.4.2 一.系统.资源准备 要想达成该目标,需要满足一下几个条件: 找一台带有摄像头的电脑,一般笔记本即可: 需配有Python3,并安装NumPy包.opencv: 需要有已经训练好的分类器,用于识别视频中的人脸.人眼等,如无分类器,可以点击这里下载:haarcascades分类器 二.动手做 1.导入相关包.设置视频格式.

随机推荐