Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
harborn committed Nov 7, 2023
1 parent e0f3352 commit 086568b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
8 changes: 4 additions & 4 deletions python/raydp/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from tensorflow import DType, TensorShape
from tensorflow.keras.callbacks import Callback

from ray import train
from ray.train.tensorflow import TensorflowTrainer, TensorflowCheckpoint, prepare_dataset_shard
from ray.air import session
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
from ray.data import read_parquet
from ray.data.dataset import Dataset
Expand Down Expand Up @@ -161,15 +161,15 @@ def train_func(config):
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = TFEstimator.build_and_compile_model(config)

train_dataset = train.get_dataset_shard("train")
train_dataset = session.get_dataset_shard("train")
train_tf_dataset = train_dataset.to_tf(
feature_columns=config["feature_columns"],
label_columns=config["label_columns"],
batch_size=config["batch_size"],
drop_last=config["drop_last"]
)
if config["evaluate"]:
eval_dataset = train.get_dataset_shard("evaluate")
eval_dataset = session.get_dataset_shard("evaluate")
eval_tf_dataset = eval_dataset.to_tf(
feature_columns=config["feature_columns"],
label_columns=config["label_columns"],
Expand All @@ -184,7 +184,7 @@ def train_func(config):
if config["evaluate"]:
test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
results.append(test_history)
train.report({}, checkpoint=TensorflowCheckpoint.from_model(multi_worker_model))
session.report({}, checkpoint=TensorflowCheckpoint.from_model(multi_worker_model))

def fit(self,
train_ds: Dataset,
Expand Down
11 changes: 6 additions & 5 deletions python/raydp/torch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ray import train
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
from ray.air import session
from ray.data.dataset import Dataset
from ray.tune.search.sample import Domain

Expand Down Expand Up @@ -220,15 +221,15 @@ def train_func(config):
metrics = config["metrics"]

# create dataset
train_data_shard = train.get_dataset_shard("train")
train_data_shard = session.get_dataset_shard("train")
train_dataset = train_data_shard.to_torch(feature_columns=config["feature_columns"],
feature_column_dtypes=config["feature_types"],
label_column=config["label_column"],
label_column_dtype=config["label_type"],
batch_size=config["batch_size"],
drop_last=config["drop_last"])
if config["evaluate"]:
evaluate_data_shard = train.get_dataset_shard("evaluate")
evaluate_data_shard = session.get_dataset_shard("evaluate")
evaluate_dataset = evaluate_data_shard.to_torch(
feature_columns=config["feature_columns"],
label_column=config["label_column"],
Expand All @@ -242,18 +243,18 @@ def train_func(config):
for epoch in range(config["num_epochs"]):
train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss,
optimizer, metrics, lr_scheduler)
train.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss))
session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss))
if config["evaluate"]:
eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset,
model, loss, metrics)
train.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss))
session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss))
loss_results.append(evaluate_loss)
if hasattr(model, "module"):
states = model.module.state_dict()
else:
# if num_workers = 1, model is not wrapped
states = model.state_dict()
train.report({}, checkpoint=TorchCheckpoint.from_state_dict(states))
session.report({}, checkpoint=TorchCheckpoint.from_state_dict(states))

@staticmethod
def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None):
Expand Down

0 comments on commit 086568b

Please sign in to comment.