-
-
Notifications
You must be signed in to change notification settings - Fork 910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert-diff-transformer CLI command / codepath #2197
base: main
Are you sure you want to change the base?
Conversation
f2c37e7
to
2717b97
Compare
I thought differential transformer requires model architecture change and modeling code change? Does this somehow automatically implement a modeling.py for the model? |
Good question. I've implemented a monkeypatch in As for the architecture change, we have |
adc024b
to
938b627
Compare
monkey patch only works in the context of Axolotl - we will need a modeling.py to make inference work properly in the wild (transformers, TGI, vllm, etc) right? (If I understand correctly) |
938b627
to
7d9ec2c
Compare
aaf6ac2
to
ead99be
Compare
I have another question, don't be offended, it's just a question. Does this really belong as a feature of axolotl, as opposed to perhaps mergekit? or even a standalone project? Just that, axolotl is pretty focused on training - not on modifying model architecture, where mergekit focuses on modifying model architecture. |
I updated the PR to include a |
Thanks for the question, definitely does not offend! I think the value-add of this PR is moreso in the conversion logic than in the attention implementation itself (especially given the aforementioned HF PR), in that we can see users leveraging it to improve the performance of existing models sans differential attention / do some experimentation with models along these lines. @winglian Care to weigh in here? |
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <[email protected]>
ae2b885
to
4f804f6
Compare
Description
This PR implements the differential attention layer from the Differential Transformer paper.
Motivation and Context
We wanted to add this attention implementation to
axolotl
so users can swap out the existing attention layers in their models for this more performant version. We matched the official implementation details as closely as possible, while adopting it to play nicely with thetransformers
attention implementations.Since we were focused on being able to convert existing LLMs to having these differential attention layers, we wanted a way to not degrade the performance of the (possibly pre-trained) LLM while doing so.
To this end, the conversion process doubles the dimensionality of the query and key projections (since the differential attention requires both a positive and negative component of the attention) and (optionally; pass
--zero-init
) initializes the weights of the negative component to zero, while copying over the weights from the original attention modules to the positive components.When doing this, the converted network computes the same function as the original (pass
--debug
to confirm this), but may suffer from a vanishing gradient problem. The default behavior is thus to initialize the weights of the negative components of the differential attention layers to 0-centered normally distributed values with a small variance.Relevant links:
How has this been tested?
SmolLM2-135m on A40 Runpod instance on this feature branch. Workflow was:
--zero-init
and--debug
flags for sanity checking exact model conversion (completions, logits, losses)axolotl evaluate
command on the smallmhenrichsen/alpaca_2k_test
dataset with both the original and converted model and check that their evaluation metrics matchFor example:
Types of changes
axolotl.integrations.diff_transformer
module, which implements the differential attention layers for the Llama LLM architecture and for various attention implementations (eager, SDPA, Flash Attention 2), andaxolotl.cli.integrations.convert_diff_transformer
module (and updates toaxolotl.cli.main
), which implements theconvert-diff-transformer
CLI command, andaxolotl.cli.integrations.convert_diff_transformer.patches
(to be moved) for updatingLLAMA_ATTENTION_CLASSES
constant intransformers.models.llama.modeling_llama
.TODO