Volver al blog
Programar una red neuronal recurrente (RNN) desde cero con PyTorch
ML10 min lectura23 ene 2025

Programar una red neuronal recurrente (RNN) desde cero con PyTorch

Construye una RNN desde cero con PyTorch. Nuestra guía hace sencillo el código de RNN para todos los niveles. ¡Empieza con deep learning ya!
SD
SolarDevs Team
Liderazgo técnico

RNN vs arquitectura feedforward

Personalmente me resulta más fácil entender las RNN cuando las comparo con redes feedforward, porque es un concepto conocido y solo voy sumando ideas nuevas; por eso las compararé a menudo.

A diferencia de las redes feedforward, el mecanismo de las RNN es un poco más complejo. Dentro de una sola capa de red neuronal recurrente hay 3 matrices de pesos, además de 2 tensores de entrada y 2 tensores de salida.

Suele decirse que las RNN son feedforward con estado interno; pero con un diagrama sencillo se ve que no es tan simple. Los componentes son bastante más complejos en una red recurrente; no te preocupes, intentaré explicarte cómo funciona y, con el código, podrás entenderlo.

Arquitectura de una capa RNN

Las redes recurrentes introducen el concepto de estado oculto: básicamente otra entrada que depende de las salidas anteriores de la capa. Si depende de salidas anteriores, ¿de dónde sale en la primera pasada? Fácil: inícialo con ceros.

Las RNN se alimentan distinto a las feedforward. Como trabajamos con secuencias, el orden en que entran los datos importa; por eso en cada paso alimentamos la red con un solo elemento de la secuencia. Por ejemplo: si son precios de acciones, entramos el precio de cada día; si es texto, entramos una letra o palabra cada vez.

Entramos paso a paso porque en cada iteración hay que calcular el estado oculto; ese estado guarda información previa para que la siguiente entrada use datos de pasadas anteriores mediante la suma de las matrices.

Entradas

  • Tensor de entrada: Debe ser un solo paso de la secuencia. Si tu secuencia tiene 100 caracteres, la entrada será un solo carácter.
  • Tensor de estado oculto: Es el estado oculto. En la primera pasada de cada secuencia completa, este tensor se rellena con ceros. Siguiendo el ejemplo: si tienes 10 secuencias de 100 caracteres (1000 caracteres en total), para cada secuencia generarás un estado oculto inicializado en ceros.

Matrices de pesos

  • Dense de entrada: Matriz densa para la entrada (como en feedforward).
  • Dense oculta: Matriz densa para la entrada del estado oculto.
  • Dense de salida: Matriz densa para el resultado de activation(input_dense + hidden_dense).
activation(input_dense + hidden_dense)

Salidas

  • Nuevo estado oculto: Tensor de estado oculto activation(input_dense + hidden_dense). Lo usarás como entrada en la siguiente iteración de la secuencia.
  • Salida: activation(output_dense). Es tu vector de predicción, análogo a la salida de una red feedforward.

Código de una capa RNN

import torch
import torch.nn as nn

class RNN(nn.Module):
    """ Bloque RNN básico. Representa una sola capa de RNN """
    def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
        """
        input_size: Número de características del vector de entrada
        hidden_size: Número de neuronas ocultas
        output_size: Número de características del vector de salida
        """
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_state) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Devuelve la salida calculada y tanh(i2h + h2h)
        Entradas
        -------
        x: Vector de entrada
        hidden_state: Estado oculto anterior
        Salidas
        -------
        out: Salida lineal (sin activación por cómo funciona PyTorch)
        hidden_state: Nueva matriz de estado oculto
        """
        x = self.i2h(x)
        hidden_state = self.h2h(hidden_state)
        hidden_state = torch.tanh(x + hidden_state)
        out = self.h2o(hidden_state)
        return out, hidden_state

    def init_zero_hidden(self, batch_size=1) -> torch.Tensor:
        """
        Función auxiliar. Devuelve un estado oculto con el tamaño de lote indicado. Por defecto 1.
        """
        return torch.zeros(batch_size, self.hidden_size, requires_grad=False)

Entrenamiento por lotes

Alimentar una RNN por lotes suele ser mucho más rápido (fácilmente 10x), y las RNN no son la excepción. Entrenar por lotes no mejora el rendimiento del modelo en sí: si tu red no funciona con un ejemplo a la vez, tampoco con 10 o 100.

En el ejemplo la RNN se entrena con texto, un carácter a la vez; la función de entrenamiento debe dar un carácter del texto en cada paso. Hacerlo por lotes ahorra mucho tiempo, así que se pueden procesar varios lotes por época.

def train(model: RNN, data: DataLoader, epochs: int, optimizer: optim.Optimizer, loss_fn: nn.Module) -> None:
    """ Entrena el modelo durante el número de épocas indicado """
    train_losses = {}
    model.to(device)
    model.train()
    print("=> Iniciando entrenamiento")
    for epoch in range(epochs):
        epoch_losses = list()
        for X, Y in data:
            if X.shape[0] != model.batch_size:
                continue
            hidden = model.init_zero_hidden(batch_size=model.batch_size)
            X, Y, hidden = X.to(device), Y.to(device), hidden.to(device)
            
            model.zero_grad()
            loss = 0
            for c in range(X.shape[1]):
                out, hidden = model(X[:, c].reshape(X.shape[0],1), hidden)
                l = loss_fn(out, Y[:, c].long())
                loss += l
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 3)
            optimizer.step()
            epoch_losses.append(loss.detach().item() / X.shape[1])
        
        train_losses[epoch] = torch.tensor(epoch_losses).mean()
        print(f'=> época: {epoch + 1}, pérdida: {train_losses[epoch]}')

Programar esto a mano ayuda mucho a entender las operaciones y el flujo; además es muy satisfactorio ver cómo la RNN aprende del texto y genera nuevo texto.

Construye tu futuro.

¿Listo para transformar tu infraestructura con agentes de IA inteligentes?

Iniciar descubrimiento