RNN Basics: Neural Networks with Memory
Traditional neural networks (like CNNs or MLPs) are feed-forward; they assume that all inputs are independent of each other. This is a problem for data that comes in a specific order, such as:
- Text: The meaning of a word depends on the words before it.
- Audio: A sound wave is a continuous sequence over time.
- Stock Prices: Today's price is highly dependent on yesterday's trend.
Recurrent Neural Networks (RNNs) solve this by introducing a "loop" that allows information to persist.
1. The Core Idea: The Hidden State
The defining feature of an RNN is the Hidden State (). You can think of this as the "memory" of the network. As the network processes each element in a sequence, it updates this hidden state based on the current input and the previous hidden state.
The Mathematical Step
At every time step , the RNN performs two operations:
- Update Hidden State:
- Calculate Output:
- : Input at time .
- : Memory from the previous step.
- : An activation function (usually Tanh or ReLU).
2. Unrolling the Network
To understand how an RNN learns, we "unroll" it. Instead of looking at it as a single cell with a loop, we view it as a chain of identical cells, each passing a message to its successor.
This structure shows that RNNs are essentially a deep network where the "depth" is determined by the length of the input sequence.
3. RNN Input-Output Architectures
RNNs are incredibly flexible and can be configured in several ways depending on the task:
| Architecture | Description | Example |
|---|---|---|
| One-to-Many | One input, a sequence of outputs. | Image Captioning (Image Sentence). |
| Many-to-One | A sequence of inputs, one output. | Sentiment Analysis (Sentence Positive/Negative). |
| Many-to-Many | A sequence of inputs and outputs. | Machine Translation (English French). |
4. The Major Flaw: Vanishing Gradients
While standard RNNs are theoretically powerful, they struggle with long-term dependencies.
Because the network is unrolled, backpropagation must travel through every time step. If the sequence is long, the gradient (error signal) is multiplied by the weights repeatedly. If the weights are small, the gradient "vanishes" before it can reach the beginning of the sequence.
- Result: The network forgets what happened at the start of a long sentence.
- The Solution: Specialized units like LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit).
5. Implementation with PyTorch
In PyTorch, the nn.RNN module handles the recurrent logic for you.
import torch
import torch.nn as nn
# Parameters: input_size=10, hidden_size=20, num_layers=1
rnn = nn.RNN(10, 20, 1, batch_first=True)
# Input shape: (batch, sequence_length, input_size)
# Example: 1 batch, 5 words, each represented by a 10-dim vector
input_seq = torch.randn(1, 5, 10)
# Initial hidden state (set to zeros)
h0 = torch.zeros(1, 1, 20)
# Forward pass
output, hn = rnn(input_seq, h0)
print(f"Output shape (all steps): {output.shape}") # [1, 5, 20]
print(f"Final hidden state shape: {hn.shape}") # [1, 1, 20]
References
- Andrej Karpathy: The Unreasonable Effectiveness of Recurrent Neural Networks
Standard RNNs have a "short-term" memory problem. To solve this, we use a more complex architecture that can decide what to remember and what to forget.