Skip to content

Commit

Permalink
this should work?
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Jan 8, 2025
1 parent f81b174 commit 0630baa
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def process(self, dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
import pdb
pdb.set_trace()
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

from itertools import chain

from typing import Dict, List, Optional, Union
from typing import Dict, Generator, List, Optional, Union

from transformers import BatchEncoding
from transformers import BatchEncoding, PreTrainedTokenizer

from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.dict import DictDefault


class StepwiseSupervisedPromptTokenizingStrategy(PromptTokenizingStrategy):
class StepwiseSupervisedPromptTokenizingStrategy:
"""
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
These datasets should include the following columns:
Expand All @@ -24,7 +25,6 @@ class StepwiseSupervisedPromptTokenizingStrategy(PromptTokenizingStrategy):

def __init__(
self,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
Expand All @@ -33,7 +33,9 @@ def __init__(
train_on_last_step_only: bool = False,
is_eval: bool = False,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.tokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
self.step_separator = step_separator
self.max_completion_length = max_completion_length
self.train_on_last_step_only = train_on_last_step_only
Expand All @@ -42,6 +44,8 @@ def __init__(
def tokenize_prompt(
self, prompt: Dict[str, Union[str, List[str]]]
) -> BatchEncoding:
# Inspired by TRL's PRMTRainer
# https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206
prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[
"input_ids"
]
Expand Down Expand Up @@ -98,3 +102,17 @@ def tokenize_prompt(
"attention_mask": [1] * len(input_ids),
}
)

@property
def supports_batched(self):
return False


def load(
tokenizer: PreTrainedTokenizer, cfg: DictDefault
) -> StepwiseSupervisedPromptTokenizingStrategy:
return StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
9 changes: 1 addition & 8 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import List, Tuple, Union

from datasets import (
concatenate_datasets,
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
Expand Down Expand Up @@ -456,13 +456,6 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif ds_strategy := config_dataset.type == "stepwise_supervised":
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
Expand Down

0 comments on commit 0630baa

Please sign in to comment.