Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Zhi Lin <[email protected]>
  • Loading branch information
kira-lin committed Apr 9, 2024
1 parent 32b6c9f commit 40d192c
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 13 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/ray_nightly_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [3.7, 3.8, 3.9]
spark-version: [3.1.3, 3.2.3, 3.3.2]
spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0]

runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -115,8 +115,7 @@ jobs:
- name: Test with pytest
run: |
ray start --head --num-cpus 6
pytest python/raydp/tests/ -v -m"not error_on_custom_resource"
pytest python/raydp/tests/ -v -m"error_on_custom_resource"
pytest python/raydp/tests/ -v
ray stop --force
- name: Test Examples
run: |
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/raydp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,13 @@ jobs:
SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])')
if [ "$(uname -s)" == "Linux" ]
then
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
else
pip install torch
fi
pip install pyarrow==6.0.1 ray[default]==2.4.0 pytest koalas tensorflow tabulate grpcio-tools wget
pip install pyarrow==6.0.1 ray[default] pytest koalas tensorflow tabulate grpcio-tools wget
pip install "xgboost_ray[default]<=0.1.13"
pip install torchmetrics
HOROVOD_WITH_GLOO=1
HOROVOD_WITH_PYTORCH=1
pip install horovod[pytorch,ray]
- name: Cache Maven
uses: actions/cache@v2
with:
Expand Down Expand Up @@ -114,7 +111,6 @@ jobs:
ray stop
python examples/pytorch_nyctaxi.py
python examples/tensorflow_nyctaxi.py
python examples/horovod_nyctaxi.py
python examples/xgboost_ray_nyctaxi.py
# python examples/raytrain_nyctaxi.py
python examples/data_process.py
2 changes: 0 additions & 2 deletions examples/raytrain_nyctaxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def train_func(config):
for epoch in range(num_epochs):
train_mse, train_loss = train_epoch(train_dataset, model, criterion, optimizer)
test_mse, test_loss = test_epoch(test_dataset, model, criterion)
train.report(epoch = epoch, train_mse = train_mse, train_loss = train_loss)
train.report(epoch = epoch, test_mse = test_mse, test_loss=test_loss)
loss_results.append(test_loss)

trainer = Trainer(backend="torch", num_workers=num_executors)
Expand Down
5 changes: 3 additions & 2 deletions python/raydp/spark/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
locations = None) -> DataFrame:
locations = get_locations(blocks)
if not isinstance(arrow_schema, pa.lib.Schema):
if hasattr(arrow_schema, "base_schema") and \
not isinstance(arrow_schema.base_schema, pa.lib.Schema):
if hasattr(arrow_schema, "base_schema"):
arrow_schema = arrow_schema.base_schema
if isinstance(arrow_schema, pa.lib.Schema):
raise RuntimeError(f"Schema is {type(arrow_schema)}, required pyarrow.lib.Schema. \n" \
f"to_spark does not support converting non-arrow ray datasets.")
schema = StructType()
Expand Down

0 comments on commit 40d192c

Please sign in to comment.