Aprendizaje automático: el sobreajuste es su amigo, no su enemigo

En este artículo, se presenta el Argumento del Sobreajuste Amistoso, así como también cuando no se sostiene. ¿El sobreajuste en el aprendizaje automático y la inteligencia artificial es realmente tan malo como la gente lo pinta?

Permítanme prologar el título potencialmente provocativo con:

Es verdad, nadie quiere modelos finales con sobreajuste, al igual que nadie quiere modelos finales con ajuste insuficiente.

Los modelos sobreajustados funcionan muy bien con los datos de entrenamiento, pero no se pueden generalizar bien a instancias nuevas. Lo que termina es un modelo que se acerca a un modelo completamente codificado y adaptado a un conjunto de datos específico.

Los modelos inadecuados no pueden generalizar a nuevos datos, pero tampoco pueden modelar el conjunto de entrenamiento original.

El modelo correcto es aquel que ajusta los datos de tal manera que se desempeña bien en la predicción de valores en el conjunto de entrenamiento, validación y prueba, así como en nuevas instancias.

Sobreajuste frente a científicos de datos

Se destaca la lucha contra el sobreajuste porque es más ilusorio y más tentador para un novato crear modelos sobreajustados cuando comienza su viaje de aprendizaje automático. A lo largo de libros, publicaciones de blog y cursos, se presenta un escenario común:

"¡Este modelo tiene una tasa de precisión del 100 %! ¡Es perfecto! O no. En realidad, simplemente sobreajusta el conjunto de datos y, cuando lo prueba en nuevas instancias, funciona con **solo X% **, que es igual a adivinar al azar."

Después de estas secciones, se dedican capítulos completos de libros y cursos a combatir el sobreajuste y cómo evitarlo. La palabra en sí se estigmatizó como algo generalmente malo. Y aquí es donde surge la concepción general:

"Debo evitar a toda costa el sobreajuste."

Se le da mucha más atención que a la falta de ajuste, que es igual de "malo". Vale la pena señalar que "malo" es un término arbitrario, y ninguna de estas condiciones es inherentemente "buena" o "mala". Algunos pueden afirmar que los modelos sobreajustados son técnicamente más útiles, porque al menos funcionan bien con algunos datos, mientras que los modelos desadaptados funcionan bien con sin datos, pero la ilusión del éxito es un buen candidato para compensar este beneficio.

Como referencia, consultemos Google Trends y Google Ngram Viewer. Google Trends muestra las tendencias de los datos de búsqueda, mientras que Google Ngram Viewer cuenta la cantidad de ocurrencias de n-gramas (secuencias de n elementos, como palabras) en la literatura, analizando una gran cantidad de libros a lo largo de los años:

overfitting vs underfitting search trends and ngram viewer

Todo el mundo habla sobre el sobreajuste y, sobre todo, en el contexto de evitarlo, lo que a menudo lleva a la gente a la idea general de que es inherentemente algo malo.

Esto es verdadero, hasta cierto punto*. Sí, no desea que el modelo final se ajuste demasiado, de lo contrario, es prácticamente inútil. Pero no llega al modelo final de inmediato: lo modifica varias veces, con varios hiperparámetros. Durante este proceso es donde no debería importarle que ocurra un sobreajuste; sin embargo, es una buena señal, no es un buen resultado.

