Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Zhi Lin <[email protected]>
  • Loading branch information
kira-lin committed Apr 10, 2024
1 parent 5363036 commit 6602572
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
7 changes: 4 additions & 3 deletions python/raydp/spark/ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ def _prepare_spark_configs(self):

raydp_agent_path = os.path.abspath(os.path.join(os.path.abspath(__file__),
"../../jars/raydp-agent*.jar"))
print(raydp_agent_path)
raydp_agent_jar = glob.glob(raydp_agent_path)[0]
self._configs[SPARK_JAVAAGENT] = raydp_agent_jar
print(os.listdir(raydp_cp))
raydp_agent_jars = glob.glob(raydp_agent_path)
if raydp_agent_jars:
self._configs[SPARK_JAVAAGENT] = raydp_agent_jars[0]
# for JVM running in ray
self._configs[SPARK_RAY_LOG4J_FACTORY_CLASS_KEY] = versions.RAY_LOG4J_VERSION

Expand Down
4 changes: 2 additions & 2 deletions python/raydp/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self,
metrics: Union[List[keras.metrics.Metric], List[str]] = None,
feature_columns: Union[str, List[str]] = None,
label_columns: Union[str, List[str]] = None,
merge_feature_columns: bool = False,
merge_feature_columns: bool = True,
batch_size: int = 128,
drop_last: bool = False,
num_epochs: int = 1,
Expand Down Expand Up @@ -268,4 +268,4 @@ def fit_on_spark(self,

def get_model(self) -> Any:
assert self._trainer, "Trainer has not been created"
return TensorflowCheckpoint(self._results.checkpoint).get_model()
return TensorflowCheckpoint(self._results.checkpoint.to_directory()).get_model()
2 changes: 1 addition & 1 deletion python/raydp/torch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,4 @@ def fit_on_spark(self,

def get_model(self):
assert self._trainer is not None, "Must call fit first"
return TorchCheckpoint(self._trained_results.checkpoint.as_directory()).get_model()
return TorchCheckpoint(self._trained_results.checkpoint.to_directory()).get_model()
2 changes: 1 addition & 1 deletion python/raydp/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ def fit_on_spark(self,
train_ds, evaluate_ds, max_retries)

def get_model(self):
return XGBoostCheckpoint.from_checkpoint(self._results.checkpoint).get_model()
return XGBoostCheckpoint(self._results.checkpoint.to_directory()).get_model()

0 comments on commit 6602572

Please sign in to comment.