TensorFlow / Python: Generación de Imágenes Con Redes Neuronales Generativas Adversariales (GAN)

Una Red Neuronal Generativa Adversarial (GAN) es un tipo de modelo de aprendizaje profundo utilizado en la generación de datos, como imágenes, música, texto, entre otros. Una GAN consta de dos redes neuronales, que son esencialmente dos modelos que se entrenan juntos pero tienen objetivos opuestos:

  1. Generador (Generator): La primera red es el generador, que tiene como tarea crear datos que sean indistinguibles de los datos reales. En el contexto de imágenes, el generador toma como entrada un vector de ruido aleatorio y genera una imagen que debería parecerse a las imágenes reales del conjunto de datos que se está utilizando. Su objetivo es producir datos falsos que sean lo más realistas posible.
  2. Discriminador (Discriminator): La segunda red es el discriminador, que actúa como un clasificador binario. Su tarea es determinar si una entrada es real (proveniente del conjunto de datos original) o falsa (generada por el generador). El objetivo del discriminador es separar correctamente los datos reales de los generados por el generador.

El proceso de entrenamiento de una GAN implica una competencia entre el generador y el discriminador. El generador intenta producir datos que sean cada vez más difíciles de distinguir de los datos reales, mientras que el discriminador intenta mejorar su capacidad para hacer esta distinción. Ambas redes se entrenan simultáneamente y, en última instancia, el generador debe llegar a ser lo suficientemente bueno como para generar datos que el discriminador no pueda distinguir de los datos reales.

Una vez que la GAN se entrena con éxito, el generador se puede utilizar para crear nuevos datos que sean similares a los datos de entrenamiento originales. Esto se ha utilizado en diversas aplicaciones, como generación de imágenes, mejora de resolución de imágenes, síntesis de voz, creación de texto y mucho más.

Código de Ejemplo:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt

# Datos de entrenamiento (MNIST, números escritos a mano)
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 127.5 - 1.0
x_train = x_train.reshape(x_train.shape[0], 784)

# Tamaño del espacio latente (z)
latent_dim = 100

# Construir el generador
generator = Sequential()
generator.add(Dense(256, input_dim=latent_dim, activation='relu'))
generator.add(Dense(784, activation='tanh'))

# Construir el discriminador
discriminator = Sequential()
discriminator.add(Dense(256, input_dim=784, activation='relu'))
discriminator.add(Dense(1, activation='sigmoid'))

# Compilar el discriminador
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Congelar el discriminador durante el entrenamiento del generador
discriminator.trainable = False

# Construir la GAN
gan = Sequential([generator, discriminator])
gan.compile(loss='binary_crossentropy', optimizer='adam')

# Función para entrenar la GAN
def train_gan(gan, generator, discriminator, x_train, latent_dim, n_epochs=10000, batch_size=64):
    for epoch in range(n_epochs):
        # Entrenar el discriminador
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        real_images = x_train[idx]
        labels = np.ones((batch_size, 1))
        fake_images = generator.predict(np.random.normal(0, 1, (batch_size, latent_dim)))
        d_loss_real = discriminator.train_on_batch(real_images, labels)
        labels = np.zeros((batch_size, 1))
        d_loss_fake = discriminator.train_on_batch(fake_images, labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Entrenar la GAN (generador)
        labels = np.ones((batch_size, 1))
        g_loss = gan.train_on_batch(np.random.normal(0, 1, (batch_size, latent_dim)), labels)

        # Mostrar el progreso
        if epoch % 100 == 0:
            print(f"Epoch: {epoch}, D Loss: {d_loss[0]}, G Loss: {g_loss}")

        # Guardar imágenes generadas
        if epoch % 1000 == 0:
            generate_and_save_images(generator, epoch)

# Función para generar y guardar imágenes generadas
def generate_and_save_images(model, epoch, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, (examples, latent_dim))
    generated_images = model.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i].reshape(28, 28), interpolation='nearest', cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'gan_generated_image_epoch_{epoch}.png')
    plt.show()

# Entrenar la GAN
train_gan(gan, generator, discriminator, x_train, latent_dim)

El código anterior entrena una GAN para generar imágenes de dígitos escritos a mano. La GAN consta de un generador y un discriminador, y se entrenan de manera adversarial. La función train_gan realiza el entrenamiento de la GAN, mientras que generate_and_save_images genera y guarda imágenes generadas a lo largo del proceso de entrenamiento.

