Opinion

Neural Networks learn Bloom Filters

​Overview:We train a tiny ReLU network to output sparse top- distributions over a vocabulary much larger than its residual dimension. The trained network seems to converge to a mechanism closely resembling a Bloom filter: tokens are assigned sparse binary hashes, the hidden layer computes an approximate union indicator, and the output logits are linearly read from this union.Here’s what a small network trained on a toy version of the sparse top- distribution task learns to use:Weight matrix of a 1-layer ReLU network trained via gradient descent on the toy -sparse distribution task below, for , , . Truncated at first tokens for visualisation purposes.Plot of the range of values of , it forms a bimodal distribution.That’s the input weight matrix of the trained network. Every entry is either or . The network has effectively encoded a binary hash for each token – and as we’ll show, this seems to enable the network to approximately simulate a Bloom filter, and so output the correct set of top- tokens with high probability.We provide a theoretical construction showing how to set the weights to exactly implement a Bloom filter. The real network seems to learn to do something similar, and seems to behaviourally act like a Bloom filter, but while we provide a fair bit of mechanistic evidence, we don’t yet have a complete mechanistic explanation of the trained network.Additionally, this is just a toy network, so it doesn’t directly tell us about what larger models might do. But I found the discreteness of the learned algorithm to be very interesting.It seems to learn a probabilistic solution which scales like in width, where is the MSE loss. In contrast, a standard JL style superposition solution would have error scaling polynomially in with respect to width.The Task:As a result of reading The Softmax Bottleneck, I have gotten quite interested in what mechanisms toy MLP networks learn to output sparse probability distributions. I have tried to devise a toy problem to address this question.We sample token indices uniformly at random without replacement in , and uniformly from for . We give as input to a -layer ReLU neural network (where represents an embedding of the th token, which we are allowed to choose). Let be the probability distribution over which assigns probability mass to each of the tokens , and probability mass to the remaining tokens. The task is to minimise the expected KL divergence between and the distribution , where is the final residual stream, and . Where the expectation is taken over random samples of token indices and token weights .We vary because we are trying to simulate a scenario where the uniform distribution is being produced in parallel with a distribution specific solution giving logit adjustments . It turns out that once we’ve got a solution that can produce the uniform distribution with no error, we can scale up the uniform distribution an outrageous amount, and then we can make the logit adjustments using a width that doesn’t depend on . can be relaxed to if we want, it just makes it less clean to discuss the network and there are weird edge cases that would never actually happen that have to be discussed, so we avoid it.Construction:The construction below is my hypothesis for what the -layer ReLU network I trained on the task above is approximately doing. I had Claude prepare an animation to visualise the construction. The initial residual stream is given as a weighted sum of the hashes of the top- tokens . The MLP learns to turn this residual stream into a discrete indicator set for the union of the hashes. The output logits fall into discrete categories, depending on the overlap of their hash with the union of the hashes of the top- tokens. The discreteness of the logits allows exactly uniform probability to be assigned to the top- tokens, with rare Bloom filter false positives when the hash of a non top- token happens to lie inside the union of the hashes of the top- tokens.Formal construction:We use a -layer neural network with ReLU activation, , and a residual connection.Fix with , with , and d .For each , sample a uniformly random subset of size of . From the Bloom filter perspective, represents the independent hashes of token .For each set if , else .Set , . Add a uniform negative bias to each hidden layer neuron.Set .Analysis of a single forward pass:The residual stream initially is given by an input vector , which has support (here we use the assumption that ). By construction, the th component of the output from the MLP layer to the residual stream is given by . This is given by if , and otherwise.Then the final residual stream is given by if , and otherwise. is set such that the th logit is given by , with equality if and only if . The indices therefore each attain the logit . If no other tokens attain this logit, we will have assigned uniform probability to the top- tokens, and lower probability to the remaining tokens.So we want to have with high probability over uniform random choice of that ” if and only if .”This is precisely the same situation as that of a Bloom filter (Youtube video that explains them better than I could), also see the wikipedia article.The standard bloom filter analysis applies and we get scaling of for the residual stream width, i.e: to get success with probability, and then failure (false positive bloom filter match) otherwise.Interestingly the false positive case just looks like another token being included in the top- set. So it’s a quite robust solution. Bloom filters never have false negatives, so if a token is in the top- then it will definitely be among the tokens given the maximum tier of logit.Training:I trained a -layer residual ReLU network on the top- distribution task using online random sampling of token sets. The model used vocabulary size , residual width , hidden width , and . For each training example, a uniformly random subset of distinct tokens was sampled together with random positive coefficients . The target distribution assigned probability mass uniformly across the selected tokens and zero elsewhere. The network was trained for optimisation steps with batch size using AdamW and a cosine learning-rate schedule beginning at . Training was performed in bfloat precision. The final model achieved total variation distance on held-out random samples, assigning nearly all probability mass to the target top- set while maintaining low leakage onto non-target tokens.Behavioural analysis of the trained network:We draw sets of tokens uniformly without replacement from , and use these as our test top- token sets as input to our trained network.Probability mass assigned to the top- tokens across random token sets drawn from without replacement. Note the discreteness of the two clusters separated at probability.PDF of probability assigned to the top- tokens across samples.So we have a distinct rarely occurring cluster where the top- probability mass falls below . A typical top- output distribution for a token set in the typical region looks as follows:Distribution of top- output probabilities on a typical set of top- tokens. Nearly uniform probability assigned to the top- tokens.After samples of random token sets, had false positives, as defined by having probability mass below on the top- tokens (boundary of the bottom probability mass cluster above). Below is a random sample from that set of false positives:Distribution of top- output probabilities when there is a single false positive leading to tokens sharing probability approximately uniformly. Green tokens are those in the ground truth top- token set, and red are those not in the top- token set.You can see how the network fairly robustly handles the false positive. It doesn’t significantly disrupt the top- token set. If the false positive weren’t highlighted, we wouldn’t be able to tell which token among the top- it was, which is plausibly due to the discrete nature of the logits.Interestingly, we can also provide tokens in the input, and the model will handle it and give uniform probability among the set of given tokens. So is more like a soft limit to the number of tokens the hashes are designed to be able to handle without breaking down. Conditioning on there being even a single false positive, we expect the union of hashes to be large, which means we should expect other tokens to have large intersection with the union. We see this in the above figure, with significant probability mass going to tokens , , and as well.The hash of token (defined by the bimodal entries of ) is covered by the union of the hashes of the top- token set, so it triggers a false positive.The hash of token is nearly covered by the hashes of the top- token set, with a single element of the hash not covered. We see in the above that it gets about % of the probability mass as a result.Ditto for token Token As a baseline, here is Token on the same token set:In a token set iteration I performed, none of them had probability above assigned to a token where that token was not fully covered by the union of the hashes of the top- tokens. It’s quite interesting how similar in probability the near misses above are, suggesting that there is a discrete probability assigned at each intersection size with the union of the hashes (indeed on other samples we get similar probabilities for near misses.)Mechanistic analysis of the trained network:Now we have seen that the model seems to behave like a Bloom filter, we give a partial mechanistic analysis of how the trained weights help to implement the filter.Learned matrix (truncated at first tokens), when training via gradient descent on the top-k task. It is seemingly completely random, with a bimodal distribution of values.Distribution of values in in the trained network.The fact that the positive values vary in magnitude (in contrast to our construction, which uses a constant value for the positive values) doesn’t change the story very much. Remember we are primarily interested in extracting an indicator set of a union of token hashes, and so as long as each hidden layer neuron fires if and only if it is contained in the union of the token hashes, this story is ok.Learned matrix (truncated at first tokens), when training on the top-k task. The mask of negative values is the transpose of the mask of positive values for the matrix (the raw values however are not transposes of each other.)Distribution of values in in the trained network.Since both and seem to learn bimodal sets of values, we can compare the two masks.Perfect match between the transpose of one mask and the other mask.So learns a mask which is the transpose of the mask of , analogously to our construction (Where .) To connect with our mechanistic story further, we linearly transform the residual stream of the network, aiming to look like the one in Claude’s animation.If the trained network were acting as our construction expects it to, then would provide the change of basis matrix from the privileged hidden layer basis (Claude’s basis :)) to the network’s internal residual stream basis. So we look at , where is the residual stream, both before and after the hidden layer, for a sample input.. We arrange the vector of dimension in a grid, but the rows and columns have no significance.Distribution of values of is defined as the initial state of the residual stream. corresponds to the first residual stream in Claude’s animation.For this input, a value is dark blue iff it doesn’t belong to the union of the “hashes” (as given by the mask of ) of the input tokens. This has been the case in all other inputs I’ve tried, though I have not rigorously proven this about the trained network. Note that, just as in Claude’s animation, among the dimensions contained in the union of the hashes, there is a wide range of possible values that are taken on. We have values ranging from up to among those contained in the union of the hashes. This is because hashes overlap at certain positions and so combine values, as well as due to variation in .There are discrete clusters in the distribution of values, which seems to correspond to the number of intersecting hashes at each position.The job of the hidden layer seems to be to collapse this range in values to produce an indicator of the union of hashes. on same input as above.This corresponds to the second residual stream in Claude’s animation. is the residual stream directly before applying the unembedding matrix .Distribution of values of The hash values which in the initial residual stream were significantly above the others have been clamped, and the union of the hashes now lies in a fairly narrow range.To get from to the final logits, we apply , which is given as before. So the th logit is approximately a dot product of the clamped union of hashes with (where is determined by ). There can be a wide variance in the raw values of the hash containing matrix as long as the sum of the hashes along each row (each corresponding to a token) weighted by the expected corresponding entries of is around the same across tokens.Values of , where of token is defined as the hash determined by the weight matrix . We use Pytorch notation here, with being viewed as a tensor of indices.The above is the dot product of the large positive terms in the th row of with the terms we expect at at the final residual stream, supposing that the th token is among the top-. This is not the exact correct thing to be looking at I don’t think, because there is also noise terms from the terms outside of , but it should be the leading term, and it tells us that the top- tokens end up with about the same logit contribution . But logits is pretty big and it seems to underestimate the performance of the model, suggesting there is something missing from our mechanistic analysis. Nonetheless, since there is a gap of about logits between a hash that completely matches, and a hash that fails to match in just a single position (see the value plot of above), our mechanistic analysis is fine grained enough to be able to argue why say false positives don’t occur with high probability, even if we can’t get precise enough to argue it’s provably a Bloom filter. We aren’t fine grained enough to argue why we get so close to a uniform distribution over the top- tokens, unfortunately :(. So our mechanistic explanation is incomplete here. on a pathological token set. Again with the dimensions arranged in a grid with no significance to the rows or columns.What’s interesting about the above figure is that it’s pathological because the union of hashes is large. The union of the hashes being large means that it’s more likely for the hashes of non top- tokens to happen to lie inside it. But the way that you get a large union of hashes is by having top- hashes that don’t intersect. In other words, on any given forward pass, we want top- hashes to intersect as much as they can (globally we don’t want this, though.)Which is the opposite intuition from usual, where we want to minimise interference between features!Conclusion / Reflections:I find multiple things interesting about the solution to the task:In contrast to other approaches that have been discussed for neural networks handling many more features than they have dimensions in the residual stream, the solution does not rely on with 100% probability having a bounded amount of error . Instead, it relies on having error with probability, and significant error with probability. Bloom filters are fundamentally probabilistic structures, and with probability they result in false positives (this translates to an erroneous token included in the top token set).The scaling is a lot better behaved with this probabilistic approach. We get MSE loss by scaling the width of the network by , whereas traditional JL style superposition would scale the width polynomially in .Because of the form of our solution to the task, the harder task of specifying an arbitrary conditional distribution over the top- tokens (say as ) with probability can be reduced to the uniform case, with the size of the reduction not depending on . So the uniform distribution task is tracking the residual stream width required to output sparse probability distributions over large vocabularies. Indeed training a -layer network to accomplish the harder task of specifying an arbitrary conditional distribution empirically results in a hidden layer identical to the uniform task.It’s always been a bit of a mystery to me how LLMs like GPT-2 with can learn such accurate probability distributions, and this Bloom filter approach significantly reduces the mystery for me (I’m not saying that LLMs use Bloom filters, there is no evidence for this, I’m saying I’m no longer surprised that it is feasible).More broadly, i’m working on probabilistic methods for performing computation in superposition. This was a fun toy case I found for showcasing the advantages of using a probabilistic approach. I view the result on -layer networks as an initial piece of evidence that it’s not absurd to hypothesise mechanisms which rely on probabilistic approaches as opposed to purely continuous ones.Related work:Prior work – The Softmax Bottleneck Does Not Limit the Probabilities of the Most Likely Tokens – shows that with probability when , given a random Gaussian matrix and a uniformly randomly drawn subset of size , you can find a vector such that is constant on , and such that if and . They find just by finding the unique solution to lying inside the row space of (which is given in closed form). The substance of the analysis is that with high probability this solution has higher logits on than any of the surrounding tokens in (though their proof uses some simplifying assumptions, but empirically it works so it’s fine). By scaling , we can get arbitrarily close to giving probability to each of the tokens in .I don’t think it’s mechanistically plausible that trained networks actually learn to find the used in that paper’s proof (and the paper isn’t claiming it is). It involves taking the precise inverse of an input-dependent matrix. However, it did provide some inspiration for the construction given in this post.Further work:Is it possible to fully understand the mechanism behind the -layer model presented here? I’m personally satisfied that the network is implementing a Bloom filter, but can we prove it? Intuitively it should be easier than a lot of Mech Interp because the hidden layer is doing a quite simple operation.What is the scaling for the density of the hashes that the model learns? How does this relate to theoretical Bloom filter density?What mechanisms do real LLMs learn to output distributions? Modern LLMs increasingly have of similar size to , so in theory they could use a single dimension per logit. But there’s still incentives for them to use the fewest dimensions they can for producing the output distribution. Discuss ​Read More

Neural Networks learn Bloom Filters

​Overview:We train a tiny ReLU network to output sparse top- distributions over a vocabulary much larger than its residual dimension. The trained network seems to converge to a mechanism closely resembling a Bloom filter: tokens are assigned sparse binary hashes, the hidden layer computes an approximate union indicator, and the output logits are linearly read from this union.Here’s what a small network trained on a toy version of the sparse top- distribution task learns to use:Weight matrix of a 1-layer ReLU network trained via gradient descent on the toy -sparse distribution task below, for , , . Truncated at first tokens for visualisation purposes.Plot of the range of values of , it forms a bimodal distribution.That’s the input weight matrix of the trained network. Every entry is either or . The network has effectively encoded a binary hash for each token – and as we’ll show, this seems to enable the network to approximately simulate a Bloom filter, and so output the correct set of top- tokens with high probability.We provide a theoretical construction showing how to set the weights to exactly implement a Bloom filter. The real network seems to learn to do something similar, and seems to behaviourally act like a Bloom filter, but while we provide a fair bit of mechanistic evidence, we don’t yet have a complete mechanistic explanation of the trained network.Additionally, this is just a toy network, so it doesn’t directly tell us about what larger models might do. But I found the discreteness of the learned algorithm to be very interesting.It seems to learn a probabilistic solution which scales like in width, where is the MSE loss. In contrast, a standard JL style superposition solution would have error scaling polynomially in with respect to width.The Task:As a result of reading The Softmax Bottleneck, I have gotten quite interested in what mechanisms toy MLP networks learn to output sparse probability distributions. I have tried to devise a toy problem to address this question.We sample token indices uniformly at random without replacement in , and uniformly from for . We give as input to a -layer ReLU neural network (where represents an embedding of the th token, which we are allowed to choose). Let be the probability distribution over which assigns probability mass to each of the tokens , and probability mass to the remaining tokens. The task is to minimise the expected KL divergence between and the distribution , where is the final residual stream, and . Where the expectation is taken over random samples of token indices and token weights .We vary because we are trying to simulate a scenario where the uniform distribution is being produced in parallel with a distribution specific solution giving logit adjustments . It turns out that once we’ve got a solution that can produce the uniform distribution with no error, we can scale up the uniform distribution an outrageous amount, and then we can make the logit adjustments using a width that doesn’t depend on . can be relaxed to if we want, it just makes it less clean to discuss the network and there are weird edge cases that would never actually happen that have to be discussed, so we avoid it.Construction:The construction below is my hypothesis for what the -layer ReLU network I trained on the task above is approximately doing. I had Claude prepare an animation to visualise the construction. The initial residual stream is given as a weighted sum of the hashes of the top- tokens . The MLP learns to turn this residual stream into a discrete indicator set for the union of the hashes. The output logits fall into discrete categories, depending on the overlap of their hash with the union of the hashes of the top- tokens. The discreteness of the logits allows exactly uniform probability to be assigned to the top- tokens, with rare Bloom filter false positives when the hash of a non top- token happens to lie inside the union of the hashes of the top- tokens.Formal construction:We use a -layer neural network with ReLU activation, , and a residual connection.Fix with , with , and d .For each , sample a uniformly random subset of size of . From the Bloom filter perspective, represents the independent hashes of token .For each set if , else .Set , . Add a uniform negative bias to each hidden layer neuron.Set .Analysis of a single forward pass:The residual stream initially is given by an input vector , which has support (here we use the assumption that ). By construction, the th component of the output from the MLP layer to the residual stream is given by . This is given by if , and otherwise.Then the final residual stream is given by if , and otherwise. is set such that the th logit is given by , with equality if and only if . The indices therefore each attain the logit . If no other tokens attain this logit, we will have assigned uniform probability to the top- tokens, and lower probability to the remaining tokens.So we want to have with high probability over uniform random choice of that ” if and only if .”This is precisely the same situation as that of a Bloom filter (Youtube video that explains them better than I could), also see the wikipedia article.The standard bloom filter analysis applies and we get scaling of for the residual stream width, i.e: to get success with probability, and then failure (false positive bloom filter match) otherwise.Interestingly the false positive case just looks like another token being included in the top- set. So it’s a quite robust solution. Bloom filters never have false negatives, so if a token is in the top- then it will definitely be among the tokens given the maximum tier of logit.Training:I trained a -layer residual ReLU network on the top- distribution task using online random sampling of token sets. The model used vocabulary size , residual width , hidden width , and . For each training example, a uniformly random subset of distinct tokens was sampled together with random positive coefficients . The target distribution assigned probability mass uniformly across the selected tokens and zero elsewhere. The network was trained for optimisation steps with batch size using AdamW and a cosine learning-rate schedule beginning at . Training was performed in bfloat precision. The final model achieved total variation distance on held-out random samples, assigning nearly all probability mass to the target top- set while maintaining low leakage onto non-target tokens.Behavioural analysis of the trained network:We draw sets of tokens uniformly without replacement from , and use these as our test top- token sets as input to our trained network.Probability mass assigned to the top- tokens across random token sets drawn from without replacement. Note the discreteness of the two clusters separated at probability.PDF of probability assigned to the top- tokens across samples.So we have a distinct rarely occurring cluster where the top- probability mass falls below . A typical top- output distribution for a token set in the typical region looks as follows:Distribution of top- output probabilities on a typical set of top- tokens. Nearly uniform probability assigned to the top- tokens.After samples of random token sets, had false positives, as defined by having probability mass below on the top- tokens (boundary of the bottom probability mass cluster above). Below is a random sample from that set of false positives:Distribution of top- output probabilities when there is a single false positive leading to tokens sharing probability approximately uniformly. Green tokens are those in the ground truth top- token set, and red are those not in the top- token set.You can see how the network fairly robustly handles the false positive. It doesn’t significantly disrupt the top- token set. If the false positive weren’t highlighted, we wouldn’t be able to tell which token among the top- it was, which is plausibly due to the discrete nature of the logits.Interestingly, we can also provide tokens in the input, and the model will handle it and give uniform probability among the set of given tokens. So is more like a soft limit to the number of tokens the hashes are designed to be able to handle without breaking down. Conditioning on there being even a single false positive, we expect the union of hashes to be large, which means we should expect other tokens to have large intersection with the union. We see this in the above figure, with significant probability mass going to tokens , , and as well.The hash of token (defined by the bimodal entries of ) is covered by the union of the hashes of the top- token set, so it triggers a false positive.The hash of token is nearly covered by the hashes of the top- token set, with a single element of the hash not covered. We see in the above that it gets about % of the probability mass as a result.Ditto for token Token As a baseline, here is Token on the same token set:In a token set iteration I performed, none of them had probability above assigned to a token where that token was not fully covered by the union of the hashes of the top- tokens. It’s quite interesting how similar in probability the near misses above are, suggesting that there is a discrete probability assigned at each intersection size with the union of the hashes (indeed on other samples we get similar probabilities for near misses.)Mechanistic analysis of the trained network:Now we have seen that the model seems to behave like a Bloom filter, we give a partial mechanistic analysis of how the trained weights help to implement the filter.Learned matrix (truncated at first tokens), when training via gradient descent on the top-k task. It is seemingly completely random, with a bimodal distribution of values.Distribution of values in in the trained network.The fact that the positive values vary in magnitude (in contrast to our construction, which uses a constant value for the positive values) doesn’t change the story very much. Remember we are primarily interested in extracting an indicator set of a union of token hashes, and so as long as each hidden layer neuron fires if and only if it is contained in the union of the token hashes, this story is ok.Learned matrix (truncated at first tokens), when training on the top-k task. The mask of negative values is the transpose of the mask of positive values for the matrix (the raw values however are not transposes of each other.)Distribution of values in in the trained network.Since both and seem to learn bimodal sets of values, we can compare the two masks.Perfect match between the transpose of one mask and the other mask.So learns a mask which is the transpose of the mask of , analogously to our construction (Where .) To connect with our mechanistic story further, we linearly transform the residual stream of the network, aiming to look like the one in Claude’s animation.If the trained network were acting as our construction expects it to, then would provide the change of basis matrix from the privileged hidden layer basis (Claude’s basis :)) to the network’s internal residual stream basis. So we look at , where is the residual stream, both before and after the hidden layer, for a sample input.. We arrange the vector of dimension in a grid, but the rows and columns have no significance.Distribution of values of is defined as the initial state of the residual stream. corresponds to the first residual stream in Claude’s animation.For this input, a value is dark blue iff it doesn’t belong to the union of the “hashes” (as given by the mask of ) of the input tokens. This has been the case in all other inputs I’ve tried, though I have not rigorously proven this about the trained network. Note that, just as in Claude’s animation, among the dimensions contained in the union of the hashes, there is a wide range of possible values that are taken on. We have values ranging from up to among those contained in the union of the hashes. This is because hashes overlap at certain positions and so combine values, as well as due to variation in .There are discrete clusters in the distribution of values, which seems to correspond to the number of intersecting hashes at each position.The job of the hidden layer seems to be to collapse this range in values to produce an indicator of the union of hashes. on same input as above.This corresponds to the second residual stream in Claude’s animation. is the residual stream directly before applying the unembedding matrix .Distribution of values of The hash values which in the initial residual stream were significantly above the others have been clamped, and the union of the hashes now lies in a fairly narrow range.To get from to the final logits, we apply , which is given as before. So the th logit is approximately a dot product of the clamped union of hashes with (where is determined by ). There can be a wide variance in the raw values of the hash containing matrix as long as the sum of the hashes along each row (each corresponding to a token) weighted by the expected corresponding entries of is around the same across tokens.Values of , where of token is defined as the hash determined by the weight matrix . We use Pytorch notation here, with being viewed as a tensor of indices.The above is the dot product of the large positive terms in the th row of with the terms we expect at at the final residual stream, supposing that the th token is among the top-. This is not the exact correct thing to be looking at I don’t think, because there is also noise terms from the terms outside of , but it should be the leading term, and it tells us that the top- tokens end up with about the same logit contribution . But logits is pretty big and it seems to underestimate the performance of the model, suggesting there is something missing from our mechanistic analysis. Nonetheless, since there is a gap of about logits between a hash that completely matches, and a hash that fails to match in just a single position (see the value plot of above), our mechanistic analysis is fine grained enough to be able to argue why say false positives don’t occur with high probability, even if we can’t get precise enough to argue it’s provably a Bloom filter. We aren’t fine grained enough to argue why we get so close to a uniform distribution over the top- tokens, unfortunately :(. So our mechanistic explanation is incomplete here. on a pathological token set. Again with the dimensions arranged in a grid with no significance to the rows or columns.What’s interesting about the above figure is that it’s pathological because the union of hashes is large. The union of the hashes being large means that it’s more likely for the hashes of non top- tokens to happen to lie inside it. But the way that you get a large union of hashes is by having top- hashes that don’t intersect. In other words, on any given forward pass, we want top- hashes to intersect as much as they can (globally we don’t want this, though.)Which is the opposite intuition from usual, where we want to minimise interference between features!Conclusion / Reflections:I find multiple things interesting about the solution to the task:In contrast to other approaches that have been discussed for neural networks handling many more features than they have dimensions in the residual stream, the solution does not rely on with 100% probability having a bounded amount of error . Instead, it relies on having error with probability, and significant error with probability. Bloom filters are fundamentally probabilistic structures, and with probability they result in false positives (this translates to an erroneous token included in the top token set).The scaling is a lot better behaved with this probabilistic approach. We get MSE loss by scaling the width of the network by , whereas traditional JL style superposition would scale the width polynomially in .Because of the form of our solution to the task, the harder task of specifying an arbitrary conditional distribution over the top- tokens (say as ) with probability can be reduced to the uniform case, with the size of the reduction not depending on . So the uniform distribution task is tracking the residual stream width required to output sparse probability distributions over large vocabularies. Indeed training a -layer network to accomplish the harder task of specifying an arbitrary conditional distribution empirically results in a hidden layer identical to the uniform task.It’s always been a bit of a mystery to me how LLMs like GPT-2 with can learn such accurate probability distributions, and this Bloom filter approach significantly reduces the mystery for me (I’m not saying that LLMs use Bloom filters, there is no evidence for this, I’m saying I’m no longer surprised that it is feasible).More broadly, i’m working on probabilistic methods for performing computation in superposition. This was a fun toy case I found for showcasing the advantages of using a probabilistic approach. I view the result on -layer networks as an initial piece of evidence that it’s not absurd to hypothesise mechanisms which rely on probabilistic approaches as opposed to purely continuous ones.Related work:Prior work – The Softmax Bottleneck Does Not Limit the Probabilities of the Most Likely Tokens – shows that with probability when , given a random Gaussian matrix and a uniformly randomly drawn subset of size , you can find a vector such that is constant on , and such that if and . They find just by finding the unique solution to lying inside the row space of (which is given in closed form). The substance of the analysis is that with high probability this solution has higher logits on than any of the surrounding tokens in (though their proof uses some simplifying assumptions, but empirically it works so it’s fine). By scaling , we can get arbitrarily close to giving probability to each of the tokens in .I don’t think it’s mechanistically plausible that trained networks actually learn to find the used in that paper’s proof (and the paper isn’t claiming it is). It involves taking the precise inverse of an input-dependent matrix. However, it did provide some inspiration for the construction given in this post.Further work:Is it possible to fully understand the mechanism behind the -layer model presented here? I’m personally satisfied that the network is implementing a Bloom filter, but can we prove it? Intuitively it should be easier than a lot of Mech Interp because the hidden layer is doing a quite simple operation.What is the scaling for the density of the hashes that the model learns? How does this relate to theoretical Bloom filter density?What mechanisms do real LLMs learn to output distributions? Modern LLMs increasingly have of similar size to , so in theory they could use a single dimension per logit. But there’s still incentives for them to use the fewest dimensions they can for producing the output distribution. Discuss ​Read More

Leave a Reply

Your email address will not be published. Required fields are marked *