Skip to content

Commit

Permalink
Merge pull request instructlab#215 from bbrowning/knowledge-contexts-…
Browse files Browse the repository at this point in the history
…fixes

Incorporate knowledge generation context selection improvements
  • Loading branch information
markmc authored Jul 25, 2024
2 parents 25018dc + 415d1a5 commit afadfd5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 26 deletions.
59 changes: 33 additions & 26 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,41 +262,48 @@ def _add_extra_contexts_to_samples(ds: Dataset, p, num_doc_in_context=4):
`keep_context_separate` equal to True. When this finishes, the `context`
column is removed from the dataset and all context moved to the user
messages.
This is inspired by the concepts of Retrieval Augmented FineTuning (RAFT)
from https://arxiv.org/abs/2403.10131
"""
all_context = ds["context"]
all_context = [
" ".join(e.split(" ")[: random.randint(100, 500)]) for e in all_context
]
ds = ds.add_column("row_idx", range(ds.num_rows))
all_context = list(set(ds["context"]))

def __pick_documents(rec, p):
# Loop until we find enough other documents to add to the context
# for this document. Exit the loop early if we have fewer total
# documents than the number of documents we want in our context
# so that we don't end up looping forever. This handles edge
# cases where the number of generated instructions is very low,
# like in CI or user's testing small sizes.
while True:
selected_docs = random.choices(range(ds.num_rows), k=num_doc_in_context)
if ds.num_rows <= num_doc_in_context:
break
if rec["row_idx"] not in selected_docs:
break
if random.uniform(0, 1) < p:
docs = [
all_context[idx] for idx in selected_docs[: num_doc_in_context - 1]
] + [rec["context"]]
# rec['indicator'] ='golden'
answer_document = [rec["context"]]
selected_docs = [e for e in all_context if e != answer_document]
if len(selected_docs) > 0:
if len(selected_docs) < num_doc_in_context:
logger.debug(
f"Number of unique documents is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the expanded context."
)
if random.uniform(0, 1) < p:
# golden/answer + distractor documents
docs = (
random.sample(selected_docs, k=num_doc_in_context)
if len(selected_docs) >= num_doc_in_context
else selected_docs + [answer_document]
)
else:
# distractor documents
docs = (
random.sample(selected_docs, k=num_doc_in_context)
if len(selected_docs) >= num_doc_in_context
else selected_docs
)
else:
docs = [all_context[idx] for idx in selected_docs]
# rec['indicator'] = 'distractor'
logger.warning(
"Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results."
)
docs = [answer_document]
random.shuffle(docs)
docs = "\n".join(([f"Document:\n{e}\n\n" for idx, e in enumerate(docs)]))
user_idx, user_msg = [
user_idx_msgs = [
(idx, rec_msg)
for idx, rec_msg in enumerate(rec["messages"])
if rec_msg["role"] == "user"
][0]
]
assert len(user_idx_msgs) > 0, "No user role found in dataset"
user_idx, user_msg = user_idx_msgs[0]
user_inst = user_msg["content"]
rec["messages"][user_idx]["content"] = f"{docs}\n\n{user_inst}"
rec["messages"] = rec["messages"]
Expand Down
62 changes: 62 additions & 0 deletions tests/test_datamixing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Unit tests for the top-level datamixing module.
"""

# Third Party
from datasets import Dataset

# First Party
from instructlab.sdg.datamixing import _add_extra_contexts_to_samples


def _fake_context(msg_id):
return {
"context": f"context {msg_id}",
"id": msg_id,
"messages": [{"role": "user", "content": f"user content {msg_id}"}],
"metadata": '{"dataset": []}',
}


def test_add_extra_contexts_to_samples_with_one_sample():
"""
Test _add_extra_contexts_to_samples doesn't error out when
given only one sample
"""
samples = Dataset.from_list([_fake_context("abc123")])
dataset = _add_extra_contexts_to_samples(samples, p=0.4)
assert len(dataset) == 1


def test_add_extra_contexts_to_samples_with_two_samples():
"""
Test _add_extra_contexts_to_samples doesn't error out when
given only two samples
"""
samples = Dataset.from_list(
[
_fake_context("abc123"),
_fake_context("bcd234"),
]
)
dataset = _add_extra_contexts_to_samples(samples, p=0.4)
assert len(dataset) == 2


def test_add_extra_contexts_to_samples_with_six_samples():
"""
Test _add_extra_contexts_to_samples doesn't error out when
given more samples
"""
samples = Dataset.from_list(
[
_fake_context("s1"),
_fake_context("s2"),
_fake_context("s3"),
_fake_context("s4"),
_fake_context("s5"),
_fake_context("s6"),
]
)
dataset = _add_extra_contexts_to_samples(samples, p=0.4)
assert len(dataset) == 6

0 comments on commit afadfd5

Please sign in to comment.