I’m working on an experiment comparing the internal representations of two architectures when solving a sequential algorithm, but training models to use a sequential algorithm is surprisingly hard. The optimization landscape makes it easier for models to learn parallel algorithms or memorize lookup tables, so I needed to make some specific architectural and training decisions to get models to actually learn the sequential algorithm. Even with all of these tricks, the results are seed-dependent and I needed to inspect the resulting models to prove that they did or didn’t learn the expected algorithm.In this post, I’ll document what did and didn’t work, and the techniques I used to prove whether or not the model learned a sequential algorithm.Note: Everything I tried here was very sensitive to task selection and architecture, so I think this is mostly useful as “here are some things that seem to help” and less as a “this will definitely work for your problem/architecture”. Hopefully it will help someone else trying to interpret toy models to more-quickly train models to do something interesting though.Table of ContentsTask SelectionNatural Language ❌Tasks With Shortcuts ❌Non-commutative Algorithm ✅LossFull-sequence Loss ❌Answer-only Loss ✅Answer-only → Full-sequence Curriculum ✅Other DecisionsData Weighting by Difficulty ✅Weight Sharing ✅ProofLogit Lens StaircaseSequential AblationLessons and LimitationsTask SelectionNatural Language ❌Frontier LLMs are able to learn multi-step reasoning by pretraining on natural text, so one obvious approach is to just train a GPT-style LLM on a dataset like FineWeb-Edu. Unfortunately, the algorithm I’m looking for requires grokking, so I can’t just train NanoGPT for an hour and have something useful.Frontier LLMs are proof that this will eventually work, but I don’t have frontier LLM amounts of money or patience, so this option was dropped in the brainstorming phase.Tasks With Shortcuts ❌I did a search to find a sequential task that models could learn without a slow grokking phase. After some false-starts with surprisingly-difficult tasks, I decided to copy the binary composition task from Unveiling Transformers with LEGO (Zhang et. al 2023).The original used BERT, and I modified the task slightly to work in an autoregressive decoder-only setup:a = +1; b = −a; e = +b; d = −f ; c = +d; f = +e ; f ? [1]Where a = +1, b = -1, e = -1, d = +1, c = +1, and f = +1.I did confirm that the model learns this very rapidly, but unfortunately (as the LEGO paper predicts), it learns a shortcut: It’s easier to just count the number of – signs. Since the “count-the-minus-signs” algorithm can be learned in parallel, models strongly favor that solution and in practice this task can be solved trivially with one layer. Non-commutative Algorithm ✅Luckily, LEGO also mentioned another option: Non-commutative algorithms. The problem with the binary task is that the order of operations doesn’t matter, but there’s a whole field of math for algorithms that are not commutative. So, continuing to follow advice from the LEGO paper, the next task was to compute compositions of the dihedral group D3. The math isn’t important, but the important thing here is that D3 compositions are non-commutative (the order of operations matters).This task looks like:((r0 ⋅ r1) ⋅ s0) ⋅ s2 = r2Where the intermediate states are: r0 ⋅ r1 = r1, r1 ⋅ s0 = s1, and s1 ⋅ s2 = r2. Note that we’re evaluating D compositions left-to-right to allow the model to compute intermediate states in order.In practice, I tokenized this like:<start>e0<op>op1<op>op2<op>op3<predict>eaThis algorithm requires a 6×6 lookup table for each operation, but because the order is important, a composed lookup table for multiple steps requires 6^k memorized values. My testing was on k=6, which means the model either needs to learn to a sequential algorithm with small 36-value lookup tables, or a single memorized 46,656 value lookup table.This worked! With this task, plus the other design decisions I’ll describe in the next sections, I was able to train models to usually learn the sequential algorithm.LossFull-sequence Loss ❌The standard way to train language models is to calculate loss on every position, meaning the model is learning to predict every token. In the chosen D3 task:<start>e0<op>e1<op>e2<op>e3<predict>e4The model would be trained to predict after <start> one of the 6 elements of D3, followed by <op> or <predict>, followed by one of the 6 elements of D3, etc.In my experiments, standard transformers always memorized the 47k value lookup table rather than a sequential algorithm when trained this way. I haven’t done a detailed investigation, but I suspect something about the loss landscape makes memorization easier.I later found that the sequential algorithm is a stable basin, so I hypothesize that the model would eventually grok the sequential algorithm with enough data, but I didn’t test this due to time and cost constraints.Answer-only Loss ✅A non-standard way to train a transformer — but closer to the BERT method in the LEGO paper — is to calculate loss only on the answer token. Given our same example above:<start>e0<op>e1<op>e2<op>e3<predict>e4The model sees the full sequence but we only calculate loss on the final token (e4 in this example). This leaves the outputs for all other positions completely unconstrained and seems to give the model space to write intermediate values.In my experiments, weight-shared transformers usually learned the sequential algorithm under answer-only loss, and standard transformers learned a partial sequential algorithm (compressed into the final 3 layers) although this was flaky and seed-dependent.Answer-only → Full-sequence Curriculum ✅Answer-only loss was effective, but also unrealistic. In order to make the results of my experiment somewhat applicable to transformers trained in the standard way, I first trained a model under answer-only loss until it reached 100% accuracy, then trained it under full-sequence loss until convergence.When a model learned the sequential algorithm in the first phase (again, flaky and inconsistent), it consistently maintained the sequential algorithm in the second phase, while learning to output the correct statistical distribution for the non-answer tokens.Other DecisionsData Weighting by Difficulty ✅Wang et al. 2024, “Grokked Transformers are Implicit Reasoners: A Mechanistic Journey to the Edge of Generalization” (NeurIPS 2024) found that the ratio of composed and non-composed examples in a dataset affects whether a model groks or not. I used this as inspiration and weighted the data set by difficulty. For each composition length k, I weighted examples of that length at , so 1:4:9:16:25:36 for k up to 6.Using this weighting, plus the answer-only to full-sequence loss curriculum, I was able to successfully train a standard transformer to learn the sequential algorithm (although still inconsistently).Weight Sharing ✅In Universal Transformers (Dehghani et al. 2018), the authors find that a transformer that shares weights between layers (or equivalently, loops a single layer) is better at learning some sequential tasks than standard transformers. I tried training a universal transformer[1] and found that it was much easier to train it to learn the sequential algorithm, learning it in the answer-only case without data weighting. Weight shared models did not learn the sequential algorithm under full-sequence loss in my experiments.ProofIn this section I’ll discuss how I used some standard mechanistic interpretability techniques to prove that the model learned the sequential algorithm.I’ll be comparing two standard 8-layer[2] transformers composing 6 elements of D3 using data weighting. Model A was trained using answer-only to full-sequence loss, and model B was trained only using full-sequence loss. Both models have dim=96 and 6 heads.ModelLayersDimHeadsLossData WeightA8966AO -> FSLB8966FSLLogit Lens StaircaseThe first piece of evidence is to just look at the residuals in the logit lens. We run a number of examples through the model, and at every position look at the softmax probability mass in every layer for our expected intermediate state at <op> and <predict> positions.Since a transformer can only attend to residuals in a previous layer, and the sequential algorithm requires 6 steps, we know that if the model learned the sequential algorithm, the first intermediate state must exist by layer 2.I used the logit lens to look for intermediate states at their corresponding <op> or <predict> positions: To learn the sequential algorithm, we expect the first intermediate state to be at the <op> position following the first operation, the second intermediate at the <op> after that, and so on[3].In model A, there’s a clear staircase where the expected intermediate states appear by layer 2, and in model B the intermediate states never rise above chance and the final state appears by layer 3, where it’s causally disconnected from any possible 6-step sequential algorithm.In Model A, you can see that the intermediate values appear at L1, L1[4], L3, L4, L4, and L5, so we suspect it learned the sequential algorithm for the first 4 steps, a combined lookup table at step 5, and then another sequential step at step 6.Model B shows no intermediates above chance until the final position, and starts to recognizably converge on the correct answer at L3, where it’s causally impossible for the sequential algorithm to have run for 6 steps.Sequential AblationThe sequential algorithm requires intermediate states to exist by a specific layer, but also never needs them again after that layer. As an additional proof of the sequential algorithm, I zeroed all residuals after the critical layer. In both models, the final answer remains at ~100% accuracy, confirming that the residual stream at these positions after the critical layer (causally disconnected from the sequential model) is not needed by either model.I zeroed the critical layer itself, and found that it destroys the accuracy of model A (sequential) but not model B. We can also see in the higher accuracy after ablating the critical layer for t[1] that the model is able to recover, and that the model seems to be using a different algorithm for t[5].By proving that model A does need the critical layer but doesn’t need anything after, we can be sure that it’s using an algorithm which uses each position sequentially.Lessons and LimitationsNeural network optimization pressure strongly disfavors sequential algorithms.Universal transformers are somewhat better at learning sequential algorithms.It’s much easier to make a model learn a sequential algorithm if there are no shortcuts, although even then it might just memorize a composed lookup table.Data distribution does affect how likely a model is to learn a sequential algorithm.^Comparing standard and universal transformers happens to be the subject of the experiment I was trying to run, so this was actually the first thing I tried, just by coincidence.^The sequential algorithm for k=6 can be learned reliably in as few as 7 layers, but I had better data at-hand for 8 layers, so that’s what I’m using here.^I also looked at the position before the <op> but consistently saw no signal. For whatever reason, all of the models I tested did their intermediate calculations at <op> positions.^Sort of. 26% is only barely above chance (17%).Discuss Read More
Training a Transformer to Compose One Step Per Layer (and Proving It)
I’m working on an experiment comparing the internal representations of two architectures when solving a sequential algorithm, but training models to use a sequential algorithm is surprisingly hard. The optimization landscape makes it easier for models to learn parallel algorithms or memorize lookup tables, so I needed to make some specific architectural and training decisions to get models to actually learn the sequential algorithm. Even with all of these tricks, the results are seed-dependent and I needed to inspect the resulting models to prove that they did or didn’t learn the expected algorithm.In this post, I’ll document what did and didn’t work, and the techniques I used to prove whether or not the model learned a sequential algorithm.Note: Everything I tried here was very sensitive to task selection and architecture, so I think this is mostly useful as “here are some things that seem to help” and less as a “this will definitely work for your problem/architecture”. Hopefully it will help someone else trying to interpret toy models to more-quickly train models to do something interesting though.Table of ContentsTask SelectionNatural Language ❌Tasks With Shortcuts ❌Non-commutative Algorithm ✅LossFull-sequence Loss ❌Answer-only Loss ✅Answer-only → Full-sequence Curriculum ✅Other DecisionsData Weighting by Difficulty ✅Weight Sharing ✅ProofLogit Lens StaircaseSequential AblationLessons and LimitationsTask SelectionNatural Language ❌Frontier LLMs are able to learn multi-step reasoning by pretraining on natural text, so one obvious approach is to just train a GPT-style LLM on a dataset like FineWeb-Edu. Unfortunately, the algorithm I’m looking for requires grokking, so I can’t just train NanoGPT for an hour and have something useful.Frontier LLMs are proof that this will eventually work, but I don’t have frontier LLM amounts of money or patience, so this option was dropped in the brainstorming phase.Tasks With Shortcuts ❌I did a search to find a sequential task that models could learn without a slow grokking phase. After some false-starts with surprisingly-difficult tasks, I decided to copy the binary composition task from Unveiling Transformers with LEGO (Zhang et. al 2023).The original used BERT, and I modified the task slightly to work in an autoregressive decoder-only setup:a = +1; b = −a; e = +b; d = −f ; c = +d; f = +e ; f ? [1]Where a = +1, b = -1, e = -1, d = +1, c = +1, and f = +1.I did confirm that the model learns this very rapidly, but unfortunately (as the LEGO paper predicts), it learns a shortcut: It’s easier to just count the number of – signs. Since the “count-the-minus-signs” algorithm can be learned in parallel, models strongly favor that solution and in practice this task can be solved trivially with one layer. Non-commutative Algorithm ✅Luckily, LEGO also mentioned another option: Non-commutative algorithms. The problem with the binary task is that the order of operations doesn’t matter, but there’s a whole field of math for algorithms that are not commutative. So, continuing to follow advice from the LEGO paper, the next task was to compute compositions of the dihedral group D3. The math isn’t important, but the important thing here is that D3 compositions are non-commutative (the order of operations matters).This task looks like:((r0 ⋅ r1) ⋅ s0) ⋅ s2 = r2Where the intermediate states are: r0 ⋅ r1 = r1, r1 ⋅ s0 = s1, and s1 ⋅ s2 = r2. Note that we’re evaluating D compositions left-to-right to allow the model to compute intermediate states in order.In practice, I tokenized this like:<start>e0<op>op1<op>op2<op>op3<predict>eaThis algorithm requires a 6×6 lookup table for each operation, but because the order is important, a composed lookup table for multiple steps requires 6^k memorized values. My testing was on k=6, which means the model either needs to learn to a sequential algorithm with small 36-value lookup tables, or a single memorized 46,656 value lookup table.This worked! With this task, plus the other design decisions I’ll describe in the next sections, I was able to train models to usually learn the sequential algorithm.LossFull-sequence Loss ❌The standard way to train language models is to calculate loss on every position, meaning the model is learning to predict every token. In the chosen D3 task:<start>e0<op>e1<op>e2<op>e3<predict>e4The model would be trained to predict after <start> one of the 6 elements of D3, followed by <op> or <predict>, followed by one of the 6 elements of D3, etc.In my experiments, standard transformers always memorized the 47k value lookup table rather than a sequential algorithm when trained this way. I haven’t done a detailed investigation, but I suspect something about the loss landscape makes memorization easier.I later found that the sequential algorithm is a stable basin, so I hypothesize that the model would eventually grok the sequential algorithm with enough data, but I didn’t test this due to time and cost constraints.Answer-only Loss ✅A non-standard way to train a transformer — but closer to the BERT method in the LEGO paper — is to calculate loss only on the answer token. Given our same example above:<start>e0<op>e1<op>e2<op>e3<predict>e4The model sees the full sequence but we only calculate loss on the final token (e4 in this example). This leaves the outputs for all other positions completely unconstrained and seems to give the model space to write intermediate values.In my experiments, weight-shared transformers usually learned the sequential algorithm under answer-only loss, and standard transformers learned a partial sequential algorithm (compressed into the final 3 layers) although this was flaky and seed-dependent.Answer-only → Full-sequence Curriculum ✅Answer-only loss was effective, but also unrealistic. In order to make the results of my experiment somewhat applicable to transformers trained in the standard way, I first trained a model under answer-only loss until it reached 100% accuracy, then trained it under full-sequence loss until convergence.When a model learned the sequential algorithm in the first phase (again, flaky and inconsistent), it consistently maintained the sequential algorithm in the second phase, while learning to output the correct statistical distribution for the non-answer tokens.Other DecisionsData Weighting by Difficulty ✅Wang et al. 2024, “Grokked Transformers are Implicit Reasoners: A Mechanistic Journey to the Edge of Generalization” (NeurIPS 2024) found that the ratio of composed and non-composed examples in a dataset affects whether a model groks or not. I used this as inspiration and weighted the data set by difficulty. For each composition length k, I weighted examples of that length at , so 1:4:9:16:25:36 for k up to 6.Using this weighting, plus the answer-only to full-sequence loss curriculum, I was able to successfully train a standard transformer to learn the sequential algorithm (although still inconsistently).Weight Sharing ✅In Universal Transformers (Dehghani et al. 2018), the authors find that a transformer that shares weights between layers (or equivalently, loops a single layer) is better at learning some sequential tasks than standard transformers. I tried training a universal transformer[1] and found that it was much easier to train it to learn the sequential algorithm, learning it in the answer-only case without data weighting. Weight shared models did not learn the sequential algorithm under full-sequence loss in my experiments.ProofIn this section I’ll discuss how I used some standard mechanistic interpretability techniques to prove that the model learned the sequential algorithm.I’ll be comparing two standard 8-layer[2] transformers composing 6 elements of D3 using data weighting. Model A was trained using answer-only to full-sequence loss, and model B was trained only using full-sequence loss. Both models have dim=96 and 6 heads.ModelLayersDimHeadsLossData WeightA8966AO -> FSLB8966FSLLogit Lens StaircaseThe first piece of evidence is to just look at the residuals in the logit lens. We run a number of examples through the model, and at every position look at the softmax probability mass in every layer for our expected intermediate state at <op> and <predict> positions.Since a transformer can only attend to residuals in a previous layer, and the sequential algorithm requires 6 steps, we know that if the model learned the sequential algorithm, the first intermediate state must exist by layer 2.I used the logit lens to look for intermediate states at their corresponding <op> or <predict> positions: To learn the sequential algorithm, we expect the first intermediate state to be at the <op> position following the first operation, the second intermediate at the <op> after that, and so on[3].In model A, there’s a clear staircase where the expected intermediate states appear by layer 2, and in model B the intermediate states never rise above chance and the final state appears by layer 3, where it’s causally disconnected from any possible 6-step sequential algorithm.In Model A, you can see that the intermediate values appear at L1, L1[4], L3, L4, L4, and L5, so we suspect it learned the sequential algorithm for the first 4 steps, a combined lookup table at step 5, and then another sequential step at step 6.Model B shows no intermediates above chance until the final position, and starts to recognizably converge on the correct answer at L3, where it’s causally impossible for the sequential algorithm to have run for 6 steps.Sequential AblationThe sequential algorithm requires intermediate states to exist by a specific layer, but also never needs them again after that layer. As an additional proof of the sequential algorithm, I zeroed all residuals after the critical layer. In both models, the final answer remains at ~100% accuracy, confirming that the residual stream at these positions after the critical layer (causally disconnected from the sequential model) is not needed by either model.I zeroed the critical layer itself, and found that it destroys the accuracy of model A (sequential) but not model B. We can also see in the higher accuracy after ablating the critical layer for t[1] that the model is able to recover, and that the model seems to be using a different algorithm for t[5].By proving that model A does need the critical layer but doesn’t need anything after, we can be sure that it’s using an algorithm which uses each position sequentially.Lessons and LimitationsNeural network optimization pressure strongly disfavors sequential algorithms.Universal transformers are somewhat better at learning sequential algorithms.It’s much easier to make a model learn a sequential algorithm if there are no shortcuts, although even then it might just memorize a composed lookup table.Data distribution does affect how likely a model is to learn a sequential algorithm.^Comparing standard and universal transformers happens to be the subject of the experiment I was trying to run, so this was actually the first thing I tried, just by coincidence.^The sequential algorithm for k=6 can be learned reliably in as few as 7 layers, but I had better data at-hand for 8 layers, so that’s what I’m using here.^I also looked at the position before the <op> but consistently saw no signal. For whatever reason, all of the models I tested did their intermediate calculations at <op> positions.^Sort of. 26% is only barely above chance (17%).Discuss Read More
