Skip to content

Commit

Permalink
[field] Support sampling wrapped functions
Browse files Browse the repository at this point in the history
This includes gradients or jit-compiled functions.
  • Loading branch information
holl- committed Apr 20, 2023
1 parent ccfc2cd commit 1f0a2ee
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions phi/field/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,9 @@ def resolution_from_staggered_tensor(values: Tensor, extrapolation: Extrapolatio


def _sample_function(f, elements: Geometry):
import inspect
from phi.math._functional import get_function_parameters
try:
signature = inspect.signature(f)
params = dict(signature.parameters)
params = get_function_parameters(f)
dims = elements.shape.get_size('vector')
names_match = tuple(params.keys())[:dims] == elements.shape.get_item_names('vector')
num_positional = 0
Expand All @@ -531,10 +530,10 @@ def _sample_function(f, elements: Geometry):
num_positional += 1
if p.kind == 2: # _ParameterKind.VAR_POSITIONAL
has_varargs = True
assert num_positional <= dims, f"Cannot sample {f.__name__}{signature} on physical space {elements.shape.get_item_names('vector')}"
assert num_positional <= dims, f"Cannot sample {f.__name__}({', '.join(tuple(params))}) on physical space {elements.shape.get_item_names('vector')}"
pass_varargs = has_varargs or names_match or num_positional > 1 or num_positional == dims
if num_positional > 1 and not has_varargs:
assert names_match, f"Positional arguments of {f.__name__}{signature} should match physical space {elements.shape.get_item_names('vector')}"
assert names_match, f"Positional arguments of {f.__name__}({', '.join(tuple(params))}) should match physical space {elements.shape.get_item_names('vector')}"
except ValueError as err: # signature not available for all functions
pass_varargs = False
if pass_varargs:
Expand Down

0 comments on commit 1f0a2ee

Please sign in to comment.