A Geometric Understanding of Deep Learning 2020 Engineering
A Geometric Understanding of Deep Learning 2020 Engineering
Engineering
journal homepage: www.elsevier.com/locate/eng
Research
Artificial Intelligence—Article
a r t i c l e i n f o a b s t r a c t
Article history: This work introduces an optimal transportation (OT) view of generative adversarial networks (GANs).
Received 2 March 2019 Natural datasets have intrinsic patterns, which can be summarized as the manifold distribution principle:
Revised 31 August 2019 the distribution of a class of data is close to a low-dimensional manifold. GANs mainly accomplish two
Accepted 11 September 2019
tasks: manifold learning and probability distribution transformation. The latter can be carried out using
Available online 11 January 2020
the classical OT method. From the OT perspective, the generator computes the OT map, while the
discriminator computes the Wasserstein distance between the generated data distribution and the real
Keywords:
data distribution; both can be reduced to a convex geometric optimization process. Furthermore, OT
Generative
Adversarial
theory discovers the intrinsic collaborative—instead of competitive—relation between the generator
Deep learning and the discriminator, and the fundamental reason for mode collapse. We also propose a novel generative
Optimal transportation model, which uses an autoencoder (AE) for manifold learning and OT map for probability distribution
Mode collapse transformation. This AE–OT model improves the theoretical rigor and transparency, as well as the
computational stability and efficiency; in particular, it eliminates the mode collapse. The experimental
results validate our hypothesis, and demonstrate the advantages of our proposed model.
Ó 2020 THE AUTHORS. Published by Elsevier LTD on behalf of Chinese Academy of Engineering and
Higher Education Press Limited Company. This is an open access article under the CC BY-NC-ND license
(http://creativecommons.org/licenses/by-nc-nd/4.0/).
Generative adversarial networks (GANs) have emerged as one The great success of GANs can be explained by the fact
of the dominant approaches for unconditional image generation. that GANs effectively discover the intrinsic structures of real
When trained on several datasets, GANs are able to produce realis- datasets, which can be formulated as the manifold distribution
tic and visually appealing samples. GAN methods train an uncondi- hypothesis: A specific class of natural data is concentrated on a
tional generator that regresses real images from random noises low-dimensional manifold embedded in the high-dimensional
and a discriminator that measures the difference between the gen- background space [2].
erated samples and real images. GANs have received various Fig. 1 shows the manifold structure of the MNIST database.
improvements. One breakthrough was achieved by combing opti- Each handwritten digit image has the dimensions 28 28, and
mal transportation (OT) theory with GANs, such as the Wasserstein is treated as a point in the image space R784 . The MNIST data-
GAN (WGAN) [1]. In the WGAN framework, the generator com- base is concentrated close to a low-dimensional manifold. By
putes the OT map from the white noise to the data distribution, using the t-SNE manifold embedding algorithm [3], the MNIST
and the discriminator computes the Wasserstein distance between database is mapped onto a planar domain, and each image is
the generated data distribution and the real data distribution. mapped onto a single point. The images representing the same
digit are mapped onto one cluster, and 10 clusters are color
⇑ Corresponding author. encoded. This demonstrates that the MNIST database is
E-mail address: [email protected] (X. Gu). distributed close to a two-dimensional (2D) surface embedded
#
These authors contributed equally to this work. in the unit cube in R784 .
https://doi.org/10.1016/j.eng.2019.09.010
2095-8099/Ó 2020 THE AUTHORS. Published by Elsevier LTD on behalf of Chinese Academy of Engineering and Higher Education Press Limited Company.
This is an open access article under the CC BY-NC-ND license (http://creativecommons.org/licenses/by-nc-nd/4.0/).
362 N. Lei et al. / Engineering 6 (2020) 361–374
Fig. 1. Manifold distribution of the MNIST database. (a) Some handwritten digitals in MNIST database; (b) the embedded result of the digitals in two-dimensional (2D) plane
by t-SNE algorithm. The x and y relative coordinates are normalized.
Fig. 2 illustrates the theoretic model of GANs. The real data dis-
tribution v is concentrated on a manifold R embedded in the ambi-
ent space v. (R, v) together show the intrinsic structure of the real
datasets. A GAN model computes a generator map gh from the
latent space Z to the manifold R, where h represents the parameter
of a deep neural network (DNN). f is a Gaussian distribution in the
latent space, and gh pushes forward f to lh. The discriminator
Fig. 3. The generator map is decomposed into a decoding map h and a transporta-
calculates a distance between the real data distribution v and tion map T. T#f is the push-forward measure induced by T.
the generated distribution lh, such as the Wasserstein distance
Wc(lh, v), which is equivalent to the Kontarovich’s potential
transformation map T: Z ? Z. The decoding map h is for manifold
un (n: the parameter of the discriminator).
learning, and the map T is for measure transportation.
Despite GANs’ advantages, they have critical drawbacks. In the-
ory, the understanding of the fundamental principles of deep
learning remains primitive. In practice, the training of GANs is 1.3. Optimal transportation view
tricky and sensitive to hyperparameters; GANs suffer from mode
collapsing. Recently, Mescheder et al. [4] studied nine different OT theory [5] studies the problem of transforming one
GAN models and variants showing that gradient-descent-based probability distribution into another distribution in the most
GAN optimization is not always locally convergent. economical way. OT provides rigorous and powerful ways to
According to the manifold distribution hypothesis, a natural compute the optimal mapping to transform one probability
dataset can be represented as a probability distribution on a distribution into another distribution, and to determine the
manifold. Therefore, GANs mainly accomplish two tasks: distance between them [6].
① manifold learning—namely, computing the decoding/encoding As mentioned before, GANs accomplish two major tasks:
maps between the latent space and the ambient space; and manifold learning and probability distribution transformation.
② probability distribution transformation, either in the latent or The latter task can be fully carried out by OT methods directly. In
image space, which involves transformation between the given detail, in Fig. 3, the probability distribution transformation map T
white noise and the data distribution. can be computed using OT theory. The discriminator computes
Fig. 3 shows the decomposition of the generator map gh = hT, the Wasserstein distance Wc(lh, v) between the generated data
where h: Z ? R is the decoding map from the latent space to the distribution and the real data distribution, which can be calculated
data manifold R in the ambient space, the probability distribution directly using the OT method.
From the theoretical point of view, the OT interpretation of
GANs makes part of the black box transparent, the probability
distribution transformation is reduced to a convex optimization
process using OT theory, the existence and uniqueness of the
solution have theoretical guarantees, and the convergence rate
and approximation accuracy are fully analyzed.
The OT interpretation also explains the fundamental reason for
mode collapse. According to the regularity theory of the Monge–
Ampère equation, the transportation map is discontinuous on
some singular sets. However, DNN can only model continuous
functions/mappings. Therefore, the target transportation mapping
is outside of the functional space representable by GANs. This
Fig. 2. The theoretic model of GANs. G: generator; D: discriminator. intrinsic conflict makes mode collapses unavoidable.
N. Lei et al. / Engineering 6 (2020) 361–374 363
methods is generally tricky and inefficient. Later, a huge break- image distribution through the Gibbs distribution by representing
through was achieved from the scheme of variational AEs (VAEs) the energy function with DNNs. These methods alternatively gen-
[17], where the decoders approximate real data distributions from erate fake samples using the current models, and then optimize
a Gaussian distribution using a variational approach [17,18]. Vari- the model parameters with the generated fake samples and real
ous recent works following this scheme have been proposed, samples.
including adversarial AEs (AAEs) [19] and Wasserstein AEs (WAEs)
[20]. Although VAEs are relatively simple to train, the images they 3. Optimal transportation theory
generate look blurry. To some extent, this is because the explicitly
expressed density functions may fail to represent the complexity of In this section, we introduce basic concepts and theorems in
a real data distribution and learn the high-dimensional data distri- classic OT theory, with a focus on Brenier’s approach and their gen-
bution [21,22]. Other non-adversarial training models have been eralization to the discrete setting. Details can be found in Villani’s
proposed, including PixelCNN [23], PixelRNN [24], and WaveNet book [5].
[25]. However, due to their auto-regressive nature, the generation
of new samples cannot be paralleled. 3.1. Monge’s problem
2.4. Evaluation of generative models We only consider maps that preserve the measures.
The evaluation of generative models remains challenging. Early Definition 3.1 (measure-preserving map). A map T: X ? Y is mea-
works include probabilistic criteria [29]. However, recent genera- sure preserving if for any measurable set B Y, the set T1(B) is l-
tive models (particularly GANs) are not amenable to such evalua- measurable and l[T1(B)] = v(B), that is,
tion. Traditionally, the evaluation of GANs relies on visual Z Z
inspection of a handful of examples or a user study. Recently, sev- f ðxÞdx ¼ g ðyÞdy ð2Þ
T 1 ðBÞ B
eral quantitative evaluation criteria were proposed. The inception
score (IS) [30] measures both diversity and image quality. How- The measure-preserving condition is denoted as T#l = v, where
ever, it is not a distance metric. To overcome the shortcomings of T#l is the push-forward measure induced by T.
the IS, the Fréchet inception distance (FID) was introduced in Ref. Given a cost function c(x, y): X Y? R0 , which indicates the
[31]. The FID has been shown to be robust to image corruption, cost of moving each unit mass from the source to the target, the
and correlates well with visual fidelity. In a more recent work total transport cost (Ct) of the map T: X ? Y is defined to be
Z
[32], precision and recall for distributions (PRD) was introduced
Ct ¼ c½x; T ðxÞdlðxÞ ð3Þ
to measure both precision and recall between generated data dis- X
tribution and real data distribution. In order to fairly compare the
Monge’s problem of OT arises from finding the measure-
GANs, a large-scale comparison was performed in Ref. [33], where
preserving map that minimizes the total transport cost.
seven different GANs and VAEs were compared under a uniform
network architecture, and a common baseline for evaluation was
Problem 3.2 (Monge’s [43]; MP). Given a transport cost function
established.
c(x, y): X Y? R0 , find the measure-preserving map T: X ? Y that
minimizes the total transport cost:
2.5. Non-adversarial models Z
ðMPÞ min c½x; T ðxÞdlðxÞ ð4Þ
Various non-adversarial models have also been proposed T # l¼v X
probability of q is equal to l and v, respectively. Let the projection u ðyÞ ¼ sup½hx;yi uðxÞ ð14Þ
x
maps formally be px(x, y) = x, py(x, y) = y, then define the joint mea-
sure class as follows: It can be shown that the following relation holds when
n o cðx; yÞ ¼ 1=2k x y k2 :
Pðl; v Þ ¼ qðx; yÞ : X Y ! R : ðpx Þ# q ¼ l; py #
q¼v ð6Þ
1 1
k y k2 uc ðy Þ ¼ k x k2 uðxÞ ð15Þ
2 2
Problem 3.4 (Kontarovich’s; KP). Given a transport cost function
c(x, y): X Y? R0 , find the joint probability measure q(x, y): Theorem 3.9 (Brenier’s polar factorization [44]). Suppose X and Y
X Y? R0 that minimizes the total transport cost. are the Euclidean space Rd , l is absolutely continuous with respect
Z to the Lebesgue measure, and a mapping u: X ! Y pushes l for-
ðKPÞ W c ðl; v Þ ¼ min cðx; yÞdqðx; yÞ ð7Þ ward to v, u#l = v , then there exists a convex function u: X ! R,
q2Pðl;v Þ XY such that u ¼ rus, where s: X ? X is measure preserving,
KP can be solved using the LP method. Due to the duality of LP, s#l = l. Furthermore, this factorization is unique.
Eq. (7) (the KP equation) can be reformulated as the duality prob- The following theorem is well known in OT theory:
lem (DP) as follows:
Theorem 3.10 (Villani [5]). Given l and v on a compact convex
Problem 3.5 (duality; DP). Given a transport cost function domain X Rd , there exists an OT plan q for the cost c(x, y) =
c(x, y): X Y? R0 , find the real functions u: X ! R and w: h(x y), with h strictly convex. It is unique and of the form
Y ! R, such that (id, T#)l (id: identity map), provided that l is absolutely continu-
Z Z ous and @ X is negligible. Moreover, there exists a Kantorovich’s
ðDPÞ max uðxÞdl þ wðyÞdv : uðxÞ þ wðyÞ cðx; yÞ ð8Þ potential u, and T can be represented as follows:
u;w X Y
1
The maximum value of Eq. (8) gives the Wasserstein distance. T ðxÞ ¼ x ðrhÞ ½ruðxÞ
Most existing WGAN models are based on the duality formulation
When cðx; yÞ ¼ 1=2k x y k2 , we have
under the L1 cost function.
1
Definition 3.6 (c-transformation). The c-transformation of u:
T ðxÞ ¼ x ruðxÞ ¼ r k x k2 uðxÞ ¼ ruðxÞ
2
X ! R is defined as uc: Y ! R:
In this case, the Brenier’s potential u and the Kantorovich’s
uc ðyÞ ¼ inf ½cðx; yÞ uðxÞ ð9Þ potential u are related by the following:
x2X
Theorem 3.7 (Brenier’s [44]). Suppose X and Y are subsets of the 3.4.1. Convex target domain
Euclidean space Rd and the transportation cost is the quadratic Definition 3.11 (Hölder continuous). A real or complex-valued
Euclidean distance cðx; yÞ ¼ 1=2k x y k2 . Furthermore, l is abso- function f on a d-dimensional Euclidean space satisfies a Hölder
lutely continuous and l and v have finite second-order moments condition, or is Hölder continuous, when there are nonnegative
Z Z real constants C, a > 0, such that jf ðxÞ f ðyÞj Ck x y ka for all
k x k2 dlðxÞ þ k y k2 dv ðyÞ < 1 ð11Þ x and y in the domain of f.
X Y
then there exists a convex function u: X ! R, the so-called Brenier’s Definition 3.12 (Hölder space). The Hölder space C k;a (X), where X
potential, whose gradient map ru gives the solution to MP: is an open subset of some Euclidean space and k 0 is an integer,
consists of those functions on X having continuous derivatives up
ðruÞ# l ¼ v ð12Þ to order k and such that the kth partial derivatives are Hölder con-
a
The Brenier’s potential is unique up to a constant; hence, the tinuous with exponent a, where 0 < a 1. C k;
loc (X) means the above
optimal mass transportation map is unique. conditions hold on any compact subset of X.
Assuming that the Briener potential is C2 smooth, then it is the
solution to the following Monge–Ampère equation: Theorem 3.13 (Caffarelli [45]). If K is convex, then the Brenier’s
! potential u is strictly convex; furthermore,
@ 2 uðxÞ f ðxÞ (1) If k f ; g 1=k for some k > 0, then u 2 C 1; a
det ¼ ð13Þ loc (X).
@xi @xj gruðxÞ (2) If f 2 C k;a k;a kþ2;a
loc (X) and g 2 C loc (K), with f, g > 0, then u 2 C loc (X)
2 2 and ðk 0; a 2 ð0:1ÞÞ.
For the L transportation cost cðx; yÞ ¼ 1=2k x y k in R , the d
[
3 ruh: W i ðhÞ ! yi ; i ¼ 1; 2; :::; n: ð22Þ
R0 ¼ XffR1 [ R2 g; R1 ¼ ck ; R2 ¼ fx0 ; x1 g
k¼0 Given the target measure v in Eq. (17), there exists a discrete
Brenier’s potential in Eq. (19) whose projected l volume of each
The subgradient of x0, @u(x0), is the entire inner hole of K, while facet wi(h) is equal to the given target measure vi. This was proved
@u(x1) is the shaded triangle. For each point on ck(t), @u[ck(t)] is a by Alexandrov [46] in convex geometry.
line segment outside K. x1 is the bifurcation point of c1, c2, and
c3. The Brenier’s potential on R1 and R2 is not differentiable, and Theorem 4.1 (Alexandrov [46]). Suppose X is a compact convex
the OT map ru on them is discontinuous.
polytope with a non-empty interior in Rn , n1, ..., nk Rnþ1 are
distinct k unit vectors, the (n + 1)th coordinates are negative, and
4. Computational algorithm P
v1, ..., vk > 0 so that ki¼1 v i ¼ volðXÞ. Then there exists a convex
Brenier’s theorem can be directly generalized to the discrete polytope P R nþ1
with the exact k codimension-1 faces F1, ..., Fk
situation. In GAN models, the source measure l is given as a so that ni is the normal vector to Fi and the intersection between
uniform (or Gaussian) distribution defined on a compact convex X and the projection of Fi has the volume vi. Furthermore, such P
domain X; the target measure v is represented as the empirical is unique up to vertical translation.
measure, which is the sum of the Dirac measures:
X
n
v ¼ v i dðy y i Þ ð17Þ
i¼1
where Y = {y1, y2, ..., yn} are training samples, with the weights
Pn
i¼1 v i ¼ lðXÞ; d is the characteristic function.
Each training sample yi corresponds to a supporting plane of the
Brenier’s potential, denoted as follows:
Alexandrov’s proof for the existence of the solution is based on Definition 4.3 (power distance). Given a point yi 2 Rd with a
algebraic topology, which is not constructive. Recently, Gu et al. [6] power weight wi, the power distance is given by the following:
provided a constructive proof based on the variational approach.
powðx; yi Þ ¼ k x yi k2 wi ð28Þ
Theorem 4.2 (Ref. [6]). Let l be a probability measure defined on a
Definition 4.4 (power diagram). Given the weighted points
compact convex domain X in Rd , and let Y = {y1, y2, ..., yn} be a set of
ðy1 ; w1 Þ; :::; ðyk ; wk Þ, the power diagram is the cell decomposition
distinct points in Rd . Then for any v1, v2, ..., vn > 0 with
Pn of Rd :
i¼1 v i ¼ lðXÞ, there exists h ¼ ðh1 ; h2 ; :::; hn Þ 2 R , which is unique
n
up to adding a constant (c, c, ..., c), so that wi(h) = v i, for all i. The [
k
vector h is the unique minimum argument of the following convex Rd ¼ W i ðwÞ ð29Þ
energy: i¼1
Z h X
n X
n where each cell is a convex polytope:
EðhÞ ¼ wi ðgÞdgi hi v i ð23Þ
0 i¼1 i¼1 W i ðwÞ ¼ x 2 Rd powðx; yi Þ pow x; yj ð30Þ
defined on an open convex set The weighted Delaunay triangulation, denoted as T(w), is the
h ¼ fh 2 Rn : wi ðhÞ > 0; i ¼ 1; 2; :::; ng ð24Þ Poincaré dual to the power diagram; if W i ðwÞ \ W j ðwÞ – /, then
there is an edge connecting yi and yj in the weighted Delaunay tri-
Furthermore, ruh minimizes the quadratic cost angulation. Note that pow(x, yi) pow(x, yj) is equivalent to
Z
1 1
k x T ðxÞ k2 dlðxÞ ð25Þ 1
2 hx; yi i þ w k yi k2 x; yj þ w k yj k2 ð31Þ
X 2 i 2 j
among all transport maps T#l = v .
Let hi ¼ 1=2 wi k yi k2 ; then we rewrite the definition of
The gradient of the above convex energy in Eq. (23) is given by
the following: W i ðwÞ as follows:
rEðhÞ ¼ ½w1 ðhÞ v 1 ; w2 ðhÞ v 2 ; :::; wn ðhÞ v n T ð26Þ W i ðwÞ ¼ x 2 Rd hx; yi i þ hi x; yj þ hj ; 8j ð32Þ
The ith row and jth column element of the Hessian of the energy In practice, our goal is to compute the discrete Brenier’s poten-
is given by the following: tial Eq. (19) by optimizing the convex energy Eq. (23). For low-
dimensional cases, we can directly use Newton’s method by com-
@wi l W i \ W j \ X @wi X @wi
¼ ; ¼ ð27Þ puting the gradient Eq. (26) and the Hessian matrix Eq. (27). For
@hj k yi yj k @hi j–i
@hj deep learning applications, direct computation of the Hessian
matrix is unfeasible; instead, we can use the gradient descend
As shown in Fig. 6, the Hessian matrix has an explicit geometric
method or quasi-Newton’s method with superlinear convergence.
interpretation. Fig. 6(a) shows the discrete Brenier’s potential uh,
The key of the gradient is to estimate the l volume wi(h). This
while Fig. 6(b) shows its Legendre transformation uh using Defini-
can be done using the Monte Carlo method: We draw n random
tion 3.8. The Legendre transformation can be constructed geomet-
samples from the distribution l, and count the number of samples
rically: For each supporting plane ph,i, we construct the dual point
n o falling within Wi(h), which is the ratio converging to the l volume.
ph;i ¼ ðyi ; hi Þ; the convex hull of the dual points ph;1 ; ph;2 ; :::; ph;n This method is purely parallel and can be implemented using a
is the graph of the Legendre transformation uh . GPU. Moreover, we can use a hierarchical method to further
The projection of uh induces a triangulation of Y ¼ improve the efficiency: First, we classify the target samples to clus-
fy1 ; y2 ; :::; yn g, which is the weighted Delaunay triangulation. As ters, and compute the OT map to the mass centers of the clusters;
shown in Fig. 7, the power diagram in Eq. (20) and the weighted second, for each cluster, we compute the OT map from the corre-
Delaunay triangulation are Poincaré dual to each other: If, in the sponding cell to the original target samples within the cluster.
power diagram, Wi(h) and Wj(h) intersect at a (d – 1)-dimensional In order to avoid mode collapse, we need to find the singularity
cell, then in the weighted Delaunay triangulation, yi connects with sets in X. As shown in Fig. 8, the target Dirac measure has two clus-
yj. The element of the Hessian matrix in Eq. (27) is the ratio ters; the source is the uniform distribution on the unit planar disk.
between the l volume of the (d – 1) cell in the power diagram The graph of the Brenier’s potential function is a convex polyhe-
and the length of the dual edge in the weighted Delaunay dron with a ridge in the middle. The projection of the ridge on
triangulation. the disk is the singularity set R1(u), and the optimal mapping is
The conventional power diagram can be closely related to the discontinuous on R1. In general cases, if two cells Wi(h) and
above theorem. Wj(h) are adjacent, then we compute the angle between the nor-
mals to the corresponding support planes:
yi ; yj
hi; j ¼
k yi kk yj k
k
gives the OT map from g h # f to v. Therefore, we obtain the
following:
h i h i
v ¼ ðruÞ# g kh # f ¼ rug kh # f ¼ id run g kh f
#
Fig. 9. Discontinuous OT map, produced by a GPU implementation of an algorithm based on Theorem 4.2: (a) is the source domain and (b) is the target domain. The middle
line in (a) is the singularity set R1.
Fig. 10. Discontinuous OT map, produced by a GPU implementation of an algorithm based on Theorem 4.2: (a) is the source domain and (b) is the target domain. c1 and c2 in
(a) are two singularity sets.
Fig. 11. OT from the Stanford bunny to a solid ball. The singular sets are the foldings on the boundary surface. (a–d) show the deformation procedure.
outside K. In practice, this will induce the phenomenon of gener- is stable and has superlinear convergence by using quasi-
ating unrealistic samples, as shown in the middle frame of Fig. 12. Newton’s method. The number of unknowns is equal to that of
Therefore, in theory, it is impossible to approximate OT maps the training samples, avoiding over-paramerization. The parallel
directly using DNNs. OT map algorithm can be implemented using a GPU. The error
bound of the OT map can be controlled by the sampling density
5.3. AE–OT model in the Monte Carlo method. The hierarchical algorithm with self-
adaptivity further improves the efficiency. In particular, the AE–
As shown in Fig. 4, we separate the two main tasks of GANs: OT model can eliminate mode collapse.
manifold learning and probability distribution transformation.
The first task is carried out by an AE to compute the encoding/ 6. Experimental results
decoding maps fh, gn; the second task is accomplished using the
explicit variational method to compute the OT map T in the latent In this section, we report our experimental results.
space. The real data distribution v is pushed forward by the encod-
ing map fh, inducing (fh)#v. In the latent space, T maps the uniform 6.1. Training process
distribution l to (fh)#v.
The AE–OT model has many advantages. In essence, finding the The training of the AE–OT model mainly includes two steps:
OT map is a convex optimization problem; the existence and the training the AE and finding the OT map. The OT step is accom-
uniqueness of the solution are guaranteed. The training process plished using a GPU implementation of the algorithm, as described
370 N. Lei et al. / Engineering 6 (2020) 361–374
Fig. 12. Facial images generated by an AE–OT model. (a) Generated realistic facial images; (b) a path through a singularity. The image in the center of (b) shows that the
transportation map is discontinuous.
in Section 4. In the AE step, during the training process, we adopt If the support of the push-forward measure (fh)#v in the latent
the Adam algorithm [49] to optimize the parameters of the neutral space is non-convex, there will be a singularity set Rk, where
network, with a learning rate of 0.003, b1 = 0.5, and b2 = 0.999. k > 0. We would like to detect the existence of Rk. We randomly
When the L2 loss stops descending, which means that the network draw line segments in the unit cube in the latent space, and then
has found a good encoding map, we freeze the encoder part and densely interpolate along this line segment to generate facial
continue to train the network for the decoding map. The training images. As shown in Fig. 12(b), we find a line segment c, and gen-
loss before and after the freezing of the encoder is shown in Table 1. erate a morphing sequence between a boy with a pair of brown
Next, in order to find the OT map from the given distribution (here, eyes and a girl with a pair of blue eyes. In the middle, we generate
we use uniform distribution) to the distribution of latent features, a face with one blue eye and one brown eye, which is definitely
we randomly sample 100N random points from the uniform distri- unrealistic and outside R. This result means that the line segment
bution to compute the gradient of the energy. Here, N is the num- c goes through a singularity set Rk, where the transportation map T
ber of latent features of the dataset. Also, in the experiment, hi,j is is discontinuous. This also shows that our hypothesis is correct:
set to be different for different datasets. To be specific, for the The support of the encoded human facial image measure on the
MNIST and Fashion-MNIST datasets, hi,j is set to be 0.75, while latent space is non-convex.
for the CIFAR-10 and CelebA datasets, it is set to be 0.68 and As a byproduct, we find that this AE–OT framework improves
0.75, respectively. the training speed by a factor of five and increases the convergence
Our AE–OT model was implemented using PyTorch on a Linux stability, since the OT step is a convex optimization. Thus, it pro-
platform. All the experiments were conducted on a GTX1080Ti. vides a promising way to improve existing GANs.
In this experiment, we want to test our hypothesis: In most real Since the synthetic dataset consists of explicit distributions and
applications, the support of the target measure is non-convex, the known modes, mode collapse can be accurately measured. We
singularity set is non-empty, and the probability distribution map chose two synthetic datasets that have been studied or proposed
is discontinuous along the singularity set. in prior works [50,51]: a 2D grid dataset.
As shown in Fig. 12, we use an AE to compute the encoding/ For a choice of the measurement metric of mode collapse, we
decoding maps from the CelebA dataset (R, v) to the latent space adopted three previously used metrics [50,51]. Number of modes
Z; the encoding map fh: R ? Z pushes forward v to (fh)#v on the counts the quantity of modes captured by the samples produced
latent space. In the latent space, we compute the OT map based by a generative model. In this metric, a mode is considered as lost
on the algorithm described in Section 4, T: Z ? Z, where T maps if no sample is generated within three standard deviations of that
the uniform distribution in a unit cube f to (fh)#v. Then we mode. Percentage of high-quality samples measures the proportion
randomly draw a sample z from the distribution f and use the of samples that are generated within three standard deviations of
decoding map gn: Z ? R to map T(z) to a generated human facial the nearest mode. The third metric, used in Ref. [51], is the reverse
image gnT(z). Fig. 12(a) demonstrates the realistic facial images Kullback–Leibler (KL) divergence. In this metric, each generated
generated by this AE–OT framework. sample is assigned to its nearest mode, and we count the his-
togram of samples assigned on each mode. This histogram then
forms a discrete distribution, whose KL divergence with the his-
Table 1 togram formed by real data is then calculated. Intuitively, this
The L2 loss of the AEs before and after the freezing of the encoder. measures how well the generated samples balance among all
Situation Dataset modes regarding the real distribution.
MNIST Fashion-MNIST CIFAR-10 CelebA
In Ref. [51], the authors evaluated GAN [26], adversarially
learned inference (ALI) [52], minibatch discriminati (MD) [30],
Before 0.0013 0.0026 0.0023 0.0077
and PacGAN [51] on synthetic datasets with the above three
After 0.0005 0.0011 0.0018 0.0074
metrics. Each experiment was trained under the same generator
N. Lei et al. / Engineering 6 (2020) 361–374 371
architecture with a total of approximately 4 105 training param- used the PRD curve, which can quantify the degree of mode drop-
eters. The networks were trained on 1 105 samples for 400 ping and mode inventing on real datasets [32].
epochs. For the AE–OT experiment, since the source space and tar-
get space are both 2D, there is no need to train an AE. We directly 6.4.1. Comparison with FID score
compute a semi-discrete OT that maps between the uniform distri- The FID score is computed as follows: ① Extract the visually
bution on the unit square and the empirical real data distribution. meaningful features of both the generated and real images by run-
Theoretically, the minimum amount of real sample needed for OT ning the inception network [30], ② fit the real and generated fea-
to recover all modes is one sample per mode. However, this may ture distributions with Gaussian distributions; and ③ compute the
lead to the generation of low-quality samples during the interpola- distance between the two Gaussian distributions using the follow-
tion process. Therefore, for OT computation, we take 512 real sam- ing formula:
ples, and new samples are generated based on this map. We note h 1=2 i
that, in this case, there are only 512 parameters to optimize in FID ¼ k lr lg k22 þ T r Rr þ Rg 2 Rr Rg ð34Þ
OT computing, and the optimization process is stable due to the
existence of the convex positive-definite Hessian. Our results are where lr and lg represent the means of the real and generated dis-
provided in Table 2, and benchmarks of previous methods are tributions, respectively; and Rr and Rg represent the variances of
copied from Ref. [51]. For illustration purposes, we plotted our these distributions.
results on synthetic datasets along with those of GAN and PacGAN The comparison results are summarized in Tables 3 and 4. The
in Fig. 13. statistics of various GANs come from Lucic et al. [33], and those
of the non-adversarial generative models come from Hoshen and
Malik [36]. In general, our proposed model achieves better FID
6.4. Comparison with the state of the art scores than the other state-of-the-art generative models.
Theoretically, the FID scores of our AE–OT model should be
We designed experiments to compare our proposed AE–OT close to those of the pre-trained AEs; this is also validated by our
model with state-of-the-art generative models, including the experiments.
adversarial models evaluated by Lucic et al. in Ref. [33], and the The fixed network architecture of our AE was adopted from
non-adversarial models studied by Hoshen and Malik in Ref. [36]. Lucic et al. [33]; its capacity is not large enough to encode
For the purpose of fair comparison, we used the same testing CIFAR-10 or CelebA, so we had to down-sample these datasets.
datasets and network architecture. The datasets included MNIST We randomly selected 2.5 104 images from CIFAR-10 and 1 104
[53], Fashion-MNIST [54], CIFAR-10 [55], and CelebA [56], similar images from CelebA to train our model. Even so, our model
to those tested in Refs. [31,36]. The network architecture was obtained the best FID score in CIFAR-10. Dut to the limited capacity
similar to that used by Lucic et al. in Ref. [33]. In particular, in of the InfoGAN model, the performance of the AE of CelebA, whose
our AE–OT model, the network architecture of the decoder was FID of 67.5 is not ideal, further caused the FID of the generated
the same as that of the generators of GANs in Ref. [33], and the dataset to be 68.4. By adding two more convolutional layers to
encoder was symmetric to the decoder. the AE architecture, the L2 loss in CelebA was less than 0.03, and
We compared our model with state-of-the-art generative mod- the FID score beat all other models (28.6, as shown in the bracket
els using the FID score [31] and PRD curve as the evaluation crite- of Table 4).
ria. The FID score measures the visual fidelity of the generated
results and is robust to image corruption. However, the FID score 6.4.2. Comparison with the PRD curve
is sensitive to mode addition and dropping [33]. Hence, we also The FID score is an effective method to measure the difference
between the generated distribution and the real data distribution,
but it mainly focuses on precision, and cannot accurately capture
Table 2 what portion of real data a generative model can cover. The
Mode collapse comparison for the 2D grid dataset. method proposed in Ref. [32] disentangles the divergence between
Method Modes Samples Reverse KL
distributions into two components: precision and recall.
Given a reference distribution P and a learned distribution Q,
GAN 17.3 ± 0.8 94.8 ± 0.7% 0.70 ± 0.07
ALI 24.1 ± 0.4 95.7 ± 0.6% 0.14 ± 0.03
the precision intuitively measures the quality of samples from Q,
MD 23.8 ± 0.5 79.9 ± 3.2% 0.17 ± 0.03 while the recall measures the proportion of P that is covered by Q.
PacGAN2 23.8 ± 0.7 91.3 ± 0.8% 0.13 ± 0.04 We used the concept of (F8, F1/8) introduced by Sajjadi et al. in
PacGAN3 24.6 ± 0.4 94.2 ± 0.4% 0.06 ± 0.02 Ref. [32] to quantify the relative importance of precision and recall.
PacGAN4 24.8 ± 0.2 93.6 ± 0.6% 0.04 ± 0.01
Fig. 14 summarizes the comparison results. Each dot represents a
AE–OT 25.0 ± 0.0 99.8 ± 0.2% 0.007 ± 0.002
specific model with a set of hyperparameters. The closer a dot is
Fig. 13. Mode collapse comparison on a 2D grid dataset. (a) GAN; (b) PacGAN4; (c) AE–OT. Orange marks are real samples and green marks are generated ones.
372 N. Lei et al. / Engineering 6 (2020) 361–374
Table 3
Quantitative comparison with FID-I.
Dataset Adversarial
MM GAN NS GAN LSGAN WGAN BEGAN
MNIST 9.8 6.8 7.8 6.7 13.1
Fashion-MNIST 29.6 26.5 30.7 21.5 22.9
CIFAR-10 72.7 58.5 87.1 55.2 71.4
CelebA 65.6 55.0 53.9 41.3 38.9
The best result is shown in bold. MM: manifold matching; NS: non-saturating; LSGAN: least squares GAN; BEGAN: boundary equilibrium GAN.
Table 4
Quantitative comparison with FID-II.
to the upper-right corner, the better the performance of the model capacity of the AE, the performance of our model is not impressive.
is. The blue and green dots show the GANs and VAEs evaluated in However, after adding two more convolutional layers to the AE, our
Ref. [32], the khaki dot represents the GLANN model in Ref. [36], model achieves the best score.
and the red dot is our AE–OT model.
It is clear that our proposed model outperforms others for 6.4.3. Visual comparison
MNIST and Fashion-MNIST. For the CIFAR-10 dataset, the precision Fig. 15 shows a visual comparison between the images gener-
of our model is slightly lower than those of GANs and GLANN, but ated by our proposed method and those generated by the GANs
the recall is the highest. For the CelebA dataset, due to the limited studied by Lucic et al. in Ref. [33] and the non-adversarial models
Fig. 14. A comparison of the precision–recall pair in (F8, F1/8) in the four datasets. (a) MNIST; (b) Fashion-MNIST; (c) CIFAR-10; (d) CelebA. The khaki dots are the results of Ref.
[36]. The red dots are the results of the proposed method. The purple dot in the fourth subfigure corresponds to the results of the architecture with two more convolutional
layers.
N. Lei et al. / Engineering 6 (2020) 361–374 373
Fig. 15. A visual comparison of the four datasets. The first column (a) shows the real data; the second column (b) is generated by an AE; the third column (c) illustrates the
generating results of the GANs [33] with the highest precision-recall scores of (F8, F1/8), corresponding to the B dots in Fig. 14; the fourth column (d) gives the results of Ref.
[36]; and the last column (e) shows the results of the proposed method.
studied by Hoshen and Malik in Ref. [36]. The first column shows In the future, we will explore the theoretical understanding of
the original images, the second column shows the results gener- the manifold learning stage, and use a rigorous method to make
ated by the AE, the third column shows the best generating results this part of the black box transparent.
of the GANs in Lucic et al. [33], the fourth column displays the
results generated by the models of Hoshen and Malik [36], and Acknowledgements
the fifth column displays the results from our method. It is clear
that our method generates high-quality images and covers all The project is partially supported by the National Natural
modes. Science Foundation of China (61936002, 61772105, 61432003,
61720106005, and 61772379), US National Science Foundation
7. Conclusion (NSF) CMMI-1762287 collaborative research ‘‘computational
framework for designing conformal stretchable electronics, Ford
This work uses OT theory to interpret GANs. According to URP topology optimization of cellular mesostructures’ nonlinear
the data manifold distribution hypothesis, GANs mainly behaviors for crash safety,” and NSF DMS-1737812 collaborative
accomplish two tasks: manifold learning and probability distri- research ‘‘ATD: theory and algorithms for discrete curvatures on
bution transformation. The latter task can be carried out using network data from human mobility and monitoring.”
the OT method directly. This theoretical understanding explains
the fundamental reason for mode collapse, and shows that the Compliance with ethics guidelines
intrinsic relation between the generator and the discriminator
should be collaboration instead of competition. Furthermore, Na Lei, Dongsheng An, Yang Guo, Kehua Su, Shixia Liu,
we propose an AE–OT model, which improves the theoretical Zhongxuan Luo, Shing-Tung Yau, and Xianfeng Gu declare that
rigor, training stability, and efficiency, and eliminates mode they have no conflicts of interest or financial conflicts to disclose.
collapse.
Our experiment validates our assumption that if the distribu- References
tion transportation map is discontinuous, then the existence of
the singularity set leads to mode collapse. Furthermore, when [1] Arjovsky M, Chintala S, Bottou L. Wasserstein generative adversarial networks.
our proposed model is compared with the state of the art, our In: Proceedings of the 34th International Conference on Machine Learning;
2017 Aug 6–11; Sydney, Australia; 2017. p. 214–23.
method eliminates the mode collapse and outperforms the other [2] Tenenbaum JB, de Silva V, Langford JC. A global geometric framework for
models in terms of the FID score and PRD curve. nonlinear dimensionality reduction. Science 2000;290(5500):2319–23.
374 N. Lei et al. / Engineering 6 (2020) 361–374
[3] van der Maaten L, Hinton G. Visualizing data using t-SNE. J Mach Learn Res [29] Zoran D, Weiss Y. From learning models of natural image patches to whole
2008;9(11):2579–605. image restoration. In: Proceedings of the 2011 International Conference on
[4] Mescheder L, Geiger A, Nowozin S. Which training methods for GANs do Computer Vision; 2011 Jun 6–11; Barcelona, Spain; 2011. p. 479–86.
actually converge? In: Proceedings of the 35th International Conference on [30] Salimans T, Goodfellow I, Zaremba W, Cheung V, Radford A, Chen X. Improved
Machine Learning; 2018 Jul 10–15; Stockholmsmässan, Sweden; 2018. p. techniques for training GANs. 2016. arXiv:1606.03498.
3478–87. [31] Heusel M, Ramsauer H, Unterthiner T, Nessler B, Klambauer G, Hochreiter S.
[5] Villani C. Optimal transport: old and new. Berlin: Springer Science & Business GANs trained by a two time-scale update rule converge to a Nash equilibrium.
Media; 2008. 2017. arXiv:1706.08500.
[6] Gu DX, Luo F, Sun J, Yau ST. Variational principles for Minkowski type [32] Sajjadi MS, Bachem O, Lucic M, Bousquet O, Gelly S. Assessing generative
problems, discrete optimal transport, and discrete Monge–Ampère equations. models via precision and recall. 2018. arXiv:1806.00035.
Asian J Math 2016;20(2):383–98. [33] Lucic M, Kurach K, Michalski M, Gelly S, Bousquet O. Are GANs created equal?
[7] Peyré G, Cuturi M. Computational optimal transport. Found Trends Mach Learn A large-scale study. 2018. arXiv:1711.10337.
2019;11(5–6):355–607. [34] Bojanowski P, Joulin A, Lopez-Paz D, Szlam A. Optimizing the latent space of
[8] Solomon J. Optimal transport on discrete domains. 2018. arXiv:1801.07745. generative networks. 2017. arXiv:1707.05776.
[9] Cuturi M. Sinkhorn distances: lightspeed computation of optimal [35] Li K, Malik J. Implicit maximum likelihood estimation. 2018.
transportation distances. Adv Neural Inf Process Syst 2013;26:2292–300. arXiv:1809.09087.
[10] Solomon J, de Goes F, Peyré G, Cuturi M, Butscher A, Nguyen A, et al. [36] Hoshen Y, Malik J. Non-adversarial image synthesis with generative latent
Convolutional wasserstein distances: efficient optimal transportation on nearest neighbors. 2018. arXiv:1812.08985.
geometric domains. ACM Trans Graph 2015;34(4):66. [37] Dinh L, Krueger D, Bengio Y. NICE: non-linear independent components
[11] Lei N, Su K, Cui L, Yau ST, Gu XD. A geometric view of optimal transportation estimation. 2014. arXiv:1410.8516.
and generative model. Comput Aided Geom Des 2019;68:1–21. [38] Dinh L, Sohl-Dickstein J, Bengio S. Density estimation using real NVP. 2017.
[12] Benamou JD, Brenier Y, Guittet K. The Monge–Kantorovitch mass transfer and arXiv:1605.08803.
its computational fluid mechanics formulation. Int J Numer Methods Fluids [39] Kingma DP, Dhariwal P. Glow: generative flow with invertible 1 1
2002;40(1–2):21–30. convolutions. 2018. arXiv:1807.03039.
[13] Jean-David Benamou BDF, Oberman AM. Numerical solution of the optimal [40] LeCun Y, Chopra S, Hadsell R, Ranzota MA, Huang FJ. A tutorial on energy-
transportation problem using the Monge–Ampère equation. J Comput Phys based learning. In: Bakir G, Hofman T, Schölkopf T, Smola A, Taskar B, editors.
2014;260:107–26. Predicting structured data. Cambridge: The MIT Press; 2006.
[14] Nicolas P, Gabriel P, Oudet E. Optimal transport with proximal splitting. SIAM J [41] Dai J, Lu Y, Wu Y. Generative modeling of convolutional neural networks. In:
Imaging Sci 2014;7(1):212–38. Proceedings of the 3rd International Conference on Learning Representations;
[15] Bengio Y, Mesnil G, Dauphin Y, Rifai S. Better mixing via deep representations. 2015 May 7–9; San Diego, CA, USA; 2015.
In: Proceedings of the 30th International Conference on Machine Learning; [42] Nijkamp E, Hill M, Zhu S, Wu Y. On learning non-convergent non-
2013 Jun 16–21; Atlanta, GA, USA; 2013. p. 552–60. persistent short-run MCMC toward energy-based model. 2019. arXiv:1904.
[16] Salakhutdinov R, Larochelle H. Efficient learning of deep Boltzmann machines. 09770.
In: Proceedings of the 13th International Conference on Artificial Intelligence [43] Bonnotte N. From Knothe’s rearrangement to Brenier’s optimal transport map.
and Statistics; 2010 May 13–15; Chia Laguna Resort, Italy; 2010. p. 693–700. SIAM J Math Anal 2013;45(1):64–87.
[17] Kingma DP, Welling M. Auto-encoding variational Bayes. 2013. [44] Brenier Y. Polar factorization and monotone rearrangement of vector-valued
arXiv:1312.6114. functions. Commun Pure Appl Math 1991;44(4):375–417.
[18] Rezende DJ, Mohamed S, Wierstra D. Stochastic backpropagation and [45] Caffarelli L. Some regularity properties of solutions of Monge–Ampère
approximate inference in deep generative models. 2014. arXiv:1401.4082. equation. Commun Pure Appl Math 1991;44(8–9):965–9.
[19] Makhzani A, Shlens J, Jaitly N, Goodfellow I, Frey B. Adversarial autoencoders. [46] Alexandrov AD. Convex polyhedra. New York: Springer; 2005.
2015. arXiv:1511.05644. [47] Guo X, Hong J, Lin T, Yang N. Relaxed wasserstein with applications to GANs.
[20] Tolstikhin I, Bousquet O, Gelly S, Schoelkopf B. Wasserstein auto-encoders. 2017. arXiv:1705.07164.
2017. arXiv:1711.01558. [48] Lei N, Guo Y, An D, Qi X, Luo Z, Gu X, et al. Mode collapse and regularity of
[21] He X, Yan S, Hu Y, Niyogi P, Zhang HJ. Face recognition using laplacianfaces. optimal transportation maps. 2019. arXiv:1902.02934.
IEEE Trans Pattern Anal Mach Intell 2005;27(3):328–40. [49] Kingma DP, Ba J. Adam: a method for stochastic optimization. 2014.
[22] Arandjelović O. Unfolding a face: from singular to manifold. In: Proceedings of arXiv:1412.6980.
the 9th Asian Conference on Computer Vision; 2009 Sep 23–27; Xi’an, China; [50] Srivastava A, Valkov L, Russell C, Gutmann MU, Sutton C. VeeGAN: reducing
2009. p. 203–13. mode collapse in GANs using implicit variational learning. 2017.
[23] Salimans T, Karpathy A, Chen X, Kingma DP. PixelCNN++: Improving the arXiv:1705.17761.
PixelCNN with discretized logistic mixture likelihood and other modifications. [51] Lin Z, Khetan A, Fanti G, Oh S. PacGAN: the power of two samples in generative
2017. arXiv:1701.05517. adversarial networks. 2017. arXiv:1712.04086.
[24] Oord Ad, Kalchbrenner N, Kavukcuoglu K. Pixel recurrent neural networks. [52] Dumoulin V, Belghazi I, Poole B, Mastropietro O, Lamb A, Arjovsky M, et al.
2016. arXiv:1601.06759. Adversarially learned inference. 2016. arXiv:1606.00704.
[25] Van Den Oord A, Dieleman S, Zen H, Simonyan K, Vinyals O, Graves A, et al. [53] LeCun Y, Cortes C, Burges CJC. The MNIST database of handwritten digits.
WaveNet: a generative model for raw audio. 2016. arXiv:1609.03499. Available from: http://yann.lecun.com/exdb/mnist/.
[26] Goodfellow I, Pouget-Abadie J, Mirza M, Xu B, Warde-Farley D, Ozair S, et al. [54] Xiao H, Rasul F, Vollgraf R. Fashion-MNIST: a novel image dataset for
Generative adversarial nets. 2014. arXiv:1406.2661. benchmarking machine learning algorithms. 2017. arXiv:1708.07747.
[27] Gulrajani I, Ahmed F, Arjovsky M, Dumoulin V, Courville AC. Improved training [55] Krizhevsky A. Learning multiple layers of features from tiny images. Technical
of wasserstein GANs. 2017. arXiv:1704.00028. report. Toronto: University of Toronto; 2009.
[28] Miyato T, Kataoka T, Koyama M, Yoshida Y. Spectral normalization for [56] Zhang Z, Luo P, Loy CC, Tang X. From facial expression recognition to
generative adversarial networks. 2018. arXiv:1802.05957. interpersonal relation prediction. Int J Comput Vis 2018;126(5):550–69.