Using Computation Graph to Understand and Implement Backpropagation

Recently, by hacking on Assignment2 of the awesome open class cs231n, I finally stop bumping my head to the wall every time I came across backpropagation, now I even feel a little bit enjoy when hacking the backpropagation process!
In this post, I will show how to use computation graph to implement both forward and backward process of Batch Normalization, Convolution and Pooling.

Batch Normalization

Batch Normalization is a recently proposed method to alleviate the pain in training neural network, especially the special care one needs when initializing weights. Here I won’t go into details to explain the method itself, the details can be found from the paper. But for a quick summary, batch normalization does two things:

  • make activation of fully connected or convolutional layer unit gaussian by normalizing the activation
  • scale and shift normalized activation so that the network can “choose” to cancel the above step

The algorithm can be summarized like this:



Computation Graph

Building a computation graph is easy, the key is to split operations into basic operations such as add, multiplication or sqrt, but keep in mind don’t split operations too broke, the larger the computation graph, the messier you will feel when backpropagate.
The computation graph of Batch Normalization is like this:
The black lines and equations are for the forward pass while the red ones are for the backward pass. Since each gate only completes a simple operation, both the forward pass and backward pass is quite straightforward. However, special attention might be paid to the summation gate, it confused me for a while. We can see this gate from equation prospect and matrix(coding) prospect:


You see, in the forward pass, you sum each row element-wisely to a single row, so each element “contributes” equally to the error. Thus, in the backpropagation pass, you only need to distribute the gradients evenly to rows like the red parts in the graph.
Another prospect is to regard this summation as a simple addition like out = row_1+row_2+...+row_m, so each row_i gets 1 as its local gradient, and you only need to multiply this local gradient with the gradients received by this gate to get row_i's true gradient.

Feedforward Pass Code

With the computation graph above, the code is quite easy to write:

batch_mean = np.sum(x,axis=0,keepdims=True) / N
batch_var = np.sum((x - batch_mean) ** 2,axis=0,keepdims=True) / N

x_minus_mu = x - batch_mean
ivar = 1.0 / np.sqrt(batch_var+eps)
x_hat = x_minus_mu * ivar
out = gamma * x_hat + beta
running_mean = momentum * running_mean + (1 - momentum) * batch_mean
running_var = momentum * running_var + (1 - momentum) * batch_var
cache = (gamma,x,batch_var,batch_mean,eps,x_minus_mu,ivar,x_hat)

Above is the train forward pass, the test forward pass is easier since your only need to normalize by running_mean and running_var accumulated during training. To avoid messy details about Batch Normalization, I leave that part for now.

Backward Pass Code

With the computation graph, the backward pass falls into a step-by-step implementation of equations in the computation graph. A trick to be used is that when you don’t know the order of multiplication, use dimension deduction. For instance, when implementing d\hat{x} = d\hat{x}\gamma * d\gamma I know that the dimensions of d\hat{x}, d\hat{x}\gamma and d\gamma are (N,D), (N,D) and (D), so the code can be written as:

dx_hat = dout * gamma

The full code for backpropagation is like this:

gamma,x,batch_var,batch_mean,eps,x_minus_mu,ivar,x_hat = cache
N, D = x.shape

dgamma = np.sum(x_hat * dout,axis=0)
dbeta = np.sum(dout,axis=0)
dx_hat = dout * gamma

dinv_sqrt_var = np.sum(dx_hat * x_minus_mu,axis=0)
dvar = -0.5*((batch_var+eps)**(-1.5)) * dinv_sqrt_var
dx_minus_mu_sqrt = (1.0 / N) * np.ones((N,D)) * dvar
dx_minus_mu = 2 * x_minus_mu * dx_minus_mu_sqrt + dx_hat * ivar
dmu = -np.sum(dx_minus_mu,axis=0)
dx = dx_minus_mu + (1.0 / N) * np.ones((N,D))*dmu

Each line is just an equation in the computation graph.

Convolution layers

Convolution layers play an important role in computer vision. By sliding a filter across an image, it captures certain features at some spatial positions. Details and intuitions behind convolution layers can be found here. Below I will focus on the implementation of convolutional layers.

There are many more efficient ways to implement convolution, but for the sake of time, here I use a way similar to convolution layers in Caffe. Assume we have filters with dimension (F,C,HH,WW) and an image with dimension (C,H,W). Then the height and width of output feature map is:

H^\prime = (H - HH) / stride + 1

