OOP in Machine Learning
Most beginner code is Procedural (a long list of instructions). However, professional Machine Learning code is almost always Object-Oriented. OOP allows us to bundle data (like model weights) and functions (like the training logic) into a single unit called an Object.
1. Classes vs. Objects
Think of a Class as a blueprint and an Object as the actual house built from that blueprint.
- Class: The template for a "Model" (e.g., defines that all models need a
fitandpredictmethod). - Object: A specific instance (e.g., a
RandomForesttrained on housing data).
2. The Core Components: Attributes and Methods
In ML, an object typically consists of:
- Attributes (Data): The "State" of the model. (e.g.,
self.weights,self.learning_rate). - Methods (Behavior): The "Actions" the model can take. (e.g.,
self.fit(),self.predict()).
class SimpleModel:
def __init__(self, lr):
# Attribute: Initializing the state
self.learning_rate = lr
self.weights = None
def fit(self, X, y):
# Method: Defining behavior
print(f"Training with LR: {self.learning_rate}")
3. The Four Pillars of OOP in ML
A. Encapsulation
Hiding the internal complexity. You don't need to know the calculus inside .fit() to use it; you just call the method. It "encapsulates" the math away from the user.
B. Inheritance
Creating a new class based on an existing one. In libraries like PyTorch, your custom neural network inherits from a base Module class.
C. Polymorphism
The ability for different objects to be treated as instances of the same general class. For example, you can loop through a list of different models and call .predict() on all of them, regardless of their internal math.
D. Abstraction
Using simple interfaces to represent complex tasks. An "Optimizer" object abstracts away the specific update rules (SGD, Adam, RMSProp).
4. Why use OOP for ML?
- Organization: Keeps weights and training logic together. Without OOP, you'd have to pass
weightsas an argument to every single function. - Reproducibility: You can save an entire object (the "state_dict") and reload it later to get the exact same results.
- Extensibility: Want to try a new loss function? You can create a subclass and just override one method without rewriting the whole training loop.
5. Standard ML Pattern: The Class Structure
class MyNeuralNet:
def __init__(self, input_size):
self.weights = initialize(input_size) # State
def forward(self, x):
return x @ self.weights # Behavior 1
def backward(self, grad):
# Update weights logic # Behavior 2
pass
Now that you understand how objects work, you can begin to navigate the source code of major ML libraries. But before we build complex classes, we need to master the math engine that powers them.