Skip to content

Commit

Permalink
BUG: Update minimize dtype for int64 & int8 support (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Jan 16, 2024
1 parent fdb953d commit d39758e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ History

Latest
-------
- BUG: Update minimize dtype for int64 & int8 support (issue #139)

0.4.2
-------
Expand Down
13 changes: 8 additions & 5 deletions geocube/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
import numpy
import odc.geo.geobox
import pandas
import rasterio
import rasterio.features
from numpy.typing import NDArray
from packaging import version
from rasterio.enums import MergeAlg
from scipy.interpolate import Rbf, griddata

_INT8_SUPPORTED = version.parse(rasterio.__gdal_version__) >= version.parse(
"3.7.0"
) and version.parse(rasterio.__version__) >= version.parse("1.3.7")


def _is_numeric(data_values: NDArray) -> bool:
"""
Expand Down Expand Up @@ -44,12 +50,9 @@ def _minimize_dtype(dtype: numpy.dtype, fill: float) -> numpy.dtype:
Attempt to convert to float32 if fill is NaN and dtype is integer.
"""
if numpy.issubdtype(dtype, numpy.integer):
if dtype.name == "int8":
# GDAL/rasterio doesn't support int8
if not _INT8_SUPPORTED and dtype.name == "int8":
# GDAL<3.7/rasterio<1.3.7 doesn't support int8
dtype = numpy.dtype("int16")
if dtype.name == "int64":
# GDAL/rasterio doesn't support int64
dtype = numpy.dtype("float64")
if numpy.isnan(fill):
dtype = (
numpy.dtype("float64") if dtype.itemsize > 2 else numpy.dtype("float32") # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
click>=6.0
geopandas>=0.7
odc_geo
rasterio
rasterio>=1.3
rioxarray>=0.4
scipy
xarray>=0.17
Expand Down
7 changes: 4 additions & 3 deletions test/integration/api/test_core_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from geocube.api.core import make_geocube
from geocube.exceptions import VectorDataError
from geocube.rasterize import (
_INT8_SUPPORTED,
rasterize_image,
rasterize_points_griddata,
rasterize_points_radial,
Expand Down Expand Up @@ -101,7 +102,7 @@ def test_make_geocube__categorical(input_geodata, tmpdir):
categorical_enums={"soil_type": ("sand", "silt", "clay")},
fill=-9999.0,
)
assert out_grid.soil_type.dtype.name == "int16"
assert out_grid.soil_type.dtype.name == "int8" if _INT8_SUPPORTED else "int16"

# test writing to netCDF
out_grid.to_netcdf(
Expand Down Expand Up @@ -451,7 +452,7 @@ def test_make_geocube__group_by__categorical(input_geodata, tmpdir):
fill=-9999.0,
)

assert out_grid.soil_type.dtype.name == "int16"
assert out_grid.soil_type.dtype.name == "int8" if _INT8_SUPPORTED else "int16"
# test writing to netCDF
out_grid.to_netcdf(
tmpdir.mkdir("make_geocube_soil") / "soil_grid_group_categorical.nc"
Expand Down Expand Up @@ -816,7 +817,7 @@ def test_make_geocube__custom_rasterize_function__filter_null(
("uint16", float("NaN"), "float32"),
("int32", 0, "int32"),
("int32", float("NaN"), "float64"),
("int64", 0, "float64"),
("int64", 0, "int64"),
("int64", float("NaN"), "float64"),
],
)
Expand Down

0 comments on commit d39758e

Please sign in to comment.