I feel like I'm missing a key insight here. I understand the problem that regular softmax attention struggles to approach assigning zero attention to irrelevant stuff. And I get that having this subtraction formula makes it possible to assign exactly (or near) zero attention weight without having crazy outlier activations. But it seems like it also makes it very easy to have negative attention weight (which is equivalent to having positive attention weight on the negation of your value vectors). Intuitively, it just feels like a difficult balancing act to keep all the stuff you don't care about so close to zero.
But Figure 1 clearly shows that it works, so I don't doubt that it is in fact possible. I'm just struggling to build a picture of how exactly the network accomplishes this.
that silly softmax1 blog post is not worth the read. no one uses it in practice
if you think about it, the "escape hatch" is the design of the entire transformer dictionary. if Key/Query attention misaligns with Value's weights, you get a layer head that does not attend to anything...
Yep. From what I've seen, if the head wants to do nothing, it can attend to itself = no inter-token communication.
Still, differential attention is pretty interesting & the benchmarking good, seems worth a try! It's in the same vein as linear or non-softmax attention, which also can work.
Note that there is an error below Eq. 1: W^V should be shape [d_model x d_model] not [d_model, 2*d_model] as in the Q, K matrices.
Idea: why not replace the lambda parameterization between softmax operations with something more general, like a matrix or MLP? E.g: Attention is the affine combination of N softmax attention operations (say, across heads). If the transformer learns an identity matrix here, then you know the original formulation was correct for the data; if it's sparse, these guys were right; if it's something else entirely then who knows...
I've tried that in a small transformer that I trained from scratch and it didn't really make any difference. I also made a version where I made this trainable somehow, probably by replacing the 1 with a constant associated with the layer, and that didn't make any difference either.
I didn't follow Miller's proposal quite as he wrote it though and I put the mechanism in all the layers rather than avoiding it at the end.
My test doesn't absolutely rule out usefulness-- there's always different ways of applying something, but I saw no indication of it.
You referring to Miller's blogpost?[0] There's not an error in attention. Adding the +1 actually makes it not attention because you no longer generate a probability distribution[1]. There's nothing really preventing attention to have a zero in any of the entries, the thing is that you probably won't get -inf (very large negative number) inside inner product and you're going to have a difficult time updating those weights via gradient descent.
I've also tested it on many networks and different types of attention and I've yet to see a meaningful improvement (or even an improvement), even in generalization.
It really is the training method...
As to the paper, I'm also still at a big lost and honestly, if reviewing could not accept it. The results look good, but I can't tell why and there's some "black magic" going on here.
- Figure 3 has "Transformer" and doesn't specify. Is this StableLM-3B-4E1T?
- What fucking dataset is this on? Stable has a WandB link[2] for that project and I don't see any experiment with similar (presumably entropy?) loss values (come on... this is fucking research... label your fucking graphs...)
- Where the fuck is the ablation? (Yes, I saw Fig 6 and Sec 3.8)
- How do I know that (assuming this is Stable) that the difference isn't just hyperparemeters? Or worse, GPUs! (yes, number of GPUs can change results due to sharding and this changing the statistics)
- How do I know it isn't down to 1k warmup steps instead of 5k?
- What about hidden size, layers, heads, or FFN size? Stable has 32/2560/32/? and this has 28/3072/12/8192 (these all will mess with sharding statistics too). Is the head dimension the same?
- How do I know it isn't down to the tokenizer?
- What is this magic? `0.8 - 0.6 * math.exp(-0.3 * depth)`
- Was this learned? Hand picked? This is a huge factor
- Any information about the learned parameters? Their final values? Trajectories?
- The code does not seem to be the same as whats in the algos...
Obviously they improved something, but there is nothing in the paper that is convincing me that it is the differential attention. There are too many parameters at play and how am I supposed to know that the difference is by the thing they are proposing. And more importantly, how much it is improved by that specific thing and not by other things.
[0] https://www.evanmiller.org/attention-is-off-by-one.html
[1] This is a bit convoluted but without this condition many "alternative forms" you see would be equivalent to other architectures like linear layers or gated units. Term is not well defined, but this really appears to be the only agreed upon aspect, even if only implicitly stated. This is a much longer conversation though.
[2] https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo
[2.1] The config: https://github.com/Stability-AI/StableLM/blob/main/configs/stablelm-3b-4e1t.yml
I feel like that blogpost was almost just ragebait for ai researchers. It goes between calling not including the +1 an error (which to me implies it would improve training losses, which it doesn't really https://news.ycombinator.com/item?id=36854613) and saying possibly it could help with some types of quantization (which could very well be true but is a much weaker statement) and the author provides basically no evidence for either.
It's the stereotypical computer scientist who thinks they know something others don't and don't feel the need to prove their claim. Specifically when it disagrees with experts. And unsurprisingly it's been something others have already investigated and even written about. Definitely not all CS people, but it is a stereotype many other fields believe.
I know he's an economist btw. I was also surprised he got a job at anthropic a few months after. I wonder if they're related.
Haven't gone through the paper fully, but just looking at the functional form of their attention, it seems more like a constraint on a standard MHA than an architectural discovery.
Take a vanilla MHA, tie the V projection between consecutive heads, make the output projection subtract consecutive heads, with some fixed prefactor and voila, you're most if not all of the way there.
It does sound like we're hindering the model a bit by allowing negative weights to exist instead of sending them through, say, a ReLU. But, dealing with this might be an easier problem than you think for the model.
In the first diagram with the attention weights, there actually are some negative scores in the noise section. But, the attention to that section is very small anyway. All the second attention map needs to do is predict the noise in the first one -- a task that can be done very accurately, because it has full access to the input of the first.
To refer back to their real-world comparison, noise-canceling headphones have access to what your ear hears through a microphone, so they can output exactly the right cancellation signal. Similarly, the second attention map knows what's being input into the first one, so it can output a corresponding cancellation signal. It's not perfect -- just as noise-canceling headphones aren't perfect -- but it still gets you 99% of the way there, which is enough to boost performance.
>I'm just struggling to build a picture of how exactly the network accomplishes this.
I mean, intuitively it would be trivial for the model to just optimise lambda to zero during training. Then you essentially have built a vanilla transformer with an overcomplicated parameter pruning mechanism. Pruning is already pretty well established in the literature as something that works surprisingly good for reducing parameter counts up to (hold on to your papers)... about 40%. In practice the model probably doesn't work exactly like that, but I wouldn't be surprised if it just approximates the normal transformer in the end anyways.
But Figure 1 clearly shows that it works, so I don't doubt that it is in fact possible. I'm just struggling to build a picture of how exactly the network accomplishes this.