From bb83f46daf49af0525f45159e19b92812bd9142b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Dec 2024 16:21:13 -0500 Subject: [PATCH 1/7] [feature] sweeps --- src/axolotl/cli/main.py | 88 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 10 deletions(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..9f92852f1 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,9 +1,15 @@ """CLI definition for various axolotl commands.""" # pylint: disable=redefined-outer-name +import logging import subprocess # nosec B404 +import tempfile +from copy import deepcopy +from itertools import product +from pathlib import Path from typing import Optional import click +import yaml import axolotl from axolotl.cli.utils import ( @@ -17,6 +23,33 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig +def generate_sweep_configs(base_config, sweeps_config): + """ + Recursively generates all possible configurations by applying sweeps to the base config. + + Args: + base_config (dict): The original configuration dictionary + sweeps_config (dict): Dictionary where keys are paths to parameters and values are lists of values to sweep + + Returns: + list: List of all possible configuration dictionaries + """ + # Get all parameter combinations + param_names = list(sweeps_config.keys()) + param_values = list(sweeps_config.values()) + all_combinations = list(product(*param_values)) + + # Generate a new config for each combination + result_configs = [] + for combination in all_combinations: + new_config = deepcopy(base_config) + for param_name, param_value in zip(param_names, combination): + new_config = new_config[param_name] = param_value + result_configs.append(new_config) + + return result_configs + + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") def cli(): @@ -43,25 +76,60 @@ def preprocess(config: str, **kwargs): default=True, help="Use accelerate launch for multi-GPU training", ) +@click.option( + "--sweep", + type=click.Path(exists=True, path_type=str), + help="YAML config for sweeping hyperparameters", +) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def train(config: str, accelerate: bool, **kwargs): +def train(config: str, accelerate: bool, sweep: Optional[str] = None, **kwargs): """Train or fine-tune a model.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 + if sweep: + # load the sweep configuration yaml file + with open(sweep, "r", encoding="utf-8") as fin: + sweep_config: dict[str, list] = yaml.safe_load(fin) + with open(config, "r", encoding="utf-8") as fin: + base_config: dict[str, list] = yaml.safe_load(fin) + + # generate all possible configurations + permutations = generate_sweep_configs(base_config, sweep_config) + + def iter_configs(): + for perm in permutations: + # open temp directory for temporary configurations + with tempfile.TemporaryDirectory() as temp_dir: + with open( + Path(temp_dir) / "config.yaml", "w", encoding="utf-8" + ) as fout: + yaml.dump(perm, fout) + yield str(Path(temp_dir) / "config.yaml") else: - from axolotl.cli.train import do_cli - - do_cli(config=config, **kwargs) + def iter_configs(): + yield config + + for cfg_file in iter_configs(): + # handle errors from subprocess so we can continue rest of sweeps + try: + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] + if cfg_file: + base_cmd.append(cfg_file) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.train import do_cli + + do_cli(config=cfg_file, **kwargs) + except subprocess.CalledProcessError as exc: + logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") + if not sweep: + raise exc @cli.command() From cb70be0c1a30cea8d73f18b8f50db862d57179fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Dec 2024 16:46:56 -0500 Subject: [PATCH 2/7] fix syntax --- src/axolotl/cli/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 9f92852f1..b1e6c3ba0 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -44,7 +44,7 @@ def generate_sweep_configs(base_config, sweeps_config): for combination in all_combinations: new_config = deepcopy(base_config) for param_name, param_value in zip(param_names, combination): - new_config = new_config[param_name] = param_value + new_config[param_name] = param_value result_configs.append(new_config) return result_configs From 5b7c9af9d48643910729ced14027044cdaa868f5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Dec 2024 19:56:31 -0500 Subject: [PATCH 3/7] randomize the order of trials --- src/axolotl/cli/main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index b1e6c3ba0..64bbd655b 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,6 +1,7 @@ """CLI definition for various axolotl commands.""" # pylint: disable=redefined-outer-name import logging +import random import subprocess # nosec B404 import tempfile from copy import deepcopy @@ -39,6 +40,10 @@ def generate_sweep_configs(base_config, sweeps_config): param_values = list(sweeps_config.values()) all_combinations = list(product(*param_values)) + # randomize the order of trials + random.seed(42) + random.shuffle(all_combinations) + # Generate a new config for each combination result_configs = [] for combination in all_combinations: From bd6d5230e6e44fee4d2718ded83df0a9624763c2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Dec 2024 08:15:05 -0500 Subject: [PATCH 4/7] add support for paired key sweeps --- src/axolotl/cli/main.py | 39 ++++++++++++++++++--- tests/cli/test_cli_sweeps.py | 68 ++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 tests/cli/test_cli_sweeps.py diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 64bbd655b..0851eecee 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -30,15 +30,44 @@ def generate_sweep_configs(base_config, sweeps_config): Args: base_config (dict): The original configuration dictionary - sweeps_config (dict): Dictionary where keys are paths to parameters and values are lists of values to sweep + sweeps_config (dict): Dictionary where keys are parameters and values are either: + - lists of values to sweep independently + - or for paired values, a list of dicts under the '_' key Returns: list: List of all possible configuration dictionaries + + Example: + sweeps_config = { + 'learning_rate': [0.1, 0.01], + '_': [ + {'load_in_8bit': True, 'adapter': 'lora'}, + {'load_in_4bit': True, 'adapter': 'qlora'} + ] + } """ - # Get all parameter combinations - param_names = list(sweeps_config.keys()) - param_values = list(sweeps_config.values()) - all_combinations = list(product(*param_values)) + # Separate paired values from regular sweeps + paired_values = sweeps_config.get("_", []) + regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"} + + # Process regular sweeps + param_names = list(regular_sweeps.keys()) + param_values = list(regular_sweeps.values()) + + # Generate combinations for regular sweeps + regular_combinations = list(product(*param_values)) if param_values else [()] + + # Combine regular sweeps with paired values + all_combinations = [] + for reg_combo in regular_combinations: + if paired_values: + for paired_set in paired_values: + # Combine regular parameters with paired parameters + full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} + all_combinations.append(full_combo) + else: + # If no paired values, just use regular combinations + all_combinations.append(dict(zip(param_names, reg_combo))) # randomize the order of trials random.seed(42) diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py new file mode 100644 index 000000000..499973af6 --- /dev/null +++ b/tests/cli/test_cli_sweeps.py @@ -0,0 +1,68 @@ +""" +unit tests for generating sweep configurations +""" +from axolotl.cli.main import generate_sweep_configs + + +def test_generate_sweep_configs_no_pairs(): + base_config = { + "learning_rate": 0.1, + "micro_batch_size": 1, + "sample_packing": True, + } + + sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]} + + generate_sweep_configs(base_config, sweeps_config) + + assert len(generate_sweep_configs(base_config, sweeps_config)) == 6 + + cfg_1 = { + "learning_rate": 0.1, + "micro_batch_size": 2, + "weight_decay": 0.0, + "sample_packing": True, + } + + assert any( + cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config) + ) + + +def test_generate_sweep_configs_with_pairs(): + base_config = { + "learning_rate": 0.1, + "micro_batch_size": 1, + "sample_packing": True, + } + + sweeps_config = { + "_": [ + { + "micro_batch_size": 1, + "gradient_accumulation_Steps": 8, + }, + { + "micro_batch_size": 2, + "gradient_accumulation_Steps": 4, + }, + { + "micro_batch_size": 4, + "gradient_accumulation_Steps": 2, + }, + { + "micro_batch_size": 8, + "gradient_accumulation_Steps": 1, + }, + ], + "weight_decay": [0.0, 0.1], + } + + generate_sweep_configs(base_config, sweeps_config) + + assert len(generate_sweep_configs(base_config, sweeps_config)) == 8 + + assert all( + cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8 + for cfg in generate_sweep_configs(base_config, sweeps_config) + ) From d5695665b65cb3e8d273b603ec80e2a25c3b2088 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Dec 2024 09:57:40 -0500 Subject: [PATCH 5/7] fix typos --- tests/cli/test_cli_sweeps.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py index 499973af6..61c886e80 100644 --- a/tests/cli/test_cli_sweeps.py +++ b/tests/cli/test_cli_sweeps.py @@ -40,19 +40,19 @@ def test_generate_sweep_configs_with_pairs(): "_": [ { "micro_batch_size": 1, - "gradient_accumulation_Steps": 8, + "gradient_accumulation_steps": 8, }, { "micro_batch_size": 2, - "gradient_accumulation_Steps": 4, + "gradient_accumulation_steps": 4, }, { "micro_batch_size": 4, - "gradient_accumulation_Steps": 2, + "gradient_accumulation_steps": 2, }, { "micro_batch_size": 8, - "gradient_accumulation_Steps": 1, + "gradient_accumulation_steps": 1, }, ], "weight_decay": [0.0, 0.1], From ecca813a5a851786147c474aea5cd8adac40a07c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Dec 2024 10:50:43 -0500 Subject: [PATCH 6/7] fix the sweep logic --- src/axolotl/cli/main.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 0851eecee..881b119a7 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -62,12 +62,22 @@ def generate_sweep_configs(base_config, sweeps_config): for reg_combo in regular_combinations: if paired_values: for paired_set in paired_values: + new_config = {} + # new_config = deepcopy(base_config) # Combine regular parameters with paired parameters full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} - all_combinations.append(full_combo) + for param_name, param_value in full_combo.items(): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) else: # If no paired values, just use regular combinations - all_combinations.append(dict(zip(param_names, reg_combo))) + # new_config = deepcopy(base_config) + new_config = {} + for param_name, param_value in zip(param_names, reg_combo): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) # randomize the order of trials random.seed(42) @@ -77,7 +87,7 @@ def generate_sweep_configs(base_config, sweeps_config): result_configs = [] for combination in all_combinations: new_config = deepcopy(base_config) - for param_name, param_value in zip(param_names, combination): + for param_name, param_value in combination.items(): new_config[param_name] = param_value result_configs.append(new_config) From cca87615ba2b8e5fbb1f7b14c258e3365c9cf6b5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Dec 2024 07:34:18 -0500 Subject: [PATCH 7/7] chore: lint (merge conflict) --- src/axolotl/cli/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 881b119a7..439b9fe44 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -153,7 +153,9 @@ def iter_configs(): ) as fout: yaml.dump(perm, fout) yield str(Path(temp_dir) / "config.yaml") + else: + def iter_configs(): yield config