datarekha
Deep Learning Hard Asked at GoogleAsked at MetaAsked at OpenAIAsked at NVIDIA

Why does depth help more than width for learning complex functions?

The short answer

Depth enables hierarchical feature composition — each layer can re-use and recombine features from the layer below, representing exponentially more functions than a single wide layer of the same parameter count. Width alone increases capacity linearly, while depth increases it exponentially in the class of computable functions.

How to think about it

The shallow-vs-deep equivalence argument:

A wide single-hidden-layer network can, in theory, approximate any function (Universal Approximation Theorem — Hornik 1989). But “in theory” hides a brutal constant: the width required to match a depth-L network can be exponential in L for piecewise linear functions (Montufar et al., 2014). In practice that means impractical numbers of neurons.

Depth provides hierarchical re-use:

In a convolutional vision model:

  • Layer 1: edges and color blobs
  • Layer 2: corners, curves, simple textures
  • Layer 3: object parts (eyes, wheels)
  • Layer 4+: whole objects

Each layer’s features are built from the previous layer’s features without duplicating computation. A single wide layer must detect every pattern independently — no sharing, no composition.

Expressiveness argument (piecewise linear networks):

A ReLU network with L layers of width n can produce up to O((n/L)^{(L-1)·d} · n^d) linear regions in d-dimensional input space. For fixed total parameters, more depth means exponentially more regions — more complex decision boundaries.

But depth has costs:

  • Deeper networks are harder to optimize (vanishing/exploding gradients).
  • Require careful initialization, normalization, and often residual connections.
  • Inference latency scales with depth on sequential hardware.

Practical heuristic: for structured data on small datasets, a wide shallow MLP often wins. For images, audio, text, or anything with spatial/temporal hierarchy, deeper architectures consistently outperform shallow-but-wide ones at the same parameter budget.

# Deep (3 layers, 256 units each) vs wide (1 layer, 768 units)
deep = nn.Sequential(nn.Linear(d, 256), nn.ReLU(),
                     nn.Linear(256, 256), nn.ReLU(),
                     nn.Linear(256, 256), nn.ReLU(),
                     nn.Linear(256, out))

wide = nn.Sequential(nn.Linear(d, 768), nn.ReLU(),
                     nn.Linear(768, out))
# Similar parameter count; deep wins on image/language tasks.

Keep practising

All Deep Learning questions

Explore further

Skip to content