Author: András Béres
Date created: 2021/10/28
Last modified: 2021/10/28
Description: Generating images from limited data using the Caltech Birds dataset.
Generative Adversarial Networks (GANs) are a popular class of generative deep learning models, commonly used for image generation. They consist of a pair of dueling neural networks, called the discriminator and the generator. The discriminator's task is to distinguish real images from generated (fake) ones, while the generator network tries to fool the discriminator by generating more and more realistic images. If the generator is however too easy or too hard to fool, it might fail to provide useful learning signal for the generator, therefore training GANs is usually considered a difficult task.
Data augmentation, a popular technique in deep learning, is the process of randomly applying semantics-preserving transformations to the input data to generate multiple realistic versions of it, thereby effectively multiplying the amount of training data available. The simplest example is left-right flipping an image, which preserves its contents while generating a second unique training sample. Data augmentation is commonly used in supervised learning to prevent overfitting and enhance generalization.
The authors of StyleGAN2-ADA show that discriminator overfitting can be an issue in GANs, especially when only low amounts of training data is available. They propose Adaptive Discriminator Augmentation to mitigate this issue.
Applying data augmentation to GANs however is not straightforward. Since the generator is updated using the discriminator's gradients, if the generated images are augmented, the augmentation pipeline has to be differentiable and also has to be GPU-compatible for computational efficiency. Luckily, the Keras image augmentation layers fulfill both these requirements, and are therefore very well suited for this task.
A possible difficulty when using data augmentation in generative models is the issue of "leaky augmentations" (section 2.2), namely when the model generates images that are already augmented. This would mean that it was not able to separate the augmentation from the underlying data distribution, which can be caused by using non-invertible data transformations. For example, if either 0, 90, 180 or 270 degree rotations are performed with equal probability, the original orientation of the images is impossible to infer, and this information is destroyed.
A simple trick to make data augmentations invertible is to only apply them with some probability. That way the original version of the images will be more common, and the data distribution can be inferred. By properly choosing this probability, one can effectively regularize the discriminator without making the augmentations leaky.
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers
# data
num_epochs = 10 # train for 400 epochs for good results
image_size = 64
# resolution of Kernel Inception Distance measurement, see related section
kid_image_size = 75
padding = 0.25
dataset_name = "caltech_birds2011"
# adaptive discriminator augmentation
max_translation = 0.125
max_rotation = 0.125
max_zoom = 0.25
target_accuracy = 0.85
integration_steps = 1000
# architecture
noise_size = 64
depth = 4
width = 128
leaky_relu_slope = 0.2
dropout_rate = 0.4
# optimization
batch_size = 128
learning_rate = 2e-4
beta_1 = 0.5 # not using the default value of 0.9 is important
ema = 0.99
In this example, we will use the Caltech Birds (2011) dataset for generating images of birds, which is a diverse natural dataset containing less then 6000 images for training. When working with such low amounts of data, one has to take extra care to retain as high data quality as possible. In this example, we use the provided bounding boxes of the birds to cut them out with square crops while preserving their aspect ratios when possible.
def round_to_int(float_value):
return tf.cast(tf.math.round(float_value), dtype=tf.int32)
def preprocess_image(data):
# unnormalize bounding box coordinates
height = tf.cast(tf.shape(data["image"])[0], dtype=tf.float32)
width = tf.cast(tf.shape(data["image"])[1], dtype=tf.float32)
bounding_box = data["bbox"] * tf.stack([height, width, height, width])
# calculate center and length of longer side, add padding
target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
target_size = tf.maximum(
(1.0 + padding) * (bounding_box[2] - bounding_box[0]),
(1.0 + padding) * (bounding_box[3] - bounding_box[1]),
)
# modify crop size to fit into image
target_height = tf.reduce_min(
[target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
)
target_width = tf.reduce_min(
[target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
)
# crop image
image = tf.image.crop_to_bounding_box(
data["image"],
offset_height=round_to_int(target_center_y - 0.5 * target_height),
offset_width=round_to_int(target_center_x - 0.5 * target_width),
target_height=round_to_int(target_height),
target_width=round_to_int(target_width),
)
# resize and clip
# for image downsampling, area interpolation is the preferred method
image = tf.image.resize(
image, size=[image_size, image_size], method=tf.image.ResizeMethod.AREA
)
return tf.clip_by_value(image / 255.0, 0.0, 1.0)
def prepare_dataset(split):
# the validation dataset is shuffled as well, because data order matters
# for the KID calculation
return (
tfds.load(dataset_name, split=split, shuffle_files=True)
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache()
.shuffle(10 * batch_size)
.batch(batch_size, drop_remainder=True)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")
After preprocessing the training images look like the following:
Kernel Inception Distance (KID) was proposed as a replacement for the popular Frechet Inception Distance (FID) metric for measuring image generation quality. Both metrics measure the difference in the generated and training distributions in the representation space of an InceptionV3 network pretrained on ImageNet.
According to the paper, KID was proposed because FID has no unbiased estimator, its expected value is higher when it is measured on fewer images. KID is more suitable for small datasets because its expected value does not depend on the number of samples it is measured on. In my experience it is also computationally lighter, numerically more stable, and simpler to implement because it can be estimated in a per-batch manner.
In this example, the images are evaluated at the minimal possible resolution of the Inception network (75x75 instead of 299x299), and the metric is only measured on the validation set for computational efficiency.
class KID(keras.metrics.Metric):
def __init__(self, name="kid", **kwargs):
super().__init__(name=name, **kwargs)
# KID is estimated per batch and is averaged across batches
self.kid_tracker = keras.metrics.Mean()
# a pretrained InceptionV3 is used without its classification layer
# transform the pixel values to the 0-255 range, then use the same
# preprocessing as during pretraining
self.encoder = keras.Sequential(
[
layers.InputLayer(input_shape=(image_size, image_size, 3)),
layers.Rescaling(255.0),
layers.Resizing(height=kid_image_size, width=kid_image_size),
layers.Lambda(keras.applications.inception_v3.preprocess_input),
keras.applications.InceptionV3(
include_top=False,
input_shape=(kid_image_size, kid_image_size, 3),
weights="imagenet",
),
layers.GlobalAveragePooling2D(),
],
name="inception_encoder",
)
def polynomial_kernel(self, features_1, features_2):
feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0
def update_state(self, real_images, generated_images, sample_weight=None):
real_features = self.encoder(real_images, training=False)
generated_features = self.encoder(generated_images, training=False)
# compute polynomial kernels using the two sets of features
kernel_real = self.polynomial_kernel(real_features, real_features)
kernel_generated = self.polynomial_kernel(
generated_features, generated_features
)
kernel_cross = self.polynomial_kernel(real_features, generated_features)
# estimate the squared maximum mean discrepancy using the average kernel values
batch_size = tf.shape(real_features)[0]
batch_size_f = tf.cast(batch_size, dtype=tf.float32)
mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
batch_size_f * (batch_size_f - 1.0)
)
mean_kernel_generated = tf.reduce_sum(
kernel_generated * (1.0 - tf.eye(batch_size))
) / (batch_size_f * (batch_size_f - 1.0))
mean_kernel_cross = tf.reduce_mean(kernel_cross)
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
# update the average KID estimate
self.kid_tracker.update_state(kid)
def result(self):
return self.kid_tracker.result()
def reset_state(self):
self.kid_tracker.reset_state()
The authors of StyleGAN2-ADA propose to change the augmentation probability adaptively during training. Though it is explained differently in the paper, they use integral control on the augmentation probability to keep the discriminator's accuracy on real images close to a target value. Note, that their controlled variable is actually the average sign of the discriminator logits (r_t in the paper), which corresponds to 2 * accuracy - 1.
This method requires two hyperparameters:
target_accuracy
: the target value for the discriminator's accuracy on real images. I
recommend selecting its value from the 80-90% range.integration_steps
:
the number of update steps required for an accuracy error of 100% to transform into an
augmentation probability increase of 100%. To give an intuition, this defines how slowly
the augmentation probability is changed. I recommend setting this to a relatively high
value (1000 in this case) so that the augmentation strength is only adjusted slowly.The main motivation for this procedure is that the optimal value of the target accuracy is similar across different dataset sizes (see figure 4 and 5 in the paper), so it does not have to be re-tuned, because the process automatically applies stronger data augmentation when it is needed.
# "hard sigmoid", useful for binary accuracy calculation from logits
def step(values):
# negative values -> 0.0, positive values -> 1.0
return 0.5 * (1.0 + tf.sign(values))
# augments images with a probability that is dynamically updated during training
class AdaptiveAugmenter(keras.Model):
def __init__(self):
super().__init__()
# stores the current probability of an image being augmented
self.probability = tf.Variable(0.0)
# the corresponding augmentation names from the paper are shown above each layer
# the authors show (see figure 4), that the blitting and geometric augmentations
# are the most helpful in the low-data regime
self.augmenter = keras.Sequential(
[
layers.InputLayer(input_shape=(image_size, image_size, 3)),
# blitting/x-flip:
layers.RandomFlip("horizontal"),
# blitting/integer translation:
layers.RandomTranslation(
height_factor=max_translation,
width_factor=max_translation,
interpolation="nearest",
),
# geometric/rotation:
layers.RandomRotation(factor=max_rotation),
# geometric/isotropic and anisotropic scaling:
layers.RandomZoom(
height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
),
],
name="adaptive_augmenter",
)
def call(self, images, training):
if training:
augmented_images = self.augmenter(images, training)
# during training either the original or the augmented images are selected
# based on self.probability
augmentation_values = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
)
augmentation_bools = tf.math.less(augmentation_values, self.probability)
images = tf.where(augmentation_bools, augmented_images, images)
return images
def update(self, real_logits):
current_accuracy = tf.reduce_mean(step(real_logits))
# the augmentation probability is updated based on the discriminator's
# accuracy on real images
accuracy_error = current_accuracy - target_accuracy
self.probability.assign(
tf.clip_by_value(
self.probability + accuracy_error / integration_steps, 0.0, 1.0
)
)
Here we specify the architecture of the two networks:
GANs tend to be sensitive to the network architecture, I implemented a DCGAN architecture in this example, because it is relatively stable during training while being simple to implement. We use a constant number of filters throughout the network, use a sigmoid instead of tanh in the last layer of the generator, and use default initialization instead of random normal as further simplifications.
As a good practice, we disable the learnable scale parameter in the batch normalization layers, because on one hand the following relu + convolutional layers make it redundant (as noted in the documentation). But also because it should be disabled based on theory when using spectral normalization (section 4.1), which is not used here, but is common in GANs. We also disable the bias in the fully connected and convolutional layers, because the following batch normalization makes it redundant.
# DCGAN generator
def get_generator():
noise_input = keras.Input(shape=(noise_size,))
x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
x = layers.BatchNormalization(scale=False)(x)
x = layers.ReLU()(x)
x = layers.Reshape(target_shape=(4, 4, width))(x)
for _ in range(depth - 1):
x = layers.Conv2DTranspose(
width, kernel_size=4, strides=2, padding="same", use_bias=False,
)(x)
x = layers.BatchNormalization(scale=False)(x)
x = layers.ReLU()(x)
image_output = layers.Conv2DTranspose(
3, kernel_size=4, strides=2, padding="same", activation="sigmoid",
)(x)
return keras.Model(noise_input, image_output, name="generator")
# DCGAN discriminator
def get_discriminator():
image_input = keras.Input(shape=(image_size, image_size, 3))
x = image_input
for _ in range(depth):
x = layers.Conv2D(
width, kernel_size=4, strides=2, padding="same", use_bias=False,
)(x)
x = layers.BatchNormalization(scale=False)(x)
x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
x = layers.Flatten()(x)
x = layers.Dropout(dropout_rate)(x)
output_score = layers.Dense(1)(x)
return keras.Model(image_input, output_score, name="discriminator")
class GAN_ADA(keras.Model):
def __init__(self):
super().__init__()
self.augmenter = AdaptiveAugmenter()
self.generator = get_generator()
self.ema_generator = keras.models.clone_model(self.generator)
self.discriminator = get_discriminator()
self.generator.summary()
self.discriminator.summary()
def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
super().compile(**kwargs)
# separate optimizers for the two networks
self.generator_optimizer = generator_optimizer
self.discriminator_optimizer = discriminator_optimizer
self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
self.kid = KID()
@property
def metrics(self):
return [
self.generator_loss_tracker,
self.discriminator_loss_tracker,
self.real_accuracy,
self.generated_accuracy,
self.augmentation_probability_tracker,
self.kid,
]
def generate(self, batch_size, training):
latent_samples = tf.random.normal(shape=(batch_size, noise_size))
# use ema_generator during inference
if training:
generated_images = self.generator(latent_samples, training)
else:
generated_images = self.ema_generator(latent_samples, training)
return generated_images
def adversarial_loss(self, real_logits, generated_logits):
# this is usually called the non-saturating GAN loss
real_labels = tf.ones(shape=(batch_size, 1))
generated_labels = tf.zeros(shape=(batch_size, 1))
# the generator tries to produce images that the discriminator considers as real
generator_loss = keras.losses.binary_crossentropy(
real_labels, generated_logits, from_logits=True
)
# the discriminator tries to determine if images are real or generated
discriminator_loss = keras.losses.binary_crossentropy(
tf.concat([real_labels, generated_labels], axis=0),
tf.concat([real_logits, generated_logits], axis=0),
from_logits=True,
)
return tf.reduce_mean(generator_loss), tf.reduce_mean(discriminator_loss)
def train_step(self, real_images):
real_images = self.augmenter(real_images, training=True)
# use persistent gradient tape because gradients will be calculated twice
with tf.GradientTape(persistent=True) as tape:
generated_images = self.generate(batch_size, training=True)
# gradient is calculated through the image augmentation
generated_images = self.augmenter(generated_images, training=True)
# separate forward passes for the real and generated images, meaning
# that batch normalization is applied separately
real_logits = self.discriminator(real_images, training=True)
generated_logits = self.discriminator(generated_images, training=True)
generator_loss, discriminator_loss = self.adversarial_loss(
real_logits, generated_logits
)
# calculate gradients and update weights
generator_gradients = tape.gradient(
generator_loss, self.generator.trainable_weights
)
discriminator_gradients = tape.gradient(
discriminator_loss, self.discriminator.trainable_weights
)
self.generator_optimizer.apply_gradients(
zip(generator_gradients, self.generator.trainable_weights)
)
self.discriminator_optimizer.apply_gradients(
zip(discriminator_gradients, self.discriminator.trainable_weights)
)
# update the augmentation probability based on the discriminator's performance
self.augmenter.update(real_logits)
self.generator_loss_tracker.update_state(generator_loss)
self.discriminator_loss_tracker.update_state(discriminator_loss)
self.real_accuracy.update_state(1.0, step(real_logits))
self.generated_accuracy.update_state(0.0, step(generated_logits))
self.augmentation_probability_tracker.update_state(self.augmenter.probability)
# track the exponential moving average of the generator's weights to decrease
# variance in the generation quality
for weight, ema_weight in zip(
self.generator.weights, self.ema_generator.weights
):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
# KID is not measured during the training phase for computational efficiency
return {m.name: m.result() for m in self.metrics[:-1]}
def test_step(self, real_images):
generated_images = self.generate(batch_size, training=False)
self.kid.update_state(real_images, generated_images)
# only KID is measured during the evaluation phase for computational efficiency
return {self.kid.name: self.kid.result()}
def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
# plot random generated images for visual evaluation of generation quality
if epoch is None or (epoch + 1) % interval == 0:
num_images = num_rows * num_cols
generated_images = self.generate(num_images, training=False)
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(generated_images[index])
plt.axis("off")
plt.tight_layout()
plt.show()
plt.close()
One can should see from the metrics during training, that if the real accuracy (discriminator's accuracy on real images) is below the target accuracy, the augmentation probability is increased, and vice versa. In my experience, during a healthy GAN training, the discriminator accuracy should stay in the 80-95% range. Below that, the discriminator is too weak, above that it is too strong.
Note that we track the exponential moving average of the generator's weights, and use that for image generation and KID evaluation.
# create and compile the model
model = GAN_ADA()
model.compile(
generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
)
# save the best model based on the validation KID metric
checkpoint_path = "gan_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
monitor="val_kid",
mode="min",
save_best_only=True,
)
# run training and plot generated images periodically
model.fit(
train_dataset,
epochs=num_epochs,
validation_data=val_dataset,
callbacks=[
keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
checkpoint_callback,
],
)
Model: "generator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 64)] 0
_________________________________________________________________
dense (Dense) (None, 2048) 131072
_________________________________________________________________
batch_normalization (BatchNo (None, 2048) 6144
_________________________________________________________________
re_lu (ReLU) (None, 2048) 0
_________________________________________________________________
reshape (Reshape) (None, 4, 4, 128) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 8, 8, 128) 262144
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 128) 384
_________________________________________________________________
re_lu_1 (ReLU) (None, 8, 8, 128) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 128) 262144
_________________________________________________________________
batch_normalization_2 (Batch (None, 16, 16, 128) 384
_________________________________________________________________
re_lu_2 (ReLU) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 128) 262144
_________________________________________________________________
batch_normalization_3 (Batch (None, 32, 32, 128) 384
_________________________________________________________________
re_lu_3 (ReLU) (None, 32, 32, 128) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 64, 64, 3) 6147
=================================================================
Total params: 930,947
Trainable params: 926,083
Non-trainable params: 4,864
_________________________________________________________________
Model: "discriminator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 64, 64, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 32, 32, 128) 6144
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 128) 384
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 32, 32, 128) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 16, 16, 128) 262144
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 128) 384
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 8, 8, 128) 262144
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 128) 384
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 8, 8, 128) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 4, 4, 128) 262144
_________________________________________________________________
batch_normalization_7 (Batch (None, 4, 4, 128) 384
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 4, 4, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 2048) 0
_________________________________________________________________
dropout (Dropout) (None, 2048) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 2049
=================================================================
Total params: 796,161
Trainable params: 795,137
Non-trainable params: 1,024
_________________________________________________________________
Epoch 1/10
46/46 [==============================] - 36s 307ms/step - g_loss: 3.3293 - d_loss: 0.1576 - real_acc: 0.9387 - gen_acc: 0.9579 - aug_p: 0.0020 - val_kid: 9.0999
Epoch 2/10
46/46 [==============================] - 10s 215ms/step - g_loss: 4.9824 - d_loss: 0.0912 - real_acc: 0.9704 - gen_acc: 0.9798 - aug_p: 0.0077 - val_kid: 8.3523
Epoch 3/10
46/46 [==============================] - 10s 218ms/step - g_loss: 5.0587 - d_loss: 0.1248 - real_acc: 0.9530 - gen_acc: 0.9625 - aug_p: 0.0131 - val_kid: 6.8116
Epoch 4/10
46/46 [==============================] - 10s 221ms/step - g_loss: 4.2580 - d_loss: 0.1002 - real_acc: 0.9686 - gen_acc: 0.9740 - aug_p: 0.0179 - val_kid: 5.2327
Epoch 5/10
46/46 [==============================] - 10s 225ms/step - g_loss: 4.6022 - d_loss: 0.0847 - real_acc: 0.9655 - gen_acc: 0.9852 - aug_p: 0.0234 - val_kid: 3.9004
Epoch 6/10
46/46 [==============================] - 10s 224ms/step - g_loss: 4.9362 - d_loss: 0.0671 - real_acc: 0.9791 - gen_acc: 0.9895 - aug_p: 0.0291 - val_kid: 6.6020
Epoch 7/10
46/46 [==============================] - 10s 222ms/step - g_loss: 4.4272 - d_loss: 0.1184 - real_acc: 0.9570 - gen_acc: 0.9657 - aug_p: 0.0345 - val_kid: 3.3644
Epoch 8/10
46/46 [==============================] - 10s 220ms/step - g_loss: 4.5060 - d_loss: 0.1635 - real_acc: 0.9421 - gen_acc: 0.9594 - aug_p: 0.0392 - val_kid: 3.1381
Epoch 9/10
46/46 [==============================] - 10s 219ms/step - g_loss: 3.8264 - d_loss: 0.1667 - real_acc: 0.9383 - gen_acc: 0.9484 - aug_p: 0.0433 - val_kid: 2.9423
Epoch 10/10
46/46 [==============================] - 10s 219ms/step - g_loss: 3.4063 - d_loss: 0.1757 - real_acc: 0.9314 - gen_acc: 0.9475 - aug_p: 0.0473 - val_kid: 2.9112
<keras.callbacks.History at 0x7fefcc2cb9d0>
# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()
By running the training for 400 epochs (which takes 2-3 hours in a Colab notebook), one can get high quality image generations using this code example.
The evolution of a random batch of images over a 400 epoch training (ema=0.999 for animation smoothness):
Latent-space interpolation between a batch of selected images:
I also recommend trying out training on other datasets, such as CelebA for example. In my experience good results can be achieved without changing any hyperparameters (though discriminator augmentation might not be necessary).
My goal with this example was to find a good tradeoff between ease of implementation and generation quality for GANs. During preparation, I have run numerous ablations using this repository.
In this section I list the lessons learned and my recommendations in my subjective order of importance.
I recommend checking out the DCGAN paper, this NeurIPS talk, and this large scale GAN study for others' takes on this subject.
Other GAN-related Keras code examples:
Modern GAN architecture-lines:
Concurrent papers on discriminator data augmentation: 1, 2, 3
Recent literature overview on GANs: talk