LSTMs: Long Short-Term Memory
Standard RNNs have a major weakness: they have a very short memory. Because of the Vanishing Gradient problem, they struggle to connect information that is far apart in a sequence.
LSTMs, introduced by Hochreiter & Schmidhuber, were specifically designed to overcome this. They introduce a "Cell State" (a long-term memory track) and a series of "Gates" that control what information is kept and what is discarded.
1. The Core Innovation: The Cell Stateβ
The "Secret Sauce" of the LSTM is the Cell State (). You can imagine it as a conveyor belt that runs straight down the entire chain of sequences, with only some minor linear interactions. It is very easy for information to just flow along it unchanged.
2. The Three Gates of LSTMβ
An LSTM uses three specialized gates to protect and control the cell state. Each gate is composed of a Sigmoid neural net layer and a point-wise multiplication operation.
A. The Forget Gate ()β
This gate decides what information we are going to throw away from the cell state.
- Input: (previous hidden state) and (current input).
- Output: A number between 0 (completely forget) and 1 (completely keep).
B. The Input Gate ()β
This gate decides which new information weβre going to store in the cell state. It works in tandem with a tanh layer that creates a vector of new candidate values ().
C. The Output Gate ()β
This gate decides what our next hidden state () should be. The hidden state contains information on previous inputs and is also used for predictions.
3. Advanced Architectural Logic (Mermaid)β
The flow within a single LSTM cell is highly structured. The "Cell State" acts as the horizontal spine, while gates regulate the vertical flow of information.
4. LSTM vs. Standard RNNβ
| Feature | Standard RNN | LSTM |
|---|---|---|
| Architecture | Simple (Single Tanh layer) | Complex (4 interacting layers) |
| Memory | Short-term only | Long and Short-term |
| Gradient Flow | Suffers from Vanishing Gradient | Resists Vanishing Gradient via the Cell State |
| Complexity | Low | High (More parameters to train) |
5. Implementation with PyTorchβ
In PyTorch, the nn.LSTM module automatically handles the complex gating logic and cell state management.
import torch
import torch.nn as nn
# input_size=10, hidden_size=20, num_layers=1
lstm = nn.LSTM(10, 20, batch_first=True)
# Input shape: (batch_size, seq_len, input_size)
input_seq = torch.randn(1, 5, 10)
# Initial Hidden State (h0) and Cell State (c0)
h0 = torch.zeros(1, 1, 20)
c0 = torch.zeros(1, 1, 20)
# Forward pass returns output and a tuple (hn, cn)
output, (hn, cn) = lstm(input_seq, (h0, c0))
print(f"Output shape: {output.shape}") # [1, 5, 20]
print(f"Final Cell State shape: {cn.shape}") # [1, 1, 20]
Referencesβ
- Colah's Blog: Understanding LSTM Networks (Essential Reading)
- Stanford CS224N: RNNs and LSTMs
LSTMs are powerful but computationally expensive because of their three gates. Is there a way to simplify this without losing the memory benefits?