back to the blog

Understanding positional embeddings in transformer models

For state-of-the-art language models, context is important.

Polysemy, homography, and other complexities of real-world language mean that a static definition of a word is often not enough to understand its meaning; The ability to interpret a word according to it's context is crucial for any high performing language model.

For example, the meaning of the word "get" changes radically between the phrases "I'll get a book" (where get means fetch), "I'll get smarter" (where get means become), and "I get it" (where get means understand).

The meaning of words and phrases isn't just affected by the bag of words around them, but also by where they are positioned in relation to each other. For example, the individual words in these sentences are the same, but modifying their order changes the meaning of the sentence dramatically:

"She told him that she only loved him."

"She only told him that she loved him."

"Only she told him that she loved him."

Without the additional information about the position of each token, a model is limited to a bag-of-words representation of the target text. It's possible to learn something about a sentence from the words alone, but understanding all of the subtleties of language with a bag-of-words model is extremely difficult.

Encoding context in prior language models

In a generation of network architectures before transformers, language models like recurrent neural networks (RNNs) and long short-term memory networks (LSTMs) used a state-machine approach to encode positional context into each token. The models stored a representation of the previous tokens' aggregate meaning, and used a combination of a fixed word embedding and its context to represent each word, updating the context vector(s) as it progressed through the text. Bidirectional variants of these models allowed context to be passed to a token from before and after its position in the text.

However, this state-machine approach is tricky to parallelise, and the length of the context window is limited as a result. State-machine approaches also exhibit a strong bias towards recent tokens and struggle to retain information over a long distance, even within the available context window. This meant that while state-of-the-art models were able to generate locally coherent text, the meaning of longer passages would quickly fall apart.

Transformers

The mechanism at the heart of transformer models like BERT and GPT is intended to solve the parallelism problem, and increase models' coherence over longer sequences.

Instead of trying to hold and update the meaning of a sequence for each token, these models split the representation of words and their position into two distinct inputs. In addition to providing a word embedding for each token, we also provide a positional embedding, which represents the position of each token in the sequence. Transformers join those inputs together with an attention matrix, which describes how dependent or related each token is to its neighbours. Removing the sequential dependence and making the dependence of each token on its neighbours explicit makes the linear algebra significantly easier to parallelise, and allows transformers to scale much further than LSTMs or RNNs.

In this context, word embeddings work the same way as they have in many generations of language models, so I won't characterise them in much more detail in this post. Positional embeddings are a newer idea though, and warrant some more elaboration and intuition-building.

What should positional embeddings do?

Unlike word embeddings, the characteristics of positional embeddings aren't emergent from data - they can be deliberately constructed to have certain properties. Concisely, positional embeddings should:

  • Be reflective of their distances from one another in the sequence - we should be able to use a distance metric to compare vectors, with positionally similar vectors producing small distances, and positionally distant vectors producing large distances in the vector space we construct.
  • Have values bounded between 0 and 1 - we're using these as inputs to a neural network, so well-bounded values are useful.
  • Be invariant to sequence length - the values assigned to each position should be the same, whether the text is long or short.
  • Be deterministic - they shouldn't change from one initialisation of a network to another.

Our goal is to design a process which delivers embeddings with these characteristics. It's worth pausing at this point to imagine a few candidate approaches yourself, and checking whether they fulfil the criteria above.

Some initial ideas

As a super-simple first stab, let's just fill a vector with the position's value, and call it a positional embedding. Position 1 gets an array of 1s, position 2 gets an array of 2s, etc:

def positional_embedding(position, dimensionality):
    return np.full(
        shape=dimensionality,
        fill_value=position
    )

We can certainly find the distances between positional embeddings which are generated this way; all we need to do is subtract the vectors from one another to determine their relative positions. However, the values aren't bound between 0-1.

Let's try dividing the values by the length of the sequence to produce each token's fractional position within it.

def positional_embedding(position, dimensionality, sequence_length):
    return np.full(
        shape=dimensionality,
        fill_value=position/sequence_length
    )

Hm, still no good. Dividing the values by the sequence length gives us a set of values which are easily comparable and bound between 0-1, but by their nature, they're not invariant to sequence length. The first token in a sequence of 30 words will have a different set of values to one where the sequence is only 5 words long, and there's no way to know the lengths of the sequences we'll be modelling ahead of time. Ideally, we'd like this to scale to sequences of any length.

A trigonometric trick

The approach in the original transformers paper (“Attention is all you need”) is really fun and smart, and it fulfils all of the criteria we defined.

First, we specify a dimensionality for the embedding. Each dimension of our embeddings will be assigned a corresponding sinusoidal wave, each with a different frequency. Early dimensions of the embedding will be tied to low-frequency waves, and the frequencies will increase as the we move towards the end of the embedding. We also alternate between sine and cosine waves, depending on the element's parity. To get a positional embedding, we plug our position value into that series of waves, and read off the resulting values.

The code looks something like this:

