Thinking Like Transformers

Gail Weiss, Yoav Goldberg, Eran Yahav

Introduction/Motivation

Transformer models are famously opaque, made up of millions—sometimes billions—of uninterpretable parameters. It's natural to ask if there is a better representation of these parameters. Thinking Like Transformers argues that sometimes the answer is "yes," and that the better representation takes the form of a programming language called RASP, short for Restricted Access Sequence Processing Language.

RASP is a simple language designed to be “compiled” to transformer weights. Like transformers, RASP programs take sequences as input and produce different sequences as output. Again like transformers, the structure of the RASP language is three-fold:

First, every program starts with tokens (the input to a program, of the kind you might see in a context window) and indices (a sequence of integers between 0 and n-1, where n is the length of tokens). By analogy to the tokens and positional embeddings of a transformer, the tokens and indices of a RASP program combine to give sequential structure to user input.

The next kind of RASP program component is the element-wise operation, analogous to the MLP components of trained transformer. These will be familiar: we can use + to add two sequences; pow(., .) to exponentiate each element in a sequence; <,>,<=,>=, and == to express predicates over two elements; and so on.

Most importantly, we also have an attention mechanism analogue: the select and aggregate operations. Although a fuller description can be found below, at a high level, select combines two sequences with a predicate, generating a matrix containing the result of applying the predicate to every possible combination of sequence elements. aggregate, subsequently, will combine that resulting matrix with a third sequence to produce a new sequence. If you squint, you can see how the key and query vectors of the transformer's attention mechanism are treated similarly to the inputs to select, and how the value vectors of the transformer are analogous to the the sequence input of aggregate.

What can you actually do with RASP? A few things. First, having constructed a RASP program, you have also established an upper bound on the number of layers and attention heads needed to express your algorithm. Second, and more speculatively, you might think you've improved your intuitive sense of what a trained model might be doing internally when it performs the task expressed by your program. (The authors find that a trained model will sometimes, but not always, embody the algorithm implied by the equivalent RASP program.) Finally, RASP can be used as "laboratory for interpretability." [1] When trained transformer models are interpreted, it is often unclear when or how full success is achieved. Using RASP programs as interpretation targets can serve as a ground truth for interpretability methods. [2]

Restricted Access Sequence Processing (RASP)


Analogues to components in a transformer block

Caption: Components of a Transformer block and their RASP analogues.

Sequence Operators (S-Ops)

S-ops in RASP are operations that take a sequence as input and generate a sequence of the same length as output. Functions of MLP and attention blocks in a transformer are emulated as S-Ops in RASP. RASP also provides two built-in S-Ops:

  tokens("hello") = [h, e, l, l, o]
  indices("hello") = [0, 1, 2, 3, 4]


Element-wise Operators

These operations are analogous to element-wise transformations performed by an MLPs.


An example of such operations will be:

  3 * indices("hello") = [0, 3, 6, 9, 12]


Select and Aggregate Operations

RASP emulates the mechanism of a single attention head with a combination of select() and aggregate() operations.

select() takes two sequences; the key and the query as input, and generates a binary mask using them. In RASP, this binary mask is called a selector.

    select( key , query , binary op )



The selector mask is be a square matrix with binary values that is analogous to the attention pattern of a single attention head.
The value in cell (i, j) of seletor is decided as:   selector(i, j) = query(i) op key(j)

To illustrate this, select( key = [1, 2, 3, 4, 5] , query = [0, 1, 2, 3, 4] , >= ) will generate the following selector -
In RASP, select() is always followed by an aggregate() operation, which takes a selector matrix and a value sequence as input, like below -

    aggregate( selector , value )


And, generates a sequence of the same length by aggregating the values in the value sequence as per the mixture defined in the rows of the selector matrix.


From here on, we will borrow the following concise diagram from RASPy [3] to illustrate how a single attention head is emulated using a pair of select - aggregate operations.


** A RASP selector is a square matrix of binary values, whereas an attention pattern has continuous values. Also, in an attention pattern, values in a row must sum to 1 due to the softmax applied on scaled attention. RASP doesn't impose this on a selector either.
** In RASP, values are aggregated as the mean. However, in this analysis we show aggregation as sum, similar to the implementation of RASPy [3].

Some Illustrative Examples


