diff --git a/src/instructlab/sdg/datamixing.py b/src/instructlab/sdg/datamixing.py index 1e9ab439..44e12741 100644 --- a/src/instructlab/sdg/datamixing.py +++ b/src/instructlab/sdg/datamixing.py @@ -262,41 +262,48 @@ def _add_extra_contexts_to_samples(ds: Dataset, p, num_doc_in_context=4): `keep_context_separate` equal to True. When this finishes, the `context` column is removed from the dataset and all context moved to the user messages. + + This is inspired by the concepts of Retrieval Augmented FineTuning (RAFT) + from https://arxiv.org/abs/2403.10131 """ - all_context = ds["context"] - all_context = [ - " ".join(e.split(" ")[: random.randint(100, 500)]) for e in all_context - ] - ds = ds.add_column("row_idx", range(ds.num_rows)) + all_context = list(set(ds["context"])) def __pick_documents(rec, p): - # Loop until we find enough other documents to add to the context - # for this document. Exit the loop early if we have fewer total - # documents than the number of documents we want in our context - # so that we don't end up looping forever. This handles edge - # cases where the number of generated instructions is very low, - # like in CI or user's testing small sizes. - while True: - selected_docs = random.choices(range(ds.num_rows), k=num_doc_in_context) - if ds.num_rows <= num_doc_in_context: - break - if rec["row_idx"] not in selected_docs: - break - if random.uniform(0, 1) < p: - docs = [ - all_context[idx] for idx in selected_docs[: num_doc_in_context - 1] - ] + [rec["context"]] - # rec['indicator'] ='golden' + answer_document = [rec["context"]] + selected_docs = [e for e in all_context if e != answer_document] + if len(selected_docs) > 0: + if len(selected_docs) < num_doc_in_context: + logger.debug( + f"Number of unique documents is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the expanded context." + ) + if random.uniform(0, 1) < p: + # golden/answer + distractor documents + docs = ( + random.sample(selected_docs, k=num_doc_in_context) + if len(selected_docs) >= num_doc_in_context + else selected_docs + [answer_document] + ) + else: + # distractor documents + docs = ( + random.sample(selected_docs, k=num_doc_in_context) + if len(selected_docs) >= num_doc_in_context + else selected_docs + ) else: - docs = [all_context[idx] for idx in selected_docs] - # rec['indicator'] = 'distractor' + logger.warning( + "Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results." + ) + docs = [answer_document] random.shuffle(docs) docs = "\n".join(([f"Document:\n{e}\n\n" for idx, e in enumerate(docs)])) - user_idx, user_msg = [ + user_idx_msgs = [ (idx, rec_msg) for idx, rec_msg in enumerate(rec["messages"]) if rec_msg["role"] == "user" - ][0] + ] + assert len(user_idx_msgs) > 0, "No user role found in dataset" + user_idx, user_msg = user_idx_msgs[0] user_inst = user_msg["content"] rec["messages"][user_idx]["content"] = f"{docs}\n\n{user_inst}" rec["messages"] = rec["messages"] diff --git a/tests/test_datamixing.py b/tests/test_datamixing.py new file mode 100644 index 00000000..48101390 --- /dev/null +++ b/tests/test_datamixing.py @@ -0,0 +1,62 @@ +""" +Unit tests for the top-level datamixing module. +""" + +# Third Party +from datasets import Dataset + +# First Party +from instructlab.sdg.datamixing import _add_extra_contexts_to_samples + + +def _fake_context(msg_id): + return { + "context": f"context {msg_id}", + "id": msg_id, + "messages": [{"role": "user", "content": f"user content {msg_id}"}], + "metadata": '{"dataset": []}', + } + + +def test_add_extra_contexts_to_samples_with_one_sample(): + """ + Test _add_extra_contexts_to_samples doesn't error out when + given only one sample + """ + samples = Dataset.from_list([_fake_context("abc123")]) + dataset = _add_extra_contexts_to_samples(samples, p=0.4) + assert len(dataset) == 1 + + +def test_add_extra_contexts_to_samples_with_two_samples(): + """ + Test _add_extra_contexts_to_samples doesn't error out when + given only two samples + """ + samples = Dataset.from_list( + [ + _fake_context("abc123"), + _fake_context("bcd234"), + ] + ) + dataset = _add_extra_contexts_to_samples(samples, p=0.4) + assert len(dataset) == 2 + + +def test_add_extra_contexts_to_samples_with_six_samples(): + """ + Test _add_extra_contexts_to_samples doesn't error out when + given more samples + """ + samples = Dataset.from_list( + [ + _fake_context("s1"), + _fake_context("s2"), + _fake_context("s3"), + _fake_context("s4"), + _fake_context("s5"), + _fake_context("s6"), + ] + ) + dataset = _add_extra_contexts_to_samples(samples, p=0.4) + assert len(dataset) == 6