Gradient descent is a method of optimizing model weights that’s commonly used when there’s no direct closed form solution; other online methods also use gradient descent when closed form solutions are too expensive to compute.

The observation at the core of gradient descent is that the gradient of our objective points toward the direction of steepest ascent; that is, the gradient defining the “tangent” to our objective is oriented so that moving “up” the tangent increases our objective the most.

Commonly, our goal in machine learning is to minimize the objective. Thus, our gradient update step subtracts the gradient

for some hyperparameter chosen manually. Intuitively, this is changing our parameters toward the direction of “steepest descent.” Below is a graphical representation with two weights and , with the inner-most blue ring representing the global minimum .

An alternate interpretation of gradient descent is through the lens of 👠 Constrained Optimization, where our objective is to minimize a linearization of our objective in some neighborhood defined by our step size; the constraint controls how far we move with the gradient. Formally, our problem is

where our objective is the gradient direction applied to our change in parameters; that is, if defines the space where we can move (constrained by our inequality), the defines how much we change for some , and we seek to minimize this change (going in the steepest descent). Finally, it can be shown that is related to our constraint by

so our constraint is indeed determined by .

Update Timing

The gradient is calculated for each data point. Due to computational costs, there are multiple timings for when to apply this gradient to update the weights .

  1. Batch gradient descent updates weights after going through the entire dataset
  2. Stochastic gradient descent updates weights after computing the derivative for a single datapoint , resulting in oscillations but decreasing convergence duration
  3. Mini-batch gradient descent is a balance between the two, updating weights after checking datapoints

Momentum

Convergence with standard gradient descent may be slow if the curvature is poorly scaled, like a valley. Momentum is an additional term that remembers what happened in previous update steps; incorporating this into our algorithm dampens oscillations and smoothes out the updates.

Mathematically, momentum is calculated as an exponential moving average,

Then, our update rule becomes