Skip to content

Commit

Permalink
Process extra feature file once instead of multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
iblacksand committed Sep 25, 2024
1 parent 998a418 commit f6bb066
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 45 deletions.
14 changes: 10 additions & 4 deletions funmap/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
check_extra_feature_file,
check_gold_standard_file,
cleanup_experiment,
process_extra_feature,
setup_experiment,
)

Expand Down Expand Up @@ -120,6 +121,11 @@ def run(config_file, force_rerun):
setup_logging(config_file)
cfg = setup_experiment(config_file)
extra_feature_file = cfg["extra_feature_file"]
if extra_feature_file is not None:
log.info("Loading extra feature file into dataframe")
extra_feature_df = process_extra_feature(extra_feature_file)
else:
extra_feature_df = None
# if (extra_feature_file is not None) and (not check_extra_feature_file(extra_feature_file)):
# return
gs_file = cfg["gs_file"]
Expand Down Expand Up @@ -192,7 +198,7 @@ def run(config_file, force_rerun):
"mr_dict": mr_dict,
"feature_type": feature_type,
"gs_file": gs_file,
"extra_feature_file": extra_feature_file,
"extra_feature_df": extra_feature_df,
"valid_id_list": all_valid_ids,
"test_size": test_size,
"seed": seed,
Expand Down Expand Up @@ -255,7 +261,7 @@ def run(config_file, force_rerun):
"ppi_feature": ppi_feature,
"cc_dict": cc_dict,
"mr_dict": mr_dict,
"extra_feature_file": extra_feature_file,
"extra_feature_df": extra_feature_df,
"prediction_dir": prediction_dir,
"output_file": predicted_all_pairs_file,
"n_jobs": n_jobs,
Expand Down Expand Up @@ -326,7 +332,7 @@ def run(config_file, force_rerun):
max_num_edges,
step_size,
llr_dataset_file,
extra_feature_file,
extra_feature_df,
)
log.info("Done.")
else:
Expand Down Expand Up @@ -356,7 +362,7 @@ def run(config_file, force_rerun):
"feature_type": "cc",
"gs_file": gs_file,
# no extra feature for plotting
"extra_feature_file": None,
"extra_feature_df": None,
"valid_id_list": all_valid_ids,
"test_size": test_size,
"seed": seed,
Expand Down
45 changes: 9 additions & 36 deletions funmap/funmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from funmap.data_urls import network_info
from funmap.logger import setup_logger
from funmap.utils import (
check_extra_feature_file,
get_data_dict,
is_url_scheme,
read_csv_with_md5_check,
Expand Down Expand Up @@ -290,7 +289,7 @@ def get_1d_indices(i_array, j_array, n):


def extract_features(
df, feature_type, cc_dict, ppi_feature=None, extra_feature=None, mr_dict=None
df, feature_type, cc_dict, ppi_feature=None, extra_feature_df=None, mr_dict=None
):
"""
extract_features - creates the final feature `pandas` dataframe used by xgboost
Expand Down Expand Up @@ -326,20 +325,7 @@ def extract_features(
)

# TODO: add extra features if provided
if extra_feature is not None:
extra_feature_df = pd.read_csv(extra_feature, sep="\t")
extra_feature_df.columns.values[0] = "P1"
extra_feature_df.columns.values[1] = "P2"
extra_feature_df[["P1", "P2"]] = extra_feature_df.apply(
lambda row: sorted([row["P1"], row["P2"]])
if row["P1"] > row["P2"]
else [row["P1"], row["P2"]],
axis=1,
result_type="expand",
)
extra_feature_df = extra_feature_df.drop_duplicates(
subset=["P1", "P2"], keep="last"
)
if extra_feature_df is not None:
df.reset_index(drop=True, inplace=True)
merged_df = pd.merge(
df[["P1", "P2"]], extra_feature_df, on=["P1", "P2"], how="left"
Expand Down Expand Up @@ -506,7 +492,7 @@ def prepare_gs_data(**kwargs):
gs_file = kwargs["gs_file"]
gs_file_md5 = None
feature_type = kwargs["feature_type"]
extra_feature_file = kwargs["extra_feature_file"]
extra_feature_df = kwargs["extra_feature_df"]
valid_id_list = kwargs["valid_id_list"]
test_size = kwargs["test_size"]
seed = kwargs["seed"]
Expand All @@ -530,10 +516,10 @@ def prepare_gs_data(**kwargs):
else:
ppi_feature = None
gs_train_df = extract_features(
gs_train, feature_type, cc_dict, ppi_feature, extra_feature_file, mr_dict
gs_train, feature_type, cc_dict, ppi_feature, extra_feature_df, mr_dict
)
gs_test_df = extract_features(
gs_test, feature_type, cc_dict, ppi_feature, extra_feature_file, mr_dict
gs_test, feature_type, cc_dict, ppi_feature, extra_feature_df, mr_dict
)

# store both the ids with gs_test_df for later use
Expand Down Expand Up @@ -599,7 +585,7 @@ def dataset_llr(
max_num_edge,
step_size,
llr_dataset_file,
extra_feature,
extra_feature_df,
):
llr_ds = pd.DataFrame()
all_ids_sorted = sorted(all_ids)
Expand Down Expand Up @@ -644,21 +630,8 @@ def dataset_llr(
llr_ds = pd.concat([llr_ds, cur_llr_res], axis=0, ignore_index=True)
log.info("Calculating llr for all datasets average ... done")
llr_ds.to_csv(llr_dataset_file, sep="\t", index=False)
if extra_feature is not None:
if extra_feature_df is not None:
log.info("Calculating LLR for extra features")
extra_feature_df = pd.read_csv(extra_feature, sep="\t")
extra_feature_df.columns.values[0] = "P1"
extra_feature_df.columns.values[1] = "P2"
extra_feature_df[["P1", "P2"]] = extra_feature_df.apply(
lambda row: sorted([row["P1"], row["P2"]])
if row["P1"] > row["P2"]
else [row["P1"], row["P2"]],
axis=1,
result_type="expand",
)
extra_feature_df = extra_feature_df.drop_duplicates(
subset=["P1", "P2"], keep="last"
)
extra_feature_df = extract_extra_features(
all_pairs, extra_feature_df
) # filter out unused pairs
Expand Down Expand Up @@ -751,7 +724,7 @@ def predict_all_pairs(
ppi_feature,
cc_dict,
mr_dict,
extra_feature_file,
extra_feature_df,
prediction_dir,
output_file,
n_jobs=1,
Expand Down Expand Up @@ -784,7 +757,7 @@ def process_and_save_chunk(start_idx, chunk_id):
feature_type,
cc_dict,
cur_ppi_feature,
extra_feature_file,
extra_feature_df,
mr_dict,
)
predictions = model.predict_proba(feature_df)
Expand Down
28 changes: 23 additions & 5 deletions funmap/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import csv
import yaml
import hashlib
import os
import tarfile
import re
import hashlib
import shutil
import tarfile
import urllib
from urllib.parse import urlparse
from pathlib import Path
from urllib.parse import urlparse

import pandas as pd
import shutil
import yaml

from funmap.data_urls import misc_urls as urls
from funmap.logger import setup_logger

Expand Down Expand Up @@ -553,3 +555,19 @@ def check_extra_feature_file(file_path, missing_value="NA"):
return False

return True


def process_extra_feature(extra_feature_file) -> pd.DataFrame:
extra_feature_df = pd.read_csv(extra_feature_file, sep="\t")
extra_feature_df.columns.values[0] = "P1"
extra_feature_df.columns.values[1] = "P2"
extra_feature_df[["P1", "P2"]] = extra_feature_df.apply(
lambda row: sorted([row["P1"], row["P2"]])
if row["P1"] > row["P2"]
else [row["P1"], row["P2"]],
axis=1,
result_type="expand",
)
extra_feature_df = extra_feature_df.drop_duplicates(
subset=["P1", "P2"], keep="last"
)

0 comments on commit f6bb066

Please sign in to comment.