Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral Nemo LoRA training has super high grad_norm #2095

Open
6 of 8 tasks
Nero10578 opened this issue Nov 21, 2024 · 7 comments
Open
6 of 8 tasks

Mistral Nemo LoRA training has super high grad_norm #2095

Nero10578 opened this issue Nov 21, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@Nero10578
Copy link
Contributor

Nero10578 commented Nov 21, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Before the gradient accumulation fixes and changes with transformers recently, the grad_norm when training Mistral Nemo 12B was below 1.0 like normal. Could also be because of changes to using chat_templates?

This was using the same config with previous versions of axolotl and transformers:
W B Chart 11_20_2024, 4_17_25 PM

Current behaviour

Gradient Normalization is now around 5 when training:
W B Chart 11_20_2024, 4_16_17 PM

Steps to reproduce

Train Mistral Nemo 12B Instruct with LoRA. I used the same config as I did back then when this works fine.

Only difference is now I am using chat_templates, where I replace the chat template in the Mistral tokenizer_config.json with the chat template shown here so that it can accept repeating same roles.

{%- for message in messages %}
      {%- if message['role'] == 'system' -%}
          {{- message['content'] -}}
      {%- else -%}
          {%- if message['role'] == 'user' -%}
              {{-'[INST] ' + message['content'].rstrip() + ' [/INST]'-}}
          {%- else -%}
              {{-'' + message['content'] + '</s>' -}}
          {%- endif -%}
      {%- endif -%}
  {%- endfor -%}
  {%- if add_generation_prompt -%}
      {{-''-}}
  {%- endif -%}

I did this by changing the chat templates in the Mistral Nemo tokenizer config to this:

  "chat_template": "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- message['content'] -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'[INST] ' + message['content'].rstrip() + ' [/INST]'-}}{%- else -%}{{-'' + message['content'] + '</s>' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-''-}}{%- endif -%}",

If this is the wrong way to do it, that might be causing the high grad_norm? But that seems unlikely since the dataset seems to be tokenized properly when I use preprocess --debug.

Config yaml

base_model: /home/user/models/Mistral-Nemo-Instruct-2407
model_type: AutoModelForCausalLM

train_on_inputs: false
group_by_length: false
load_in_8bit:
load_in_4bit: false
strict: false
sequence_len: 8192
bf16: auto
flash_attention: true

shuffle_merged_datasets: true

# Data
datasets:
  - path: /home/user/datasets/conversations-escaped.jsonl
    type: chat_template
    field_messages: conversations
    message_field_role: from
    message_field_content: value

warmup_steps: 10
dataset_prepared_path: ./lora_last_run_prepared

# Iterations
num_epochs: 1
saves_per_epoch: 8
saves_total_limit: 8

# Evaluation
val_set_size: 0.0025
eval_max_new_tokens: 128
eval_sample_packing: false
evals_per_epoch: 8
eval_table_size:

# LoRA
output_dir: ./lora_out1
adapter: lora
lora_model_dir:
lora_r: 64
lora_alpha: 128
lora_dropout: 0.05
lora_target_linear: true
loraplus_lr_ratio: 16
save_safetensors: true

# Sampling
sample_packing: true
pad_to_sequence_len: true

# Batching
gradient_accumulation_steps: 16
micro_batch_size: 1
gradient_checkpointing: unsloth

# wandb
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: mistral-nemo
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_name: nemo-v1.3-8192
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training

# Optimizer
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

# Misc
auto_resume_from_checkpoints: true
logging_steps: 1
weight_decay: 0.0

special_tokens:
  pad_token: <pad>

# Multi-GPU
deepspeed:
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.11

axolotl branch-commit

main/db51a9e4

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@Nero10578 Nero10578 added the bug Something isn't working label Nov 21, 2024
@NanoCode012
Copy link
Collaborator

Hey, thanks for the report.

Could also be because of changes to using chat_templates?

Are you able to test one with chat_template before the GA patch (making sure you install the old requirements)?

@winglian
Copy link
Collaborator

@Nero10578 btw, I say you're using FSDP, how many GPUs? thanks

@winglian
Copy link
Collaborator

It could also be that the grad norm is higher because it doesn't expect the repeating of roles

@Nero10578
Copy link
Contributor Author

Hey, thanks for the report.

Could also be because of changes to using chat_templates?

Are you able to test one with chat_template before the GA patch (making sure you install the old requirements)?

Will have to test, but I suspect this is the source of the issue.

@Nero10578 btw, I say you're using FSDP, how many GPUs? thanks

I am using 2x3090Ti with FSDP with CPU offloading.

It could also be that the grad norm is higher because it doesn't expect the repeating of roles

I don't think so because the dataset didn't change and it was fine before.

@Nero10578
Copy link
Contributor Author

Nero10578 commented Nov 26, 2024

So tested it out with the commit before the GA patch 718cfb2 and with the latest commit 724b660. These are the differences in grad_norm:

Before GA patch:
W B Chart 11_25_2024, 11_27_05 PM

After GA Patch:
W B Chart 11_25_2024, 11_28_17 PM

The thing is that now that I have let both run, the GA patch does make the loss better, can be seen in how the first eval shows lower loss after the GA patch.

Before GA patch:
image

After GA Patch:
image

I am using a different config with Liger kernels enabled:

base_model: /home/user/models/Mistral-Nemo-Instruct-2407
model_type: AutoModelForCausalLM

train_on_inputs: false
group_by_length: false
load_in_8bit:
load_in_4bit: false
strict: false
sequence_len: 8192
bf16: auto
flash_attention: true

shuffle_merged_datasets: true

# Data
datasets:
  - path: /home/user/datasets/conversations-escaped.jsonl
    type: chat_template
    field_messages: conversations
    message_field_role: from
    message_field_content: value

warmup_steps: 10
dataset_prepared_path: ./lora_last_run_prepared

# Iterations
num_epochs: 1
saves_per_epoch: 8
saves_total_limit: 8

# Evaluation
val_set_size: 0.0025
eval_max_new_tokens: 128
eval_sample_packing: false
evals_per_epoch: 8
eval_table_size:

# LoRA
output_dir: ./lora_out3
adapter: lora
lora_model_dir:
lora_r: 64
lora_alpha: 128
lora_dropout: 0.05
lora_target_linear: true
peft_use_rslora: false
loraplus_lr_ratio: 16
save_safetensors: true

# Sampling
sample_packing: true
pad_to_sequence_len: true

# Batching
gradient_accumulation_steps: 16
micro_batch_size: 1
gradient_checkpointing: unsloth

# wandb
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: mistral-nemo-v1
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_name: loraplus-nemo--8192
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training

# Optimizer
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

# Misc
auto_resume_from_checkpoints: true
logging_steps: 1
weight_decay: 0.0

special_tokens:
  pad_token: <pad>

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

# Multi-GPU
deepspeed:
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD

@fblgit
Copy link

fblgit commented Dec 12, 2024

the high grad_norm is because there is a big mismatch of whatever the model originally fitted and what is being presented now.
This could be a wrong pad/eos/bos.. it could be obviously the entire template change that OP is performing 😄
i guess that changing template with train_on_inputs: false this.. may not be ideal right?..

@fizzAI
Copy link

fizzAI commented Dec 27, 2024

image

FWIW, an MN lora trained fine for me on 1xGPU but I'm still seeing people occasionally complaining about this being a bug. Possibly a multi-GPU issue? Seems to persist across FSDP and Deepspeed though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants