Skip to content

Commit

Permalink
more checks using helper
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 10, 2025
1 parent 1d9d237 commit f80abe4
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 27 deletions.
5 changes: 2 additions & 3 deletions tests/e2e/patched/test_fused_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest
from transformers.utils import is_torch_bf16_gpu_available
Expand All @@ -16,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -73,4 +72,4 @@ def test_fft_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
3 changes: 1 addition & 2 deletions tests/e2e/patched/test_phi_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
Expand Down Expand Up @@ -69,7 +68,7 @@ def test_ft_packed(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_qlora_packed(self, temp_dir):
Expand Down
16 changes: 8 additions & 8 deletions tests/e2e/test_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_dpo_lora(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_dpo_nll_lora(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_dpo_use_weighting(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_kto_pair_lora(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_ipo_lora(self, temp_dir):
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_ipo_lora(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_orpo_lora(self, temp_dir):
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_orpo_lora(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@pytest.mark.skip(reason="Fix the implementation")
@with_temp_dir
Expand Down Expand Up @@ -358,4 +358,4 @@ def test_kto_lora(self, temp_dir):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
3 changes: 1 addition & 2 deletions tests/e2e/test_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
Expand Down Expand Up @@ -163,4 +162,4 @@ def test_ft(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
5 changes: 2 additions & 3 deletions tests/e2e/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest

Expand All @@ -15,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -65,4 +64,4 @@ def test_fft(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
3 changes: 1 addition & 2 deletions tests/e2e/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

from transformers.utils import is_torch_bf16_gpu_available

Expand Down Expand Up @@ -112,4 +111,4 @@ def test_ft(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
3 changes: 1 addition & 2 deletions tests/e2e/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_phi_ft(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_phi_qlora(self, temp_dir):
Expand Down
8 changes: 3 additions & 5 deletions tests/e2e/test_relora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import check_tensorboard, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -78,10 +78,8 @@ def test_relora(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/relora", cfg)

check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
Expand Down

0 comments on commit f80abe4

Please sign in to comment.