What are neural networks and how do they work
Neural networks sound like a complex and mysterious two words that somehow model how our brain works. In reality, they are just a large mathematical expression of some inputs that produce meaningful outputs.
They take some input, run it through a series of calculations, and produce an output.
Same as:
So in the above case, a simple mathematical expression has an output that is a digit, and an input that is one digit.
Neural nets are really that, but with millions and billions of parameters.
The way neural networks can be used is like so: say we are trying to get computers to recognise digits from 0 to 9.
As an input, we will be feeding it a image of a handwritten image, and we want it to tell us what digit it is. We can model the output as an array of 10 numbers. For example:
Each position represents the network's confidence for a digit.
So index 0 represents confidence that the output is 0.
Index 1 represents confidence that the output is 1 and so on.
In the output above, the neural network predicted that the input image is digit 2.
Let's get back to the input.
Say we have a pixel image.
That means that the image contains 3600 pixels ().
Each number represents one pixel.
The neural network takes this input and produces an output with 10 numbers:
Input: 3600 pixels
Output: 10 confidence values
So conceptually, we want to build a mathematical function that does this:
For example, we give the network an image of a handwritten 1.
Ideally, we want the output to look something like this:
This means:
The correct answer is digit 1.
But at the start, the neural network has no idea what it is doing.
Its internal values are basically random. So when we first give it an image of 1, it might output something like this:
This is garbage to us.
The neural network does not know what the image means yet. So, the question becomes:
How do we adjust the mathematical expression that is a neural network so that its output becomes closer to the correct answer?
This is really what training a neural network means.
Training a neural network means adjusting the internal values of the mathematical expression so that it produces better outputs.
Those internal values are represented by neurons, which are numerical representations of a human neuron (just very simple versions). We will not get into those now, but for now, let's just remember that we have a whole bunch of neurons represented by weights and biases.
At the start, the weights and biases are completely random numbers. When initialised, the neural network knows nothing.
So the network gives bad predictions, it simply just calculates whatever came out of the initialisation.
The good news is that by giving it a lot of examples, we can slowly adjust the weights and biases until the network starts giving us useful predictions.
For example, imagine we have a huge dataset:
Each training example contains two things:
Input = image
Output = correct label
So we show the neural network an image, let it make a prediction, compare that prediction to the correct answer, and then adjust the network.
Over time, the network becomes a mathematical expression that maps inputs to meaningful outputs.
That's the whole idea of training a neural net.
A trained neural network is a mathematical function which is modeled by the data we gave it. It's a mathematical expression of the input data.
It's an , where is our array and is the image. However, in a neural network, instead of having 3 terms, we have closer to thousands to millions of parameters (we set these ourselves in the neural network initialisation).
So, now we know we need to initialise the network, we know what data we need to model it off of and what needs to happen for the neural network to be trained. How do we model knowledge of whether the neural network is right or wrong?
Say again we are given an image of 1.
The correct output should be
But the network outputs:
That is not good enough. The network gave 0.2 confidence to the correct answer, and gave 0.3 confidence to digit 2. It's clearly wrong, but we need to make sense of it so that we can use it to adjust our weights and biases to get it closer to the prediction.
That measurement is called the loss.
The loss function tells us:
How far was the network's prediction from the correct answer?
A high loss means the network was very wrong.
A low loss means the network was close.
We would model the accuracy of our mathematical expression of data this way because we can simply then just minimise the loss function to get the more accurate response.
So, how do we actually calculate the loss?
A simple way is to go position by position through the output, calculate the difference between what the network gave us and what the correct answer was, square that difference, and add all of those up.
Squaring the differences does two things. It makes sure that positive and negative errors don't cancel each other out, and it penalises bigger errors more than smaller ones.
Let's calculate it for our example.
For each position, we take the difference and square it:
Adding all of those up, we get a loss of .
That single number now represents how wrong the network was for this image. If the network was perfect, the loss would be . The further the prediction is from the correct answer, the bigger this number becomes.
This particular loss is called squared error, and if we divide it by the number of outputs we get the mean squared error (MSE), which is one of the simplest loss functions out there.
So now we have some sort of a meaningful representation of the accuracy of our mathematical representation. We will need to manually inspect each weight and essentially "make some bigger", and "make some smaller".
That would simply be impossible at the scale we need this at.
The neural network needs a system.
It needs to be efficient at being able to answer:
Which weights should I change, and by how much, so that the loss goes down?
This might sound familiar to you, but this is where derivatives become important.
They tell us how sensitive one value of a function is to a tiny change in another value. In the context of neural networks, we want to know how changes of weights affect the loss function, our representation of accuracy of our neural network.
This is exactly what the derivative tells us.
If you want the intuition behind what a derivative actually is, and how to compute one from first principles, see Intuitively understanding derivatives. The rest of this post assumes the basic idea: a derivative measures how sensitive a function's output is to a small change in its input.
For every weight in the network, we want to know:
If this weight increases slightly, what happens to the loss function?
If an increase of a weight makes our loss function go up, we should make it smaller.
The derivative gives us the direction, and tells us how each weight is affecting the loss.
Gradient descent is the method we use to reduce the loss. It's a process that, simply explained, involves looking for a minimum based on parameters given to it. In the context of neural networks, it means:
- Make a prediction
- Compare prediction to the correct answer
- Calculate the loss
- Use derivatives to figure out how each weight affected the loss function
- Adjust the weights in the direction that reduces the loss
- Repeat
This is how a neural network learns.
It's not thinking or reasoning, it's simply repeatedly adjusting a huge mathematical expression so that its outputs become more useful.
The gradient tells us the steepest increase in the loss.
So gradient descent moves in the opposite direction.
Because we want to minimise the loss.
And we update the weights such that:
That small step is controlled by something called the learning rate.
If the learning rate is too big, the network may jump around and fail to learn.
If the learning rate is too low, the network may learn too slowly.
So, over many examples and many updates, the neural network eventually becomes better.
It starts off as a random mathematical expression. And through training, it becomes a mathematical expression that captures patterns in the data.
For handwritten digits, it appears that it can recognise patterns like loops, strokes, curves, etc.
Not because someone manually programmed those rules, but because the network adjusted itself based on examples.
So the core idea of a neural network is that it is simply a mathematical function with adjustable internal values.
Training the neural network means adjusting those values so that the network maps inputs to useful outputs.
The loss function is a representation of how wrong the neural network is at representing that data.
Derivatives tell us how each weight affects the loss.
Gradient descent uses those derivatives to update the weights in the direction that reduces the loss.
So neural networks learn by iteratively resolving:
How was I wrong?
Which part of me caused that error?
How should I change those parts so that I am less wrong next time?
This is the whole idea behind neural networks.