Example 1: Count the number of times a character c appears in a sequence S
Input Format: <sequence>-<c>
RASP Program:


The figure below illustrates each step for an input "abaa-a". This implementation will require 3 transformer layers. Each of the attention and MLP operations will add one set of new information to the residual streams. The answer will be written after Attn_3.




Example 2: Find the minimum element in a sequence
Idea: Given a sequence S = [3, 2, 3, 5, 2], we will calculate 2 counts as follows -
  1. For each xi, count the number of elements in S that are less than xi
  2. Also, count the number of elements that appear before xi and is equal to xi
For the example sequence, S = [3, 2, 3, 5, 2], the counts will be

count_less = [2, 0, 2, 4, 0]
count_before_eq = [0, 0, 1, 0, 1]
And, count = count_less + count_before_eq = [2, 0, 3, 4, 1]

There will only be one zero in the count sequence and it will be on the index where the minimum element appears first in S.
The RASP code with operations and updates at each step is illustrated below.




Compiling RASP into a transformer model

Once we have a RASP program that solves a certain task, it is possible to compile that RASP program into an actual transformer model. The longest path in the RASP program from input to output decides the number of transformer layers. And, the number of independent select - aggregate operations being done in parallel decides the number of attention heads per layer.

After we decide on a transformer architecture, the model can be trained with supervision from -

  1. An dataset with input-output pairs that demonstrates the task, how any supervised model is usually trained.
  2. Attention patterns of each of the heads are matched to a corresponding selector pattern in the RASP program.
(a) shows the RASP program that solves the double histogram problem. The program uses three attention operators. (b) shows the selector pattern for the 3 attention operations for an input sequence "§aaabbccdef". (c) shows the corresponding attention heatmaps for a transformer model trained to solve the task with the attention pattern specified by the selectors of the RASP program.
Two followup works by Lindner et al, 2023 [1] and Friedman et al, 2023 [2] presented compilers that generate a transformer model given a RASP program and a set of input-output demonstration pairs.

Related Works

Prior to this work, expressing RNNs as automata was an active line of research, seemingly inaugurated by [1]. Certain problem formulations provided access to training examples, while others allowed access to model weights or activations [2], and still others constrained the extractor to use model outputs only [3] [4]. Similarly, using automata to characterize the expressive power of RNNs was tackled in [5], which focused on second-order RNNs, and in [6] and [7], both of which restricted their attention to “saturated” RNNs, or RNNs in which all the activation functions have been replaced with step functions.

At the time this paper was published, analogous work for transformers was more tightly-scoped: [8] showed that transformers are expressive enough to approximate arbitrary continuous sequence-to-sequence functions; [9], by contrast, showed that fixed-size transformers are unable to model both periodic finite-state languages and hierarchical structure; finally, [10] showed that some transformer variants are Turing complete.

Applications


Testing Interpretability Tools

Interpretability is a budding field that seeks to understand how neural networks work by decompiling them into human-readable algorithm or program. Weiss et al. here take an opposite approach in this paper - solving a particular problem with RASP that only allows operations analogous to different components of a transformer. This RASP solution can be compiled into a transformer model that is interpretable by design. These interpretable toy models may not be very useful in practice, but that can be used as test cases for different tools or approaches aimed at understanding the functioning of a more intricate transformer model.

Caption: Toy models compiled to follow a known mechanism can be used as unit tests for interpretability tools.
** Figure taken from Lindner et al. 2023 (Figure 1). Tracr is their compiler to compile RASP into a transformer.

Customize Transformers for Specific Tasks

Although transformer architectures are being used in different situations, the current paradigm of how the components are ordered is somewhat consistent - attention and MLP are interleaved with attention followed by MLP in each transformer block, just as suggested by Vaswani et al 2017. Press et al 2023 found that re-ordering MLP and attention modules can have significant implications on task performance. Specifically, they found that, provided there is some interleaving in the middle - 1) pushing MLPs towards earlier in the computation weakened the model's capability to learn language, 2) while pushing attentions towards earlier resulted in a stronger language model. This insight suggests that transformers with different architectures may be helpful in learning different tasks. And, RASP may help design such customized transformers and evaluate their differences. However, we will later discuss how RASP operations are not completely faithful to the components of a transformer block.


