Introducción a las GAN con Python y TensorFlow

Los modelos generativos son una familia de arquitecturas de IA cuyo objetivo es crear muestras de datos desde cero. Logran esto al capturar las distribuciones de datos de t ...

Introducción

Los modelos generativos son una familia de arquitecturas de IA cuyo objetivo es crear muestras de datos desde cero. Lo logran capturando las distribuciones de datos del tipo de cosas que queremos generar.

Este tipo de modelos se están investigando mucho y hay una gran cantidad de publicidad a su alrededor. Solo mire el gráfico que muestra la cantidad de artículos publicados en el campo en los últimos años:

Gan papers

Desde 2014, cuando se publicó el primer artículo sobre redes adversarias generativas, los modelos generativos se están volviendo increíblemente poderosos y ahora podemos generar muestras de datos hiperrealistas para una amplia gama de distribuciones: imágenes, videos, música, escritos, etc.

Estos son algunos ejemplos de imágenes generadas por un GAN:

Una cara generada con GANs

Imágenes generadas por GAN

¿Qué son los modelos generativos?

El marco GAN

El marco más exitoso propuesto para los modelos generativos, al menos en los últimos años, toma el nombre de Generative Adversarial Networks (GANs).

En pocas palabras, una GAN se compone de dos modelos separados, representados por redes neuronales: un generador G y un discriminador D. El objetivo del discriminador es decir si una muestra de datos proviene de una distribución de datos real o si, en cambio, es generada por G.

El objetivo del generador es generar muestras de datos para engañar al discriminador.

El generador no es más que una red neuronal profunda. Toma como entrada un vector de ruido aleatorio (generalmente gaussiano o de una distribución uniforme) y genera una muestra de datos de la distribución que queremos capturar.

El discriminador es, nuevamente, solo una red neuronal. Su objetivo es, como su nombre lo indica, discriminar entre muestras reales y falsas. En consecuencia, su entrada es una muestra de datos, ya sea proveniente del generador o de la distribución de datos real.

La salida es un número simple, que representa la probabilidad de que la entrada sea real. Una alta probabilidad significa que el discriminador confía en que las muestras que está recibiendo son genuinas. Por el contrario, una probabilidad baja muestra una confianza alta en el hecho de que la muestra proviene de la red del generador:

El marco

Imagine un falsificador de arte que intenta crear obras de arte falsas y un crítico de arte que necesita distinguir entre pinturas adecuadas y falsas.

En este escenario, el crítico actúa como nuestro discriminador, y el falsificador es el generador, tomando retroalimentación del crítico para mejorar sus habilidades y hacer que su arte falsificado parezca más convincente:

Marco simplificado

Capacitación

Entrenar un GAN puede ser algo doloroso. La inestabilidad del entrenamiento siempre ha sido un problema, y ​​muchas investigaciones se han centrado en hacer que el entrenamiento sea más estable.

La función objetivo básica de un modelo Vanilla GAN es la siguiente:

Función de pérdida de GAN

Aquí, D se refiere a la red discriminadora, mientras que G obviamente se refiere al generador.

Como muestra la fórmula, el generador se optimiza para confundir al máximo al discriminador, al intentar que genere altas probabilidades de muestras de datos falsos.

Por el contrario, el discriminador trata de mejorar en la distinción de muestras provenientes de G de muestras provenientes de la distribución real.

El término * adversario * proviene exactamente de la forma en que se entrenan los GANS, enfrentando a las dos redes entre sí.

Una vez que hemos entrenado nuestro modelo, el discriminador ya no es necesario. Todo lo que tenemos que hacer es alimentar al generador con un vector de ruido aleatorio y, con suerte, obtendremos una muestra de datos artificiales y realistas como resultado.

Problemas de GAN

Entonces, ¿por qué las GAN son tan difíciles de entrenar? Como se indicó anteriormente, las GAN son muy difíciles de entrenar en su forma estándar. Veremos brevemente por qué este es el caso.

Equilibrio de Nash difícil de alcanzar

Dado que estas dos redes se envían información entre sí, podría representarse como un juego en el que uno adivina si la entrada es real o no.

