1
$\begingroup$

This is my network represented in matrices: (a dot represents an arbitrary number) enter image description here Feed-forwarding: (I omitted nesting it all in an activation function for the sake of brevity) enter image description here Backpropagation enter image description here

The question

enter image description here $\partial E/\partial b^3$ should equal a matrix of dimensions: $3 \times 1$, in order to make the subtraction possible by the former $b^3$. The batch size of 5, however, made the dimensions of the $b^3$ matrix equal to $3 \times 5$, which is problematic as you now can't subtract $\partial E/\partial b^3$ because of the dimension mismatch.

What went wrong? Am I supposed to take the average of each row of the $\partial E/\partial b^3$ matrix or perhaps just accumulate each row? Or is it something completely different?

My reflection At the moment I am thinking that what I have done up to now is correct, however I just need to either accumulate or take the average of all the numbers in each row in the matrix $∂E/∂b3$. This way the matrix will be of size $3\times1$ as wanted, and it also makes intuitive sense in the way that I am updating the bias based on an error calculated over an entire batch size, therefor either accumulating or taking the average would make sense. However as I mentioned I am not sure which one it is, or if it is even the right choice.

Any help is highly appreciated!

$\endgroup$
3
  • $\begingroup$ What is you loss function ? It is supposed to be a scalar function so the derivative should be a $3 \times 1$ vector. Also, you never use the bias, are you sure that for example $a^1 = W^1 \cdot X $ instead of $a^1 = W^1 \cdot X + b^1 $ $\endgroup$ Commented Oct 23, 2018 at 9:27
  • $\begingroup$ @user7573566 Would it be incorrect to simply average each row, so that the current $3\times 5$ matrix becomes a $3\times 1$ vector? (I refer to $\partial E/\partial b^3$) $\endgroup$ Commented Oct 23, 2018 at 13:05
  • $\begingroup$ @user7573566 My loss function is MSE (mean squared error) $1/M∗(a^3−y)$ where M is the total number of training examples. I left $1/M$ in the above to simplify the example. What scalar function would you recommend I used instead, so the derivative is becomes $3\times1$ vector? $\endgroup$ Commented Oct 23, 2018 at 13:45

1 Answer 1

1
$\begingroup$

You use the batched operations in your derivations which makes it more difficult to understand. When you do this, $a^3 = W^3 a^2 + b^3$ is an invalid matrix sum since b is a vector and not a matrix.

You cannot apply usual chain rule

$$ \frac{\partial E}{\partial b} = \frac{\partial E}{\partial a^3}\frac{\partial a^3}{\partial b}$$ because $\frac{\partial a^3}{\partial b}$ is a matrix by vector derivatives. Instead, it is easier to apply chain rule as

$$ \frac{\partial E}{\partial b} = \sum_{i = 1}^3 \frac{\partial E}{\partial a^3_i}\frac{\partial a^3_i}{\partial b}$$ $$ \frac{\partial E}{\partial b} = \sum_{i = 1}^3 2(a^3_i - y_i)$$

Where i is the batch index and $y_i$ the target of the i-th input.

$\endgroup$
7
  • $\begingroup$ So to clarify, you are saying that I should sum up each row, so the ($3 \times 5$) matrix is reduced to a ($3 \times 1$) vector. Is that correct? $\endgroup$ Commented Oct 23, 2018 at 14:16
  • $\begingroup$ Yes, when you compute the chain rule, it appears that you sum each row $\endgroup$ Commented Oct 23, 2018 at 14:20
  • $\begingroup$ Do you happen to have a source that states the same (that you have to sum each row), I don't want to base my knowledge on a single comment (answer) from a stranger online. $\endgroup$ Commented Oct 23, 2018 at 14:30
  • $\begingroup$ You can look at the answer, it is simply a chain rule. $\endgroup$ Commented Oct 23, 2018 at 14:39
  • $\begingroup$ Doing the feed-forward phase, the way I added the bias value is not right, as you said yourself: "$a^3=W^3a^2+b^3$ is an invalid matrix sum since b is a vector and not a matrix." But what is the correct way to add the bias then? i.sstatic.net/5YyKp.png $\endgroup$ Commented Oct 23, 2018 at 14:56

Not the answer you're looking for? Browse other questions tagged or ask your own question.