Skip to content

Commit

Permalink
Debug GPU testing
Browse files Browse the repository at this point in the history
  • Loading branch information
bfhealy committed Feb 23, 2024
1 parent a707f0d commit 44dcb09
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions scope/scope_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2614,8 +2614,17 @@ def test(self, doGPU=False):
with status("Test training"):
print()

period_suffix_config = self.config["features"]["info"]["period_suffix"]
period_suffix_2 = "LS"
period_suffix_config = (
self.config.get("features").get("info").get("period_suffix")
)
if doGPU & (
period_suffix_config not in ["ELS", "ECE", "EAOV", "ELS_ECE_EAOV"]
):
period_suffix_test = "ELS_ECE_EAOV"
if (not doGPU) & (
period_suffix_config not in ["LS", "CE", "AOV", "LS_CE_AOV"]
):
period_suffix_test = "LS"

if not path_mock.exists():
path_mock.mkdir(parents=True, exist_ok=True)
Expand All @@ -2639,13 +2648,13 @@ def test(self, doGPU=False):
feature_names_new[j] = f"{name}_{period_suffix_config}"

feature_names = feature_names_orig.copy()
if not ((period_suffix_2 is None) | (period_suffix_2 == "None")):
if not ((period_suffix_test is None) | (period_suffix_test == "None")):
periodic_bool = [
all_feature_names[x]["periodic"] for x in feature_names
]
for j, name in enumerate(feature_names):
if periodic_bool[j]:
feature_names[j] = f"{name}_{period_suffix_2}"
feature_names[j] = f"{name}_{period_suffix_test}"

class_names = [
self.config["training"]["classes"][class_name]["label"]
Expand Down Expand Up @@ -2738,7 +2747,7 @@ def test(self, doGPU=False):
test=True,
algorithm=algorithm,
skip_cv=True,
period_suffix=period_suffix_2,
period_suffix=period_suffix_test,
group=group_mock,
)
path_model = (
Expand Down Expand Up @@ -2784,7 +2793,7 @@ def test(self, doGPU=False):
trainingSet=df_mock,
feature_directory=test_feature_directory,
feature_file_prefix=test_feature_filename,
period_suffix=period_suffix_2,
period_suffix=period_suffix_test,
no_write_metadata=True,
)
print()
Expand All @@ -2798,7 +2807,7 @@ def test(self, doGPU=False):
xgb_model=True,
feature_directory=test_feature_directory,
feature_file_prefix=test_feature_filename,
period_suffix=period_suffix_2,
period_suffix=period_suffix_test,
no_write_metadata=True,
)

Expand Down

0 comments on commit 44dcb09

Please sign in to comment.