MLOps Community
+00:00 GMT

Explainable AI: Visualizing Attention in Transformers

Explainable AI: Visualizing Attention in Transformers
# Explainable AI
# Machine learning
# MLops

And logging the results in an experiment-tracking tool Photo by Jeffery Ho on Unsplash , edited by the author

August 14, 2023
Abby Morgan
Abby Morgan
Explainable AI: Visualizing Attention in Transformers

And logging the results in an experiment-tracking tool

In this article, we explore one of the most popular tools for visualizing the core distinguishing feature of transformer architectures: the attention mechanism. Keep reading to learn more about BertViz and how you can incorporate this attention visualization tool into your NLP and MLOps workflow with Comet.

Feel free to follow along with the full-code tutorial here, or, if you can’t wait, check out the final project here.

Introduction

Transformers have been described as the single most important technological development to NLP in recent years, but their processes remain largely opaque. This is a problem because, as we continue to make major machine learning advancements, we can’t always explain how or why– which can lead to issues like undetected model bias, model collapse, and other ethical and reproducibility issues. Especially as models are more frequently deployed to sensitive areas like healthcare, law, finance, and security, model explainability is critical.

Horizontal bar chart showing gender and race projections for different professions, as predicted by a Word2Vec model (pre-transformer)

What is BertViz?

BertViz is an open-source tool that visualizes the attention mechanism of transformer models at multiple scales, including model level, attention head level, and neuron level. But BertViz isn’t new. In fact, early versions of BertViz have been around since as early as 2017.

So, why are we still talking about BertViz?

BertViz is an explainability tool in a field (NLP) that is otherwise notoriously opaque. And, despite its name, BertViz doesn’t only work on BERT. The BertViz API supports many transformer language models, including the GPT family of models, T5, and most HuggingFace models.

BertViz visualization in the Comet UI for two different types of transformer models: an encoder-only distilbert transformer for question-answering and a decoder-only gpt-2 transformer for text generation

As transformer architectures have increasingly dominated the machine learning landscape in recent years, they’ve also revived an old but important debate regarding interpretability and transparency in AI. So, while BertViz may not be new, its application as an explainability tool in the AI space is more relevant now than ever.

But first, transformers

To explain BertViz, it helps to have a basic understanding of transformers and self-attention. If you’re already familiar with these concepts, feel free to skip ahead to the section where we start coding.

We won’t go into the nitty gritty details of transformers here, as that’s a little beyond the scope of this article, but we will cover some of the basics. I also encourage you to check out the additional resources at the end of the article.

In the beginning (the prehistoric era of NLP)

So, how, exactly, does a computer “learn” natural language? In short, they can’t– at least not directly. Computers can only understand and process numerical data, so the first step of NLP is to break down sentences into “tokens,” which are assigned numerical values. The question driving NLP then becomes “How can we accurately reduce language and communication processes to computations?”

Some of the first NLP models included feed-forward neural networks like the Multi-Layer Perceptron (MLP) and even CNNs, which are more popularly used today for computer vision. These models worked for some simple classification tasks (like sentiment analysis) but had a major drawback: their feed-forward nature meant that at each point in time, the network only saw one word as its input. Imagine trying to predict the word that follows “the” in a sentence. How many possibilities are there?

A visualization of the difficulty next-word sentence prediction for sequence models that don’t “remember” any context

To solve this problem, Recurrent Neural Networks (RNNs) and Long Short-Term Memory Networks (LSTMs like Seq2Seq) allowed for feedback or cycles. This meant that each computation was informed by the previous computation, allowing for more context.

This context was still limited, however. If the input sequence was very long, the model would tend to forget the beginning of the sequence by the time it got to the end of the sequence. Also, their sequential nature didn’t allow for parallelization, making them extremely inefficient. RNNs also suffered notoriously from exploding gradients.

Introducing transformers

Transformers are sequence models that abandon the sequential structure of RNNs and LSTMs and adopt a fully attention-based approach. Transformers were initially developed for text processing, and are central to relatively all state-of-the-art NLP neural networks today, but they can also be used with image, video, audio, or virtually any other sequential data.

The key differentiating feature of transformers from previous NLP models was the attention mechanism, as popularized in the Attention Is All You Need paper. This allowed for parallelization, which meant faster training and optimized performance. Attention also allowed for much larger contexts than recurrence, meaning transformers could craft more coherent, relevant, and complex outputs.

The original transformer architecture, as visualized in the 2017 paper that made them famous, Attention Is All You Need.

