Python如何加载模型并查看网络

目录
  • 加载模型并查看网络
    • 打开终端
  • 神经网络_模型的保存,模型的加载
    • 模型的保存(torch.save)
    • 模型的加载(torch.load)

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • PyTorch深度学习模型的保存和加载流程详解

    一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt..pth或.pkl). torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 . torch.nn.Module.state_dict()函数

  • python深度学习TensorFlow神经网络模型的保存和读取

    目录 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用.为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西. TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型. 下面代码给出了保存TensorFlow模型的方法: import tensorflow as tf # 声明两个变量

  • python使用tensorflow保存、加载和使用模型的方法

    使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用.介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍: http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 我对这篇文章进行了整理和汇总. 首先是模型的保存.直接上代码: #!/usr/bin/env python #-*- c

  • Python 实现LeNet网络模型的训练及预测

    目录 1.LeNet模型训练脚本 (1).下载CIFAR10数据集 (2).图像增强 (3).加载数据集 (4).显示部分图像 (5).初始化模型 (6).训练模型及保存模型参数 2.预测脚本 1.LeNet模型训练脚本 整体的训练代码如下,下面我会为大家详细讲解这些代码的意思 import torch import torchvision from torchvision.transforms import transforms import torch.nn as nn from torch

  • Python如何加载模型并查看网络

    目录 加载模型并查看网络 打开终端 神经网络_模型的保存,模型的加载 模型的保存(torch.save) 模型的加载(torch.load) 加载模型并查看网络 加载模型,以vgg19为例. 打开终端 > python Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit (AMD64)] on win32 Type "help", "copyright"

  • 解决python 无法加载downsample模型的问题

    downsample 在最新版本里面修改了位置 from theano.tensor.single import downsample (旧版本) 上面以上的的import会有error raise: from theano.tensor.signal import downsample ImportError: cannot import name 'downsample' 找到from theano.tensor.single import downsample所在文件,如: ...\lib

  • Python实现加载及解析properties配置文件的方法

    本文实例讲述了Python实现加载及解析properties配置文件的方法.分享给大家供大家参考,具体如下: 这里参考前面一篇:http://www.jb51.net/article/137393.htm 我们都是在java里面遇到要解析properties文件,在python中基本没有遇到这中情况,今天用python跑深度学习的时候,发现有些参数可以放在一个global.properties全局文件中,这样使用的时候更加方便.原理都是加载文件,然后用line方法进行解析判断"=",自

  • Tensorflow加载模型实现图像分类识别流程详解

    目录 前言 正文 VGG19网络介绍 总结 前言 深度学习框架在市面上有很多.比如Theano.Caffe.CNTK.MXnet .Tensorflow等.今天讲解的就是主角Tensorflow.Tensorflow的前身是Google大脑项目的一个分布式机器学习训练框架,它是一个十分基础且集成度很高的系统,它的目标就是为研究超大型规模的视觉项目,后面延申到各个领域.Tensorflow 在2015年正式开源,开源的一个月内就收获到1w多的starts,这足以说明Tensorflow的优越性以及

  • python动态加载包的方法小结

    本文实例总结了python动态加载包的方法.分享给大家供大家参考,具体如下: 动态加载模块有三种方法 1. 使用系统函数__import_() stringmodule = __import__('string') 2. 使用imp 模块 import imp stringmodule = imp.load_module('string',*imp.find_module('string')) imp.load_source("TYACMgrHandler_"+app.upper(),

  • TensorFlow获取加载模型中的全部张量名称代码

    核心代码如下: [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] 实例代码:(加载了Inceptino_v3的模型,并获取该模型所有节点的名称) # -*- coding: utf-8 -*- import tensorflow as tf import os model_dir = 'C:/Inception_v3' model_name = 'output_graph.pb' # 读取并创建一个图gr

  • Ajax bootstrap美化网页并实现页面的加载删除与查看详情

    Bookstrap:美化页面: Bootstrap是Twitter推出的一个开源的用于前端开发的工具包. 它由Twitter的设计师Mark Otto和Jacob Thornton合作开发,是一个CSS/HTML框架. Bootstrap提供了优雅的HTML和CSS规范,它即是由动态CSS语言Less写成. Bootstrap一经推出后颇受欢迎,一直是GitHub上的热门开源项目,包括NASA的MSNBC(微软全国广播公司)的Breaking News都使用了该项目. 只需要引用一些定义好的类,

  • python 动态加载的实现方法

    脚本语言都有一个优点,就是动态加载.lua语言有这个优点,python也有这个特性.说简单点就是,如果开发者发现自己的代码有bug,那么他可以在不关闭原来代码的基础之上,动态替换模块.替换方法一般用reload来完成. 1.reload的基本原理 reload主要做了两个动作,删除原来的模块,添加新的模块 2.reload的等效代码 del sys.modules[module_name] __import__(module_name) 3.reload使用的时候要注意什么 3.1 reload

  • python+django加载静态网页模板解析

    接着前面Django入门使用示例 今天我们来看看Django是如何加载静态html的? 我们首先来看一看什么是静态HTML,什么是动态的HTML?二者有什么区别? 静态HTML指的是使用单纯的HTML或者结合CSS制作的包括图片.文字等的只供用户浏览但不包含任何脚本.不含有任何交互功能的网页! 动态的HTML指的是网页不仅提供给用户浏览,网页本身还有交互功能,存在着在脚本如JAVASCRIPT,并利用某种服务器端语言如PHP等实现如用户注册,用户登录,上传文件,下载文件等功能 接下来,了解下加载

  • python pyinstaller 加载ui路径方法

    如下所示: class Login(QMainWindow): """登录窗口""" global status_s global connect_signal def __init__(self, *args): super(Login, self).__init__(*args) if getattr(sys,'frozen',False): bundle_dir = sys._MEIPASS else: bundle_dir = os.pa

随机推荐