Skip to content

Commit

Permalink
Merge pull request #3376 from jsiirola/expr-cnwld-fix
Browse files Browse the repository at this point in the history
Resolve bugs in create_node_with_local_data
  • Loading branch information
jsiirola authored Oct 9, 2024
2 parents ef61460 + 9c287c1 commit 23fb726
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pyomo/core/base/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ def __call__(self, exception=True):
return arg
return arg(exception=exception)

def create_node_with_local_data(self, values):
def create_node_with_local_data(self, values, classtype=None):
"""
Construct a simple expression after constructing the
contained expression.
This class provides a consistent interface for constructing a
node, which is used in tree visitor scripts.
"""
obj = self.__class__()
if classtype is None:
classtype = self.parent_component()._ComponentDataClass
obj = classtype()
obj._args_ = values
return obj

Expand Down
2 changes: 1 addition & 1 deletion pyomo/core/expr/numeric_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def create_node_with_local_data(self, args, classtype=None):
# types, the simplest / fastest thing to do is just defer to
# the operator dispatcher.
return operator.mul(*args)
return self.__class__(args)
return classtype(args)


class DivisionExpression(NumericExpression):
Expand Down
31 changes: 31 additions & 0 deletions pyomo/core/tests/unit/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
sum_product,
)
from pyomo.core.base.expression import ExpressionData
from pyomo.core.base.objective import ObjectiveData
from pyomo.core.expr.compare import compare_expressions, assertExpressionsEqual
from pyomo.common.tee import capture_output

Expand Down Expand Up @@ -290,6 +291,36 @@ def obj_rule(model):
self.assertEqual(inst.obj.expr(), 3.0)
self.assertEqual(id(inst.obj.expr.arg(1)), id(inst.ec))

def test_create_node_with_local_data(self):
m = ConcreteModel()
m.x = Var()

m.e = Expression(expr=m.x)
ee = m.e.create_node_with_local_data([5])
self.assertIsNot(m.e, ee)
self.assertIs(type(ee), ExpressionData)
self.assertEqual(ee._args_, [5])

m.f = Expression([0], rule=lambda m, i: m.x)
ff = m.f[0].create_node_with_local_data([5])
self.assertIsNot(m.f, ff)
self.assertIsNot(m.f[0], ff)
self.assertIs(type(ff), ExpressionData)
self.assertEqual(ff._args_, [5])

m.g = Objective(expr=m.x)
gg = m.g.create_node_with_local_data([5])
self.assertIsNot(m.g, gg)
self.assertIs(type(gg), ObjectiveData)
self.assertEqual(gg._args_, [5])

m.h = Objective([0], rule=lambda m, i: m.x)
hh = m.h[0].create_node_with_local_data([5])
self.assertIsNot(m.h, hh)
self.assertIsNot(m.h[0], hh)
self.assertIs(type(hh), ObjectiveData)
self.assertEqual(hh._args_, [5])


class TestExpression(unittest.TestCase):
def setUp(self):
Expand Down
12 changes: 12 additions & 0 deletions pyomo/core/tests/unit/test_numeric_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4313,6 +4313,18 @@ def test_sin(self):
total = counter.count - start
self.assertEqual(total, 1)

def test_create_node_with_local_data(self):
e = self.m.p * self.m.a
self.assertIs(type(e), MonomialTermExpression)

f = e.create_node_with_local_data([self.m.b, self.m.p])
self.assertIs(type(f), MonomialTermExpression)
self.assertStructuredAlmostEqual(f._args_, [self.m.p, self.m.b])

g = e.create_node_with_local_data([self.m.b, self.m.p], ProductExpression)
self.assertIs(type(g), ProductExpression)
self.assertStructuredAlmostEqual(g._args_, [self.m.b, self.m.p])


#
# Fixed - Expr has a fixed value
Expand Down

0 comments on commit 23fb726

Please sign in to comment.