Mathematics behind Exploding and Vanishing Gradients
have you ever noticed in a deep learning codebase, while initializing weights, we initialize them normally with a mean of zero and variance one but then divide each weight by the sq root of embedding dimensions of inputs ( which is called Xaviers Initialization formally ). why?
what is the issue with just Normal (0, 1) initialization? the problem is it causes exploding and vanishing gradients over multiple layers and even while backprop.
we are going to prove how does dividing by sq root of embedding dimension helps Statistically
Let’s consider a single neuron in one layer.
Its output y is a linear combination of its inputs x and weights w.
y = w1*x1 + w2*x2 + ... + wn*xn
Here, n is the number of input dimensions (let’s call it n_in).
We will make three simple assumptions about our inputs and weights.
The inputs x are normalized. mean(x) = 0 and variance(x) = 1.
The weights w are initialized with mean(w) = 0 and variance(w) = 1. This is the Normal(0, 1) case we want to analyze.
Inputs xi and weights wi are all independent of each other.
Our goal is to find the variance of the output y. A stable network should have outputs with a variance close to 1. If the variance keeps increasing layer by layer, the outputs will explode to huge values.
Let’s calculate the variance of y.
Var(y) = Var(w1*x1 + w2*x2 + ... + wn*xn)
A property of variance is that for independent variables, the variance of a sum is the sum of their variances.
Var(y) = Var(w1*x1) + Var(w2*x2) + ... + Var(wn*xn)
Another property of variance for two independent variables A and B with zero mean is: Var(A*B) = Var(A) * Var(B).
Applying this:
Var(y) = (Var(w1)*Var(x1)) + (Var(w2)*Var(x2)) + ... + (Var(wn)*Var(xn))
We assumed Var(wi) = 1 and Var(xi) = 1 for all i.
Var(y) = (1 * 1) + (1 * 1) + ... + (1 * 1)
Var(y) = 1 + 1 + ... + 1 (n_in times)
Var(y) = n_in
This is the problem.
The variance of the output of the layer is equal to the number of input dimensions.
If the input dimension n_in is 512, the variance of the output is 512. The standard deviation is sqrt(512), which is about 22.
To understand what this means, we need to know what standard deviation represents. Standard deviation is a measure of how spread out numbers are from their mean. Our mean is zero. For a normal distribution, about 68% of all values lie within one standard deviation of the mean. About 95% lie within two standard deviations.
Our inputs had a mean of 0 and a standard deviation of 1. So most input values were between -2 and 2.
Our outputs now have a mean of 0 and a standard deviation of 22. So most output values will be between -44 and 44.
The typical magnitude of an output value is now 22 times larger than the typical magnitude of an input value.
Pass this through a few layers. Let’s prove how the variance explodes.
Layer 1:
Input to Layer 1: x_l1. Var(x_l1) = 1.
Output of Layer 1: y_l1.
As we proved, Var(y_l1) = n_in * Var(w_l1) * Var(x_l1) = 512 * 1 * 1 = 512.
Layer 2:
The input to Layer 2 is the output of Layer 1. So, x_l2 = y_l1.
The variance of the input to Layer 2 is Var(x_l2) = Var(y_l1) = 512.
The output of Layer 2 is y_l2. We use the same formula for variance.
Var(y_l2) = n_in * Var(w_l2) * Var(x_l2)
The weights of Layer 2 are also initialized from Normal(0, 1), so Var(w_l2) = 1.
Var(y_l2) = 512 * 1 * 512 = 262,144.
The variance of the output from Layer 2 is (n_in)^2. The standard deviation is n_in = 512.
After just two layers, the typical magnitude of the output values is 512 times larger than the original input values. After a third layer, it would be (n_in)^3.
The numbers explode. These large outputs are then passed to an activation function (like tanh or sigmoid). These functions saturate for large inputs (their output is close to 1 or -1).
When the activation function is saturated, its gradient is almost zero. During backpropagation, these zero gradients are multiplied back through the network. The weights do not get updated. The network does not learn.
Now, let’s fix it.
Our goal is to make the output variance Var(y) equal to 1.
We saw that: Var(y) = n_in * Var(w) * Var(x)
Assuming Var(x) = 1, we have: Var(y) = n_in * Var(w)
We want Var(y) = 1.
So, 1 = n_in * Var(w)
This means we must choose our weights w to have a variance of:
Var(w) = 1 / n_in
How do we get a random variable with Var(w) = 1 / n_in?
We start with a variable W_standard from a standard normal distribution Normal(0, 1). It has Var(W_standard) = 1.
If we scale this variable by a constant c, the new variance is Var(c * W_standard) = c^2 * Var(W_standard) = c^2.
We want the new variance to be 1 / n_in.
So, c^2 = 1 / n_in
c = 1 / sqrt(n_in)
This is the solution. We initialize weights not from Normal(0, 1), but by taking a sample from Normal(0, 1) and then dividing it by sqrt(n_in).
This makes the variance of the weights equal to 1 / n_in. Let’s re-calculate the output variance with this new weight initialization.
Var(y) = n_in * Var(w)
Var(y) = n_in * (1 / n_in)
Var(y) = 1
Now, the output of the layer has the same variance as the input. The variance does not explode or vanish from layer to layer. The inputs to activation functions stay in a range where their gradients are non-zero, and the network can learn effectively. This same logic applies to the gradients during the backward pass, keeping them stable as well.

