pytorch GAN伪造手写体mnist数据集方式

一,mnist数据集

形如上图的数字手写体就是mnist数据集。

二,GAN原理(生成对抗网络)

GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)

一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:

在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。

三,训练代码

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda = True if torch.cuda.is_available() else False

class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()

  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers

  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
   nn.Tanh()
  )

 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img

class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()

  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
   nn.Sigmoid(),
  )

 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
 generator.cuda()
 discriminator.cuda()
 adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
 datasets.MNIST(
  "../../data/mnist",
  train=True,
  download=True,
  transform=transforms.Compose(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  ),
 ),
 batch_size=opt.batch_size,
 shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 全1
   fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 全0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))

   # -----------------
   # Train Generator
   # -----------------

   optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度

   # Sample noise as generator input
   z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)

   g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
   optimizer_G.step()

   # ---------------------
   # Train Discriminator
   # ---------------------

   optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2

   d_loss.backward() # d_loss用于更新D网络的权值
   optimizer_D.step()

   print(
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   )

   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存一个batchsize中的25张
   if (epoch+1) %2 ==0:
    print('save..')
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

运行结果:

一开始时,G生成的全是杂音:

然后逐渐呈现数字的雏形:

最后一次生成的结果:

四,测试代码:

导入最后保存生成器的模型:

from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g = torch.load('g199.pth') #导入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)

z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #输入的噪音
gen_imgs =g(z) #生产图片
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

生成结果:

以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

    CGAN的全拼是Conditional Generative Adversarial Networks,条件生成对抗网络,在初始GAN的基础上增加了图片的相应信息. 这里用传统的卷积方式实现CGAN. import torch from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms from torch import opti

  • pytorch:实现简单的GAN示例(MNIST数据集)

    我就废话不多说了,直接上代码吧! # -*- coding: utf-8 -*- """ Created on Sat Oct 13 10:22:45 2018 @author: www """ import torch from torch import nn from torch.autograd import Variable import torchvision.transforms as tfs from torch.utils.dat

  • MNIST数据集转化为二维图片的实现示例

    本文介绍了MNIST数据集转化为二维图片的实现示例,分享给大家,具体如下: #coding: utf-8 from tensorflow.examples.tutorials.mnist import input_data import scipy.misc import os # 读取MNIST数据集.如果不存在会事先下载. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 我们把原始图片保存在MNIST

  • 详解PyTorch手写数字识别(MNIST数据集)

    MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程.虽然网上的案例比较多,但还是要自己实现一遍.代码采用 PyTorch 1.0 编写并运行. 导入相关库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, t

  • 使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式

    简介 这是深度学习课程的第一个实验,主要目的就是熟悉 Pytorch 框架.MLP 是多层感知器,我这次实现的是四层感知器,代码和思路参考了网上的很多文章.个人认为,感知器的代码大同小异,尤其是用 Pytorch 实现,除了层数和参数外,代码都很相似. Pytorch 写神经网络的主要步骤主要有以下几步: 1 构建网络结构 2 加载数据集 3 训练神经网络(包括优化器的选择和 Loss 的计算) 4 测试神经网络 下面将从这四个方面介绍 Pytorch 搭建 MLP 的过程. 项目代码地址:la

  • 用Pytorch训练CNN(数据集MNIST,使用GPU的方法)

    听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便. 参考代码(在莫烦python的教程代码基础上修改)如下: import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import tor

  • Pytorch 神经网络—自定义数据集上实现教程

    第一步.导入需要的包 import os import scipy.io as sio import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms, ut

  • pytorch GAN伪造手写体mnist数据集方式

    一,mnist数据集 形如上图的数字手写体就是mnist数据集. 二,GAN原理(生成对抗网络) GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D) 一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的.D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值.因此G和D都在不停地更新权值.以下图为例

  • 使用tensorflow实现VGG网络,训练mnist数据集方式

    VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行. 先介绍下VGG ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet.它主要的贡献是展示出网络的深度是算法优良性能的关键部分. 他们最好的网络包含了16个卷积/全连接层.网络的结构

  • Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

    原始生成对抗网络Generative Adversarial Networks GAN包含生成器Generator和判别器Discriminator,数据有真实数据groundtruth,还有需要网络生成的"fake"数据,目的是网络生成的fake数据可以"骗过"判别器,让判别器认不出来,就是让判别器分不清进入的数据是真实数据还是fake数据.总的来说是:判别器区分真实数据和fake数据的能力越强越好:生成器生成的数据骗过判别器的能力越强越好,这个是矛盾的,所以只能

  • 关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu18.04 显卡:GTX1080ti python版本:2.7(3.7) 网络架构 具有4层的CNN具有以下架构. 输入层:784个节点(MNIST图像大小) 第一卷积层:5x5x32 第一个最大池层 第二卷积层:5x5x64 第二个最大池层 第三个完全连接层:1024个节点 输出层:10个节点(M

  • 详解如何从TensorFlow的mnist数据集导出手写体数字图片

    在TensorFlow的官方入门课程中,多次用到mnist数据集. mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件. 如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片.了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解. 下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代

  • pytorch 把MNIST数据集转换成图片和txt的方法

    本文介绍了pytorch 把MNIST数据集转换成图片和txt的方法,分享给大家,具体如下: 1.下载Mnist 数据集 import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # t

随机推荐