From f76ed4e57e57ebfd0ee2c7c7bb7bc13320ab70d6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 28 Oct 2024 20:42:21 +0700 Subject: [PATCH] feat: add test dpo lora multi-gpu --- tests/e2e/multigpu/test_llama.py | 69 ++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 957a6a9e36..5329b9f919 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -144,6 +144,75 @@ def test_lora_ddp_packed(self, temp_dir): ] ) + @with_temp_dir + def test_dpo_lora_ddp(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama_v1.1", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 2048, + "sample_packing": True, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.05, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + }, + ], + "num_epochs": 1, + "max_steps": 50, + "micro_batch_size": 4, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) + @with_temp_dir def test_fsdp(self, temp_dir): # pylint: disable=duplicate-code