Cómo el sobreajuste no es tan malo como parece {#cómo el sobreajuste no es tan malo como parece}

Un modelo y una arquitectura que tiene la capacidad de sobreajustarse, es más probable que tenga la capacidad de generalizar bien a nuevas instancias, si lo simplifica (y/o modifica los datos).

  • [A veces, no se trata solo del modelo, como veremos un poco más adelante.]{.small}

Si un modelo puede sobreajustarse, tiene suficiente capacidad entrópica para extraer características (de una manera significativa y no significativa) de los datos. A partir de ahí, es que el modelo tiene más de la capacidad entrópica requerida (complejidad/potencia) o que los datos en sí no son suficientes (caso muy común).

La declaración inversa también puede ser cierta, pero más raramente. Si un modelo o una arquitectura determinados no se ajustan bien, puede intentar ajustar el modelo para ver si recoge ciertas características, pero el tipo de modelo podría ser simplemente incorrecto para la tarea y no podrá ajustar los datos con él. no importa lo que hagas. Algunos modelos simplemente se atascan en cierto nivel de precisión, ya que simplemente no pueden extraer suficientes características para distinguir entre ciertas clases o predecir valores.

En cocinar, se puede crear una analogía inversa. Es mejor desalar el guiso desde el principio, ya que siempre se puede agregar sal más tarde al gusto, pero es difícil quitarlo una vez que ya está puesto.

En Aprendizaje automático, es todo lo contrario. Es mejor tener un modelo sobreajustado, luego simplificarlo, cambiar los hiperparámetros, aumentar los datos, etc. para que se generalice bien, pero es más difícil (en entornos prácticos) hacer lo contrario. Evitar el sobreajuste antes de que suceda podría impedirle encontrar el modelo y/o la arquitectura correctos durante un período de tiempo más prolongado.

En la práctica, y en algunos de los casos de uso más fascinantes de Machine Learning y Deep Learning, trabajará en conjuntos de datos que tendrá problemas para sobreajustar. Estos serán conjuntos de datos que rutinariamente se ajustarán de forma inadecuada, sin la capacidad de encontrar modelos y arquitecturas que puedan generalizar bien y extraer características.

También vale la pena señalar la diferencia entre lo que llamo sobreajuste verdadero y sobreajuste parcial. Un modelo que sobreajusta un conjunto de datos y logra un 60 % de precisión en el conjunto de entrenamiento, con solo un 40 % en los conjuntos de validación y prueba, está sobreajustando una parte de los datos. Sin embargo, no es verdaderamente sobreajustado en el sentido de eclipsar todo el conjunto de datos y lograr una tasa de precisión cercana al 100 % (falsa), mientras que sus conjuntos de validación y prueba son bajos, digamos, ~40 %. .

Un modelo que se sobreajusta parcialmente no es uno que pueda generalizarse bien con la simplificación, ya que no tiene suficiente capacidad entrópica para realmente (sobre)ajustarse. Una vez que lo hace, se aplica mi argumento, aunque no garantiza el éxito, como se aclara en las secciones anteriores.

Estudio de caso: Argumento amistoso de sobreajuste

El Conjunto de datos de dígitos escritos a mano del MNIST, compilado por Yann LeCun, es uno de los conjuntos de datos de referencia clásicos utilizados para entrenar modelos de clasificación. LeCun es ampliamente considerado uno de los padres fundadores del aprendizaje profundo, con contribuciones al campo que la mayoría no puede poner bajo su cinturón, y el conjunto de datos de dígitos escritos a mano del MNIST fue uno de los primeros puntos de referencia importantes utilizados para las primeras etapas de Neural convolucional. Redes.

También es el conjunto de datos más utilizado en exceso, potencialmente nunca.

No hay nada malo con el conjunto de datos en sí, ni con LeCun, quien lo creó; en realidad, es bastante bueno, pero encontrar un ejemplo tras otro en el mismo conjunto de datos en línea es aburrido. En un momento, ** nos adaptamos demasiado ** al mirarlo. ¿Cuánto cuesta? Aquí está mi intento de enumerar los primeros diez dígitos MNIST de la parte superior de mi cabeza:

1
5, 0, 4, 1, 9, 2, 2, 4, 3

¿Cómo lo hice?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# Import and normalize the images, splitting out a validation set
(X_train_full, Y_train_full), (X_test, Y_test) = keras.datasets.mnist.load_data()

X_valid, X_train = X_train_full[:5000]/255.0, X_train_full[5000:]/255.0
Y_valid, Y_train = Y_train_full[:5000], Y_train_full[5000:]

X_test = X_test/255.0

# Print out the first ten digits
fig, ax = plt.subplots(1, 10, figsize=(10,2))
for i in range(10):
    ax[i].imshow(X_train_full[i])
    ax[i].axis('off')
    plt.subplots_adjust(wspace=1) 

plt.show()

Casi ahí.

Aprovecharé esta oportunidad para hacer un llamado público a todos los creadores de contenido para que no usen en exceso este conjunto de datos más allá de las partes introductorias, donde la simplicidad del conjunto de datos se puede usar para reducir la barrera de entrada. Por favor.

Además, este conjunto de datos dificulta la creación de un modelo que no se ajuste. Es demasiado simple, e incluso un clasificador Perceptrón multicapa (MLP) bastante pequeño construido con un número intuitivo de capas y neuronas por capa puede alcanzar fácilmente más del 98 % de precisión en el conjunto de entrenamiento, prueba y validación. . Aquí hay un cuaderno Jupyter de un MLP simple que logró ~98 % de precisión tanto en el entrenamiento , conjuntos de validación y prueba, que hice girar con valores predeterminados razonables.

Ni siquiera me he molestado en intentar ajustarlo para que funcione mejor que la configuración inicial.

Los conjuntos de datos CIFAR10 y CIFAR100

Usemos un conjunto de datos que es más complicado que los dígitos escritos a mano del MNIST, y que hace un ajuste insuficiente de MLP simple, pero que es lo suficientemente simple como para permitir que una CNN de tamaño decente realmente se ajuste en exceso. Un buen candidato es el Conjunto de datos CIFAR.

Hay 10 clases de imágenes en CIFAR10 y 100 en CIFAR100. Además, el conjunto de datos CIFAR100 tiene 20 familias de clases similares, lo que significa que la red también tiene que aprender las diferencias mínimas entre clases similares pero diferentes. Estas se conocen como "etiquetas finas" (100) y "etiquetas gruesas" (20) y predecirlas equivale a predecir la clase específica, o simplemente la familia a la que pertenece.

Por ejemplo, aquí hay una superclase (etiqueta gruesa) y sus subclases (etiquetas finas):


Subclases de superclase recipientes para alimentos botellas, tazones, latas, tazas, platos


Una taza es un cilindro, similar a una lata de refresco, y algunas botellas también pueden serlo. Dado que estas características de bajo nivel son relativamente similares, es fácil colocarlas todas en la categoría "recipiente de comida", pero se requiere una abstracción de mayor nivel para adivinar correctamente si algo es una "taza\ “ o una "lata".

Lo que hace que este trabajo sea aún más difícil es que CIFAR10 tiene 6000 imágenes por clase, mientras que CIFAR100 tiene 600 imágenes por clase, lo que le da a la red menos imágenes de las que aprender las sutiles diferencias. Existen tazas sin asas y latas sin rebordes también. Desde un perfil: puede que no sea demasiado fácil distinguirlos.

Aquí es donde, digamos, un perceptrón multicapa simplemente no tiene el poder de abstracción para aprender, y está condenado al fracaso, terriblemente inadecuado. Las Redes Neuronales Convolucionales se construyen en base al neocognitrón, que tomó pistas de la neurociencia y el reconocimiento de patrones jerárquicos que realiza el cerebro. Estas redes pueden extraer características como esta y sobresalir en la tarea. Tanto es así que a menudo se ajustan mal y no se pueden usar como están al final, donde normalmente sacrificamos algo de precisión en aras de la capacidad de generalización.

Entrenemos dos arquitecturas de red diferentes en el conjunto de datos CIFAR10 y CIFAR100 como ilustración de mi punto.

Aquí es también donde podremos ver cómo, incluso cuando una red se adapta en exceso, no hay garantía de que la red en sí se generalice bien si se simplifica; es posible que no se pueda generalizar si se simplifica, aunque hay una tendencia. La red puede ser correcta, pero los datos pueden no ser suficientes.

En el caso de CIFAR100, solo 500 imágenes para entrenamiento (y 100 para prueba) por clase no son suficientes para que una CNN simple realmente generalice bien las 100 clases completas, y tendremos que realizar un aumento de datos para ayudar a lo largo. Incluso con el aumento de datos, es posible que no obtengamos una red muy precisa, ya que hay mucho que puede hacer con los datos. Si la misma arquitectura funciona bien en CIFAR10, pero no en CIFAR100, significa que simplemente no se puede distinguir de algunos de los detalles más finos que marcan la diferencia entre los objetos cilíndricos que llamamos "copa", " lata" y "botella", por ejemplo.

La gran mayoría de arquitecturas de red avanzadas que logran una alta precisión en el conjunto de datos CIFAR100 realizan el aumento de datos o amplían el conjunto de entrenamiento.

La mayoría de ellos tienen que hacerlo, y eso no es una señal de mala ingeniería. De hecho, el hecho de que podamos expandir estos conjuntos de datos y ayudar a las redes a generalizar mejor es una señal de ingenio de ingeniería.

Además, invitaría a cualquier humano a intentar adivinar cuáles son, si están convencidos de que la clasificación de imágenes no es demasiado difícil con imágenes tan pequeñas como 32x32:

¿Son Imagen 4 algunas naranjas? ¿Pelotas de ping pong? ¿Yemas de huevo? Bueno, probablemente no las yemas de huevo, pero eso requiere un conocimiento previo sobre qué son los "huevos" y si es probable que encuentre yemas sobre la mesa, algo que una red no tendrá. Considere la cantidad de conocimiento previo que pueda tener sobre el mundo y cuánto afecta lo que ve.

Importación de los datos

Usaremos Keras como la biblioteca de aprendizaje profundo preferida, pero puede seguir con otras bibliotecas o incluso con sus modelos personalizados si está dispuesto a hacerlo.

Pero antes que nada, carguémoslo, separemos los datos en un conjunto de entrenamiento, prueba y validación, normalizando los valores de la imagen a 0..1:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# Starting with CIFAR10
(X_train_full, Y_train_full), (X_test, Y_test) = keras.datasets.cifar10.load_data()

X_valid, X_train = X_train_full[:5000]/255.0, X_train_full[5000:]/255.0
Y_valid, Y_train = Y_train_full[:5000], Y_train_full[5000:]

X_test = X_test/255.0

Luego, visualicemos algunas de las imágenes en el conjunto de datos para tener una idea de a lo que nos enfrentamos:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
fig, ax = plt.subplots(5, 5, figsize=(10, 10))
ax = ax.ravel()

# Labels come as numbers of [0..9], so here are the class names for humans
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

for i in range(25):
    ax[i].imshow(X_train_full[i])
    ax[i].set_title(class_names[Y_train_full[i][0]])
    ax[i].axis('off')
    plt.subplots_adjust(wspace=1) 

plt.show()

Perceptrón multicapa inadecuado {#perceptrón multicapa inadecuado}

Más o menos no importa lo que hagamos, el MLP no funcionará tan bien. Definitivamente alcanzará cierto nivel de precisión en función de las secuencias sin procesar de información que ingresa, pero este número tiene un límite y probablemente no sea demasiado alto.

La red comenzará a sobreajustarse en un punto, aprendiendo las secuencias concretas de datos que denotan imágenes, pero aún tendrá poca precisión en el conjunto de entrenamiento incluso cuando se sobreajuste, que es el mejor momento para dejar de entrenarlo, ya que simplemente no cabe. bien los datos. [La formación en redes tiene una huella de carbono, ¿sabes?] {.small}

Agreguemos una devolución de llamada EarlyStopping para evitar ejecutar la red más allá del punto de sentido común, y configuremos epochs en un número más allá del que ejecutaremos (para que EarlyStopping pueda activarse) .

Usaremos la API secuencial para agregar un par de capas con BatchNormalization y un poco de Dropout. Ayudan con la generalización y queremos al menos intentar que este modelo aprenda algo.

Los principales hiperparámetros que podemos modificar aquí son el número de capas, sus tamaños, las funciones de activación, los inicializadores del kernel y las tasas de abandono, y aquí hay una configuración con un rendimiento "decente":

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
checkpoint = keras.callbacks.ModelCheckpoint("simple_dense.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.Sequential([
  keras.layers.Flatten(input_shape=[32, 32, 3]),
  keras.layers.BatchNormalization(),
  keras.layers.Dense(75),
    
  keras.layers.Dense((50), activation='elu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),
    
  keras.layers.Dense((50), activation='elu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),
    
  keras.layers.Dense(10, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Nadam(learning_rate=1e-4),
              metrics=["accuracy"])

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150, 
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])

Veamos si la hipótesis inicial es cierta: comenzará aprendiendo y generalizando hasta cierto punto, pero terminará teniendo poca precisión tanto en el conjunto de entrenamiento como en el conjunto de prueba y validación, lo que resultará en un bajo nivel general. precisión.

Para CIFAR10, la red realiza "bien"-ish:

1
2
3
4
5
Epoch 1/150
1407/1407 [==============================] - 5s 3ms/step - loss: 1.9706 - accuracy: 0.3108 - val_loss: 1.6841 - val_accuracy: 0.4100
...
Epoch 50/150
1407/1407 [==============================] - 4s 3ms/step - loss: 1.2927 - accuracy: 0.5403 - val_loss: 1.3893 - val_accuracy: 0.5122

Echemos un vistazo a la historia de su aprendizaje:

1
2
3
4
pd.DataFrame(history.history).plot()
plt.show()

model.evaluate(X_test, Y_test)

1
2
313/313 [==============================] - 0s 926us/step - loss: 1.3836 - accuracy: 0.5058
[1.383605718612671, 0.5058000087738037]

La precisión general llega hasta ~50 % y la red llega bastante rápido y comienza a estabilizarse. 5/10 imágenes clasificadas correctamente suena como tirar una moneda al aire, pero recuerde que hay 10 clases aquí, por lo que si estuviera adivinando al azar, en promedio adivinaría una sola imagen de cada diez. Pasemos al conjunto de datos CIFAR100, que también requiere una red con al menos un poco más de potencia, ya que hay menos instancias de entrenamiento por clase, así como una cantidad mucho mayor de clases:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
checkpoint = keras.callbacks.ModelCheckpoint("bigger_dense.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

# Changing the loaded data
(X_train_full, Y_train_full), (X_test, Y_test) = keras.datasets.cifar100.load_data()

# Modify the model
model1 = keras.Sequential([
  keras.layers.Flatten(input_shape=[32, 32, 3]),
  keras.layers.BatchNormalization(),
  keras.layers.Dense(256, activation='relu', kernel_initializer="he_normal"),
    
  keras.layers.Dense(128, activation='relu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),

  keras.layers.Dense(100, activation='softmax')
])


model1.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Nadam(learning_rate=1e-4),
              metrics=["accuracy"])

history = model1.fit(X_train, 
                    Y_train, 
                    epochs=150, 
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])

La red funciona bastante mal:

1
2
3
4
5
Epoch 1/150
1407/1407 [==============================] - 13s 9ms/step - loss: 4.2260 - accuracy: 0.0836 - val_loss: 3.8682 - val_accuracy: 0.1238
...
Epoch 24/150
1407/1407 [==============================] - 12s 8ms/step - loss: 2.3598 - accuracy: 0.4006 - val_loss: 3.3577 - val_accuracy: 0.2434

Y tracemos el historial de su progreso, así como también evaluémoslo en el conjunto de prueba (que probablemente funcionará tan bien como el conjunto de validación):

1
2
3
4
pd.DataFrame(history.history).plot()
plt.show()

model.evaluate(X_test, Y_test)

1
2
313/313 [==============================] - 0s 2ms/step - loss: 3.2681 - accuracy: 0.2408
[3.2681326866149902, 0.24079999327659607]

Como era de esperar, la red no pudo captar bien los datos. Terminó teniendo una precisión de sobreajuste del 40 % y una precisión real de ~24 %.

La precisión se limitó al 40%: no fue realmente capaz de sobreajustar el conjunto de datos, incluso si sobreajustó algunas partes que pudo discernir dada la arquitectura limitada. Este modelo no tiene la capacidad entrópica necesaria para que realmente se sobreajuste por el bien de mi argumento.

Este modelo y su arquitectura simplemente no son adecuados para esta tarea, y aunque técnicamente podríamos hacer que se (sobre) ajuste más, seguirá teniendo problemas a largo plazo. Por ejemplo, convirtámoslo en una red más grande, lo que teóricamente le permitiría reconocer patrones más complejos:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
model2 = keras.Sequential([
  keras.layers.Flatten(input_shape=[32, 32, 3]),
  keras.layers.BatchNormalization(),
  keras.layers.Dense(512, activation='relu', kernel_initializer="he_normal"),
    
  keras.layers.Dense(256, activation='relu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),
    
  keras.layers.Dense(128, activation='relu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),

  keras.layers.Dense(100, activation='softmax')
])

Sin embargo, esto no funciona mucho mejor en absoluto:

1
2
Epoch 24/150
1407/1407 [==============================] - 28s 20ms/step - loss: 2.1202 - accuracy: 0.4507 - val_loss: 3.2796 - val_accuracy: 0.2528

Es mucho más complejo (la densidad explota), pero simplemente no puede extraer mucho más:

1
2
model1.summary()
model2.summary()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Model: "sequential_17"
...
Total params: 845,284
Trainable params: 838,884
Non-trainable params: 6,400
_________________________________________________________________
Model: "sequential_18"
...
Total params: 1,764,324
Trainable params: 1,757,412
Non-trainable params: 6,912

Sobreajuste de red neuronal convolucional en CIFAR10 {#sobreajuste de red neuronal convolucional encifar10}

Ahora, intentemos hacer algo diferente. Cambiar a una CNN ayudará significativamente a extraer características del conjunto de datos, lo que permitirá que el modelo verdaderamente se sobreajuste, alcanzando una precisión mucho mayor (ilusoria).

Desactivaremos la devolución de llamada EarlyStopping para dejar que haga lo suyo. Además, no usaremos capas Dropout y, en su lugar, intentaremos forzar a la red a aprender las características a través de más capas.

{.icon aria-hidden=“true”}

Nota: Fuera del contexto de tratar de probar el argumento, este sería un consejo horrible. Esto es lo contrario de lo que te gustaría hacer al final. El abandono ayuda a las redes a generalizar mejor, al obligar a las neuronas que no se han perdido a tomar el relevo. Obligar a la red a aprender a través de más capas es más probable que conduzca a un modelo sobreajustado.

La razón por la que estoy haciendo esto a propósito es permitir que la red se sobreajuste horriblemente como un signo de su capacidad para discernir realmente las características, antes de simplificarla y agregar Dropout para permitir que realmente se generalice. Si alcanza una precisión alta (ilusoria), puede extraer mucho más que el modelo MLP, lo que significa que podemos comenzar a simplificarlo.

Usemos una vez más la API secuencial para construir una CNN, primero en el conjunto de datos CIFAR10:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
checkpoint = keras.callbacks.ModelCheckpoint("overcomplicated_cnn_cifar10.h5", save_best_only=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(64, 3, activation='relu', 
                        kernel_initializer="he_normal", 
                        kernel_regularizer=keras.regularizers.l2(l=0.01), 
                        padding='same', 
                        input_shape=[32, 32, 3]),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

model.summary()

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint])

¡Impresionante, se sobreajustó bastante rápido! En solo unas pocas épocas, comenzó a sobreajustar los datos y, en la época 31, llegó al 98 %, con una precisión de validación más baja:

1
2
3
4
5
Epoch 1/150
704/704 [==============================] - 149s 210ms/step - loss: 1.9561 - accuracy: 0.4683 - val_loss: 2.5060 - val_accuracy: 0.3760
...
Epoch 31/150
704/704 [==============================] - 149s 211ms/step - loss: 0.0610 - accuracy: 0.9841 - val_loss: 1.0433 - val_accuracy: 0.6958

Dado que solo hay 10 clases de salida, aunque intentamos sobreajustarlo mucho creando una CNN innecesariamente grande, la precisión de la validación sigue siendo bastante alta.

Simplificación de la red neuronal convolucional en CIFAR10 {#simplificación de la red neuronal convolucional en cifar10}

Ahora, simplifiquemos para ver cómo le irá con una arquitectura más razonable. Agregaremos BatchNormalization y Dropout ya que ambos ayudan con la generalización:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
checkpoint = keras.callbacks.ModelCheckpoint("simplified_cnn_cifar10.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(32, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.5),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(32, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

model.summary()

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])

Este modelo tiene un recuento (modesto) de 323 146 parámetros entrenables, en comparación con 1 579 178 de la CNN anterior. ¿Cómo funciona?

1
2
3
4
5
Epoch 1/150
704/704 [==============================] - 91s 127ms/step - loss: 2.1327 - accuracy: 0.3910 - val_loss: 1.5495 - val_accuracy: 0.5406
...
Epoch 52/150
704/704 [==============================] - 89s 127ms/step - loss: 0.4091 - accuracy: 0.8648 - val_loss: 0.4694 - val_accuracy: 0.8500

¡En realidad logra una precisión bastante decente de ~85%! La navaja de Occam ataca de nuevo. Echemos un vistazo a algunos de los resultados:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
y_preds = model.predict(X_test)
print(y_preds[1])
print(np.argmax(y_preds[1]))

fig, ax = plt.subplots(6, 6, figsize=(10, 10))
ax = ax.ravel()

for i in range(0, 36):
    ax[i].imshow(X_test[i])
    ax[i].set_title("Actual: %s\nPred: %s" % (class_names[Y_test[i][0]], class_names[np.argmax(y_preds[i])]))
    ax[i].axis('off')
    plt.subplots_adjust(wspace=1)
    
plt.show()

Los principales errores de clasificación son dos imágenes en este pequeño conjunto: un perro fue clasificado erróneamente como un ciervo (lo suficientemente respetable), pero un primer plano de un pájaro emú fue clasificado como un gato (lo suficientemente divertido, así que lo dejaremos pasar).

Sobreajuste de la red neuronal convolucional en CIFAR100

¿Qué sucede cuando buscamos el conjunto de datos CIFAR100?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
checkpoint = keras.callbacks.ModelCheckpoint("overcomplicated_cnn_model_cifar100.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(32, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(256, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.BatchNormalization(),
    
    keras.layers.Dense(100, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

model.summary()

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint])
1
2
3
4
5
Epoch 1/150
704/704 [==============================] - 97s 137ms/step - loss: 4.1752 - accuracy: 0.1336 - val_loss: 3.9696 - val_accuracy: 0.1392
...
Epoch 42/150
704/704 [==============================] - 95s 135ms/step - loss: 0.1543 - accuracy: 0.9572 - val_loss: 4.1394 - val_accuracy: 0.4458

¡Maravilloso! ~96% de precisión en el conjunto de entrenamiento! No te preocupes por la precisión de validación del ~44 % todavía. Simplifiquemos el modelo muy rápido para que se generalice mejor.

Fallo al generalizar después de la simplificación

Y aquí es donde queda claro que la capacidad de sobreajuste no garantiza que el modelo pueda generalizarse mejor cuando se simplifica. En el caso de CIFAR100, no hay muchas instancias de capacitación por clase, y esto probablemente impedirá que una versión simplificada del modelo anterior aprenda bien. Vamos a probarlo:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
checkpoint = keras.callbacks.ModelCheckpoint("simplified_cnn_model_cifar100.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(32, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.5),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(256, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(100, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])
1
2
3
4
5
Epoch 1/150
704/704 [==============================] - 96s 135ms/step - loss: 4.4432 - accuracy: 0.1112 - val_loss: 3.7893 - val_accuracy: 0.1702
...
Epoch 48/150
704/704 [==============================] - 92s 131ms/step - loss: 1.2550 - accuracy: 0.6370 - val_loss: 1.7147 - val_accuracy: 0.5466

Se está estabilizando y realmente no puedo llegar a generalizar los datos. En este caso, es posible que no sea culpa del modelo; tal vez sea el adecuado para la tarea, especialmente dada la alta precisión del conjunto de datos CIFAR10, que tiene la misma forma de entrada e imágenes similares en el conjunto de datos. Parece que el modelo puede ser razonablemente preciso con las formas generales, pero no con la distinción entre formas finas.

El modelo más simple en realidad funciona mejor que el más complicado en términos de precisión de validación, por lo que la CNN más compleja no obtiene estos detalles mucho mejor. Aquí, lo más probable es que el problema resida en el hecho de que solo hay 500 imágenes de entrenamiento por clase, lo que realmente no es suficiente. En la red más compleja, esto conduce a un ajuste excesivo, porque no hay suficiente diversidad; cuando se simplifica para evitar el ajuste excesivo, esto provoca un ajuste insuficiente porque, de nuevo, no hay diversidad.

Esta es la razón por la que la gran mayoría de los artículos vinculados antes y la gran mayoría de las redes aumentan los datos del conjunto de datos CIFAR100.

Realmente no es un conjunto de datos en el que sea fácil obtener una alta precisión, a diferencia del conjunto de datos de dígitos escritos a mano del MNIST, y una CNN simple como la que estamos construyendo probablemente no sea suficiente para lograr una alta precisión. Solo recuerde la cantidad de clases bastante específicas, cuán poco informativas son algunas de las imágenes y * cuánto conocimiento previo tienen los humanos para discernir entre ellas *.

Hagamos nuestro mejor esfuerzo aumentando algunas imágenes y expandiendo artificialmente los datos de entrenamiento, para al menos tratar de obtener una mayor precisión. Tenga en cuenta que el CIFAR100 es, nuevamente, un conjunto de datos genuinamente difícil para obtener una alta precisión con modelos simples. Los modelos de última generación utilizan técnicas diferentes y novedosas para eliminar errores, y muchos de estos modelos ni siquiera son CNN, son Transformers.

Si quieres echar un vistazo al panorama de estos modelos, PapelesConCódigo ha hecho una preciosa recopilación de papers, código fuente y resultados.

Aumento de datos con la clase ImageDataGenerator de Keras

¿Ayudará el aumento de datos? Por lo general, lo hace, pero con una seria falta de datos de entrenamiento como la que enfrentamos, hay mucho que puede hacer con rotaciones aleatorias, voltear, recortar, etc. Si una arquitectura no puede generalizar bueno, en un conjunto de datos, es probable que lo aumente a través del aumento de datos, pero probablemente no sea mucho.

Dicho esto, usemos la clase ImageDataGenerator de Keras para intentar generar nuevos datos de entrenamiento con cambios aleatorios, con la esperanza de mejorar la precisión del modelo. Si mejora, no debería ser en gran medida, y es probable que vuelva a sobreajustar parcialmente el conjunto de datos sin la capacidad de generalizar bien o sobreajustar completamente los datos.

Dadas las variaciones aleatorias constantes en los datos, es menos probable que el modelo se sobreajuste en el mismo número de épocas, ya que las variaciones hacen que siga ajustándose a los datos "nuevos". Vamos a ejecutarlo durante, digamos, 300 épocas, que es significativamente más que el resto de las redes que hemos entrenado. Esto es posible sin un sobreajuste importante, nuevamente, debido a las modificaciones aleatorias realizadas en las imágenes mientras fluyen:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
checkpoint = keras.callbacks.ModelCheckpoint("augmented_cnn.h5", save_best_only=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(64, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(512, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(100, activation='softmax')
])

    
train_datagen = ImageDataGenerator(rotation_range=30,
        height_shift_range=0.2,
        width_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest')

valid_datagen = ImageDataGenerator()

train_datagen.fit(X_train)
valid_datagen.fit(X_valid)

train_generator = train_datagen.flow(X_train, Y_train, batch_size=128)
valid_generator = valid_datagen.flow(X_valid, Y_valid, batch_size=128)

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-6),
              metrics=["accuracy"])

history = model.fit(train_generator, 
                    epochs=300,
                    batch_size=128,
                    steps_per_epoch=len(X_train)//128,
                    validation_data=valid_generator,
                    callbacks=[checkpoint])
1
2
3
4
5
Epoch 1/300
351/351 [==============================] - 16s 44ms/step - loss: 5.3788 - accuracy: 0.0487 - val_loss: 5.3474 - val_accuracy: 0.0440
...
Epoch 300/300
351/351 [==============================] - 15s 43ms/step - loss: 1.0571 - accuracy: 0.6895 - val_loss: 2.0005 - val_accuracy: 0.5532

El modelo se está desempeñando con ~55 % en el conjunto de validación, y aún sobreajusta parcialmente los datos. El val_loss ha dejado de bajar y es bastante inestable, incluso con un batch_size más alto.

Esta red simplemente no puede aprender y ajustar los datos con alta precisión, a pesar de que las variaciones fuera de ella tienen la capacidad entrópica para sobreajustar los datos.

¿Conclusión?

El sobreajuste no es inherentemente algo malo, es solo una cosa. No, no desea modelos finales sobreajustados, pero no debe tratarse como una plaga e incluso puede ser una buena señal de que un modelo podría funcionar mejor con más datos y un paso de simplificación. Esto no está garantizado, de ninguna manera, y el conjunto de datos CIFAR100 se ha utilizado como ejemplo de un conjunto de datos que no es fácil de generalizar bien.

El punto de esta divagación es, nuevamente, no ser contrario, sino incitar a la discusión sobre el tema, que no parece estar teniendo mucho lugar.

¿Quién soy yo para hacer esta afirmación?

Solo alguien que se sienta en casa, practicando el oficio, con una profunda fascinación por el mañana.

¿Tengo la capacidad de equivocarme?

Mucho.

¿Cómo debes tomar esta pieza?

Tómalo como puedas, piensa por ti mismo si tiene sentido o no. Si no crees que estoy fuera de mi lugar por notar esto, házmelo saber. Si cree que me equivoco en esto, por supuesto, hágamelo saber y no se ande con rodeos. :)