Transformers are made up of encoders and decoders, and the tasks we can perform with them depend on whether we use either or both of these components. Some common transformer tasks for NLP include text classification, named entity recognition, question-answering, text summarization, fill-in-the-blanks, next-word prediction, translation, and text generation.

Chart showing the three different types of transformers: encoder-only, decoder-only, and encoder-decoder models. Chart also lists tasks specific to each type of transformer, as well as examples and alternative names.

How do transformers fit into the larger ecosystem of NLP models?

You’ve probably heard of Large Language Models (LLMs) like ChatGPT or LLaMA. The transformer architecture is a fundamental building block of LLMs, which use self-supervised learning on vast amounts of unlabelled data. These models are also sometimes referred to as “foundation models” because they tend to generalize well to a wide range of tasks, and in some cases are also available for more specific fine-tuning. BERT is an example of this category of model.

A graphic showing the relationship between transformer architectures, foundation models, and large language models. Graphic includes (as examples): ViT, BLOOM, BERT, Falcon, LLaMA, ChatGPT, and SAM (Segment Anything Model).

That’s a lot of information but the important takeaway here is that the key differentiating feature of the transformer model (and by extension, all transformer-based foundational LLMs) is the concept of self-attention, which we’ll go over next.

Attention

Generally speaking, attention describes the ability of a model to pay attention to the important parts of a sentence (or image, or any other sequential input). It does this by assigning weights to input features based on their importance and their position in the sequence.

Remember that attention was the concept that improved the performance of previous NLP models (like RNNs and LSTMs) by lending itself to parallelization. But attention isn’t just about optimization. It also plays a pivotal role in broadening the context a language model is able to consider while processing and generating language. This enables a model to produce contextually appropriate and coherent texts in much longer sequences.

A graphic showing the BertViz representation of the sentence “the animal didn’t cross the street because it was too scared.” The last word, scared, was predicted by the GPT-2 model. The graphic shows that GPT-2 correlates “it” to the animal.

If we break transformers down into a “communication” phase and a “computation” phase, attention would represent the “communication” phase. In another analogy, attention is a lot like a search-retrieval problem, where given a query, q, we want to find the set of keys, k, most similar to q, and return the corresponding values, v.

  1. Query: What are the things I am looking for?
  2. Key: What are the things that I have?
  3. Value: What are the things that I will communicate?
A visualization of how to calculate attention for transformers

Types of attention

Self-attention refers to the fact that every node produces a key, query, and value from that individual node. Multi-headed attention is just self-attention that is applied multiple times in parallel with different initialized weights. Cross-attention means that the queries are still produced from a given decoder node, but the keys and the values are produced as a function of the nodes in the encoder.

This is an oversimplified summary of transformer architectures, and we’ve glossed over quite a few details (like positional encodings, segment encodings, and attention masks). For more information, check out the additional resources below.

Visualizing attention before BertViz

Transformers are not inherently interpretable, but there have been many attempts to contribute post-hoc explainability tools to attention-based models.

Previous attempts to visualize attention were often overly complicated and didn’t translate well to non-technical audiences. They could also vary greatly from project to project and use-case to use-case.

A compliation of some very confusing and complicated previous attempts to visualize attention

Some successful attempts to explain attention behavior included attention-matrix heat maps and bi-partite graph representations, both of which are still used today. But these methods also have some major limitations.

A graphic showing some methods of visualizing transformer attention other than BertViz

BertViz ultimately gained popularity for its ability to illustrate low-level, granular details of self-attention, while still remaining remarkably simple and intuitive to use.

GIF of BertViz Attention Head View, selecting transformer later and attention format type, and selecting specific attention heads, as visualized in Comet ML

That’s a nice, clean visualization. But, what are we actually looking at?

How BertViz Breaks It All Down

BertViz visualizes the attention mechanism at multiple local scales: the neuron level, attention head level, and model level. Below we break down what that means, starting from the lowest, most granular level, and making our way up.

A graphic showing the model view, attention head view, and neuron view of a transformer model using BertViz

Visualizing BertViz With Comet

We’ll log our BertViz plots to Comet, an experiment tracking tool, so we can compare our results later on. To get started with Comet, create a free account here, grab your API key, and run the following code:

import comet_ml
comet_ml.init(api_key='<YOUR-API-KEY>')
experiment = comet_ml.Experiment()

Visualizing attention in Comet will help us interpret our models’ decisions by showing how they attend to different parts of the input. In this tutorial, we’ll use these visualizations to compare and dissect the performance of several pre-trained LLMs. But these visualizations can also be used during fine-tuning for debugging purposes.