W^\prime = (W - WW) / stride + 1
The main steps of implementing convolution are as follows:

  1. Lay out each (HH,WW) patch in the image to a row and concatenate all rows to a single (H^ \prime * W^ \prime, HH*WW*C) matrix.
  2. Lay out each (C,HH,WW) filter to a column and concatenate all rows to a single (HH*WW*C, F) matrix.
  3. Multiply the two matrixes. Then you will get a (H^\prime*W^\prime, F) matrix, each column of this matrix is a feature map.
  4. Add bias to the feature map matrix.
  5. Reshape the feature map matrix to a (F,H^\prime,W^\prime) matrix, then we get our final F convolutional feature maps!

It will be much clear to show above steps in a computation graph. However, the computation graph of convolution is slightly different from fully connected layers or batch normalization, since it often deals with 4D matrixes and include a lot of reshape operations. I found one helpful way is to directly use matrix instead of nodes in the computation graph.
The computation graph of convolution is like this, each of the above 5 steps are shown in the graph:


Forward Pass Code

Now I will show step by step how to implement the forward pass:

  • STEP1 The operation in step1 is typically referred to as im2col. It lays out overlapped patches in a image into a matrix. Here I use a naive for loop to implement this operation:
def im2col(x,hh,ww,stride):

x: image matrix to be translated into columns, (C,H,W)
hh: filter height
ww: filter width
stride: stride
col: (new_h*new_w,hh*ww*C) matrix, each column is a cube that will convolve with a filter
new_h = (H-hh) // stride + 1, new_w = (W-ww) // stride + 1

c,h,w = x.shape
new_h = (h-hh) // stride + 1
new_w = (w-ww) // stride + 1
col = np.zeros([new_h*new_w,c*hh*ww])

for i in range(new_h):
for j in range(new_w):
patch = x[...,i*stride:i*stride+hh,j*stride:j*stride+ww]
col[i*new_w+j,:] = np.reshape(patch,-1)
return col
  • STEP2 Operations in step2 is a naive reshape operation in numpy:
filter_col = np.reshape(w,(F,-1))
  • STEP3&4 Step3 and Step4 is just matrix multiplication and addition:
mul = + b
  • STEP 5 Step5 is a typical operation called col2im, it rearranges a matrix back into blocks.
def col2im(mul,h_prime,w_prime,C):
mul: (h_prime*w_prime*w,F) matrix, each col should be reshaped to C*h_prime*w_prime when C>0, or h_prime*w_prime when C = 0
h_prime: reshaped filter height
w_prime: reshaped filter width
C: reshaped filter channel, if 0, reshape the filter to 2D, Otherwise reshape it to 3D
if C == 0: (F,h_prime,w_prime) matrix
Otherwise: (F,C,h_prime,w_prime) matrix
F = mul.shape[1]
if(C == 1):
out = np.zeros([F,h_prime,w_prime])
for i in range(F):
col = mul[:,i]
out[i,:,:] = np.reshape(col,(h_prime,w_prime))
out = np.zeros([F,C,h_prime,w_prime])
for i in range(F):
col = mul[:,i]
out[i,:,:] = np.reshape(col,(C,h_prime,w_prime))
return out

Above are the forward pass of convolution, combine them all together we get our final forward function:

def conv_forward_naive(x, w, b, conv_param):
A naive implementation of the forward pass for a convolutional layer.

The input consists of N data points, each with C channels, height H and width
W. We convolve each input with F different filters, where each filter spans
all C channels and has height HH and width HH.

- x: Input data of shape (N, C, H, W)
- w: Filter weights of shape (F, C, HH, WW)
- b: Biases, of shape (F,)
- conv_param: A dictionary with the following keys:
- 'stride': The number of pixels between adjacent receptive fields in the
horizontal and vertical directions.
- 'pad': The number of pixels that will be used to zero-pad the input.

Returns a tuple of:
- out: Output data, of shape (N, F, H', W') where H' and W' are given by
H' = 1 + (H + 2 * pad - HH) / stride
W' = 1 + (W + 2 * pad - WW) / stride
- cache: (x, w, b, conv_param)
out = None
pad_num = conv_param['pad']
stride = conv_param['stride']
N,C,H,W = x.shape
F,C,HH,WW = w.shape
H_prime = (H+2*pad_num-HH) // stride + 1
W_prime = (W+2*pad_num-WW) // stride + 1
out = np.zeros([N,F,H_prime,W_prime])
for im_num in range(N):
im = x[im_num,:,:,:]
im_pad = np.pad(im,((0,0),(pad_num,pad_num),(pad_num,pad_num)),'constant')
im_col = im2col(im_pad,HH,WW,stride)
filter_col = np.reshape(w,(F,-1))
mul = + b
out[im_num,:,:,:] = col2im(mul,H_prime,W_prime,1)
cache = (x, w, b, conv_param)
return out, cache

