diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..439b9fe44 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,9 +1,16 @@ """CLI definition for various axolotl commands.""" # pylint: disable=redefined-outer-name +import logging +import random import subprocess # nosec B404 +import tempfile +from copy import deepcopy +from itertools import product +from pathlib import Path from typing import Optional import click +import yaml import axolotl from axolotl.cli.utils import ( @@ -17,6 +24,76 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig +def generate_sweep_configs(base_config, sweeps_config): + """ + Recursively generates all possible configurations by applying sweeps to the base config. + + Args: + base_config (dict): The original configuration dictionary + sweeps_config (dict): Dictionary where keys are parameters and values are either: + - lists of values to sweep independently + - or for paired values, a list of dicts under the '_' key + + Returns: + list: List of all possible configuration dictionaries + + Example: + sweeps_config = { + 'learning_rate': [0.1, 0.01], + '_': [ + {'load_in_8bit': True, 'adapter': 'lora'}, + {'load_in_4bit': True, 'adapter': 'qlora'} + ] + } + """ + # Separate paired values from regular sweeps + paired_values = sweeps_config.get("_", []) + regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"} + + # Process regular sweeps + param_names = list(regular_sweeps.keys()) + param_values = list(regular_sweeps.values()) + + # Generate combinations for regular sweeps + regular_combinations = list(product(*param_values)) if param_values else [()] + + # Combine regular sweeps with paired values + all_combinations = [] + for reg_combo in regular_combinations: + if paired_values: + for paired_set in paired_values: + new_config = {} + # new_config = deepcopy(base_config) + # Combine regular parameters with paired parameters + full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} + for param_name, param_value in full_combo.items(): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) + else: + # If no paired values, just use regular combinations + # new_config = deepcopy(base_config) + new_config = {} + for param_name, param_value in zip(param_names, reg_combo): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) + + # randomize the order of trials + random.seed(42) + random.shuffle(all_combinations) + + # Generate a new config for each combination + result_configs = [] + for combination in all_combinations: + new_config = deepcopy(base_config) + for param_name, param_value in combination.items(): + new_config[param_name] = param_value + result_configs.append(new_config) + + return result_configs + + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") def cli(): @@ -43,25 +120,62 @@ def preprocess(config: str, **kwargs): default=True, help="Use accelerate launch for multi-GPU training", ) +@click.option( + "--sweep", + type=click.Path(exists=True, path_type=str), + help="YAML config for sweeping hyperparameters", +) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def train(config: str, accelerate: bool, **kwargs): +def train(config: str, accelerate: bool, sweep: Optional[str] = None, **kwargs): """Train or fine-tune a model.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 + if sweep: + # load the sweep configuration yaml file + with open(sweep, "r", encoding="utf-8") as fin: + sweep_config: dict[str, list] = yaml.safe_load(fin) + with open(config, "r", encoding="utf-8") as fin: + base_config: dict[str, list] = yaml.safe_load(fin) + + # generate all possible configurations + permutations = generate_sweep_configs(base_config, sweep_config) + + def iter_configs(): + for perm in permutations: + # open temp directory for temporary configurations + with tempfile.TemporaryDirectory() as temp_dir: + with open( + Path(temp_dir) / "config.yaml", "w", encoding="utf-8" + ) as fout: + yaml.dump(perm, fout) + yield str(Path(temp_dir) / "config.yaml") + else: - from axolotl.cli.train import do_cli - do_cli(config=config, **kwargs) + def iter_configs(): + yield config + + for cfg_file in iter_configs(): + # handle errors from subprocess so we can continue rest of sweeps + try: + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] + if cfg_file: + base_cmd.append(cfg_file) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.train import do_cli + + do_cli(config=cfg_file, **kwargs) + except subprocess.CalledProcessError as exc: + logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") + if not sweep: + raise exc @cli.command() diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py new file mode 100644 index 000000000..61c886e80 --- /dev/null +++ b/tests/cli/test_cli_sweeps.py @@ -0,0 +1,68 @@ +""" +unit tests for generating sweep configurations +""" +from axolotl.cli.main import generate_sweep_configs + + +def test_generate_sweep_configs_no_pairs(): + base_config = { + "learning_rate": 0.1, + "micro_batch_size": 1, + "sample_packing": True, + } + + sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]} + + generate_sweep_configs(base_config, sweeps_config) + + assert len(generate_sweep_configs(base_config, sweeps_config)) == 6 + + cfg_1 = { + "learning_rate": 0.1, + "micro_batch_size": 2, + "weight_decay": 0.0, + "sample_packing": True, + } + + assert any( + cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config) + ) + + +def test_generate_sweep_configs_with_pairs(): + base_config = { + "learning_rate": 0.1, + "micro_batch_size": 1, + "sample_packing": True, + } + + sweeps_config = { + "_": [ + { + "micro_batch_size": 1, + "gradient_accumulation_steps": 8, + }, + { + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + }, + { + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + }, + { + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + }, + ], + "weight_decay": [0.0, 0.1], + } + + generate_sweep_configs(base_config, sweeps_config) + + assert len(generate_sweep_configs(base_config, sweeps_config)) == 8 + + assert all( + cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8 + for cfg in generate_sweep_configs(base_config, sweeps_config) + )