2020/05/13

# 使用Jittor实现Conditional GAN

## 损失函数

### GAN的损失函数

D和G以这样的方式联合训练，最终达到G的生成能力越来越强，D的判别能力越来越强的目的。

## Jittor代码数字生成

``````import jittor as jt
from jittor import nn
import numpy as np
import pylab as pl

%matplotlib inline

# 隐空间向量长度
latent_dim = 100
# 类别数量
n_classes = 10
# 图片大小
img_size = 32
# 图片通道数量
channels = 1
# 图片张量的形状
img_shape = (channels, img_size, img_size)
``````

``````class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(n_classes, n_classes)

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(
*block((latent_dim + n_classes), 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh())

def execute(self, noise, labels):
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
img = self.model(gen_input)
img = img.view((img.shape[0], *img_shape))
return img
``````

``````class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear((n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 1))

def execute(self, img, labels):
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
validity = self.model(d_in)
return validity
``````

``````# 下载提供的预训练参数
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl
``````

``````# 定义模型
generator = Generator()
discriminator = Discriminator()
generator.eval()
discriminator.eval()

# 加载参数

# 定义一串数字
number = "201962517"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)

pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)))
``````

## 从头训练Condition GAN

``````!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py
!python3.7 ./cgan.py --help

# 选择合适的batch size，运行试试
# 运行命令： !python3.7 ./cgan.py --batch_size 8
``````

``````# 此段代码仅仅用于解释意图，不能运行，需要运行请运行完整文件cgan.py
# Define Loss

# Define Model
generator = Generator()
discriminator = Discriminator()

from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(opt.img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])

optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
``````

``````# 此段代码仅仅用于解释意图，不能运行，需要运行请运行完整文件cgan.py
# valid表示真，fake表示假

# 真实图像和对应的标签
real_imgs = jt.array(imgs)
labels = jt.array(labels)

#########################################################
#   训练生成器G
#       - 希望生成的图片尽可能地让D觉得是valid
#########################################################

# 随机向量z和随机生成的标签
z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()

# 随机向量z和随机生成的标签经过生成器G生成的图片，希望判别器能够认为生成的图片和生成的标签是一致的，以此优化生成器G的生成能力。
gen_imgs = generator(z, gen_labels)
validity = discriminator(gen_imgs, gen_labels)
g_loss.sync()
optimizer_G.step(g_loss)

#########################################################
#   训练判别器D
#       - 尽可能识别real_imgs为valid
#       - 尽可能识别gen_imgs为fake
#########################################################

# 真实的图片和标签经过判别器的结果，要尽可能接近valid。
validity_real = discriminator(real_imgs, labels)

# G生成的图片和对应的标签经过判别器的结果，要尽可能接近fake。

d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
optimizer_D.step(d_loss)
``````

## 参考文献

[1] Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in neural information processing systems. 2014.

[2] Mirza, Mehdi, and Simon Osindero. “Conditional generative adversarial nets.” arXiv preprint arXiv:1411.1784 (2014).