Stochastic Gradient Descent

←Back to Tech Tree

inventorycoverage

Stochastic Gradient Descent #

OptimizationDifficulty: ★★★★☆Depth: 8Unlocks: 7

Gradient descent with random sample estimates. Mini-batches.

Interactive Visualization #

⏮◀◀▶▶STEP0.25x1xZOOM

t=0s

Core Concepts #

Key Symbols & Notation #

g_hat_t : stochastic gradient estimate at iteration t (computed from the sampled example or mini-batch)

Essential Relationships #

Prerequisites (2) #

Gradient Descent6 atomsExpected Value5 atoms

Unlocks (4) #

Deep Learninglvl 5Policy Gradient Methodslvl 5Diffusion Modelslvl 5Meta-Learninglvl 5

Advanced Learning Details

Graph Position #

89

Depth Cost

7

Fan-Out (ROI)

4

Bottleneck Score

8

Chain Length

Cognitive Load #

5

Atomic Elements

44

Total Elements

L3

Percentile Level

L3

Atomic Level

All Concepts (19) #

Teaching Strategy #

Quick unlock - significant prerequisite investment but simple final step. Verify prerequisites first.

Full-batch gradient descent is like steering a ship using the average of all ocean currents you can measure—accurate, but slow and expensive. Stochastic Gradient Descent (SGD) steers using a small, randomly sampled set of measurements each step—noisier, but dramatically faster per update, and often better at finding solutions that generalize.

TL;DR:

Stochastic Gradient Descent minimizes an empirical risk by repeatedly updating parameters using a random (single-example or mini-batch) estimate of the true gradient: θₜ₊₁ = θₜ − ηₜ ĝₜ, where ĝₜ is computed from a randomly sampled example/mini-batch. The noise makes updates cheap and can help escape poor regions, but requires careful choices of learning rate and batch size for stable convergence.

What Is Stochastic Gradient Descent? #

Why we need something beyond full gradient descent #

Suppose you are minimizing a loss over a dataset of N examples:

A standard objective in supervised learning is the empirical risk (average loss):

L(θ) = (1/N) ∑ᵢ₌₁ᴺ ℓᵢ(θ)

Full-batch gradient descent computes the exact gradient of L:

∇L(θ) = (1/N) ∑ᵢ₌₁ᴺ ∇ℓᵢ(θ)

and then updates:

θₜ₊₁ = θₜ − ηₜ ∇L(θₜ)

This is conceptually simple, but computationally expensive when N is large, because each step requires scanning the whole dataset.

Definition: SGD in one sentence #

Stochastic Gradient Descent (SGD) is gradient descent where the gradient ∇L(θ) is replaced by a stochastic (random) estimator ĝₜ computed from a randomly sampled single example or a mini-batch.

The SGD update is:

θₜ₊₁ = θₜ − ηₜ ĝ

where ĝₜ ≈ ∇L(θₜ) but is computed cheaply.

The stochastic gradient estimator ĝ#

At iteration t, we sample:

Mini-batches are the modern default.

Why this is even “correct”: unbiasedness via expectation #

A key idea is that ĝₜ is often an unbiased estimator of the true gradient.

Assume iₜ is uniform. Then:

E[ĝₜ | θₜ] = E[∇ℓᵢₜ(θₜ) | θₜ]

Because iₜ is uniform:

E[∇ℓᵢₜ(θₜ) | θₜ]

= (1/N) ∑ᵢ₌₁ᴺ ∇ℓᵢ(θₜ)

= ∇L(θₜ)

So, on average, SGD points in the same direction as full gradient descent.

For a mini-batch sampled uniformly (with or without replacement), you similarly get:

E[ĝₜ | θₜ] = ∇L(θₜ)

What prerequisites does SGD really depend on? #

You listed that the learner already knows Gradient Descent and Expected Value. To understand SGD deeply, it also helps to have:

SGD is not a different kind of optimization step; it’s the same step using a randomized gradient estimate.

Intuition: the “noise” is a feature, not only a bug #

SGD introduces gradient noise:

ĝₜ = ∇L(θₜ) + ξ