El marco GAN es un juego no cooperativo, de dos jugadores, no convexo, con parámetros continuos de alta dimensión, en el que cada jugador quiere minimizar su función de costo. El óptimo de este proceso toma el nombre de Equilibrio de Nash - donde cada jugador no se desempeñará mejor cambiando una estrategia, dado que el otro jugador no cambia su estrategia.

Sin embargo, las GAN generalmente se entrenan usando técnicas de descenso de gradiente que están diseñadas para encontrar el valor bajo de una función de costo y no encontrar el Equilibrio de Nash de un juego.

Modo Colapso

La mayoría de las distribuciones de datos son multimodales. Tome el conjunto de datos MNIST: hay 10 "modos" de datos, que se refieren a los diferentes dígitos entre 0 y 9.

Un buen modelo generativo sería capaz de producir muestras con suficiente variabilidad, pudiendo así generar muestras de todas las diferentes clases.

Sin embargo, esto no siempre sucede.

Digamos que el generador se vuelve realmente bueno produciendo el dígito "3". Si las muestras producidas son lo suficientemente convincentes, es probable que el discriminador les asigne altas probabilidades.

Como resultado, el generador será empujado hacia la producción de muestras que provengan de ese modo específico, ignorando las otras clases la mayor parte del tiempo. Esencialmente, enviará spam al mismo número y con cada número que pase el discriminador, este comportamiento solo se aplicará más.

Un ejemplo de colapso de modo

