50行代码理解生成对抗网络(GAN)

本文用简单的 pytorch 代码实现一个生成和识别简单 sine 波的生成对抗网络(Generative Adversarial Networks,简称GAN) 。为了便于理解,我尽可能省去了不必要的代码。在阅读本文之前,读者需要对 pytorch 有一定的了解。读者读完这个简单例子,可以学会如何自己构建一个简单的GAN。具体代码如下(运行环境:python 2.7,jupyter notebook)。

首先,我们导入必要的模块:

%matplotlib inline
import torch
import numpy as np
import pylab as pl
from torch import nn
torch.set_default_tensor_type(torch.FloatTensor) # 设置pytorch默认数据类型

接下来,我们定义一个函数,用来生成真实的sine波的数据(此处我们考虑固定周期+随机相位,size 是每个样本数据点数目,每运行一次会生成一个batch的训练数据):

def true_data_gen(size=20, batch_size=5):
    return torch.from_numpy(np.sin(2*np.pi*np.linspace(0,1,size).reshape(1,-1).repeat(batch_size,0)+2*np.pi*np.random.random(batch_size).reshape(batch_size,-1))).float()

如下是其随机生成的一个batch的图样(batch_size=5,因此有5条;绘图代码:pl.plot(true_data_gen().numpy().T,'.-')):

samples
图1. 随机生成的5个真实 sine 波样本

接下来,我们定义两个神经网络,一个是生成器(Generator),一个是判别器(Discriminator)。代码浅显易懂,读者若了解pytorch,应该能轻易看懂,因此不做过多解释,代码如下:

class Generator(nn.Module):
    def __init__(self, in_size=5, out_size=20, n_hidden=3, hidden_size=20):
        super(Generator, self).__init__()
        self.input = nn.Linear(in_size,hidden_size)
        self.hidden = nn.Linear(hidden_size,hidden_size)
        self.output = nn.Linear(hidden_size,out_size)
        self.n_hidden = n_hidden
        self.in_size = in_size
    def forward(self, x):
        x = torch.tanh(self.input(x))
        for n in range(self.n_hidden):
            x = torch.tanh(self.hidden(x))
        x = torch.tanh(self.output(x))
        return x

class  Discriminator(nn.Module):
    def __init__(self, in_size=20, out_size=1, n_hidden=3, hidden_size=20):
        super(Discriminator, self).__init__()
        self.input = nn.Linear(in_size,hidden_size)
        self.hidden = nn.Linear(hidden_size,hidden_size)
        self.output = nn.Linear(hidden_size,out_size)
        self.n_hidden = n_hidden
    def forward(self, x):
        x = torch.tanh(self.input(x))
        for n in range(self.n_hidden):
            x = torch.tanh(self.hidden(x))
        x = torch.sigmoid(self.output(x))
        return x

下一步,我们初始化并训练一个具体的模型:

G = Generator(in_size=5, out_size=20, n_hidden=3, hidden_size=20)
D = Discriminator(in_size=20, out_size=1, n_hidden=3, hidden_size=20)

# 用还没训练好的生成网络(Generator)生成一个batch的伪造数据。
def fake_data_gen(size=20, batch_size=5, in_size=5, detach=True):
    if detach:
        return G(torch.randn(batch_size, in_size)).detach()
    else:
        return G(torch.randn(batch_size, in_size))

# 训练模型
## 定义基本参数
d_learning_rate = 1e-3
g_learning_rate = 1e-3
optim_betas = (0.9, 0.999)
batch_size = 10
num_epochs = 1001
disp_interval = 200
d_steps = 10
g_steps = 5  # 实践中一般 g_steps < d_steps 训练效果比较好

## 定义一个损失函数
criterion = nn.BCELoss()  # Binary cross entropy,参考 http://pytorch.org/docs/nn.html#bceloss

## 定义两个优化器,一个是生成器的,一个是判别器的
d_optimizer = torch.optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = torch.optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

fig = pl.figure(1,figsize=(6,4))

## 开始训练,每一步交替训练判别器和生成器
for epoch in range(num_epochs):
    # 1. 训练判别器
    for d in range(d_steps):
        D.zero_grad()
        #正样本
        real_sample = true_data_gen(batch_size=batch_size, size=20)
        d_real_err = criterion(D(real_sample),torch.ones(batch_size, 1)) # 1 代表判别结果是真实数据
        d_real_err.backward()
        #负样本
        fake_sample = fake_data_gen(batch_size=batch_size, size=20)
        d_fake_err = criterion(D(fake_sample),torch.zeros(batch_size, 1)) # 0 代表判别结果是有Generator生成的虚假数据
        d_fake_err.backward()
        #更新判别器参数的梯度
        d_optimizer.step()
    # 2. 训练生成器
    for g in range(g_steps):
        G.zero_grad()
        fake_sample = fake_data_gen(batch_size=batch_size, size=20,detach=False)
        g_err = criterion(D(fake_sample),torch.ones(batch_size, 1)) # 此处假定判别器已经训练好,然后将生成器向真实数据方向进行拟合/优化
        g_err.backward()
        g_optimizer.step() # 更新生成器参数的梯度
    if epoch % disp_interval == 0:
        print(r'epoch %d' % epoch)
        pl.plot(G(torch.randn(G.in_size)).detach().numpy(),'.-',label=('epoch %d' % epoch))
        pl.legend(loc=&#039;best&#039;, prop={&#039;size&#039;:10})<span id="mce_SELREST_start" style="overflow:hidden;line-height:0;"></span>

以下是运行结果:

training
图2. 生成器随机生成的样本随着训练步数的增加,越来越趋近真实数据(sine 波)

以下是训练1000步后由生成网络随机生成(代码:pl.plot(G(torch.randn(batch_size, G.in_size)).detach().numpy().T,'.-'))的的样本:

generated
图3. 训练1000步后,生成器随机生成的10个 sine 波样本

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s