🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎

📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃

🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​

📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】

 🖍foreword

✔说明⇢本人讲解主要包括Python、机器学习(ML)、深度学习(DL)、自然语言处理(NLP)等内容。

如果你对这个系列感兴趣的话,可以关注订阅哟👋

生成对抗网络 GAN 的基本原理

说到GAN第一篇要看的paper当然是Ian Goodfellow大牛的Generative Adversarial Networks(arxiv:https://arxiv.org/abs/1406.2661),这篇paper算是这个领域的开山之作。

GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

  • G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。
  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。

以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:

简单分析一下这个公式:

  • 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
  • D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。
  • G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G。
  • D的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此式子对于D来说是求最大(max_D)

下面这幅图片很好地描述了这个过程:

那么如何用随机梯度下降法训练D和G?论文中也给出了算法:

这里红框圈出的部分是我们要额外注意的。第一步我们训练D,D是希望V(G, D)越大越好,所以是加上梯度(ascending)。第二步训练G时,V(G, D)越小越好,所以是减去梯度(descending)。整个训练过程交替进行。

生成对抗网络Pytorch的实现


 
 
  1. import os
  2. import torch
  3. import torchvision
  4. import torch.nn as nn
  5. from torchvision import transforms
  6. from torchvision.utils import save_image
  7. # 设备配置
  8. device = torch.device( 'cuda' if torch.cuda. is_available() else 'cpu')
  9. # 超参数
  10. latent_ size = 64
  11. hidden_ size = 256
  12. image_ size = 784
  13. num_epochs = 200
  14. batch_ size = 100
  15. sample_dir = 'samples'
  16. # 如果不存在则创建目录
  17. if not os.path.exists(sample_dir):
  18. os.makedirs(sample_dir)
  19. # 图像处理
  20. # transform = transforms.Compose([
  21. # transforms.ToTensor(),
  22. # transforms.Normalize(mean =( 0.5, 0.5, 0.5), # 3 for RGB channels
  23. # std =( 0.5, 0.5, 0.5))])
  24. transform = transforms.Compose([
  25. transforms.ToTensor(),
  26. transforms.Normalize(mean =[ 0.5], # 1 for greyscale channels
  27. std =[ 0.5])])
  28. # MNIST 数据集
  29. mnist = torchvision.datasets.MNIST(root = '../../data/',
  30. train = True,
  31. transform =transform,
  32. download = True)
  33. # 数据加载器
  34. data_loader = torch.utils. data.DataLoader(dataset =mnist,
  35. batch_ size =batch_ size,
  36. shuffle = True)
  37. # 鉴别器
  38. D = nn. Sequential(
  39. nn.Linear(image_ size, hidden_ size),
  40. nn.LeakyReLU( 0.2),
  41. nn.Linear(hidden_ size, hidden_ size),
  42. nn.LeakyReLU( 0.2),
  43. nn.Linear(hidden_ size, 1),
  44. nn.Sigmoid())
  45. # 生成器
  46. G = nn. Sequential(
  47. nn.Linear(latent_ size, hidden_ size),
  48. nn.ReLU(),
  49. nn.Linear(hidden_ size, hidden_ size),
  50. nn.ReLU(),
  51. nn.Linear(hidden_ size, image_ size),
  52. nn.Tanh())
  53. # 设备设置
  54. D = D. to(device)
  55. G = G. to(device)
  56. # 二元交叉熵损失和优化器
  57. criterion = nn.BCELoss()
  58. d_optimizer = torch.optim.Adam(D.parameters(), lr = 0.0002)
  59. g_optimizer = torch.optim.Adam(G.parameters(), lr = 0.0002)
  60. def denorm(x):
  61. out = (x + 1) / 2
  62. return out.clamp( 0, 1)
  63. def reset_grad():
  64. d_optimizer. zero_grad()
  65. g_optimizer. zero_grad()
  66. # 开始训练
  67. total_step = len( data_loader)
  68. for epoch in range(num_epochs):
  69. for i, (images, _) in enumerate( data_loader):
  70. images = images.reshape(batch_ size, - 1). to(device)
  71. # 创建稍后用作 BCE 损失输入的标签
  72. real_labels = torch.ones(batch_ size, 1). to(device)
  73. fake_labels = torch. zeros(batch_ size, 1). to(device)
  74. # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
  75. # 训练判别器 #
  76. # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
  77. # 使用真实图像计算 BCE_Loss 其中 BCE_Loss(x, y): - y * log(D(x)) - ( 1-y) * log( 1 - D(x))
  78. # 损失的第二项总是为零,因为 real_labels = = 1
  79. outputs = D(images)
  80. d_loss_real = criterion(outputs, real_labels)
  81. real_score = outputs
  82. # 使用假图像计算 BCELoss
  83. # 损失的第一项总是为零,因为 fake_labels = = 0
  84. z = torch.randn(batch_ size, latent_ size). to(device)
  85. fake_images = G(z)
  86. outputs = D(fake_images)
  87. d_loss_fake = criterion(outputs, fake_labels)
  88. fake_score = outputs
  89. # 反向传播和优化
  90. d_loss = d_loss_real + d_loss_fake
  91. reset_grad()
  92. d_loss.backward()
  93. d_optimizer.step()
  94. # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
  95. # 训练生成器 #
  96. # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
  97. # 用假图像计算损失
  98. z = torch.randn(batch_ size, latent_ size). to(device)
  99. fake_images = G(z)
  100. outputs = D(fake_images)
  101. # 我们训练 G 最大化 log(D(G(z)) 而不是最小化 log( 1-D(G(z)))
  102. # 原因见第 3节最后一段。 https: / /arxiv.org /pdf / 1406.2661.pdf
  103. g_loss = criterion(outputs, real_labels)
  104. # 反向传播和优化
  105. reset_grad()
  106. g_loss.backward()
  107. g_optimizer.step()
  108. if (i + 1) % 200 = = 0:
  109. print( 'Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
  110. . format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
  111. real_score.mean().item(), fake_score.mean().item()))
  112. # 保存真实图片
  113. if (epoch + 1) = = 1:
  114. images = images.reshape(images. size( 0), 1, 28, 28)
  115. save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
  116. # 保存采样图像
  117. fake_images = fake_images.reshape(fake_images. size( 0), 1, 28, 28)
  118. save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'. format(epoch + 1)))
  119. # 保存模型checkpoints
  120. torch.save(G.state_dict(), 'G.ckpt')
  121. torch.save(D.state_dict(), 'D.ckpt')
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