def positional_embedding(position: int, d_model: int = 1024):
    i = np.arange(d_model)
    angles = position / np.power(10000, (2 * (i // 2) / d_model))
    angles[0::2] = np.sin(angles[0::2])
    angles[1::2] = np.cos(angles[1::2])
    return angles

Why it works

This is the confusing bit - why does that process produce embeddings which fulfil our criteria??

First, using sine and cosine functions means that we're bounding our values between 0-1, no matter how long our sequence length. We can also see that the process is deterministic, and doesn't depend on any stochastic/external parameters. But what about the other criteria? How do these rules produce an embedding space with values which can neatly represent a continuous number line?

Let's illustrate the simplest possible example of the method. Here are two cosine waves, one with a low frequency, and one with a high frequency. Let's imagine that each one is tied to a dimension in a two-dimensional positional embedding.

Two overlapping cosine waves, one with a higher frequency than the other

Now let's select two nearby positions, and mark them in our space with a pair of vertical lines.

Two vertical lines have been added to the two cosine waves, close to one another

We can use our position values (aka vertical lines) to get the first elements in our positional embeddings. The height at which our low-frequency curve meets our vertical lines is what we'll use as the first element in our embedding. I've marked those heights with horizontal lines on the graph below.

the points where the vertical lines meet the low-frequency cosine wave are marked by a pair of corresponding horizontal lines, which are also close to one another

The resulting values are very close together, because we're at a slowly changing point on our low-frequency curve, with a very shallow gradient.

We can do the same thing with our second, higher-frequency curve to get the second elements of our 2D embeddings.

the points where the vertical lines meet the high-frequency wave have also been marked by two horizontal lines, and the vertical distance between them is noticeably larger than the two horizontal lines corresponding to the low-frequency wave

We're at a point on the curve with a much steeper gradient, which means that the values are much further apart! The higher frequency wave also changes gradient much more rapidly, so we can expect more variability in the values we read from it. At higher frequencies, this would be even more true.

Now let's consider a new line, representing a more distant position:

a new vertical line has been added, further to the right. it intersects with the low- and high-frequency waves at low points, in contrast to the previous examples

Both values here are very different to our first two positions. Next, try to visualise the values for a position close to our new, third line. You should find that they're much more similar to the values for our third position than they are to our first two positions.

Now, imagine extending our example to include many more curves, with many different frequencies, representing a positional embedding with a much higher dimensionality. As we accumulate more values, correlations between similar positions (whose values are close at slowly-changing points on our curves) will become more apparent, as will the lack of correlation between positions which are positionally further apart (which are unlikely to share slowly-changing regions of curves).

With increasing dimensionality, these correlated and uncorrelated dimensions start resembling a pretty neat vector space. We can use the cosine distance between points to get a measure of the distance between them, achieving our final criteria!

It's worth considering what might happen if we were to use sin or cosine waves alone, and the strange results that it might give us. At low-frequencies, we would see regions where neighbouring waves would all be at their steepest, with the corresponding values changing a lot, and other regions where their values would all be very similar and hard to distinguish. By alternating between sin and cosine curves across the elements of our embedding, we ensure that in regions where the sine curves are changing a lot, the cosines won't be, and vice-versa, thus stabilising our embeddings.

Examining distances between positional embeddings

Given a large enough embedding, we can use the cosine distance to measure the distance between positions in our vector space.

from scipy.spatial.distance import cosine

cosine(positional_embedding(1), positional_embedding(2))
>>> 0.026488616022189992

Positions which are further apart should also generate larger distances between their embeddings

cosine(positional_embedding(1), positional_embedding(3))
>>> 0.09339161307513

cosine(positional_embedding(1), positional_embedding(30))
>>> 0.4323030365719962

We should also see positions which are far away from position 1 but close to each other at similarly small distances from one another

cosine(positional_embedding(30), positional_embedding(31))
>>> 0.02648861602218988

We can also visualise the distribution of these distances. Here are all of the positions up to 250, plotted against their distance from position 1 in embedding space.

A scatter plot of position against distance in embedding space from position 1. The distances increase monotonically with position in a roughly logarithmic shape.

Here's the distance-matrix version of that chart, showing the distances from all of the first 250 positions to one another. The data in the first row matches the values we've scattered on the plot above.

the heatmap shows values approaching 0 along the main diagonal of the matrix, increasing with the same shape as the indices get further from the diagonal.

We can also squeeze the dimensionality of our positional embeddings to plot the distances between all of them on a 2D plane.

from umap import UMAP

embeddings = np.array([positional_embedding(i) for i in range(1000)])
embedder = UMAP(n_components=2, metric="cosine")
embedder.fit(embeddings)
embeddings_2d = embedder.transform(embeddings)

We'd expect them to fall in a continuous line, with neighbouring positions closest to one another. Here they are, coloured according to their position:

a wiggling continuous line, with a smooth colour gradient along its length

Just what we expected!

Combining positional and word embeddings

The original paper uses the same dimensionality for both word and positional embeddings, and uses their sum as the input to the network. To be honest, I don't understand that decision. I would expect that concatenating the embeddings would inject more coherent information into the network. If you're interested in that experiment, let me know - I might write a follow-up post with the results!

Positional embeddings in practice

Many networks use the positional embedding scheme described above, while some models like the RoFormer expand on the idea, more effectively leveraging the information that positional embeddings encode.

The authors of “Attention is all you need” note that they also experimented with fully learned embeddings, as well as the sinusoidally-generated ones we've explored here, with both approaches producing similar results. Many transformer implementations choose to forego implementing the positional embeddings, using a standard learned embedding layer instead.

Personally, I still love the idea of an embedding layer which efficiently encodes information without having to rely on learning, and I'll be looking for areas in my own work where I can make use of those neat trigonometric tricks too!

All of the code used to generate the results and visualisations in this post is openly licensed and available in the corresponding github repo. If you got something from this post, please give it a star! And if I've made any silly mistakes, open an issue and let me know! 😅