pytorch自定义二值化网络层方式

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • pytorch自定义初始化权重的方法

    在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化.但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值. 核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值.但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable. import torch import torch.nn as nn import torch.optim as optim imp

  • pytorch构建网络模型的4种方法

    利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种. 假设构建一个网络模型如下: 卷积层-->Relu层-->池化层-->全连接层-->Relu层-->全连接层 首先导入几种方法用到的包: import torch import torch.nn.functional as F from collections import OrderedDict 第一种方法 # Method 1 --------------------------------------

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch自定义二值化网络层方式

    任务要求: 自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下: import torch from torch.autograd import Function from torch.autograd import Variable 定义二值化函数 class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) a = torch

  • python实现图片二值化及灰度处理方式

    我就废话不多说了,直接上代码吧! 集成环境:win10 pycharm #!/usr/bin/env python3.5.2 # -*- coding: utf-8 -*- '''4图片灰度调整及二值化: 集成环境:win10 python3 Pycharm ''' from PIL import Image # load a color image im = Image.open('picture\\haha.png' )#当前目录创建picture文件夹 # convert to grey

  • python图片二值化提高识别率代码实例

    这篇文章主要介绍了python图片二值化提高识别率代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 代码如下 import cv2from PIL import Imagefrom pytesseract import pytesseractfrom PIL import ImageEnhanceimport reimport string def createFile(filePath,newFilePath): img = Image

  • python 图片二值化处理(处理后为纯黑白的图片)

    先随便招一张图片test.jpg做案例 然后对图片进行处理 # 图片二值化 from PIL import Image img = Image.open('test.jpg') # 模式L"为灰色图像,它的每个像素用8个bit表示,0表示黑,255表示白,其他数字表示不同的灰度. Img = img.convert('L') Img.save("test1.jpg") # 自定义灰度界限,大于这个值为黑色,小于这个值为白色 threshold = 200 table = []

  • c#实现图片二值化例子(黑白效果)

    C#将图片2值化示例代码,原图及二值化后的图片如下: 原图: 二值化后的图像: 实现代码: using System; using System.Drawing; namespace BMP2Grey { class Program { static void ToGrey(Bitmap img1) { for (int i = 0; i < img1.Width; i++) { for (int j = 0; j < img1.Height; j++) { Color pixelColor

  • 基于c#图像灰度化、灰度反转、二值化的实现方法详解

    图像灰度化:将彩色图像转化成为灰度图像的过程成为图像的灰度化处理.彩色图像中的每个像素的颜色有R.G.B三个分量决定,而每个分量有255中值可取,这样一个像素点可以有1600多万(255*255*255)的颜色的变化范围.而灰度图像是R.G.B三个分量相同的一种特殊的彩色图像,其一个像素点的变化范围为255种,所以在数字图像处理种一般先将各种格式的图像转变成灰度图像以使后续的图像的计算量变得少一些.灰度图像的描述与彩色图像一样仍然反映了整幅图像的整体和局部的色度和亮度等级的分布和特征.图像的灰度

  • C#数字图像处理之图像二值化(彩色变黑白)的方法

    本文实例讲述了C#数字图像处理之图像二值化(彩色变黑白)的方法.分享给大家供大家参考.具体如下: //定义图像二值化函数 private static Bitmap PBinary(Bitmap src,int v) { int w = src.Width; int h = src.Height; Bitmap dstBitmap = new Bitmap(src.Width ,src.Height ,System .Drawing .Imaging .PixelFormat .Format24

  • python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别

    前言 写爬虫有一个绕不过去的问题就是验证码,现在验证码分类大概有4种: 图像类 滑动类 点击类 语音类 今天先来看看图像类,这类验证码大多是数字.字母的组合,国内也有使用汉字的.在这个基础上增加噪点.干扰线.变形.重叠.不同字体颜色等方法来增加识别难度. 相应的,验证码识别大体可以分为下面几个步骤: 灰度处理 增加对比度(可选) 二值化 降噪 倾斜校正分割字符 建立训练库 识别 由于是实验性质的,文中用到的验证码均为程序生成而不是批量下载真实的网站验证码,这样做的好处就是可以有大量的知道明确结果

  • python opencv 二值化 计算白色像素点的实例

    贴部分代码 #! /usr/bin/env python # -*- coding: utf-8 -*- import cv2 import numpy as np from PIL import Image area = 0 def ostu(img): global area image=cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 转灰度 blur = cv2.GaussianBlur(image,(5,5),0) # 阈值一定要设为 0 !高斯模糊 re

  • Android实现图像灰度化、线性灰度变化和二值化处理方法

    1.图像灰度化: public Bitmap bitmap2Gray(Bitmap bmSrc) { // 得到图片的长和宽 int width = bmSrc.getWidth(); int height = bmSrc.getHeight(); // 创建目标灰度图像 Bitmap bmpGray = null; bmpGray = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); // 创建画布 Canvas c = ne

随机推荐