Pytorch实现WGAN用于动漫头像生成

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset

batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'

# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)

def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out

dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

class Generator(nn.Module):
  def __init__(self):
    super().__init__()

    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )

  def forward(self, x):
    x = self.gen(x)
    return x

  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32

      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16

      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8

      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4

      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )

  def forward(self, x):
    x = self.dis(x)
    return x

  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)

def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")

import QuickModelBuilder as builder

if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one

  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()

  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()

  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)

  fake_img = None

  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假

      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)

      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)

      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数

      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)

      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False

      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0

      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数

      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch 保存模型生成图片方式

    三通道数组转成彩色图片 img=np.array(img1) img=img.reshape(3,img1.shape[2],img1.shape[3]) img=(img+0.5)*255##img做过归一化处理,[-0.5,0.5] img_path='/home/isee/wei/image/imageset/result.jpg' img=cv2.merge(img) cv2.imwrite(img_path,img) 单通道数组转化成灰度图 Img_mask=np.array(img_

  • Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

    CGAN的全拼是Conditional Generative Adversarial Networks,条件生成对抗网络,在初始GAN的基础上增加了图片的相应信息. 这里用传统的卷积方式实现CGAN. import torch from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms from torch import opti

  • pytorch::Dataloader中的迭代器和生成器应用详解

    在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader. 为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器. 这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念. 本文的内容主要有: 解释python中的迭代器和生成器概念 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数

  • pytorch GAN生成对抗网络实例

    我就废话不多说了,直接上代码吧! import torch import torch.nn as nn from torch.autograd import Variable import numpy as np import matplotlib.pyplot as plt torch.manual_seed(1) np.random.seed(1) BATCH_SIZE = 64 LR_G = 0.0001 LR_D = 0.0001 N_IDEAS = 5 ART_COMPONENTS =

  • Pytorch实现基于CharRNN的文本分类与生成示例

    1 简介 本篇主要介绍使用pytorch实现基于CharRNN来进行文本分类与内容生成所需要的相关知识,并最终给出完整的实现代码. 2 相关API的说明 pytorch框架中每种网络模型都有构造函数,在构造函数中定义模型的静态参数,这些参数将对模型所包含weights参数的维度进行设置.在运行时,模型的实例将接收动态的tensor数据并调用forword,在得到模型输出之后便可以和真实的标签数据进行误差计算,并通过优化器进行反向传播以调整模型的参数.下面重点介绍NLP常用到的模型和相关方法. 2

  • PyTorch 随机数生成占用 CPU 过高的解决方法

    PyTorch 随机数生成占用 CPU 过高的问题 今天在使用 pytorch 的过程中,发现 CPU 占用率过高.经过检查,发现是因为先在 CPU 中生成了随机数,然后再调用.to(device)传到 GPU,这样导致效率变得很低,并且CPU 和 GPU 都被消耗. 查阅PyTorch文档后发现,torch.randn(shape, out)可以直接在GPU中生成随机数,只要shape是tensor.cuda.Tensor类型即可.这样,就可以避免在 CPU 中生成过大的矩阵,而 shape

  • Pytorch实现WGAN用于动漫头像生成

    WGAN与GAN的不同 去除sigmoid 使用具有动量的优化方法,比如使用RMSProp 要对Discriminator的权重做修整限制以确保lipschitz连续约 WGAN实战卷积生成动漫头像 import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.utils import s

  • Python人工智能学习PyTorch实现WGAN示例详解

    目录 1.GAN简述 2.生成器模块 3.判别器模块 4.数据生成模块 5.判别器训练 6.生成器训练 7.结果可视化 1.GAN简述 在GAN中,有两个模型,一个是生成模型,用于生成样本,一个是判别模型,用于判断样本是真还是假.但由于在GAN中,使用的JS散度去计算损失值,很容易导致梯度弥散的情况,从而无法进行梯度下降更新参数,于是在WGAN中,引入了Wasserstein Distance,使得训练变得稳定.本文中我们以服从高斯分布的数据作为样本. 2.生成器模块 这里从2维数据,最终生成2

  • Python 将 QQ 好友头像生成祝福语的实现代码

    本文我们来看一下如何使用 Python 将 QQ 好友头像拼成"五一快乐"四个字.我们可以将整个实现过程分为两步:爬取 QQ 好友头像.利用好友头像生成文字. 爬取头像 爬取 QQ 好友头像我们需要借助于 QQ 邮箱,首先我们从浏览器上登录 QQ 邮箱,之后按 F12 键打开开发者工具并用鼠标选中 Network 选项,如下图所示: 再接着我们按 F5 键刷新一下网页,然后在 Filter 中输入 laddr_lastlist ,如下图所示: 我们再点 Name 下的链接,点击之后右侧

  • Pytorch保存模型用于测试和用于继续训练的区别详解

    保存模型 保存模型仅仅是为了测试的时候,只需要 torch.save(model.state_dict, path) path 为保存的路径 但是有时候模型及数据太多,难以一次性训练完的时候,而且用的还是 Adam优化器的时候, 一定要保存好训练的优化器参数以及epoch state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch } torch.save(state, pat

  • .NET Core利用skiasharp文字头像生成方法教程(基于docker发布)

    一.问题背景 目前.NET Core下面针对于图像处理的库微软并没有集成,在.NET FrameWork下我们已经习惯使用System.Drawing类库做简单的图像处理,到了.NET Core下一脸懵逼的我,只能百度+谷歌看看有没啥解决方案,好在网上资料也多,.NET Core下的图像处理还是有些开源库的,我目前使用的其中一个:SkiaSharp,介绍反正大家自己网上找找都有,下面就用该库实现一个文字头像的小功能,话不多说了,来一起看看详细的介绍吧. 二.简单的设计要求 对于输入的名字得解析(

  • Python生成截图选餐GIF动画

    目录 python生成文字动图 下载表情图片到本地 分析动图 生成单张图片 爬取菜品数据 生成菜名动图 PIL操作gif的其他操作 Gif拆分 GIF倒放 之前群里有小伙伴问今天中午该吃什么,然后另一位小伙伴发了一张下面的动图: 我个人觉得还挺有意思的,截图还真像抽奖一样随机选一个菜名.考虑到这张动图中的菜名候选并不见得都是我们能够吃的菜.我们可以用python根据菜名列表生成这样的动图玩玩. 之前还看到什么截图选头像之类的动图,那类通过图片生成的动图都比较简单,通过文中提到的Imagine的动

  • pytorch下tensorboard的使用程序示例

    目录 一.tensorboard程序实例: 1.代码 2.在命令提示符中操作 3.在浏览器中打开网址 4.效果 二.writer.add_scalar()与writer.add_scalars()参数说明 1.概述 2.参数说明 3.writer.add_scalar()效果 4.writer.add_scalars()效果 我们都知道tensorflow框架可以使用tensorboard这一高级的可视化的工具,为了使用tensorboard这一套完美的可视化工具,未免可以将其应用到Pytorc

  • 基于Python实现在线二维码生成工具

    目录 1.环境搭建 2.二维码生成功能的封装 3.网页应用的搭建 在今天的教程中,费老师我将为大家展示如何通过纯Python编程的方式,开发出一个网页应用,从而帮助用户直接通过浏览器访问,即可基于输入的网址等文字内容,完成常规二维码.静态底图二维码以及动图底图二维码的快捷生成,先来看一看应用的主要功能操作演示: 只写Python开发这样精致的工具应用非常简单,下面我来带大家从搭建环境开始,学习整个过程: 1.环境搭建 首先我们来创建应用的虚拟开发环境,建议使用Conda,命令如下: 创建虚拟环境

  • Java OpenSSL生成的RSA公私钥进行数据加解密详细介绍

    Java中使用OpenSSL生成的RSA公私钥进行数据加解密 RSA是什么:RSA公钥加密算法是1977年由Ron Rivest.Adi Shamirh和LenAdleman在(美国麻省理工学院)开发的.RSA取名来自开发他们三者的名字.RSA是目前最有影响力的公钥加密算法,它能够抵抗到目前为止已知的所有密码攻击,已被ISO推荐为公钥数据加密标准.目前该加密方式广泛用于网上银行.数字签名等场合.RSA算法基于一个十分简单的数论事实:将两个大素数相乘十分容易,但那时想要对其乘积进行因式分解却极其困

随机推荐