Skip to content

Commit

Permalink
feat: add support for data_files in pretraining (#2238)
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 authored Jan 9, 2025
1 parent 7669a03 commit ed77e70
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None


class UserDefinedPrompterType(BaseModel):
Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
Expand All @@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]

data_files = cfg.pretraining_dataset[0].get("data_files")

ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
Expand All @@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
)

train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
tokenizer,
cfg,
ds_wrapper_partial,
Expand Down

0 comments on commit ed77e70

Please sign in to comment.