Skip to main content

Image Segmentation

While Image Classification tells us what is in an image, and Object Detection tells us where it is, Image Segmentation provides a pixel-perfect understanding of the scene.

It is the process of partitioning a digital image into multiple segments (sets of pixels) to simplify or change the representation of an image into something that is more meaningful and easier to analyze.

1. Types of Segmentation

Not all segmentation tasks are the same. We generally categorize them into three levels of complexity:

A. Semantic Segmentation

Every pixel is assigned a class label (e.g., "Road," "Sky," "Car"). However, it does not differentiate between multiple instances of the same class. Two cars parked next to each other will appear as a single connected "blob."

B. Instance Segmentation

This goes a step further by detecting and delineating each distinct object of interest. If there are five people in a photo, instance segmentation will give each person a unique color/ID.

C. Panoptic Segmentation

The "holy grail" of segmentation. It combines semantic and instance segmentation to provide a total understanding of the scene—identifying individual objects (cars, people) and background textures (sky, grass).

2. The Architecture: Encoder-Decoder (U-Net)

Traditional CNNs lose spatial resolution through pooling. To get back to an image output of the same size as the input, we use an Encoder-Decoder architecture.

  1. Encoder (The "What"): A standard CNN that downsamples the image to extract high-level features.
  2. Bottleneck: The compressed representation of the image.
  3. Decoder (The "Where"): Uses Transposed Convolutions (Upsampling) to recover the spatial dimensions.
  4. Skip Connections: These are the "secret sauce" of the U-Net architecture. They pass high-resolution information from the encoder directly to the decoder to help refine the boundaries of the mask.

3. Loss Functions for Segmentation

Because we are classifying every pixel, standard accuracy can be misleading (especially if 90% of the image is just background). We use specialized metrics:

  • Intersection over Union (IoU) / Jaccard Index: Measures the overlap between the predicted mask and the ground truth.
  • Dice Coefficient: Similar to IoU, it measures the similarity between two sets of data and is more robust to class imbalance.
IoU=Area of OverlapArea of UnionIoU = \frac{\text{Area of Overlap}}{\text{Area of Union}}

4. Real-World Applications

  • Medical Imaging: Identifying tumors or mapping organs in MRI and CT scans.
  • Self-Driving Cars: Identifying the exact boundaries of lanes, sidewalks, and drivable space.
  • Satellite Imagery: Mapping land use, deforestation, or urban development.
  • Portrait Mode: Separating the person (subject) from the background to apply a "bokeh" blur effect.
ModelTypeBest For
U-NetSemanticMedical imaging and biomedical research.
Mask R-CNNInstanceDetecting objects and generating masks (e.g., counting individual cells).
DeepLabV3+SemanticState-of-the-art results using Atrous (Dilated) Convolutions.
SegNetSemanticEfficient scene understanding for autonomous driving.

6. Implementation Sketch (PyTorch)

Using a pre-trained segmentation model from torchvision:

import torch
from torchvision import models

# Load a pre-trained DeepLabV3 model
model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()

# Input: (Batch, Channels, Height, Width)
dummy_input = torch.randn(1, 3, 224, 224)

# Output: Returns a dictionary containing 'out' - the pixel-wise class predictions
with torch.no_grad():
output = model(dummy_input)['out']

print(f"Output shape: {output.shape}")
# Shape will be [1, 21, 224, 224] (for 21 Pascal VOC classes)

References


Segmentation provides a high level of detail, but it's computationally expensive. How do we make these models faster for real-time applications?