back to the blog

Reasons for uncertainty in machine learning models

I've often heard the saying:

madness is doing the same thing again and again, and expecting a different result

While I know that there's a valuable lesson there, I totally disagree with the statement itself!

As any scientist knows, producing the same results by doing the same thing again and again is surprisingly hard. In practice, I think it's more true to say:

madness is doing the same thing again and again, and expecting the same result

Understanding uncertainty

Consider a simple physics experiment, where the tester releases a ball down a slope, and measures how far the ball travels after meeting the flat ground. Despite keeping every variable constant (the angle of the incline, the weight of the ball, the release position, the smoothness of the track), there will inevitably be some variation in the final resting point of the ball after each run of the experiment. Unmeasurable hidden variables (imperfections in the roundness of the ball, or minute changes in air pressure) can all introduce a degree of stochasticity or randomness into the mechanics of the experiment, and therefore its results.

If we run the experiment again and again, the results typically will form a normal (or gaussian) distribution, with a measurable mean and standard deviation. When describing the results to other people, we can give them the mean as the measured value (μ), and use the standard deviation (σ) to provide a range of uncertainty around that value.

In this version of the ball-rolling experiment, the mean final position is 100cm, with a standard deviation of 10cm. We could communicate the full complexity of these results to someone else by writing 100±10cm.

In this version of the ball-rolling experiment, the mean final position is 100cm, with a standard deviation of 10cm. We could communicate the full complexity of these results to someone else by writing 100±10cm.

A Galton board is another great demonstration of this principle. Small, random events can perturb the individual experimental runs, sending some of the balls way off into the left or right bins. But the overall shape of the results will approximate something with a neat, easily modelled distribution. Modelling these distributions themselves can tell us a lot about the underlying systems.

Side note

Francis Galton, after whom the Galton board is named, was the original eugenicist. Despite contributing lots of useful, influential ideas to statistics, his work was largely motivated by scientific racism. As scientists, we should be very careful when applying simple principles from statistics to systems as complicated as human biology and psychology. Check out Subhadra Das's speaking and writing for a deeper dive into Galton's skewed ideas than I can provide here.

Back to the main maths

Even in the simplest systems, it's valuable to provide every measurement with its associated uncertainty. It tells the reader how confident they can be in each value they're working with. By combining and compounding the values' uncertainties in any subsequent calculations, they can make informed decisions about when to stop trusting their own estimations.

In science, every measurement should be given with an associated uncertainty.

Uncertainty in machine learning

As machine learning practitioners, we build a lot of powerful decision-making systems. However, I rarely see people calculating or communicating uncertainties in this field. This is a huge missed opportunity!

Measurements vs uncertainties

Consider a binary classifier, which produces values between 0 and 1.

A number line from 0 to 1, with a series of points clustered at both ends, and a sparsely populated centre

Naively, could say that all observations on the left hand side are 0s, and all observations on the right hand side are 1s. We could also say that anything close to those polar values is a confident prediction, while anything close to the boundary at 0.5 is a more uncertain prediction.

A common route to improving the accuracy of the classifier would be to use active learning. By gathering examples which the model is most uncertain about and asking a human labeller (or oracle) for the correct value, we can efficiently guide the model in the right direction.

A rough gaussian distribution representing uncertainty in the model's predictions, with a small amount of noise added.
This is what we assume the model's confidence in its predictions looks like when using a naive active learning scheme. The confidence is directly related to the value

While it might be true that some values fall at the classification boundary between 0 and 1 because the model is uncertain about which class the observation belongs to, it's not true that the model's confidence in its predictions is directly related to the value it outputs.

There will invariably be cases where the model is very confident in a prediction that sits between the two classes, and cases where it's very uncertain in a prediction that sits at the poles.

Pay attention to the observations at the boundary in the graphic below. The variance in uncertainty at the boundary is much higher than in the previous diagram. While some of the observations near the boundary also have a high associated uncertainty, the uncertainty for some others is just as low as it is for the observations at the poles. This class of high-confidence, non-polar values provide a very interesting set of data for further exploration

A more nuanced representation of uncertainty in model predictions, where confidence in each prediction is more independent from the observed value.
A more nuanced representation of uncertainty in model predictions, where confidence in each prediction is more independent from the observed value.

In my experience, the illustration above is much closer to the reality than the one before it.

By taking a more sophisticated approach to uncertainty, and capturing this orthogonal representation of measurement and uncertainty, we can improve sampling for active learning and learn much more ourselves about how our model is performing.

Methods for uncertainty estimation in neural networks

Despite being under-discussed, they do exist!

Laplace approximation and Monte-carlo dropout are two common methods used in uncertainty estimation in neural networks. While they each have their strengths, the methods hinge on different principles.

Let's train a very simple MNIST classification model, with a couple of linear layers running over a flattened array of input pixels for each image.

import torch 
from torch import optim, nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import datasets, transforms

# Load the MNIST dataset
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

# Define the model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.Dropout(0.2),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.Dropout(0.2),
    nn.ReLU(),
    nn.Linear(64, 10),
)

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

# Train the model
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

n_epochs = 5
losses = []
eval_losses = []
eval_accuracies = []
for epoch in range(n_epochs):
    model.train()
    progress_bar = tqdm(train_loader, total=len(train_loader))
    for imgs, labels in progress_bar:
        imgs = imgs.view(imgs.shape[0], -1)
        optimizer.zero_grad()
        output = model(imgs)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        rolling_loss = torch.tensor(losses[-100:]).mean()
        progress_bar.set_description(f'Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}')

    model.eval()
    progress_bar = tqdm(test_loader, total=len(test_loader))
    for imgs, labels in progress_bar:
        imgs = imgs.view(imgs.shape[0], -1)
        output = model(imgs)
        loss = loss_fn(output, labels)
        acc = accuracy(output, labels)
        eval_losses.append(loss.item())
        eval_accuracies.append(acc.item())
        rolling_loss = torch.tensor(eval_losses[-100:]).mean()
        rolling_acc = torch.tensor(eval_accuracies[-100:]).mean()
        progress_bar.set_description(f'Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}, acc: {rolling_acc.item():.4f}')