Gradiente decreciente {#gradiente decreciente}

Muy similar al ejemplo anterior, el discriminador puede tener demasiado éxito en distinguir muestras de datos. Cuando eso es cierto, el gradiente del generador se desvanece, comienza a aprender cada vez menos y no logra converger.

Este desequilibrio, al igual que el anterior, se puede producir si entrenamos las redes por separado. La evolución de las redes neuronales puede ser bastante impredecible, lo que puede llevar a que una esté muy por delante de la otra. Si los entrenamos juntos, en su mayoría nos aseguramos de que estas cosas no sucedan.

Estado del arte

Sería imposible dar una visión completa de todas las mejoras y desarrollos que hicieron que las GAN fueran más potentes y estables en los últimos años.

Lo que haré en su lugar será compilar una lista de las arquitecturas y técnicas más exitosas, proporcionando enlaces a recursos relevantes para profundizar más.

DCGAN

Las GAN convolucionales profundas (DCGAN) introdujeron convoluciones en las redes generadoras y discriminadoras.

Sin embargo, no se trataba simplemente de agregar capas convolucionales al modelo, ya que el entrenamiento se volvió aún más inestable.

Se tuvieron que aplicar varios trucos para que los DCGAN fueran útiles:

  • La normalización de lotes se aplicó tanto al generador como a la red discriminadora.
  • El abandono se utiliza como técnica de regularización.
  • El generador necesitaba una forma de aumentar la muestra del vector de entrada aleatoria a una imagen de salida. Aquí se emplea la transposición de capas convolucionales.
  • LeakyRelu y [bronceado](http://mathworld.wolfram.com /HyperbolicTangent.html) las activaciones se utilizan en ambas redes

DCGAN

WGAN

GAN de Wasserstein (WGANs) tienen como objetivo mejorar la estabilidad del entrenamiento. Hay una gran cantidad de matemáticas detrás de este tipo de modelo. Una explicación más accesible se puede encontrar aquí.

La idea básica aquí fue proponer una nueva función de costo que tenga un gradiente más suave en todas partes.

La nueva función de costo usa una métrica llamada distancia de Wasserstein, que tiene un gradiente más suave en todas partes.

Como resultado, el discriminador, que ahora se llama crítico, genera valores de confianza que ya no deben interpretarse como una probabilidad. Los valores altos significan que el modelo confía en que la entrada es real.

Dos mejoras significativas para WGAN son:

  • No tiene signos de colapso de modo en los experimentos.
  • El generador aún puede aprender cuando el crítico se desempeña bien

de SAGAN

GAN de autoatención (SAGANs) introducen un mecanismo de atención al framework GAN.

Los mecanismos de atención permiten utilizar localmente la información global. Lo que esto significa es que podemos capturar el significado de diferentes partes de una imagen y usar esa información para producir mejores muestras.

Esto proviene de la observación de que las convoluciones son bastante malas para capturar dependencias a largo plazo en muestras de entrada, ya que la convolución es una operación local cuyo campo receptivo depende del tamaño espacial del kernel.

Esto significa que, por ejemplo, no es posible que una salida en la posición superior izquierda de una imagen tenga alguna relación con la salida en la parte inferior derecha.

Una forma de solucionar este problema sería utilizar kernels de mayor tamaño, para poder capturar más información. Sin embargo, esto haría que el modelo fuera computacionalmente ineficiente y muy lento para entrenar.

La autoatención resuelve este problema, proporcionando una forma eficiente de capturar información global y usarla localmente cuando pueda resultar útil.

BigGAN

BigGAN son, en el momento de escribir este artículo, considerados más o menos avanzados, en lo que respecta a la calidad de las muestras generadas.

Lo que hicieron los investigadores aquí fue reunir todo lo que había estado funcionando hasta ese momento y luego ampliarlo masivamente. Su modelo base era de hecho un SAGAN, al que le añadieron algunos trucos para mejorar la estabilidad.

Demostraron que las GAN se benefician drásticamente del escalado, incluso cuando no se introducen más mejoras funcionales en el modelo, como se cita en el documento original:

Hemos demostrado que las redes adversarias generativas entrenadas para modelar imágenes naturales de múltiples categorías se benefician mucho del escalado, tanto en términos de fidelidad como de variedad de las muestras generadas. Como resultado, nuestros modelos establecen un nuevo nivel de rendimiento entre los modelos ImageNet GAN, mejorando el estado del arte por un amplio margen.

Una GAN simple en Python

Implementación de código {#implementación de código}

Dicho todo esto, avancemos e implementemos un GAN simple que genere dígitos del 0 al 9, un ejemplo bastante clásico:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

# Sample z from uniform distribution
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

Ahora podemos definir el marcador de posición para nuestras muestras de entrada y vectores de ruido:

1
2
3
4
5
# Input image, for discriminator model.
X = tf.placeholder(tf.float32, shape=[None, 784])

# Input noise for generator.
Z = tf.placeholder(tf.float32, shape=[None, 100])

Ahora, definimos nuestras redes generadoras y discriminadoras. Son perceptrones simples con una sola capa oculta.

Usamos activaciones relu en las neuronas de la capa oculta, y sigmoides para las capas de salida.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def generator(z):
    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(z, 128, activation=tf.nn.relu)
        x = tf.layers.dense(z, 784)
        x = tf.nn.sigmoid(x)
    return x

def discriminator(x):
    with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(x, 128, activation=tf.nn.relu)
        x = tf.layers.dense(x, 1)
        x = tf.nn.sigmoid(x)
    return x

Ahora podemos definir nuestros modelos, funciones de pérdida y optimizadores:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
# Generator model
G_sample = generator(Z)

# Discriminator models
D_real = discriminator(X)
D_fake = discriminator(G_sample)


# Loss function
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

# Select parameters
disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("disc")]
gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("gen")]

# Optimizers
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=disc_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=gen_vars)

Finalmente, podemos escribir la rutina de entrenamiento. En cada iteración, realizamos un paso de optimización para el discriminador y otro para el generador.

Cada 100 iteraciones guardamos algunas muestras generadas para que podamos ver nuestro progreso.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Batch size
mb_size = 128

# Dimension of input noise
Z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out2/'):
    os.makedirs('out2/')

i = 0

for it in range(1000000):

    # Save generated images every 1000 iterations.
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)


    # Get next batch of images. Each batch has mb_size samples.
    X_mb, _ = mnist.train.next_batch(mb_size)


    # Run disciminator solver
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})

    # Run generator solver
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

    # Print loss
    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))

Resultados y posibles mejoras

Durante las primeras iteraciones, todo lo que vemos es ruido aleatorio:

Primeras iteraciones

Aquí, las redes no aprendieron nada todavía. Aunque, después de solo un par de minutos, ¡ya podemos ver cómo nuestros dígitos están tomando forma!

iteración 68000

Recursos

Si desea jugar con el código, está disponible en GitHub!