Skip to content

Commit

Permalink
BUG: Support pandas IntegerArray (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Jan 16, 2024
1 parent d39758e commit 341b730
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
-------
- BUG: Update minimize dtype for int64 & int8 support (issue #139)
- BUG: Support pandas IntegerArray (pull #158)

0.4.2
-------
Expand Down
12 changes: 9 additions & 3 deletions geocube/rasterize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module contains tools for rasterizing vector data.
"""
from typing import Optional
from typing import Optional, Union

import geopandas
import numpy
Expand Down Expand Up @@ -65,7 +65,7 @@ def _minimize_dtype(dtype: numpy.dtype, fill: float) -> numpy.dtype:

def rasterize_image(
geometry_array: geopandas.GeoSeries,
data_values: NDArray,
data_values: Union[NDArray, pandas.arrays.IntegerArray],
geobox: odc.geo.geobox.GeoBox,
fill: float,
merge_alg: MergeAlg = MergeAlg.replace,
Expand All @@ -80,7 +80,7 @@ def rasterize_image(
-----------
geometry_array: geopandas.GeoSeries
A geometry array of points.
data_values: list
data_values: Union[NDArray, pandas.arrays.IntegerArray]
Data values associated with the list of geojson shapes
geobox: :obj:`odc.geo.geobox.GeoBox`
Transform of the resulting image.
Expand Down Expand Up @@ -110,6 +110,12 @@ def rasterize_image(
# only numbers can be rasterized
return None

if isinstance(data_values, pandas.arrays.IntegerArray):
data_values = data_values.to_numpy(
dtype=_minimize_dtype(data_values.dtype.numpy_dtype, fill),
na_value=fill,
)

if filter_nan:
data_values, geometry_array = _remove_missing_data(data_values, geometry_array)

Expand Down
27 changes: 27 additions & 0 deletions test/integration/api/test_core_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial

import geopandas
import numpy
import pandas
import pytest
import xarray
Expand Down Expand Up @@ -857,3 +858,29 @@ def test_rasterize__like_1d():
)
assert geom_array.rio.transform() == like.rio.transform()
assert geom_array.in_geom.shape == (2, 1)


@pytest.mark.parametrize(
"dtype, expected_dtype",
[
("Int32", "int32"),
("Int64", "int64"),
],
)
def test_make_geocube__pandas_integer_array(dtype, expected_dtype, tmpdir):
soil_data = geopandas.read_file(TEST_INPUT_DATA_DIR / "soil_data_flat.geojson")[
["geometry", "sandtotal_r", "om_r"]
]
soil_data["sandtotal_r"] = numpy.round(soil_data["sandtotal_r"] * 100).astype(dtype)
soil_data["sandtotal_r"].values[0] = pandas.NA

out_grid = make_geocube(
vector_data=soil_data,
output_crs=TEST_GARS_PROJ,
geom=json.dumps(mapping(TEST_GARS_POLY)),
resolution=[-10, 10],
fill=-1,
)
# test writing to netCDF
out_grid.to_netcdf(tmpdir.mkdir("make_geocube_soil") / "soil_grid_flat.nc")
assert out_grid.sandtotal_r.dtype.name == expected_dtype

0 comments on commit 341b730

Please sign in to comment.