Laplace approximation

Laplace Approximation is a deterministic technique based on a second-order Taylor expansion around the mode of the posterior. It approximates the posterior distribution of a Neural Network, under the assumption that uncertainty can be modelled with a Gaussian distribution.

Advantages

  • Computationally Efficient: Laplace approximation avoids the computational costs typically involved in full Bayesian approximations.
  • Better at handling simple models: This approach works well when the likelihood and the prior are approximately Gaussian, which is often the case in simpler models.

Implementation

Luckily for us, the laplace package provides a really easy way to get started with Laplace approximation in PyTorch.

from laplace import Laplace

la = Laplace(model, "classification", subset_of_weights="all", hessian_structure="kron")
la.fit(train_loader)
la.optimize_prior_precision(method="marglik")

Now that it's fit to the data, we can ask the model to produce samples from the distribution for each prediction, and find the mean and standard deviation of the underlying distribution.

input_data = input_data = test_dataset.data.view(-1, 784).float()[0:100]
samples = la.predictive_samples(input_data, n_samples=100)

mean = samples.mean(dim=0)
std = samples.std(dim=0)

For a confident prediction, those values might look like

Class 0: 0.00 ± 0.00
Class 1: 0.00 ± 0.00
Class 2: 0.00 ± 0.00
Class 3: 0.00 ± 0.00
Class 4: 0.00 ± 0.00
Class 5: 0.00 ± 0.00
Class 6: 0.00 ± 0.00
Class 7: 1.00 ± 0.01
Class 8: 0.00 ± 0.00
Class 9: 0.00 ± 0.00

While for a less confident prediction we might see

Class 0: 0.00 ± 0.00
Class 1: 0.00 ± 0.00
Class 2: 0.01 ± 0.10
Class 3: 0.00 ± 0.00
Class 4: 0.20 ± 0.40
Class 5: 0.33 ± 0.47
Class 6: 0.37 ± 0.48
Class 7: 0.00 ± 0.00
Class 8: 0.06 ± 0.24
Class 9: 0.03 ± 0.17

Monte Carlo dropout

Monte Carlo dropout is a Bayesian approximation technique. By using dropout at inference time, and taking multiple stochastic forward passes through the network, we can measure the variation in the outputs across our different passes, providing an estimate of uncertainty.

Advantages

  • Simplicity: It demands no modifications to the existing architecture (assuming dropout is already being used for regularisation during training), making it relatively straightforward to implement.
  • Works well with complex distributions: Unlike the Laplace approximation, MC Dropout can capture complicated, non-Gaussian posterior distributions! This makes it a better fit for more complex models.

Implementation

import numpy as np

X = test_dataset.data.float().view(-1, 784) / 255

# Make sure the model is in training mode, so that dropout is used
model.train()
y_mc = torch.stack([model(X) for _ in range(100)])

y_mean = y_mc.mean(dim=0).detach().numpy()
y_std = y_mc.std(dim=0).detach().numpy()

We can use the standard deviations of the predictions to calculate the entropy, giving us a a sense of their confidence.

entropy = -(y_std * np.log(y_std)).sum(axis=1)

most_confident_indices = entropy.argsort()[:n_examples]
least_confident_indices = entropy.argsort()[-n_examples:]

The model's most confident predictions
The model's most confident predictions

The model's least confident predictions
The model's least confident predictions

The model's most confident predictions are big, heavy 0s, and they're all correct, white its less confident predictions are all much more lightweight lines, in occasionally odd orientations within the frame. As we'd expect, its predictions are often wrong in cases where its uncertainty is higher, though there are some low-confidence but correct predictions!

We can use our entropy data to find high confidence predictions which were incorrect.

The model's most confident incorrect predictions

As you can imagine, this is an exceptionally useful technique when debugging predictions in neural networks, or feeding a network in an active learning process!

Other Methods for Uncertainty Estimation

  • Bootstrapping: Similar to Monte Carlo dropout, but applied to the dataset. By creating multiple versions of the original dataset through random sampling and training separate models on each, we can using the variance in predictions as a measure of uncertainty.
  • Deep ensembles: A higher level version of Monte Carlo dropout - Deep ensembles involve training multiple neural network models independently on the same data and then aggregating their predictions. This approach provides a measure of uncertainty by examining the variance in predictions across the ensemble members. It has been shown to be effective in capturing epistemic uncertainty, particularly when combined with techniques like temperature scaling and Monte Carlo dropout.
  • Gaussian Processes: These are a class of models that specify a prior directly on the space of functions. Instead of assuming a fixed function shape, Gaussian Processes allow for flexibility by considering a range of possible functions that could fit the data.
  • Stochastic Weight Averaging with Gaussians (SWAG): SWAG extends the idea of stochastic weight averaging by incorporating Gaussian approximations to the posterior distribution over the weights of neural network models.

Conclusion

As a field, I believe we should start paying much more attention to the uncertainties in our predictions. Incorporating these techniques not only provides a measure of confidence in each prediction for the user/reader, but also offers valuable insights into how the model is performing. Regardless of the technique chosen, acknowledging and addressing uncertainty brings us one step closer to building more reliable, robust, and explainable systems.