where E[ξₜ | θₜ] = 0.

This noise has two major effects:

  1. 1)Cheaper steps: each update is far less expensive than summing over N examples.
  2. 2)Exploration: noise can help move through flat regions or escape shallow/poor minima in nonconvex landscapes (common in deep learning).

But noise also causes:

That tension—cheap noisy steps vs stable accurate steps—is the heart of SGD.

Core Mechanic 1: Mini-batches, Noise, and the Stochastic Gradient #

Why mini-batches exist (not just “because GPUs”) #

Single-example SGD is maximally noisy: one example can be atypical, leading to a gradient that points away from the true average direction.

Mini-batches reduce noise by averaging several example gradients.

Let Bₜ be a mini-batch of size m. Define:

ĝₜ = (1/m) ∑_{i ∈ Bₜ} ∇ℓᵢ(θₜ)

This is still stochastic (depends on which batch you drew), but typically has smaller variance than a single-example gradient.

A variance viewpoint (the key conceptual lever) #

Think of per-example gradients as random vectors:

G = ∇ℓᵢ(θ)

where i is a random index uniform over {1,…,N}. Then:

E[G] = ∇L(θ)

The mini-batch gradient is an average of m i.i.d. (approximately i.i.d.) samples:

ĝ = (1/m) ∑_{k=1}^m G

A core statistical fact: averaging reduces variance.

Very informally (scalar intuition), if Var(G) = σ², then:

Var(ĝ) = σ² / m

For vectors, you can think in terms of covariance matrices; the same “divide by m” scaling appears under independence assumptions.

Practical meaning:

Why the noise can help (conceptual clarity) #

It’s tempting to view noise as purely harmful. But in large nonconvex problems (deep nets), the objective surface has:

Noise can:

A useful mental model is that SGD behaves like gradient descent plus random perturbations.

Epochs, iterations, and data order #

Terminology you will see constantly:

If batch size is m, then steps per epoch ≈ N/m.

A common training loop:

  1. 1)Shuffle data (to avoid correlated batches)
  2. 2)Split into mini-batches
  3. 3)For each mini-batch B:

The shuffle matters: if data are ordered (e.g., all cats then all dogs), batches become biased and SGD can behave poorly.

Comparing full-batch, mini-batch, and single-example #

MethodGradient usedCost per stepNoise levelTypical use
Full-batch GD∇L = (1/N)∑∇ℓᵢHighNoneSmall datasets, convex problems
Mini-batch SGD(1/m)∑_{i∈B}∇ℓᵢMediumMedium (↓ with m)Default in deep learning
Single-example SGD∇ℓᵢLowHighRare alone; sometimes online learning

A subtle but important point: “SGD” often means “mini-batch SGD” #

In modern ML, people say “SGD” even when m = 128 or 1024. The defining feature is still the same: ĝₜ is computed from a random subset, not the full dataset.

Measuring noise with the gradient norm #

A quick diagnostic idea:

But there is no universal best setting: compute budget, model size, and data complexity all interact.

Core Mechanic 2: Learning Rate, Convergence Behavior, and Practical Stability #

Why learning rate matters more in SGD than in full-batch GD #

In full-batch GD, the gradient is deterministic for a given θ. In SGD, the update direction changes randomly each step.

If η is too large, noise can cause the iterates to bounce around or diverge.

If η is too small, progress becomes extremely slow.

So in SGD, η is not just “step size”; it controls the trade-off between:

A useful decomposition: drift + noise #

Write:

ĝₜ = ∇L(θₜ) + ξₜ, with E[ξₜ|θₜ] = 0.

Then the update is:

θₜ₊₁

= θₜ − ηₜ ∇L(θₜ) − ηₜ ξ

You can read this as:

If ηₜ stays constant, the random component never truly disappears; you often converge to a neighborhood around a minimizer.

If ηₜ decays (ηₜ ↓ 0), noise influence shrinks and you can converge more tightly.

Classical convergence intuition (high level) #

Under convexity and smoothness assumptions (and unbiased gradients with bounded variance), SGD can achieve convergence rates like:

You don’t need the full proofs to use SGD, but the intuition is essential:

