Reasons for uncertainty in machine learning models
I've often heard the saying:
madness is doing the same thing again and again, and expecting a different result
While I know that there's a valuable lesson there, I totally disagree with the statement itself!
As any scientist knows, producing the same results by doing the same thing again and again is surprisingly hard. In practice, I think it's more true to say:
madness is doing the same thing again and again, and expecting the same result
Understanding uncertainty
Consider a simple physics experiment, where the tester releases a ball down a slope, and measures how far the ball travels after meeting the flat ground. Despite keeping every variable constant (the angle of the incline, the weight of the ball, the release position, the smoothness of the track), there will inevitably be some variation in the final resting point of the ball after each run of the experiment. Unmeasurable hidden variables (imperfections in the roundness of the ball, or minute changes in air pressure) can all introduce a degree of stochasticity or randomness into the mechanics of the experiment, and therefore its results.
If we run the experiment again and again, the results typically will form a normal (or gaussian) distribution, with a measurable mean and standard deviation. When describing the results to other people, we can give them the mean as the measured value (μ), and use the standard deviation (σ) to provide a range of uncertainty around that value.
In this version of the ball-rolling experiment, the mean final position is 100cm, with a standard deviation of 10cm. We could communicate the full complexity of these results to someone else by writing 100±10cm.
A Galton board is another great demonstration of this principle. Small, random events can perturb the individual experimental runs, sending some of the balls way off into the left or right bins. But the overall shape of the results will approximate something with a neat, easily modelled distribution. Modelling these distributions themselves can tell us a lot about the underlying systems.
Side note
Francis Galton, after whom the Galton board is named, was the original eugenicist. Despite contributing lots of useful, influential ideas to statistics, his work was largely motivated by scientific racism. As scientists, we should be very careful when applying simple principles from statistics to systems as complicated as human biology and psychology. Check out Subhadra Das's speaking and writing for a deeper dive into Galton's skewed ideas than I can provide here.
Back to the main maths
Even in the simplest systems, it's valuable to provide every measurement with its associated uncertainty. It tells the reader how confident they can be in each value they're working with. By combining and compounding the values' uncertainties in any subsequent calculations, they can make informed decisions about when to stop trusting their own estimations.
In science, every measurement should be given with an associated uncertainty.
Uncertainty in machine learning
As machine learning practitioners, we build a lot of powerful decision-making systems. However, I rarely see people calculating or communicating uncertainties in this field. This is a huge missed opportunity!
Measurements vs uncertainties
Consider a binary classifier, which produces values between 0 and 1.
Naively, could say that all observations on the left hand side are 0
s, and all observations on the right hand side are 1
s. We could also say that anything close to those polar values is a confident prediction, while anything close to the boundary at 0.5
is a more uncertain prediction.
A common route to improving the accuracy of the classifier would be to use active learning. By gathering examples which the model is most uncertain about and asking a human labeller (or oracle) for the correct value, we can efficiently guide the model in the right direction.
While it might be true that some values fall at the classification boundary between 0 and 1 because the model is uncertain about which class the observation belongs to, it's not true that the model's confidence in its predictions is directly related to the value it outputs.
There will invariably be cases where the model is very confident in a prediction that sits between the two classes, and cases where it's very uncertain in a prediction that sits at the poles.
Pay attention to the observations at the boundary in the graphic below. The variance in uncertainty at the boundary is much higher than in the previous diagram. While some of the observations near the boundary also have a high associated uncertainty, the uncertainty for some others is just as low as it is for the observations at the poles. This class of high-confidence, non-polar values provide a very interesting set of data for further exploration
In my experience, the illustration above is much closer to the reality than the one before it.
By taking a more sophisticated approach to uncertainty, and capturing this orthogonal representation of measurement and uncertainty, we can improve sampling for active learning and learn much more ourselves about how our model is performing.
Methods for uncertainty estimation in neural networks
Despite being under-discussed, they do exist!
Laplace approximation and Monte-carlo dropout are two common methods used in uncertainty estimation in neural networks. While they each have their strengths, the methods hinge on different principles.
Let's train a very simple MNIST classification model, with a couple of linear layers running over a flattened array of input pixels for each image.
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import datasets, transforms
# Load the MNIST dataset
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())
# Define the model
model = nn.Sequential(
nn.Linear(784, 256),
nn.Dropout(0.2),
nn.ReLU(),
nn.Linear(256, 64),
nn.Dropout(0.2),
nn.ReLU(),
nn.Linear(64, 10),
)
def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
# Train the model
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
n_epochs = 5
losses = []
eval_losses = []
eval_accuracies = []
for epoch in range(n_epochs):
model.train()
progress_bar = tqdm(train_loader, total=len(train_loader))
for imgs, labels in progress_bar:
imgs = imgs.view(imgs.shape[0], -1)
optimizer.zero_grad()
output = model(imgs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
losses.append(loss.item())
rolling_loss = torch.tensor(losses[-100:]).mean()
progress_bar.set_description(f'Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}')
model.eval()
progress_bar = tqdm(test_loader, total=len(test_loader))
for imgs, labels in progress_bar:
imgs = imgs.view(imgs.shape[0], -1)
output = model(imgs)
loss = loss_fn(output, labels)
acc = accuracy(output, labels)
eval_losses.append(loss.item())
eval_accuracies.append(acc.item())
rolling_loss = torch.tensor(eval_losses[-100:]).mean()
rolling_acc = torch.tensor(eval_accuracies[-100:]).mean()
progress_bar.set_description(f'Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}, acc: {rolling_acc.item():.4f}')
Laplace approximation
Laplace Approximation is a deterministic technique based on a second-order Taylor expansion around the mode of the posterior. It approximates the posterior distribution of a Neural Network, under the assumption that uncertainty can be modelled with a Gaussian distribution.
Advantages
- Computationally Efficient: Laplace approximation avoids the computational costs typically involved in full Bayesian approximations.
- Better at handling simple models: This approach works well when the likelihood and the prior are approximately Gaussian, which is often the case in simpler models.
Implementation
Luckily for us, the laplace package provides a really easy way to get started with Laplace approximation in PyTorch.
from laplace import Laplace
la = Laplace(model, "classification", subset_of_weights="all", hessian_structure="kron")
la.fit(train_loader)
la.optimize_prior_precision(method="marglik")
Now that it's fit to the data, we can ask the model to produce samples from the distribution for each prediction, and find the mean and standard deviation of the underlying distribution.
input_data = input_data = test_dataset.data.view(-1, 784).float()[0:100]
samples = la.predictive_samples(input_data, n_samples=100)
mean = samples.mean(dim=0)
std = samples.std(dim=0)
For a confident prediction, those values might look like
Class 0: 0.00 ± 0.00
Class 1: 0.00 ± 0.00
Class 2: 0.00 ± 0.00
Class 3: 0.00 ± 0.00
Class 4: 0.00 ± 0.00
Class 5: 0.00 ± 0.00
Class 6: 0.00 ± 0.00
Class 7: 1.00 ± 0.01
Class 8: 0.00 ± 0.00
Class 9: 0.00 ± 0.00
While for a less confident prediction we might see
Class 0: 0.00 ± 0.00
Class 1: 0.00 ± 0.00
Class 2: 0.01 ± 0.10
Class 3: 0.00 ± 0.00
Class 4: 0.20 ± 0.40
Class 5: 0.33 ± 0.47
Class 6: 0.37 ± 0.48
Class 7: 0.00 ± 0.00
Class 8: 0.06 ± 0.24
Class 9: 0.03 ± 0.17
Monte Carlo dropout
Monte Carlo dropout is a Bayesian approximation technique. By using dropout at inference time, and taking multiple stochastic forward passes through the network, we can measure the variation in the outputs across our different passes, providing an estimate of uncertainty.
Advantages
- Simplicity: It demands no modifications to the existing architecture (assuming dropout is already being used for regularisation during training), making it relatively straightforward to implement.
- Works well with complex distributions: Unlike the Laplace approximation, MC Dropout can capture complicated, non-Gaussian posterior distributions! This makes it a better fit for more complex models.
Implementation
import numpy as np
X = test_dataset.data.float().view(-1, 784) / 255
# Make sure the model is in training mode, so that dropout is used
model.train()
y_mc = torch.stack([model(X) for _ in range(100)])
y_mean = y_mc.mean(dim=0).detach().numpy()
y_std = y_mc.std(dim=0).detach().numpy()
We can use the standard deviations of the predictions to calculate the entropy, giving us a a sense of their confidence.
entropy = -(y_std * np.log(y_std)).sum(axis=1)
most_confident_indices = entropy.argsort()[:n_examples]
least_confident_indices = entropy.argsort()[-n_examples:]
The model's most confident predictions are big, heavy 0
s, and they're all correct, white its less confident predictions are all much more lightweight lines, in occasionally odd orientations within the frame. As we'd expect, its predictions are often wrong in cases where its uncertainty is higher, though there are some low-confidence but correct predictions!
We can use our entropy data to find high confidence predictions which were incorrect.
As you can imagine, this is an exceptionally useful technique when debugging predictions in neural networks, or feeding a network in an active learning process!
Other Methods for Uncertainty Estimation
- Bootstrapping: Similar to Monte Carlo dropout, but applied to the dataset. By creating multiple versions of the original dataset through random sampling and training separate models on each, we can using the variance in predictions as a measure of uncertainty.
- Deep ensembles: A higher level version of Monte Carlo dropout - Deep ensembles involve training multiple neural network models independently on the same data and then aggregating their predictions. This approach provides a measure of uncertainty by examining the variance in predictions across the ensemble members. It has been shown to be effective in capturing epistemic uncertainty, particularly when combined with techniques like temperature scaling and Monte Carlo dropout.
- Gaussian Processes: These are a class of models that specify a prior directly on the space of functions. Instead of assuming a fixed function shape, Gaussian Processes allow for flexibility by considering a range of possible functions that could fit the data.
- Stochastic Weight Averaging with Gaussians (SWAG): SWAG extends the idea of stochastic weight averaging by incorporating Gaussian approximations to the posterior distribution over the weights of neural network models.
Conclusion
As a field, I believe we should start paying much more attention to the uncertainties in our predictions. Incorporating these techniques not only provides a measure of confidence in each prediction for the user/reader, but also offers valuable insights into how the model is performing. Regardless of the technique chosen, acknowledging and addressing uncertainty brings us one step closer to building more reliable, robust, and explainable systems.