18.17. Retropropagación y entrenamiento de redes neuronales: redes neuronales recurrentes (RNN) y retropropagación a través del tiempo (BPTT)
Las redes neuronales recurrentes (RNN) son una clase de redes neuronales que son especialmente efectivas para procesar secuencias de datos, como series temporales, lenguaje natural o cualquier tipo de datos donde el orden temporal sea relevante. A diferencia de las redes neuronales feedforward, donde la información fluye en una sola dirección, las RNN tienen conexiones que forman ciclos, lo que permite "retener" la información en la red durante algún tiempo. Esto es crucial para tareas donde el contexto y el orden de los datos son importantes.
Para entrenar RNN, utilizamos una técnica conocida como retropropagación a través del tiempo (BPTT). BPTT es una generalización del algoritmo de retropropagación para redes con ciclos. Exploremos cómo funciona BPTT y cómo se aplica al entrenamiento de RNN.
Cómo funcionan los RNN
En un RNN, cada neurona o unidad tiene una conexión recurrente consigo misma. Esto permite que la neurona conserve el estado anterior como una especie de "memoria", que influye en la salida actual basándose no sólo en la entrada actual, sino también en las entradas anteriores. Matemáticamente, esto se expresa mediante la siguiente fórmula, donde ht es el estado oculto en el momento t, xt es la entrada en el momento t, y W y U son pesos que deben aprenderse: p>
ht = f(W * ht-1 + U * xt + b)
La función f suele ser una función de activación no lineal, como tanh o ReLU. El vector de estado oculto ht se actualiza en cada paso de tiempo, capturando información sobre la secuencia hasta el momento actual.
Retropropagación a través del tiempo (BPTT)
BPTT es un proceso que adapta el algoritmo de retropropagación estándar para redes con conexiones recurrentes. El principio básico es desplegar el RNN en el tiempo, transformándolo en una red de avance profundo, donde cada "capa" corresponde a un paso de tiempo en la secuencia de entrada. Esto permite calcular los gradientes para cada paso de tiempo, teniendo en cuenta las dependencias temporales.
Para realizar el BPTT, seguimos los siguientes pasos:
- Propagación hacia adelante: la entrada se procesa secuencialmente, y cada estado oculto se calcula en función del estado anterior y la entrada actual.
- Cálculo de error: después de la propagación hacia adelante, el error se calcula en la salida final de la secuencia o en cada paso de tiempo, dependiendo de la tarea.
- Propagación hacia atrás: el error se propaga hacia atrás a través de la red desplegada, calculando los gradientes para cada paso de tiempo.
- Actualización de pesos: Los pesos se actualizan en función de los gradientes calculados, normalmente utilizando un optimizador como SGD, Adam, entre otros.
Una de las dificultades con BPTT es que la red desplegada puede volverse muy profunda durante secuencias largas, lo que puede provocar problemas como la desaparición o la explosión de gradientes. Los gradientes que desaparecen ocurren cuando el gradiente se vuelve tan pequeño que el entrenamiento no progresa, mientras que los gradientes explosivos pueden hacer que los pesos se vuelvan demasiado grandes e inestables.
Variantes de RNN y soluciones a problemas de BPTT
Para abordar estos problemas, se han desarrollado variantes de RNN, como la memoria a corto plazo (LSTM) y las unidades recurrentes cerradas (GRU). Estas arquitecturas incluyen mecanismos de puerta que controlan el flujo de información, lo que permite a la red aprender cuándo "recordar" u "olvidar" información pasada, lo que ayuda a mitigar el problema de los gradientes que desaparecen.
Además, se utilizan técnicas como el recorte de degradado para evitar que los degradados exploten al recortarlos cuando superan un determinado valor.
Conclusión
Entrenar RNN con BPTT es una técnica poderosa para aprender dependencias temporales en datos secuenciales. Aunque son un desafío debido a problemas como la desaparición y la explosión de gradientes, los avances en las arquitecturas RNN y las técnicas de optimización continúan mejorando la efectividad de los RNN en una amplia variedad de tareas, desde el reconocimiento de voz hasta la generación de texto. Comprender y aplicar BPTT es esencial para cualquiera que quiera trabajar con aprendizaje automático y aprendizaje profundo para datos de secuencia.