Learning-rate schedules (what people do) #

Common schedules:

  1. 1)Step decay: η drops by a factor (e.g., ×0.1) at fixed epochs.
  2. 2)Cosine decay: smooth decay from η₀ to near 0.
  3. 3)Warmup: start with a small η and increase over the first few epochs.

Warmup helps when early gradients are unstable (common in deep nets).

Batch size and learning rate are linked #

Bigger batches reduce gradient variance, so you can often use larger η.

A rough heuristic sometimes used is “linear scaling”:

η ∝ m

when increasing batch size m, at least within some range.

But this is not a law; it depends on model, optimizer (momentum/Adam), and data.

Momentum (brief, because it often comes with SGD) #

Many practitioners mean “SGD with momentum.” Momentum reduces variance by averaging gradients over time.

A standard momentum form:

vₜ₊₁ = β vₜ + ĝ

θₜ₊₁ = θₜ − η vₜ₊₁

with β ∈ [0,1). Typical β is 0.9.

Conceptually:

Even if you don’t use momentum in a first implementation, you should recognize it as a key stabilization tool.

Convergence “look and feel” in practice #

When you plot training loss vs steps:

Often you care about loss vs wall-clock time, not loss vs number of steps.

SGD wins because it can take many more steps per second.

When SGD is especially appropriate #

SGD shines when:

When full-batch methods may be fine:

A practical checklist for stability #

If training diverges or is wildly unstable:

SGD is simple, but its behavior is tightly coupled to these choices.

Application/Connection: SGD in Linear Regression and in Deep Learning Workflows #

Why linear regression is the perfect “glass box” example #

Linear regression lets you see SGD without distractions:

Yet the training loop is structurally the same as in deep learning.

Empirical risk minimization template #

Most ML training loops fit this template:

L(θ) = (1/N) ∑ᵢ ℓᵢ(θ)

SGD uses the per-example (or per-batch) gradient to approximate ∇L.

This same structure appears in:

So once SGD is clear, “training deep networks” becomes less mysterious: the core loop is the same, only the gradient computation is more complex.

Practical hyperparameters: what you actually tune #

The big three:

  1. 1)Learning rate η
  2. 2)Batch size m
  3. 3)Number of epochs / steps

Secondary but common:

A compact comparison of trade-offs:

KnobIncrease it →BenefitsCosts/Risks
Learning rate ηbigger stepsfaster initial progressdivergence, overshooting
Batch size mlower noisesmoother training, better hardware utilizationless regularization effect, more memory
Epochsmore passesbetter fitoverfitting, time

Interpreting “generalization” effects #

Empirically, small-batch SGD often finds solutions that generalize better than large-batch training for the same compute.

One intuitive story: noise biases the optimizer toward flatter minima (regions where small parameter changes don’t increase loss much). Flatter minima often correlate with better generalization.

This is not a universal theorem, but it’s a useful working intuition.

Connections to the nodes you unlock #

SGD is the workhorse behind many advanced methods:

A final conceptual anchor #

If you remember only one thing:

SGD is not “a different update rule.” It is gradient descent using a random estimate of the gradient.

Everything else—mini-batches, schedules, momentum, stability tricks—is about managing the consequences of that randomness.

Worked Examples (3) #

Single-example SGD on 1D linear regression (fully computed by hand) #