An Educational Tool

Solving different problems with RASP, may help develop intuitions on how different modules of a transformer work and also their respective capabilities. This may be useful as an educational tool for training researchers or industry practitioners interested in working with transformers.

Social Impact

As of December 4, 2023, this paper has only 40 citations. It is impressive considering that the paper was published only two years ago, however it has not been "very impactful" yet. But, here is one scenario where this paper may end up have some impact (although this is very speculative) -

The impact of this paper may grow in future based on development in the field of interpretability. In future, there might be some interpretability tool that claims to understand a category of phenomena in models. A simple toy task under that category may be solved with a RASP program and then compiled into a transformer model. That model may be used to evaluate the efficacy of the interpretability tool.
As artificial models get increasingly capable and get more integrated into different aspects of our daily lives, policy-makers may want to make sure that these models are safe for deployment. And, this paper may have laid the groundwork for testing a subset of the tools that will be used to ensure the safety of artificial models.

Follow-on Research

We suggest two possible avenues of further research:

When does gradient descent find a RASP program?

Strikingly, the authors observe an example of a model trained on the Histogram-BOS task ultimately matching the attention pattern of the equivalent RASP program. How often—and under which conditions—does this happen?

There are many of ways to approach this, but one might involve generating random RASP programs as targets, and then training transformer architectures chosen to match the targets while varying the random seed. How often does SGD find an equivalent attention pattern? How does that vary with model size or program type? The extent to which RASP is a useful avenue of further research is closely related to its ability to mimic the inductive biases of real-world transformer models, and the hope is that this research agenda could shed light on that.

Can RASP inspire improvements to the transformer architecture?

One of the ways Weiss et al. demonstrate the value of RASP is by describing its ability to explain the findings of [1]. In that paper, the authors show that a transformer architecture which has more attention sublayers early in the model, and more MLP sublayers later, outperforms standard transformers on language modeling tasks. Weiss et al. explain this with reference to RASP:

"In RASP, there is little value to repeated elementwise operations before the first aggregate: each position has only its initial input, and cannot generate new information... In contrast, an architecture beginning with several attention sublayers—i.e., multiple select-aggregate pairs—will be able to gather a large amount of information into each position early in the computation, even if only by simple rules"

Although this kind of reasoning is prompted by existing architectures, the process could in principle run in reverse: perhaps there are lessons to be taken from the structure of RASP programs which are relevant to designing better "transformer-esque architectures?"

An example of this process might take inspiration from the findings of [2] which identified tasks which a transformer could theoretically solve, but for which empirical training did not converge to a solution. RASP programs for those tasks could shed light on the kinds of architectures with the right inductive bias for those tasks.

About the Authors

The first author of the paper, Gail Weiss, was a PhD student at the Technion when the work was completed. Advised by the other two authors, she focuses mostly on the intersection of neural networks and formal language theory. (She is the founder, for example, of the Formal Languages and Neural Networks (FLaNN) Discord.) Today, she is a postdoc at EPFL, in France.

Yoav Goldberg—also an Israeli computer scientist, although in this case, a professor at Bar Ilan University—was an early participant in NLP's embrace of neural networks. This led to a 2017 book, Neural Network Methods for Natural Language Processing and a series of highly cited papers in the mid-to-early 2010s attempting to understand word embeddings. In addition to his academic position, he also serves as Research Director at the Israeli branch of the Allen Institute for AI. He received his PhD at Ben-Gurion University, went on to a postdoc at Google Research, and sells t-shirts on the internet, available for sale right now.

The final author, Eran Yahav, is a professor at Technion and CTO at TabNine (an early LLM-based coding assistant). His previous life is a bit of a mystery, although he was, at some point, affiliated with Tel-Aviv University. In his current life, he has written many machine learning papers on program analysis, program synthesis, and programming languages.

Review

This paper proposes a restricted programming language RASP, where a user can only use operations that are roughly analogous to the workings of different modules of a transformer model. The authors argue that using RASP will allow people to "think like transformers" and provide experimental evidence with some toy problems that human intuition can even be compiled into transformer models.

