Skip to main content

RNN Basics: Neural Networks with Memory

Traditional neural networks (like CNNs or MLPs) are feed-forward; they assume that all inputs are independent of each other. This is a problem for data that comes in a specific order, such as:

  • Text: The meaning of a word depends on the words before it.
  • Audio: A sound wave is a continuous sequence over time.
  • Stock Prices: Today's price is highly dependent on yesterday's trend.

Recurrent Neural Networks (RNNs) solve this by introducing a "loop" that allows information to persist.

1. The Core Idea: The Hidden State

The defining feature of an RNN is the Hidden State (hth_t). You can think of this as the "memory" of the network. As the network processes each element in a sequence, it updates this hidden state based on the current input and the previous hidden state.

The Mathematical Step

At every time step tt, the RNN performs two operations:

  1. Update Hidden State: ht=ϕ(Whhht1+Wxhxt+bh)h_t = \phi(W_{hh} h_{t-1} + W_{xh} x_t + b_h)
  2. Calculate Output: yt=Whyht+byy_t = W_{hy} h_t + b_y
  • xtx_t: Input at time tt.
  • ht1h_{t-1}: Memory from the previous step.
  • ϕ\phi: An activation function (usually Tanh or ReLU).

2. Unrolling the Network

To understand how an RNN learns, we "unroll" it. Instead of looking at it as a single cell with a loop, we view it as a chain of identical cells, each passing a message to its successor.

This structure shows that RNNs are essentially a deep network where the "depth" is determined by the length of the input sequence.

3. RNN Input-Output Architectures

RNNs are incredibly flexible and can be configured in several ways depending on the task:

ArchitectureDescriptionExample
One-to-ManyOne input, a sequence of outputs.Image Captioning (Image \rightarrow Sentence).
Many-to-OneA sequence of inputs, one output.Sentiment Analysis (Sentence \rightarrow Positive/Negative).
Many-to-ManyA sequence of inputs and outputs.Machine Translation (English \rightarrow French).

4. The Major Flaw: Vanishing Gradients

While standard RNNs are theoretically powerful, they struggle with long-term dependencies.

Because the network is unrolled, backpropagation must travel through every time step. If the sequence is long, the gradient (error signal) is multiplied by the weights repeatedly. If the weights are small, the gradient "vanishes" before it can reach the beginning of the sequence.

  • Result: The network forgets what happened at the start of a long sentence.
  • The Solution: Specialized units like LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit).

5. Implementation with PyTorch

In PyTorch, the nn.RNN module handles the recurrent logic for you.

import torch
import torch.nn as nn

# Parameters: input_size=10, hidden_size=20, num_layers=1
rnn = nn.RNN(10, 20, 1, batch_first=True)

# Input shape: (batch, sequence_length, input_size)
# Example: 1 batch, 5 words, each represented by a 10-dim vector
input_seq = torch.randn(1, 5, 10)

# Initial hidden state (set to zeros)
h0 = torch.zeros(1, 1, 20)

# Forward pass
output, hn = rnn(input_seq, h0)

print(f"Output shape (all steps): {output.shape}") # [1, 5, 20]
print(f"Final hidden state shape: {hn.shape}") # [1, 1, 20]

References


Standard RNNs have a "short-term" memory problem. To solve this, we use a more complex architecture that can decide what to remember and what to forget.