Skip to content

Commit

Permalink
Fixed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingTil committed Nov 16, 2023
1 parent d365970 commit 7d274d5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
7 changes: 7 additions & 0 deletions py_css/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,19 @@ def gen_context_docs(context: Context) -> Generator[Document, None, None]:

# if in df there are multiple rows that have the same qid and docno, keep the one with the highest score. For the ones removed, add a row each with the EMPTY_PLACEHOLDER_DOC
rank_size_per_qid: int = df.groupby("qid").size().max()
print(f"Rank size per qid: {rank_size_per_qid}")
df = df.sort_values(["qid", "docno", "score"], ascending=[True, True, False])
total_size = df.shape[0]
df = df.drop_duplicates(subset=["qid", "docno"], keep="first")
dropped_any: bool = total_size != df.shape[0]
print(f"Dropped any: {dropped_any}")
df = df.reset_index(drop=True)
df = self.pad_empty_documents(
df, df["qid"].unique(), rank_size_per_qid, df[["qid", "query"]]
)
print(f"Number of max rank size per qid: {df.groupby('qid').size().max()}")
df = df.reset_index(drop=True)
df = df.sort_values(["qid", "rank"], ascending=[True, True])

for query, context in context_list:
# check if there is a row in the df with "qid" == query.query_id, where "docno" == EMPTY_PLACEHOLDER_DOC.docno
Expand Down
1 change: 1 addition & 0 deletions py_css/models/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
self.top_docs = ((bm25 % bm25_docs).compile(), bm25_docs)
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
# self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE, model="castorini/monot5-large-msmarco"), mono_t5_docs)
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)

def transform_input(
Expand Down
1 change: 1 addition & 0 deletions py_css/models/baseline_prf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
rm3 = pt.rewrite.RM3(index, fb_docs=rm3_fb_docs, fb_terms=rm3_fb_terms)
self.top_docs = ((bm25 >> rm3 >> bm25) % bm25_docs, bm25_docs)
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
# self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE, model="castorini/monot5-large-msmarco"), mono_t5_docs)
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)

def transform_input(
Expand Down

0 comments on commit 7d274d5

Please sign in to comment.