Skip to content

Commit

Permalink
BUG: Support pandas IntegerArray
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Jan 16, 2024
1 parent fdb953d commit 5419ac5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
12 changes: 9 additions & 3 deletions geocube/rasterize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module contains tools for rasterizing vector data.
"""
from typing import Optional
from typing import Optional, Union

import geopandas
import numpy
Expand Down Expand Up @@ -62,7 +62,7 @@ def _minimize_dtype(dtype: numpy.dtype, fill: float) -> numpy.dtype:

def rasterize_image(
geometry_array: geopandas.GeoSeries,
data_values: NDArray,
data_values: Union[NDArray, pandas.arrays.IntegerArray],
geobox: odc.geo.geobox.GeoBox,
fill: float,
merge_alg: MergeAlg = MergeAlg.replace,
Expand All @@ -77,7 +77,7 @@ def rasterize_image(
-----------
geometry_array: geopandas.GeoSeries
A geometry array of points.
data_values: list
data_values: Union[NDArray, pandas.arrays.IntegerArray]
Data values associated with the list of geojson shapes
geobox: :obj:`odc.geo.geobox.GeoBox`
Transform of the resulting image.
Expand Down Expand Up @@ -107,6 +107,12 @@ def rasterize_image(
# only numbers can be rasterized
return None

if isinstance(data_values, pandas.arrays.IntegerArray):
data_values = data_values.to_numpy(
dtype=_minimize_dtype(data_values.dtype.numpy_dtype, fill),
na_value=fill,
)

if filter_nan:
data_values, geometry_array = _remove_missing_data(data_values, geometry_array)

Expand Down
27 changes: 27 additions & 0 deletions test/integration/api/test_core_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial

import geopandas
import numpy
import pandas
import pytest
import xarray
Expand Down Expand Up @@ -856,3 +857,29 @@ def test_rasterize__like_1d():
)
assert geom_array.rio.transform() == like.rio.transform()
assert geom_array.in_geom.shape == (2, 1)


@pytest.mark.parametrize(
"dtype, expected_dtype",
[
("Int32", "int32"),
("Int64", "float64"),
],
)
def test_make_geocube__pandas_integer_array(dtype, expected_dtype, tmpdir):
soil_data = geopandas.read_file(TEST_INPUT_DATA_DIR / "soil_data_flat.geojson")[
["geometry", "sandtotal_r", "om_r"]
]
soil_data["sandtotal_r"] = numpy.round(soil_data["sandtotal_r"] * 100).astype(dtype)
soil_data["sandtotal_r"].values[0] = pandas.NA

out_grid = make_geocube(
vector_data=soil_data,
output_crs=TEST_GARS_PROJ,
geom=json.dumps(mapping(TEST_GARS_POLY)),
resolution=[-10, 10],
fill=-1,
)
# test writing to netCDF
out_grid.to_netcdf(tmpdir.mkdir("make_geocube_soil") / "soil_grid_flat.nc")
assert out_grid.sandtotal_r.dtype.name == expected_dtype

0 comments on commit 5419ac5

Please sign in to comment.