To add BertViz to your dashboard, navigate to Comet’s public panels and select either ‘Transformers Model Viewer’ or ‘Transformers Attention Head Viewer.’

GIF showing how to add transformer model view of BertViz visualization to Comet UI dashboard.

We’ll define some functions to parse our models’ results and log the attention information to Comet. See the Colab tutorial to get the full code used. Then, we’ll run the following commands to start logging our data to Comet:

Text generation example

textgen_model_version = "gpt2"
text_gen_prompts = [
"The animal didn't cross the street because it was too",
"The dog didn't play at the park becase it was too",
"I went to the store. At the store I bought fresh",
"At the store he bought flowers, candy, jewelry, and",
"The dog ran up the street and barked too",
"In 2016, the Young Mens' Christian Association (YMCA) was very",
"The Doctor asked the Nurse a question. She",
"The Doctor asked the Nurse a question. He",
]
text_generation_viz(
text_gen_prompts,
textgen_model_version,
)

Question-answering example

context = r"""A robot may not injure a human being or, through inaction, allow a human being to come to harm.
A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
A robot must protect its own existence as long as such protection does not conflict with the First or Second Laws.
"""
questions = [
"Can a robot hurt a human?",
"Can a robot injure a human?",
"Should a robot obey orders from humans?",
"Can a robot protect itself from a human?",
"Can a robot love a human?"
]
qa_viz(context, questions, "distilbert-base-uncased-distilled-squad")

Sentiment analysis example

sa_prompts = [
"Many people dislike Steve Jobs, while acknowledging his genius.",
"The quick, brown fox jumps over the lazy dog.",
"It was a beautiful day.",
"It was a horrible day.",
"I am confused.",
"That movie was so sick but I wish it was longer.",
"That movie was so awesome but I wish it was longer.",
"That movie was so gross but I wish it was longer.",
"That movie was so available but I wish it was longer.",
]
sa_model_version = "distilbert-base-uncased-finetuned-sst-2-english"
sentiment_viz(sa_prompts, sa_model_version)

Neuron View

At the lowest level, BertViz visualizes the query, key, and value embeddings used to compute attention in a neuron. Given a selected token, this view traces the computation of attention from that token to the other tokens in the sequence.

In the GIF below, positive values are colored blue, and negative values are colored orange, with color intensity reflecting the magnitude of the value. Connecting lines are weighted based on the attention score between respective words.

A short GIF demonstrating how to use BertViz to visualize the computations on a neuron-level of the attention layer for our transformer experiment in Comet ML.

Whereas the views in the following two sections will show what attention patterns the model learns, this neuron view shows how those patterns are learned. The neuron view is a bit more granular than we need to get for this particular tutorial, but for a deeper dive, we could use this view to link neurons to specific attention patterns and, more generally, to model behavior.

It’s important to note that it isn’t entirely clear what relationships exist between attention weights and model outputs. Some, like Jain et al. in Attention Is Not Explanation, claim that standard attention modules should not be treated as though they provide meaningful explanations for predictions. They propose no alternatives, however, and BertViz remains one of the most popular attention visualization tools today.

Head View

The attention-head view shows how attention flows between tokens within the same transformer layer by uncovering patterns between attention heads. In this view, the tokens on the left are attending to the tokens on the right and attention is represented as a line connecting each token pair. Colors correspond to attention heads and line thickness represents the attention weight.

In the drop-down menu, we can select the experiment we’d like to visualize, and if we logged more than one asset to our experiment, we can also select our asset. We can then choose which attention layer we’d like to visualize and, optionally, we can choose any combination of attention heads we’d like to see. Note that the color intensity of the lines connecting tokens corresponds to the attention weights between tokens.

BertViz interactive visualization, as plotted within the Comet UI. Select experiment, asset, transformer model layer, and attention format.

We can also specify how we’d like our tokens to be formatted. For the question-answering example below, we’ll select “Sentence A → Sentence B” so we can examine the attention between the question and answer:

A BertViz visualization of attention with different sentence structure comparisons

Attention head patterns

Attention heads do not share parameters, so each head learns a unique attention mechanism. In the graphic below, attention heads are examined across layers of the same model given one input. We can see that different attention heads seem to focus on very unique patterns.

On the top left, attention is strongest between identical words (note the crossover where the two instances of “the” intersect). In the top center, there’s a focus on the next word in the sentence. On the top right and bottom left, the attention heads are focusing on each of the delimiters ([SEP] and [CLS], respectively). The bottom center places emphasis on the comma. And the bottom right is almost a bag-of-words pattern.

