Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace .at() interpolation in 2D callbacks #373

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions thetis/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,20 @@ def __init__(self, solver_obj,
self.field_names = field_names
self._name = name

# create VertexOnlyMesh and function spaces for interpolation
self.mesh = solver_obj.fields[field_names[0]].function_space().mesh()
cpjordan marked this conversation as resolved.
Show resolved Hide resolved
self.vom = VertexOnlyMesh(self.mesh, detector_locations, redundant=True)
self.function_spaces = {}
for field_name in field_names:
field = solver_obj.fields[field_name]
if isinstance(field.function_space().ufl_element(), VectorElement):
P0DG = VectorFunctionSpace(self.vom, "DG", 0)
P0DG_input_ordering = VectorFunctionSpace(self.vom.input_ordering, "DG", 0)
else:
P0DG = FunctionSpace(self.vom, "DG", 0)
P0DG_input_ordering = FunctionSpace(self.vom.input_ordering, "DG", 0)
self.function_spaces[field_name] = (P0DG, P0DG_input_ordering)
tkarna marked this conversation as resolved.
Show resolved Hide resolved

@property
def name(self):
return self._name
Expand All @@ -539,7 +553,8 @@ def variable_names(self):

def _values_per_field(self, values):
"""
Given all values evaulated in a detector location, return the values per field"""
Given all values evaulated in a detector location, return the values per field
"""
i = 0
result = []
for dim in self.field_dims:
Expand All @@ -554,7 +569,13 @@ def message_str(self, *args):
for name, values in zip(self.detector_names, args))

def _evaluate_field(self, field_name):
return self.solver_obj.fields[field_name](self.detector_locations)
field = self.solver_obj.fields[field_name]
P0DG, P0DG_input_ordering = self.function_spaces[field_name]
f_at_points = Function(P0DG)
f_at_input_points = Function(P0DG_input_ordering)
cpjordan marked this conversation as resolved.
Show resolved Hide resolved
f_at_points.interpolate(field)
f_at_input_points.interpolate(f_at_points)
return f_at_input_points.dat.data_ro

def __call__(self):
"""
Expand Down Expand Up @@ -673,6 +694,7 @@ def __init__(self, solver_obj, fieldnames, x, y,
self.tolerance = tolerance
self.eval_func = eval_func
self._initialized = False
self.bathymetry_val = None
cpjordan marked this conversation as resolved.
Show resolved Hide resolved

@PETSc.Log.EventDecorator("thetis.TimeSeriesCallback2D._initialize")
def _initialize(self):
Expand All @@ -684,10 +706,31 @@ def _initialize(self):
xyz = (self.x, self.y, self.z) if self.on_sphere else (self.x, self.y)
self.xyz = numpy.array([xyz])

# create VertexOnlyMesh and function spaces
self.mesh = self.solver_obj.fields[self.fieldnames[0]].function_space().mesh()
self.vom = VertexOnlyMesh(self.mesh, self.xyz, redundant=True)
self.function_spaces = {}
for field_name in self.fieldnames:
field = self.solver_obj.fields[field_name]
if isinstance(field.function_space().ufl_element(), VectorElement):
P0DG = VectorFunctionSpace(self.vom, "DG", 0)
P0DG_input_ordering = VectorFunctionSpace(self.vom.input_ordering, "DG", 0)
else:
P0DG = FunctionSpace(self.vom, "DG", 0)
P0DG_input_ordering = FunctionSpace(self.vom.input_ordering, "DG", 0)
self.function_spaces[field_name] = (P0DG, P0DG_input_ordering)

# test evaluation
try:
if self.eval_func is None:
self.solver_obj.fields.bathymetry_2d.at(self.xyz, tolerance=self.tolerance)
bathymetry_field = self.solver_obj.fields.bathymetry_2d
P0DG = FunctionSpace(self.vom, "DG", 0)
P0DG_input_ordering = FunctionSpace(self.vom.input_ordering, "DG", 0)
f_at_points = Function(P0DG)
f_at_input_points = Function(P0DG_input_ordering)
f_at_points.interpolate(bathymetry_field)
f_at_input_points.interpolate(f_at_points)
cpjordan marked this conversation as resolved.
Show resolved Hide resolved
self.bathymetry_val = f_at_input_points.dat.data_ro[:]
cpjordan marked this conversation as resolved.
Show resolved Hide resolved
else:
self.eval_func(self.solver_obj.fields.bathymetry_2d, self.xyz, tolerance=self.tolerance)
except PointNotInDomainError as e:
Expand All @@ -707,7 +750,12 @@ def __call__(self):
try:
field = self.solver_obj.fields[fieldname]
if self.eval_func is None:
val = field.at(self.xyz, tolerance=self.tolerance)
P0DG, P0DG_input_ordering = self.function_spaces[fieldname]
f_at_points = Function(P0DG)
f_at_input_points = Function(P0DG_input_ordering)
f_at_points.interpolate(field)
f_at_input_points.interpolate(f_at_points)
tkarna marked this conversation as resolved.
Show resolved Hide resolved
val = f_at_input_points.dat.data_ro[:]
else:
val = self.eval_func(field, self.xyz, tolerance=self.tolerance)
arr = numpy.array(val)
Expand Down
Loading