Backward Pass Code

Below I will go step by step to show how to implement the backward pass, we will start from step5 downto step1.

  • STEP1 step5 is a reshape operation, so in the backward pass, we need to assign the gradients to its corresponding element in the input matrix. This can be easily implemented using a numpy reshape:
dbias_sum = np.reshape(dout_i,(F,-1))
  • STEP2 step4 is a addition gate, thus this gate distributes gradients equally to its inputs.
#bias_sum = mul + b
db += np.sum(dbias_sum,axis=0)
dmul = dbias_sum
  • STEP3 step3 is a multiplication gate, it multiplies gradients with “the other factor” to get local gradients.
#mul = im_col * filter_col
dfilter_col = (im_col.T).dot(dmul)
dim_col =
  • STEP4 step2 is another reshape operation, so we use reshape in numpy to reshape the gradients back to its original shape. A special note is that since the weights convolved with a batch of images, so their gradients should be accumulated across all images.
dw += np.reshape(dfilter_col.T,(F,C,HH,WW))
  • STEP5 step1 is the im2col operation, its backpropagation is a little bit trickier. You reshape each row to a (C,HH,WW) patch, but these patches might overlap, when they do, you need to accumulate(add) their overlapped parts to calculate the gradients. I wrote a col2im_back function to implement this function:
def col2im_back(dim_col,h_prime,w_prime,stride,hh,ww,c):
dim_col: gradients for im_col,(h_prime*w_prime,hh*ww*c)
h_prime,w_prime: height and width for the feature map
strid: stride
hh,ww,c: size of the filters
dx: Gradients for x, (C,H,W)
H = (h_prime - 1) * stride + hh
W = (w_prime - 1) * stride + ww
dx = np.zeros([c,H,W])
for i in range(h_prime*w_prime):
row = dim_col[i,:]
h_start = (i / w_prime) * stride
w_start = (i % w_prime) * stride
dx[:,h_start:h_start+hh,w_start:w_start+ww] += np.reshape(row,(c,hh,ww))
return dx

Combine all 5 steps above, we get our final backpropagation:

def conv_backward_naive(dout, cache):
A naive implementation of the backward pass for a convolutional layer.

- dout: Upstream derivatives.
- cache: A tuple of (x, w, b, conv_param) as in conv_forward_naive

Returns a tuple of:
- dx: Gradient with respect to x
- dw: Gradient with respect to w
- db: Gradient with respect to b
dx, dw, db = None, None, None

x, w, b, conv_param = cache
pad_num = conv_param['pad']
stride = conv_param['stride']
N,C,H,W = x.shape
F,C,HH,WW = w.shape
H_prime = (H+2*pad_num-HH) // stride + 1
W_prime = (W+2*pad_num-WW) // stride + 1

dw = np.zeros(w.shape)
dx = np.zeros(x.shape)
db = np.zeros(b.shape)

for i in range(N):
im = x[i,:,:,:]
im_pad = np.pad(im,((0,0),(pad_num,pad_num),(pad_num,pad_num)),'constant')
im_col = im2col(im_pad,HH,WW,stride)
filter_col = np.reshape(w,(F,-1)).T

dout_i = dout[i,:,:,:]
dbias_sum = np.reshape(dout_i,(F,-1))
dbias_sum = dbias_sum.T

#bias_sum = mul + b
db += np.sum(dbias_sum,axis=0)
dmul = dbias_sum

#mul = im_col * filter_col
dfilter_col = (im_col.T).dot(dmul)
dim_col =

dx_padded = col2im_back(dim_col,H_prime,W_prime,stride,HH,WW,C)
dx[i,:,:,:] = dx_padded[:,pad_num:H+pad_num,pad_num:W+pad_num]
dw += np.reshape(dfilter_col.T,(F,C,HH,WW))
return dx, dw, db


Hugh! Finally we complete all backpropagations. A key intuition when implementing backpropagation is to find the “error source” and assign the gradients to it. Imagine you are a manager, and all the workers(parameters) in your network contributes some work to the final output. However, they might did something wrong to cause a wrong result in the end, so in backpropagation, your job is to find that worker(parameter) and tell it how wrong it did and make it corrects weights its error!




[3] Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015.



Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s