Pytorch学习笔记DCGAN极简入门教程

目录
  • 1.图片分类网络
  • 2.图片生成网络
    • 首先是图片分类网络:
    • 重点是生成网络
  • 每一个step分为三个步骤:

1.图片分类网络

这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片

2.图片生成网络

输入是一个随机噪声,输出是一张图片,使用的是反卷积层

相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了

首先是图片分类网络:

简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)

重点是生成网络

代码如下,其实就是反卷积+bn+relu

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
    def forward(self, input):
        return self.main(input)

讲道理,以上两个网络都挺简单。

真正的重点到了,怎么训练

每一个step分为三个步骤:

  • 训练二分类网络
    1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
    2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络
  • 训练生成网络
    3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络

不多说直接上代码

for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

以上就是Pytorch学习笔记DCGAN极简入门教程的详细内容,更多关于Pytorch学习DCGAN入门教程的资料请关注我们其它相关文章!

(0)

相关推荐

  • Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

    原始生成对抗网络Generative Adversarial Networks GAN包含生成器Generator和判别器Discriminator,数据有真实数据groundtruth,还有需要网络生成的"fake"数据,目的是网络生成的fake数据可以"骗过"判别器,让判别器认不出来,就是让判别器分不清进入的数据是真实数据还是fake数据.总的来说是:判别器区分真实数据和fake数据的能力越强越好:生成器生成的数据骗过判别器的能力越强越好,这个是矛盾的,所以只能

  • PyTorch安装与基本使用详解

    什么要学习PyTorch? 有的人总是选择,选择的人最多的框架,来作为自己的初学框架,比如Tensorflow,但是大多论文的实现都是基于PyTorch的,如果我们要深入论文的细节,就必须选择学习入门PyTorch 安装PyTorch 一行命令即可 官网 pip install torch===1.6.0 torchvision===0.7.0 - https://download.pytorch.org/whl/torch_stable.html 时间较久,耐心等待 测试自己是否安装成功 运行

  • 使用Pytorch搭建模型的步骤

    本来是只用Tenorflow的,但是因为TF有些Numpy特性并不支持,比如对数组使用列表进行切片,所以只能转战Pytorch了(pytorch是支持的).还好Pytorch比较容易上手,几乎完美复制了Numpy的特性(但还有一些特性不支持),怪不得热度上升得这么快. 1  模型定义 和TF很像,Pytorch也通过继承父类来搭建自定义模型,同样也是实现两个方法.在TF中是__init__()和call(),在Pytorch中则是__init__()和forward().功能类似,都分别是初始化

  • Pytorch学习笔记DCGAN极简入门教程

    目录 1.图片分类网络 2.图片生成网络 首先是图片分类网络: 重点是生成网络 每一个step分为三个步骤: 1.图片分类网络 这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片 2.图片生成网络 输入是一个随机噪声,输出是一张图片,使用的是反卷积层 相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了 首先是图片分类网络: 简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络

  • JavaScript极简入门教程(一):基础篇

    阅读本文需要有其他语言的编程经验. 开始学习之前 大多数的编程语言都存在好的部分和差的部分.本文只讲述 JavaScript 中好的部分,这是因为: 1.仅仅学习好的部分能够缩短学习时间 2.编写的代码更加健壮 3.编写的代码更加易读 4.编写的代码更加易于维护 弱类型和强类型 通常来说,越早的修复错误,为之付出的代价就越小.强类型语言的编译器可以在编译时检查某些错误.而 JavaScript 是一门弱类型语言,其解释器无法检查类型错误,但实践表明: 1.强类型能够避免的错误并不是那些关键性错误

  • Golang极简入门教程(一):基本概念

    安装 Golang 在 http://golang.org/dl/ 可以下载到 Golang.安装文档:http://golang.org/doc/install. Hello Go 我们先创建一个文件 hello.go: 复制代码 代码如下: package main   import "fmt"   func main() {     fmt.Printf("hello Golang\n"); } 执行此程序: 复制代码 代码如下: go run hello.g

  • Golang极简入门教程(四):编写第一个项目

    workspace Golang 的代码必须放置在一个 workspace 中.一个 workspace 是一个目录,此目录中包含几个子目录: 1.src 目录.包含源文件,源文件被组织为包(一个目录一个包) 2.pkg 目录.包含包对象(package objects) 3.bin 目录.包含可执行的命令 包源文件(package source)被编译为包对象(package object),命令源文件(command source)被编译为可执行命令(command executable).

  • Golang极简入门教程(二):方法和接口

    方法 在 Golang 中没有类,不过我们可以为结构体定义方法.我们看一个例子: 复制代码 代码如下: package main   import (     "fmt"     "math" )   type Vertex struct {     X, Y float64 }   // 结构体 Vertex 的方法 // 这里的方法接收者(method receiver)v 的类型为 *Vertex func (v *Vertex) Abs() float64

  • Nodejs极简入门教程(三):进程

    Node 虽然自身存在多个线程,但是运行在 v8 上的 JavaScript 是单线程的.Node 的 child_process 模块用于创建子进程,我们可以通过子进程充分利用 CPU.范例: 复制代码 代码如下: var fork = require('child_process').fork; // 获取当前机器的 CPU 数量 var cpus = require('os').cpus(); for (var i = 0; i < cpus.length; i++) {     // 生

  • Nodejs极简入门教程(二):定时器

    setTimeout 和 clearTimeout 复制代码 代码如下: var obj = setTimeout(cb, ms); setTimeout 用于设置一个回调函数 cb,其在最少 ms 毫秒后被执行(并非在 ms 毫秒后马上执行).setTimeout 返回值可以作为 clearTimeout 的参数,clearTimeout 用于停止定时器,这样回调函数就不会被执行了. setInterval 和 clearInterval 复制代码 代码如下: var obj = setInt

  • Nodejs极简入门教程(一):模块机制

    JavaScript 规范(ECMAScript)没有定义一套完善的能适用于大多数程序的标准库.CommonJS 提供了一套 JavaScript 标准库规范.Node 实现了 CommonJS 规范. 模块基础 在 Node 中,模块和文件是一一对应的.我们定义一个模块: 复制代码 代码如下: // circle.js var PI = Math.PI;   // 导出函数 area exports.area = function(r) {     return PI * r * r; }  

  • Golang极简入门教程(三):并发支持

    Golang 运行时(runtime)管理了一种轻量级线程,被叫做 goroutine.创建数十万级的 goroutine 是没有问题的.范例: 复制代码 代码如下: package main   import (     "fmt"     "time" )   func say(s string) {     for i := 0; i < 5; i++ {         time.Sleep(100 * time.Millisecond)       

  • JavaScript极简入门教程(三):数组

    阅读本文需要有其他语言的编程经验. 在 JavaScript 中数组是对象(而非线性分配的内存). 通过数组 literal 来创建数组: 复制代码 代码如下: var empty = []; var numbers = [     'zero', 'one', 'two', 'three', 'four',     'five', 'six', 'seven', 'eight', 'nine' ]; empty[1] // undefined numbers[1] // 'one' empty

随机推荐