Aquí una explicación más detallada:

  1. Datos de Entrenamiento: Los datos de entrenamiento provienen del conjunto de datos MNIST, que contiene imágenes en escala de grises de números escritos a mano (dígitos del 0 al 9). Estas imágenes se escalan para que los valores de los píxeles estén en el rango de -1 a 1 y se aplanan a un vector de 784 dimensiones.
  2. Generador (Generator): El generador es una red neuronal que toma un vector de ruido (el espacio latente) como entrada y genera imágenes falsas. En este caso, consta de dos capas densas: una capa de entrada con 256 unidades y activación ReLU, y una capa de salida con 784 unidades y activación tangente hiperbólica (tanh). El generador toma muestras aleatorias del espacio latente y produce imágenes de dígitos.
  3. Discriminador (Discriminator): El discriminador es otra red neuronal que actúa como un clasificador binario. Toma una imagen (real o falsa) como entrada y debe distinguir si la imagen es real o falsa. Está formado por dos capas densas: una capa de entrada con 256 unidades y activación ReLU, y una capa de salida con una unidad y activación sigmoide. La salida se interpreta como la probabilidad de que la imagen sea real.
  4. Compilación del Discriminador: El discriminador se compila con una función de pérdida de entropía cruzada binaria (binary_crossentropy) y el optimizador Adam. Se ajusta para distinguir entre imágenes reales y falsas.
  5. Congelación del Discriminador: Para el entrenamiento de la GAN, el discriminador se congela y no se entrena mientras se entrena el generador. Esto significa que el discriminador no se actualiza con el generador, lo que permite que el generador aprenda a generar imágenes que engañen al discriminador.
  6. Construcción de la GAN: La GAN se crea combinando el generador y el discriminador en una secuencia. La GAN se compila con la misma función de pérdida de entropía cruzada binaria y el optimizador Adam.
  7. Función de Entrenamiento: La función train_gan es donde ocurre el entrenamiento de la GAN. Durante cada iteración, se entrena primero el discriminador para distinguir entre imágenes reales y falsas. Luego, se entrena la GAN para engañar al discriminador y hacer que crea que las imágenes generadas son reales. Esto crea una competencia entre el generador y el discriminador. La función muestra el progreso y, cada cierto número de épocas, genera y guarda imágenes de dígitos generados.
  8. Generación y Guardado de Imágenes: La función generate_and_save_images genera imágenes de dígitos utilizando el generador y las guarda en un archivo de imagen. Estas imágenes se generan durante el entrenamiento para observar el progreso de la GAN.
  9. Entrenamiento de la GAN: Finalmente, se llama a la función train_gan para entrenar la GAN. Durante el entrenamiento, tanto el generador como el discriminador mejoran sus habilidades, y el generador aprende a generar imágenes realistas de dígitos escritos a mano.

Epoch: 0, D Loss: 0.9474876374006271, G Loss: 1.3809154033660889 1/1 [==============================] – 0s 32ms/step

Epoch: 1000, D Loss: 0.04573569446802139, G Loss: 5.032070159912109 1/1 [==============================] – 0s 14ms/step

Epoch: 2000, D Loss: 0.2332010306417942, G Loss: 4.198168754577637 1/1 [==============================] – 0s 16ms/step

Epoch: 3000, D Loss: 0.09189734607934952, G Loss: 8.080753326416016 1/1 [==============================] – 0s 13ms/step

Epoch: 4000, D Loss: 0.2200010120868683, G Loss: 4.917149543762207 1/1 [==============================] – 0s 18ms/step

Epoch: 5000, D Loss: 0.08954114839434624, G Loss: 5.541191101074219 1/1 [==============================] – 0s 13ms/step

Epoch: 6000, D Loss: 0.1730344444513321, G Loss: 3.601317882537842 1/1 [==============================] – 0s 13ms/step

Epoch: 7000, D Loss: 0.23358163982629776, G Loss: 3.3358240127563477 1/1 [==============================] – 0s 16ms/step

Epoch: 8000, D Loss: 0.19538309052586555, G Loss: 4.939027309417725 1/1 [==============================] – 0s 18ms/step

Epoch: 9000, D Loss: 0.1658102571964264, G Loss: 4.499600887298584 1/1 [==============================] – 0s 14ms/step

A medida que la GAN se entrena, el generador aprende a generar imágenes de dígitos cada vez más realistas. Este es un ejemplo muy sencillo, pero puedes ajustar los hiperparámetros y el número de épocas para obtener mejores resultados.

Deja un comentario