We fit a 1D linear model ŷ = w x (no bias) using squared error. Dataset has two points: (x₁=1, y₁=2), (x₂=2, y₂=0). Per-example loss: ℓᵢ(w) = (1/2)(w xᵢ − yᵢ)². Empirical loss: L(w) = (1/2)·(1/2)∑ᵢ (w xᵢ − yᵢ)² = (1/4)∑ᵢ (w xᵢ − yᵢ)². We will do SGD with learning rate η = 0.1, starting at w₀ = 0. We sample one example per step.

  1. Compute the per-example gradient.

    ℓᵢ(w) = (1/2)(w xᵢ − yᵢ)²

    dℓᵢ/dw = (1/2)·2(w xᵢ − yᵢ)·xᵢ

    = (w xᵢ − yᵢ) xᵢ

  2. Step t=0: sample example 1 (i₀=1).

    At w₀=0:

    ĝ₀ = dℓ₁/dw = (w₀ x₁ − y₁) x₁

    = (0·1 − 2)·1

    = −2

    Update:

    w₁ = w₀ − η ĝ₀ = 0 − 0.1(−2) = 0.2

  3. Step t=1: sample example 2 (i₁=2).

    At w₁=0.2:

    ĝ₁ = dℓ₂/dw = (w₁ x₂ − y₂) x₂

    = (0.2·2 − 0)·2

    = (0.4)·2

    = 0.8

    Update:

    w₂ = w₁ − η ĝ₁ = 0.2 − 0.1(0.8) = 0.12

  4. Step t=2: sample example 1 again (i₂=1).

    At w₂=0.12:

    ĝ₂ = (w₂·1 − 2)·1 = (0.12 − 2) = −1.88

    Update:

    w₃ = 0.12 − 0.1(−1.88) = 0.308

  5. Compare with the full gradient direction at w₂=0.12.

    Full gradient of empirical loss L(w) = (1/4)∑(w xᵢ − yᵢ)²:

    First compute d/dw of (w xᵢ − yᵢ)² = 2(w xᵢ − yᵢ)xᵢ.

    So:

    dL/dw = (1/4)∑ 2(w xᵢ − yᵢ)xᵢ

    = (1/2)∑ (w xᵢ − yᵢ)xᵢ

    At w=0.12:

    For i=1: (0.12·1 − 2)·1 = −1.88

    For i=2: (0.12·2 − 0)·2 = (0.24)·2 = 0.48

    Sum = −1.88 + 0.48 = −1.40

    So dL/dw = (1/2)(−1.40) = −0.70

    A full-batch GD step would move w upward by 0.1·0.70 = 0.07 (to 0.19).

    SGD’s step depended on which example we sampled; it moved more (to 0.308) when it saw the high-error example 1.

Insight: SGD makes progress using gradients from individual examples. Each step is cheap but noisy: the update direction can differ substantially from the full gradient. Over time, the randomness averages out, but the trajectory is jagged.

Mini-batch SGD for logistic regression (vector form, showing the stochastic gradient estimator) #

Binary classification with logistic regression. For each example (xᵢ, yᵢ) with yᵢ ∈ {0,1}, model pᵢ = σ(wxᵢ) where σ(z)=1/(1+e^(−z)). Per-example loss (cross-entropy): ℓᵢ(w) = −yᵢ log pᵢ − (1−yᵢ) log(1−pᵢ). We use a mini-batch Bₜ of size m to compute ĝₜ and update w.

  1. Compute gradient of per-example loss.

    Let zᵢ = wxᵢ, pᵢ = σ(zᵢ).

    A standard result:

    ∇ℓᵢ(w) = (pᵢ − yᵢ) x

    Derivation sketch (showing the key chain rule steps):

    ∂ℓᵢ/∂zᵢ = pᵢ − yᵢ

    and ∂zᵢ/∂w = x

    So ∇ℓᵢ(w) = (∂ℓᵢ/∂zᵢ)(∂zᵢ/∂w) = (pᵢ − yᵢ)x

  2. Define the empirical objective:

    L(w) = (1/N) ∑ᵢ ℓᵢ(w)

    Full gradient:

    ∇L(w) = (1/N) ∑ᵢ (pᵢ − yᵢ)x

  3. Mini-batch stochastic gradient estimator at step t:

    Sample Bₜ uniformly, |Bₜ| = m.

    ĝₜ = (1/m) ∑_{i∈Bₜ} ∇ℓᵢ(wₜ)

    = (1/m) ∑_{i∈Bₜ} (pᵢ − yᵢ) x

  4. SGD update:

    wₜ₊₁ = wₜ − ηₜ ĝ

    = wₜ − ηₜ (1/m) ∑_{i∈Bₜ} (pᵢ − yᵢ) x

  5. Unbiasedness check (conceptual):

    If Bₜ is sampled uniformly, then

    E[ĝₜ | wₜ] = ∇L(wₜ)

    So in expectation, the update is a descent direction for the full objective.