BertViz shows that transformer attention captures various patterns in language, including positional patterns, delimiter patterns, and bag-of-words.

Attention heads also capture lexical patterns. In the following graphic, we can see examples of attention heads that focus on list items (left), verbs (center), and acronyms (on the right).

BertViz shows transformer attention heads capture lexical patterns like list items, verbs, and acronyms.

Attention head biases

One application of the head view is detecting model bias. If we provide our model (in this case GPT-2) with two inputs that are identical except for the final pronouns, we get very different generated outputs:

BertViz can help capture model bias in transformer attention mechanisms

The model is assuming that “he” refers to the doctor, and “she” to the nurse, which might suggest that the co-reference mechanism is encoding gender bias. We would hope that by identifying a source of bias, we can potentially work to counteract it (perhaps with additional training data).

Model View

The model view is a bird’s-eye perspective of attention across all layers and heads. Here we may notice attention patterns across layers, illustrating the evolution of attention patterns from input to output. Each row of figures represents an attention layer and each column represents individual attention heads. To enlarge the figure for any particular head, we can simply click on it. Note that you can find the same line pattern in the model view as in the head view.

A GIF showing how to enlarge the attention head view in the Comet UI using the model view.

Model view applications

So, how might we use the model view? Firstly, because each layer is initialized with separate, independent weights, the layers that focus on specific patterns for one sentence may focus on different patterns for another sentence. So we can’t necessarily look at the same attention heads for the same patterns across experiment runs. With the model view, we can more generally identify which layers may be focusing on areas of interest for a given sentence. Note that this is a very inexact science and, as many have mentioned, “if you look for it, you will find it.” Nonetheless, this view does give us some interesting insight as to what the model may be focusing on.

In the image below, we use the same example from earlier in the tutorial (left). On the right, a slightly different version of the sentence. In both cases, GPT-2 generated the last word in the sentence. At first, it may seem silly to think the dog had too many plans to go to the park. But examining the attention heads shows us the model was probably referring to the “park” as “too busy.”

BertViz helps unravel how a transformer understands language.

Explainability in AI

A horizontal bar chart showing gender discrepancies in Amazon’s hiring practices

As AI becomes more advanced, model calculations can become nearly impossible to interpret, even by the engineers and researchers that create them. This can lead to a whole host of unintended consequences, including, but not limited to: perpetuation of bias and stereotypes, distrust in organizational decision-making, and even legal ramifications. Explainable Artificial Intelligence (XAI) is a set of processes used to describe a model’s expected impact and potential biases. A commitment to XAI helps:

  1. Organizations adopt a responsible approach to AI development
  2. Developers ensure a model is working as expected and meets regulatory requirements
  3. Researchers characterize accuracy, fairness, and transparency for decision-making
  4. Organizations build trust and confidence

So how can practitioners incorporate XAI practices into their workflows, when the most popular ML architectures today– transformers– are notoriously opaque? The answer to this question isn’t simple, and explainability must be approached from many different angles. But we hope this tutorial gives you one more tool in your XAI toolbox by helping you visualize attention in transformers.

Conclusion

Thanks for making it all the way to the end, and we hope you enjoyed this article. Feel free to connect with us on our Community Slack channel with any questions, comments, or suggestions!

Additional Resources

  1. The Illustrated Transformer
  2. The Illustrated BERT
  3. Deconstructing BERT part 1
  4. Deconstructing BERT part 2
  5. A Multi-scale Visualization of Attention in the Transformer Model
  6. BertViz: A Tool For Visualizing Multi-Head Self-Attention in the BERT Model
  7. Stanford’s CS25: Introduction to Transformers with Andrej Karpathy
  8. Stanford’s CS25: Transformers in Language with Mark Chen
  9. DeepLearning AI’s Natural Language Processing Specialization
  10. Natural Language Processing with PyTorch by Delip Rao, Brian McMahan
  11. ___________________________________________________________________________
  12. This story is published on Comet. Connect with us on LinkedIn to get the latest AI stories and insights right in your feed. Let’s shape the future of AI together!
  13. ___________________________________________________________________________
Dive in
Related
Blog
Components of a Production ML System Using Only Python
Sep 9th, 2022 ‱ Views 175
Blog
What is the secret formula for MLOps success?
Nov 21st, 2021 ‱ Views 225
Blog
Is AI/ML Monitoring just Data Engineering? đŸ€”
By MĂ©dĂ©ric Hurier ‱ Jul 24th, 2023 ‱ Views 0