The strengths of this paper has been already discussed in previous sections - RASP can be used as an educational tool to devolop intuitions on how different components of a transformer block work. It may even inform us on how to customize different transformer architectures targeting specific problems. And, there might be a future where this approach can be used to evaluate a subset of tools used for ensuring that a transformer model is safe for deployment, although currently this is just a speculation by the authors of this blog.

A weakness of this work is that RASP operations are not completely faithful to how different components of a transformer block work in practice. For example:

  1. RASP completely ignores the Layernorms (Ba et al. 2016). Layernorms are used in transformers before each of the attention and MLP operations, and also finally before decoding the latent back to vocabulary space. This ensures stability in training and improves the performance (Brody et al. 2023). Element-wise operations in RASP, which are analogous to MLP transformations, may be used to perform element-wise mathematical operations, such as; multiplications and divisions. But, the effect of Layernorms may wash out these transformations in practice.
  2. RASP makes the assumption that information stored in intermediate variables will be stored as it is in the residual stream. In theory this may be achieved by ensuring mutually orthogonal subspaces for all the intermediate variables. However, in practice transformers only have a limited number of dimensions in the residual stream. Thus it is most likely that subsequent attention and MLP modules may rewrite or even delete some of those variables.
  3. RASP emulates the attention pattern with a binary selector and as a result all the values that get attended by a query will get equal attention. In practice, however, the attention pattern has continuous values and attention paid to every single value can be different.
Despite certain limitations this paper can be lauded for the novelty of its main idea. Overall the paper was well presented and easy to follow.

References

[1] Lindner, David, János Kramár, Matthew Rahtz, Thomas McGrath, and Vladimir Mikulik. "Tracr: Compiled transformers as a laboratory for interpretability." (2023)

[2] Friedman, Dan, Alexander Wettig, and Danqi Chen. "Learning Transformer Programs." (2023)

[3] Sasha Rush and Gail Weiss. "RASPy"

[4] Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." (2017)

[5] Press, Ofir, Noah A. Smith, and Omer Levy. "Improving transformer models by reordering their sublayers." (2019)

[6] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." (2016)

[7] Brody, Shaked, Uri Alon, and Eran Yahav. "On the Expressivity Role of LayerNorm in Transformers' Attention." (2023)

[8] Arthur Conmy, Augustine N. Mavor-Parker, Aengus Lynch, Stefan Heimersheim, and Adriá Garriga-Alonso. "Towards Automated Circuit Discovery for Mechanistic Interpretability." (2023)

[9] C. L. Giles, C. B. Miller, D. Chen, H. H. Chen, G. Z. Sun, and Y. C. Lee. "Learning and Extracting Finite State Automata with Second-Order Recurrent Neural Networks." (1992)

[10] Christian W. Omlin, and C.Lee Giles. "Extraction of rules from discrete-time recurrent neural networks." (1996)

[11] Gail Weiss, Yoav Goldberg, and Eran Yahav. "Extracting automata from recurrent neural networks using queries and counterexamples." (2018)

[12] Stephane Ayache, Remi Eyraud, and Noe Goudian. "Explaining Black Boxes on Sequential Data using Weighted Automata." (2019)

[13] Guillaume Rabusseau, Tianyu Li, and Doina Precup. "Connecting Weighted Automata and Recurrent Neural Networks through Spectral Learning." (2019)

[14] William Merrill. "Sequential Neural Networks as Automata." (2019)

[15] William Merrill, Gail Weiss, Yoav Goldberg, Roy Schwartz, Noah A. Smith, and Eran Yahav. "A Formal Hierarchy of RNN Architectures." (2020)

[16] Chulhee Yun, Srinadh Bhojanapalli, Ankit Singh Rawat, Sashank J. Reddi, and Sanjiv Kumar. "Are Transformers universal approximators of sequence-to-sequence functions?." (2019)

[17] Michael Hahn. "Theoretical Limitations of Self-Attention in Neural Sequence Models." (2020)

[18] Jorge Pérez, Pablo Barceló, Javier Marinkovic. "Attention is Turing-Complete." (2021)

[19] Grégoire Delétang, Anian Ruoss, Jordi Grau-Moya, Tim Genewein, Li Kevin Wenliang, Elliot Catt, Chris Cundy, Marcus Hutter, Shane Legg, Joel Veness, and Pedro A. Ortega. "Neural Networks and the Chomsky Hierarchy." (2022)

Team Members