Insight: In deep learning, backprop computes ∇ℓᵢ(θ) for each sample in a mini-batch. Mini-batch SGD is simply the average of those per-sample gradients, used as ĝₜ in the same update rule.

Seeing batch size reduce noise (a tiny numeric thought experiment) #

Assume at some θ, per-example gradients along one coordinate behave like a random variable G with mean μ and variance σ². We compare a single-example gradient estimate ĝ₁ = G to a mini-batch estimate ĝₘ = (1/m)∑_{k=1}^m Gₖ (independent samples).

  1. Compute expectation:

    E[ĝₘ] = E[(1/m)∑ Gₖ] = (1/m)∑ E[Gₖ] = (1/m)·m·μ = μ

  2. Compute variance:

    Var(ĝₘ) = Var((1/m)∑ Gₖ)

    = (1/m²) Var(∑ Gₖ)

    Assuming independence:

    Var(∑ Gₖ) = ∑ Var(Gₖ) = m σ²

    So:

    Var(ĝₘ) = (1/m²)(m σ²) = σ²/m

  3. Interpretation:

    If you quadruple batch size (m → 4m), the variance halves twice:

    σ²/(4m) = (1/4)(σ²/m).

    So the stochastic gradient concentrates around the true gradient as batch size increases.

Insight: Mini-batches are a statistical averaging device. They trade extra compute per step for a cleaner gradient direction, often allowing a larger learning rate or more stable optimization.

Key Takeaways #

Common Mistakes #

Practice #

easy

You have L(θ) = (1/N)∑ᵢ ℓᵢ(θ). You sample i uniformly and set ĝ = ∇ℓᵢ(θ). Show that E[ĝ] = ∇L(θ).

Hint: Write the expectation as a sum over i with probability 1/N.

Show solution

Because P(i=k)=1/N,

E[ĝ] = ∑_{k=1}^N (1/N) ∇ℓ_k(θ)

= (1/N)∑_{k=1}^N ∇ℓ_k(θ)

= ∇[(1/N)∑_{k=1}^N ℓ_k(θ)]

= ∇L(θ).

medium

Consider 1D linear regression ŷ = w x with per-example loss ℓᵢ(w) = (1/2)(w xᵢ − yᵢ)². For a mini-batch B of size m, write the mini-batch gradient estimator ĝ(w) and the SGD update. Then compute ĝ(w) explicitly for batch B = { (x=1,y=2), (x=3,y=1) } at w=0.

Hint: First compute dℓᵢ/dw = (w xᵢ − yᵢ)xᵢ, then average across the batch.

Show solution

Per-example gradient: dℓᵢ/dw = (w xᵢ − yᵢ)xᵢ.

Mini-batch estimator:

ĝ(w) = (1/m)∑_{i∈B} (w xᵢ − yᵢ)xᵢ.

Update: w⁺ = w − η ĝ(w).

Here m=2 and w=0.

For (1,2): (0·1−2)·1 = −2.

For (3,1): (0·3−1)·3 = −3.

Average: ĝ(0) = (1/2)(−2 + −3) = −2.5.

So w⁺ = 0 − η(−2.5) = 2.5η.

hard

Assume a scalar per-example gradient random variable G has Var(G)=σ². For a mini-batch average ĝₘ = (1/m)∑_{k=1}^m Gₖ with independent samples, derive Var(ĝₘ). If σ²=16, compare the standard deviation of ĝ₁, ĝ₄, and ĝ₁₆.

Hint: Use Var(aX)=a²Var(X) and Var(∑ independent)=∑Var.

Show solution

Var(ĝₘ) = Var((1/m)∑ Gₖ) = (1/m²) Var(∑ Gₖ).

Independence gives Var(∑ Gₖ)=∑Var(Gₖ)=mσ².

So Var(ĝₘ) = (1/m²)(mσ²)=σ²/m.

Standard deviation is √(σ²/m)=σ/√m.

With σ²=16, σ=4.

Connections #

Quality: A (4.4/5)

← back to treebrowse all →