Gradient Descent is perhaps the most intuitive of all optimization algorithms. Imagine you're standing on the side of a mountain and want to reach the bottom. You'd probably do something like this,
- Look around you and see which way points the most downwards
- Take a step in that direction, then repeat
Well that's Gradient Descent!
How does it work?
So how do we frame Gradient Descent mathematically? As usual, we define our problem in terms of minimizing a function,
We assume that
Given this, Gradient Descent is simply the following,
Input: initial iterate
- For
- if converged, return
- Compute the gradient of
at ,
- if converged, return
The initial iterate
A Small Example
Let's look at Gradient Descent in action. We'll use the objective function



Why does it work?
Gradient Descent works, but it isn't guaranteed to find the optimal solution
to our problem (that is,
is convex and finite for all- a finite solution
exists is Lipschitz continuous with constant . If is twice differentiable, this means that the largest eigenvalue of the Hessian is bounded by ( ). But more directly, there must be an such that,
Assumptions So what do these assumptions give us? Assumption 1 tells us
that
Assumption 3 also tells us that
Proof Outline Now let's dive into the proof. Our plan of attack is as
follows. First, we upper bound the error
Step 1: upper bounding
Step 2: Upper bound
Step 3: Upper bound
Thus, we can conclude that if we want to reach an error tolerance
When should I use it?
Because it's so easy to implement, Gradient Descent should be the first thing
to try if you need to implement an optimization from scratch. So long as you
calculate the gradient right, it's practically impossible to make a mistake.
If you have access to an automatic differentiation library to do
the gradient computation for you, even better! In addition, Gradient Descent
requires a minimal memory footprint, making it ideal for problems where
As we'll see in later posts, Gradient Descent trades memory for speed. The
number of iterations required to reach a desired accuracy is actually quite
large if you want accuracy on the order of
Extensions
Step Size The proof above relies on a constant step size, but quicker
convergence can be obtained when using Line Search, wherein
Checking Convergence We have shown that the algorithm's error at iteration
References
Proof of Convergence The proof of convergence for Gradient Descent is adapted from slide 1-18 of of UCLA's EE236C lecture on Gradient Methods.
Line Search The algorithm for Backtracking Line Search, a smart method for choosing step sizes, can be found on slide 10-6 of UCLA's EE236b lecture on unconstrained optimization.
Reference Implementation
Here's a quick implementation of gradient descent,
def gradient_descent(gradient, x0, alpha, n_iterations=100):
"""Gradient Descent
Parameters
----------
gradient : function
Computes the gradient of the objective function at x
x0 : array
initial value for x
alpha : function
function computing step sizes
n_iterations : int, optional
number of iterations to perform
Returns
-------
xs : list
intermediate values for x
"""
xs = [x0]
for t in range(n_iterations):
x = xs[-1]
g = gradient(x)
x_plus = x - alpha(t) * g
xs.append(x_plus)
return xs
# This generates the plots that appear above
if __name__ == '__main__':
import os
import numpy as np
import pylab as pl
import yannopt.plotting as plotting
### GRADIENT DESCENT ###
# problem definition
function = lambda x: x ** 4 # the function to minimize
gradient = lambda x: 4 * x **3 # its gradient
step_size = 0.05
x0 = 1.0
n_iterations = 10
# run gradient descent
iterates = gradient_descent(gradient, x0, lambda x: step_size, n_iterations=n_iterations)
### PLOTTING ###
plotting.plot_iterates_vs_function(iterates, function,
path='figures/iterates.png', y_star=0.0)
plotting.plot_iteration_vs_function(iterates, function,
path='figures/convergence.png', y_star=0.0)
# make animation
try:
os.makedirs('figures/animation')
except OSError:
pass
for t in range(n_iterations):
x = iterates[t]
x_plus = iterates[t+1]
f = function
g = gradient
f_hat = lambda y: f(x) + g(x) * (y - x)
x_min = (0-f(x))/g(x) + x
x_max = (1.1-f(x))/g(x) + x
pl.figure()
pl.plot(np.linspace(0, 1.1, 100), function(np.linspace(0, 1.1, 100)), alpha=0.2)
pl.xlim([0, 1.1])
pl.ylim([0, 1.1])
pl.xlabel('x')
pl.ylabel('f(x)')
pl.plot([x_min, x_max], [f_hat(x_min), f_hat(x_max)], '--', alpha=0.2)
pl.scatter([x, x_plus], [f(x), f(x_plus)], c=[0.8, 0.2])
pl.savefig('figures/animation/%02d.png' % t)
pl.close()