[2017AAAI]SeqGAN Sequence Generative Adversarial Nets with Policy Gradient
[2017AAAI]SeqGAN Sequence Generative Adversarial Nets with Policy Gradient
2852
low (Bachman and Precup 2015; Bahdanau et al. 2016) and to adjust the output continuously, which does not work on
consider the sequence generation procedure as a sequential discrete data generation (Goodfellow 2016).
decision making process. The generative model is treated On the other hand, a lot of efforts have been made to gen-
as an agent of reinforcement learning (RL); the state is the erate structured sequences. Recurrent neural networks can be
generated tokens so far and the action is the next token to trained to produce sequences of tokens in many applications
be generated. Unlike the work in (Bahdanau et al. 2016) that such as machine translation (Sutskever, Vinyals, and Le 2014;
requires a task-specific sequence score, such as BLEU in Bahdanau, Cho, and Bengio 2014). The most popular way of
machine translation, to give the reward, we employ a discrim- training RNNs is to maximize the likelihood of each token
inator to evaluate the sequence and feedback the evaluation in the training data whereas (Bengio et al. 2015) pointed
to guide the learning of the generative model. To solve the out that the discrepancy between training and generating
problem that the gradient cannot pass back to the generative makes the maximum likelihood estimation suboptimal and
model when the output is discrete, we regard the generative proposed scheduled sampling strategy (SS). Later (Huszár
model as a stochastic parametrized policy. In our policy gra- 2015) theorized that the objective function underneath SS is
dient, we employ Monte Carlo (MC) search to approximate improper and explained the reason why GANs tend to gen-
the state-action value. We directly train the policy (genera- erate natural-looking samples in theory. Consequently, the
tive model) via policy gradient (Sutton et al. 1999), which GANs have great potential but are not practically feasible to
naturally avoids the differentiation difficulty for discrete data discrete probabilistic models currently.
in a conventional GAN. As pointed out by (Bachman and Precup 2015), the se-
Extensive experiments based on synthetic and real data quence data generation can be formulated as a sequential
are conducted to investigate the efficacy and properties of decision making process, which can be potentially be solved
the proposed SeqGAN. In our synthetic data environment, by reinforcement learning techniques. Modeling the sequence
SeqGAN significantly outperforms the maximum likelihood generator as a policy of picking the next token, policy gradi-
methods, scheduled sampling and PG-BLEU. In three real- ent methods (Sutton et al. 1999) can be adopted to optimize
world tasks, i.e. poem generation, speech language generation the generator once there is an (implicit) reward function to
and music generation, SeqGAN significantly outperforms guide the policy. For most practical sequence generation
the compared baselines in various metrics including human tasks, e.g. machine translation (Sutskever, Vinyals, and Le
expert judgement. 2014), the reward signal is meaningful only for the entire
sequence, for instance in the game of Go (Silver et al. 2016),
Related Work the reward signal is only set at the end of the game. In those
cases, state-action evaluation methods such as Monte Carlo
Deep generative models have recently drawn significant at- (tree) search have been adopted (Browne et al. 2012). By
tention, and the ability of learning over large (unlabeled) data contract, our proposed SeqGAN extends GANs with the RL-
endows them with more potential and vitality (Salakhutdinov based generator to solve the sequence generation problem,
2009; Bengio et al. 2013). (Hinton, Osindero, and Teh 2006) where a reward signal is provided by the discriminator at the
first proposed to use the contrastive divergence algorithm end of each episode via Monte Carlo approach, and the gen-
to efficiently training deep belief nets (DBN). (Bengio et al. erator picks the action and learns the policy using estimated
2013) proposed denoising autoencoder (DAE) that learns the overall rewards.
data distribution in a supervised learning fashion. Both DBN
and DAE learn a low dimensional representation (encoding) Sequence Generative Adversarial Nets
for each data instance and generate it from a decoding net-
work. Recently, variational autoencoder (VAE) that combines The sequence generation problem is denoted as follows.
deep learning with statistical inference intended to represent Given a dataset of real-world structured sequences, train a
a data instance in a latent hidden space (Kingma and Welling θ-parameterized generative model Gθ to produce a sequence
2014), while still utilizing (deep) neural networks for non- Y1:T = (y1 , . . . , yt , . . . , yT ), yt ∈ Y, where Y is the vocabu-
linear mapping. The inference is done via variational methods. lary of candidate tokens. We interpret this problem based on
All these generative models are trained by maximizing (the reinforcement learning. In timestep t, the state s is the current
lower bound of) training data likelihood, which, as mentioned produced tokens (y1 , . . . , yt−1 ) and the action a is the next
by (Goodfellow and others 2014), suffers from the difficulty token yt to select. Thus the policy model Gθ (yt |Y1:t−1 ) is
of approximating intractable probabilistic computations. stochastic, whereas the state transition is deterministic af-
a
(Goodfellow and others 2014) proposed an alternative ter an action has been chosen, i.e. δs,s = 1 for the next
training methodology to generative models, i.e. GANs, where state s = Y1:t if the current state s = Y1:t−1 and the action
the training procedure is a minimax game between a gener- a = yt ; for other next states s , δs,s a
= 0.
ative model and a discriminative model. This framework Additionally, we also train a φ-parameterized discrimina-
bypasses the difficulty of maximum likelihood learning and tive model Dφ (Goodfellow and others 2014) to provide a
has gained striking successes in natural image generation guidance for improving generator Gθ . Dφ (Y1:T ) is a prob-
(Denton et al. 2015). However, little progress has been made ability indicating how likely a sequence Y1:T is from real
in applying GANs to sequence discrete data generation prob- sequence data or not. As illustrated in Figure 1, the discrimi-
lems, e.g. natural language generation (Huszár 2015). This is native model Dφ is trained by providing positive examples
due to the generator network in GAN is designed to be able from the real sequence data and negative examples from the
2853
n n
G Next
action
MC
search D where Y1:t = (y1 , . . . , yt ) and Yt+1:T is sampled based on
True data
Reward the roll-out policy Gβ and the current state. In our experi-
Real World
State
Reward ment, Gβ is set the same as the generator, but one can use
Train
D a simplified version if the speed is the priority (Silver et al.
Reward
G
Generate 2016). To reduce the variance and get more accurate assess-
Reward
ment of the action value, we run the roll-out policy starting
Policy Gradient from current state till the end of the sequence for N times to
get a batch of output samples. Thus, we have:
Figure 1: The illustration of SeqGAN. Left: D is trained over G
QDθ (s = Y1:t−1 , a = yt ) = (4)
the real data and the generated data by G. Right: G is trained
φ
1 N n ), n Gβ
by policy gradient where the final reward signal is provided N n=1 Dφ (Y1:T Y1:T ∈ MC (Y1:t ; N ) for t<T
by D and is passed back to the intermediate action value via Dφ (Y1:t ) for t = T,
Monte Carlo search.
where, we see that when no intermediate reward, the function
is iteratively defined as the next-state value starting from state
synthetic sequences generated from the generative model Gθ . s = Y1:t and rolling out to the end.
At the same time, the generative model Gθ is updated by em- A benefit of using the discriminator Dφ as a reward func-
ploying a policy gradient and MC search on the basis of the tion is that it can be dynamically updated to further improve
the generative model iteratively. Once we have a set of more
expected end reward received from the discriminative model realistic generated sequences, we shall re-train the discrimi-
Dφ . The reward is estimated by the likelihood that it would nator model as follows:
fool the discriminative model Dφ . The specific formulation
is given in the next subsection. min −EY ∼pdata [log Dφ (Y )] − EY ∼Gθ [log(1 − Dφ (Y ))]. (5)
φ
where RT is the reward for a complete sequence. Note that ∇θ J(θ) = EY1:t−1 ∼Gθ ∇θ Gθ (yt |Y1:t−1 ) · QG
Dφ (Y1:t−1 , yt ) .
θ
2854
Algorithm 1 Sequence Generative Adversarial Nets Short-Term Memory (LSTM) cells (Hochreiter and Schmid-
Require: generator policy Gθ ; roll-out policy Gβ ; discriminator huber 1997) to implement the update function g in Eq. (9). It
Dφ ; a sequence dataset S = {X1:T } is worth noticing that most of the RNN variants, such as the
1: Initialize Gθ , Dφ with random weights θ, φ. gated recurrent unit (GRU) (Cho et al. 2014) and soft atten-
2: Pre-train Gθ using MLE on S tion mechanism (Bahdanau, Cho, and Bengio 2014), can be
3: β ← θ used as a generator in SeqGAN.
4: Generate negative samples using Gθ for training Dφ
5: Pre-train Dφ via minimizing the cross entropy
6: repeat
The Discriminative Model for Sequences
7: for g-steps do Deep discriminative models such as deep neural network
8: Generate a sequence Y1:T = (y1 , . . . , yT ) ∼ Gθ (DNN) (Veselỳ et al. 2013), convolutional neural network
9: for t in 1 : T do (CNN) (Kim 2014) and recurrent convolutional neural net-
10: Compute Q(a = yt ; s = Y1:t−1 ) by Eq. (4) work (RCNN) (Lai et al. 2015) have shown a high perfor-
11: end for mance in complicated sequence classification tasks. In this
12: Update generator parameters via policy gradient Eq. (8)
paper, we choose the CNN as our discriminator as CNN
13: end for
14: for d-steps do has recently been shown of great effectiveness in text (to-
15: Use current Gθ to generate negative examples and com- ken sequence) classification (Zhang and LeCun 2015). Most
bine with given positive examples S discriminative models can only perform classification well
16: Train discriminator Dφ for k epochs by Eq. (5) for an entire sequence rather than the unfinished one. In this
17: end for paper, we also focus on the situation where the discriminator
18: β←θ predicts the probability that a finished sequence is real.2
19: until SeqGAN converges We first represent an input sequence x1 , . . . , xT as:
E1:T = x1 ⊕ x2 ⊕ . . . ⊕ xT , (11)
k
In summary, Algorithm 1 shows full details of the pro- where xt ∈ R is the k-dimensional token embedding and
posed SeqGAN. At the beginning of the training, we use ⊕ is the concatenation operator to build the matrix E1:T ∈
the maximum likelihood estimation (MLE) to pre-train Gθ RT ×k . Then a kernel w ∈ Rl×k applies a convolutional
on training set S. We found the supervised signal from the operation to a window size of l words to produce a new
pre-trained discriminator is informative to help adjust the feature map:
generator efficiently. ci = ρ(w ⊗ Ei:i+l−1 + b), (12)
After the pre-training, the generator and discriminator are where ⊗ operator is the summation of elementwise pro-
trained alternatively. As the generator gets progressed via duction, b is a bias term and ρ is a non-linear function.
training on g-steps updates, the discriminator needs to be re- We can use various numbers of kernels with different win-
trained periodically to keeps a good pace with the generator. dow sizes to extract different features. Finally we apply
When training the discriminator, positive examples are from a max-over-time pooling operation over the feature maps
the given dataset S, whereas negative examples are generated c̃ = max {c1 , . . . , cT −l+1 }.
from our generator. In order to keep the balance, the number To enhance the performance, we also add the highway ar-
of negative examples we generate for each d-step is the same chitecture (Srivastava, Greff, and Schmidhuber 2015) based
as the positive examples. And to reduce the variability of the on the pooled feature maps. Finally, a fully connected layer
estimation, we use different sets of negative samples com- with sigmoid activation is used to output the probability that
bined with positive ones, which is similar to bootstrapping the input sequence is real. The optimization target is to min-
(Quinlan 1996). imize the cross entropy between the ground truth label and
the predicted probability as formulated in Eq. (5).
The Generative Model for Sequences Detailed implementations of the generative and discrimi-
We use recurrent neural networks (RNNs) (Hochreiter and native models are provided in the supplementary material.
Schmidhuber 1997) as the generative model. An RNN
maps the input embedding representations x1 , . . . , xT of Synthetic Data Experiments
the sequence x1 , . . . , xT into a sequence of hidden states
h1 , . . . , hT by using the update function g recursively. To test the efficacy and add our understanding of SeqGAN,
we conduct a simulated test with synthetic data3 . To simulate
ht = g(ht−1 , xt ) (9) the real-world structured sequences, we consider a language
Moreover, a softmax output layer z maps the hidden states model to capture the dependency of the tokens. We use a
into the output token distribution randomly initialized LSTM as the true model, aka, the oracle,
to generate the real data distribution p(xt |x1 , . . . , xt−1 ) for
p(yt |x1 , . . . , xt ) = z(ht ) = softmax(c + V ht ), (10) the following experiments.
where the parameters are a bias vector c and a weight ma- 2
In our work, the generated sequence has a fixed length T , but
trix V . To deal with the common vanishing and exploding note that CNN is also capable of the variable-length sequence dis-
gradient problem (Goodfellow, Bengio, and Courville 2016) crimination with the max-over-time pooling technique (Kim 2014).
3
of the backpropagation through time, we leverage the Long Experiment code: https://github.com/LantaoYu/SeqGAN
2855
Evaluation Metric
Table 1: Sequence generation performance comparison. The
The benefit of having such oracle is that firstly, it provides
the training dataset and secondly evaluates the exact perfor- p-value is between SeqGAN and the baseline from T-test.
mance of the generative models, which will not be possible Algorithm Random MLE SS PG-BLEU SeqGAN
with real data. We know that MLE is trying to minimize the NLL 10.310 9.038 8.985 8.946 8.736
cross-entropy between the true data distribution p and our p-value < 10−6 < 10−6 < 10−6 < 10−6
approximation q, i.e. −Ex∼p log q(x). However, the most ac-
curate way of evaluating generative models is that we draw
some samples from it and let human observers review them
based on their prior knowledge. We assume that the human
observer has learned an accurate model of the natural distribu-
tion phuman (x). Then in order to increase the chance of pass-
ing Turing Test, we actually need to minimize the exact op-
posite average negative log-likelihood −Ex∼q log phuman (x)
(Huszár 2015), with the role of p and q exchanged. In our
synthetic data experiments, we can consider the oracle to be
the human observer for real-world problems, thus a perfect
evaluation metric should be Figure 2: Negative log-likelihood convergence w.r.t. the train-
T ing epochs. The vertical dashed line represents the end of
NLLoracle = −EY1:T ∼Gθ log Goracle (yt |Y1:t−1 ) , (13) pre-training for SeqGAN, SS and PG-BLEU.
t=1
2856
Table 2: Chinese poem generation performance comparison.
Algorithm Human score p-value BLEU-2 p-value
MLE 0.4165 0.6670
0.0034 < 10−6
SeqGAN 0.5356 0.7389
Real data 0.6011 0.746
(a) g-steps=100, d-steps=1, (b) g-steps=30, d-steps=1, Table 3: Obama political speech generation performance.
k=10 k=30 Algorithm BLEU-3 p-value BLEU-4 p-value
MLE 0.519 −6 0.416
< 10 0.00014
SeqGAN 0.556 0.427
2857
in midi file format. We study the solo track of each music. Goodfellow, I. 2016. Generative adversarial networks for text.
In our work, we use 88 numbers to represent 88 pitches, http://goo.gl/Wg9DR7.
which correspond to the 88 keys on the piano. With the pitch Graves, A. 2013. Generating sequences with recurrent neural
sampling for every 0.4s8 , we transform the midi files into networks. arXiv:1308.0850.
sequences of numbers from 1 to 88 with the length 32. He, J.; Zhou, M.; and Jiang, L. 2012. Generating chinese classical
To model the fitness of the discrete piano key patterns, poems with statistical machine translation models. In AAAI.
BLEU is used as the evaluation metric. To model the fitness Hingston, P. 2009. A turing test for computer game bots. IEEE
of the continuous pitch data patterns, the mean squared error TCIAIG 1(3):169–186.
(MSE) (Manaris et al. 2007) is used for evaluation. Hinton, G. E.; Osindero, S.; and Teh, Y.-W. 2006. A fast learning
From Table 4, we see that SeqGAN outperforms the MLE algorithm for deep belief nets. Neural computation 18(7):1527–
significantly in both metrics in the music generation task. 1554.
Hochreiter, S., and Schmidhuber, J. 1997. Long short-term memory.
Conclusion Neural computation 9(8):1735–1780.
Huszár, F. 2015. How (not) to train your generative model: Sched-
In this paper, we proposed a sequence generation method, uled sampling, likelihood, adversary? arXiv:1511.05101.
SeqGAN, to effectively train generative adversarial nets for
Kim, Y. 2014. Convolutional neural networks for sentence classifi-
structured sequences generation via policy gradient. To our cation. arXiv:1408.5882.
best knowledge, this is the first work extending GANs to
Kingma, D. P., and Welling, M. 2014. Auto-encoding variational
generate sequences of discrete tokens. In our synthetic data
bayes. ICLR.
experiments, we used an oracle evaluation mechanism to
Lai, S.; Xu, L.; Liu, K.; and Zhao, J. 2015. Recurrent convolutional
explicitly illustrate the superiority of SeqGAN over strong
neural networks for text classification. In AAAI, 2267–2273.
baselines. For three real-world scenarios, i.e., poems, speech
language and music generation, SeqGAN showed excellent Manaris, B.; Roos, P.; Machado, P.; et al. 2007. A corpus-based
hybrid approach to music analysis and composition. In NCAI,
performance on generating the creative sequences. We also volume 22, 839.
performed a set of experiments to investigate the robustness
Papineni, K.; Roukos, S.; Ward, T.; and Zhu, W.-J. 2002. Bleu: a
and stability of training SeqGAN. For future work, we plan
method for automatic evaluation of machine translation. In ACL,
to build Monte Carlo tree search and value network (Silver et 311–318.
al. 2016) to improve action decision making for large scale
Quinlan, J. R. 1996. Bagging, boosting, and c4. 5. In AAAI/IAAI,
data and in the case of longer-term planning. Vol. 1, 725–730.
Salakhutdinov, R. 2009. Learning deep generative models. Ph.D.
References Dissertation, University of Toronto.
Bachman, P., and Precup, D. 2015. Data generation as sequential Silver, D.; Huang, A.; Maddison, C. J.; Guez, A.; Sifre, L.; et al.
decision making. In NIPS, 3249–3257. 2016. Mastering the game of go with deep neural networks and tree
Bahdanau, D.; Brakel, P.; Xu, K.; et al. 2016. An actor-critic search. Nature 529(7587):484–489.
algorithm for sequence prediction. arXiv:1607.07086. Srivastava, N.; Hinton, G. E.; Krizhevsky, A.; Sutskever, I.; and
Bahdanau, D.; Cho, K.; and Bengio, Y. 2014. Neural machine trans- Salakhutdinov, R. 2014. Dropout: a simple way to prevent neural
lation by jointly learning to align and translate. arXiv:1409.0473. networks from overfitting. JMLR 15(1):1929–1958.
Bengio, Y.; Yao, L.; Alain, G.; and Vincent, P. 2013. Generalized Srivastava, R. K.; Greff, K.; and Schmidhuber, J. 2015. Highway
denoising auto-encoders as generative models. In NIPS, 899–907. networks. arXiv:1505.00387.
Bengio, S.; Vinyals, O.; Jaitly, N.; and Shazeer, N. 2015. Scheduled Sutskever, I.; Vinyals, O.; and Le, Q. V. 2014. Sequence to sequence
sampling for sequence prediction with recurrent neural networks. learning with neural networks. In NIPS, 3104–3112.
In NIPS, 1171–1179. Sutton, R. S.; McAllester, D. A.; Singh, S. P.; Mansour, Y.; et al.
Browne, C. B.; Powley, E.; Whitehouse, D.; Lucas, S. M.; et al. 1999. Policy gradient methods for reinforcement learning with
2012. A survey of monte carlo tree search methods. IEEE TCIAIG function approximation. In NIPS, 1057–1063.
4(1):1–43. Veselỳ, K.; Ghoshal, A.; Burget, L.; and Povey, D. 2013. Sequence-
discriminative training of deep neural networks. In INTERSPEECH,
Cho, K.; Van Merriënboer, B.; Gulcehre, C.; et al. 2014. Learning
2345–2349.
phrase representations using RNN encoder-decoder for statistical
machine translation. EMNLP. Wen, T.-H.; Gasic, M.; Mrksic, N.; Su, P.-H.; Vandyke, D.; and
Young, S. 2015. Semantically conditioned LSTM-based natural lan-
Denton, E. L.; Chintala, S.; Fergus, R.; et al. 2015. Deep generative
guage generation for spoken dialogue systems. arXiv:1508.01745.
image models using a laplacian pyramid of adversarial networks. In
NIPS, 1486–1494. Williams, R. J. 1992. Simple statistical gradient-following algo-
rithms for connectionist reinforcement learning. Machine learning
Glynn, P. W. 1990. Likelihood ratio gradient estimation for stochas- 8(3-4):229–256.
tic systems. Communications of the ACM 33(10):75–84.
Yi, X.; Li, R.; and Sun, M. 2016. Generating chinese classical
Goodfellow, I., et al. 2014. Generative adversarial nets. In NIPS, poems with RNN encoder-decoder. arXiv:1604.01537.
2672–2680.
Zhang, X., and Lapata, M. 2014. Chinese poetry generation with
Goodfellow, I.; Bengio, Y.; and Courville, A. 2016. Deep learning. recurrent neural networks. In EMNLP, 670–680.
2015.
Zhang, X., and LeCun, Y. 2015. Text understanding from scratch.
8
http://deeplearning.net/tutorial/rnnrbm.html arXiv:1502.01710.
2858