diff --git a/dwave/optimization/expression/__init__.pxd b/dwave/optimization/expression/__init__.pxd index a3a49812..baeb2740 100644 --- a/dwave/optimization/expression/__init__.pxd +++ b/dwave/optimization/expression/__init__.pxd @@ -1,4 +1,4 @@ -# Copyright 2023 D-Wave Systems Inc. +# Copyright 2024 D-Wave Systems Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dwave/optimization/expression/expression.pxd b/dwave/optimization/expression/expression.pxd index 3b7d5df2..534437a5 100644 --- a/dwave/optimization/expression/expression.pxd +++ b/dwave/optimization/expression/expression.pxd @@ -22,3 +22,5 @@ __all__ = ["Expression"] cdef class Expression(_Model): cdef readonly ArraySymbol output + + cpdef Py_ssize_t num_inputs(self) noexcept diff --git a/dwave/optimization/expression/expression.pyi b/dwave/optimization/expression/expression.pyi index 32b262d4..5ad7a07e 100644 --- a/dwave/optimization/expression/expression.pyi +++ b/dwave/optimization/expression/expression.pyi @@ -21,7 +21,13 @@ _ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]] class Expression(_Model): - def __init__(self): ... + def __init__( + self, + num_inputs: int = 0, + lower_bound: Optional[float] = None, + upper_bound: Optional[float] = None, + integral: Optional[bool] = None, + ): ... def input(self, lower_bound: float, upper_bound: float, integral: bool, shape: Optional[tuple] = None): diff --git a/dwave/optimization/expression/expression.pyx b/dwave/optimization/expression/expression.pyx index 07991335..0ebae6ab 100644 --- a/dwave/optimization/expression/expression.pyx +++ b/dwave/optimization/expression/expression.pyx @@ -12,26 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers from typing import Optional from libcpp cimport bool +from libcpp.cast cimport dynamic_cast from dwave.optimization.libcpp.array cimport Array as cppArray +from dwave.optimization.libcpp.graph cimport Node as cppNode +from dwave.optimization.libcpp.nodes cimport InputNode as cppInputNode +from dwave.optimization.model cimport ArraySymbol, _Model, States from dwave.optimization.symbols cimport symbol_from_ptr +ctypedef cppNode* cppNodePtr + __all__ = ["Expression"] cdef class Expression(_Model): - def __init__(self): - pass + def __init__( + self, + num_inputs: int = 0, + # necessary to prevent Cython from rejecting an int + lower_bound: Optional[numbers.Real] = None, + upper_bound: Optional[numbers.Real] = None, + integral: Optional[bool] = None, + ): + self.states = States(self) + + self._data_sources = [] + + if num_inputs > 0: + if any(arg is None for arg in (lower_bound, upper_bound, integral)): + raise ValueError( + "`lower_bound`, `upper_bound` and `integral` must be provided " + "explicitly when initializing inputs" + ) + for _ in range(num_inputs): + self.input(lower_bound, upper_bound, integral) - def input(self, lower_bound: float, upper_bound: float, bool integral, shape: Optional[tuple] = None): + def input(self, lower_bound: float, upper_bound: float, bool integral): """TODO""" # avoid circular import from dwave.optimization.symbols import Input - return Input(self, lower_bound, upper_bound, integral, shape=shape) + # Shape is always scalar for now + return Input(self, lower_bound, upper_bound, integral, shape=tuple()) def set_output(self, value: ArraySymbol): self.output = value + + cpdef Py_ssize_t num_inputs(self) noexcept: + return self._graph.num_inputs() + + def iter_inputs(self): + inputs = self._graph.inputs() + for i in range(self._graph.num_inputs()): + yield symbol_from_ptr(self, inputs[i]) diff --git a/dwave/optimization/include/dwave-optimization/graph.hpp b/dwave/optimization/include/dwave-optimization/graph.hpp index f9f33ce8..83ae1827 100644 --- a/dwave/optimization/include/dwave-optimization/graph.hpp +++ b/dwave/optimization/include/dwave-optimization/graph.hpp @@ -34,6 +34,7 @@ namespace dwave::optimization { class ArrayNode; class Node; class DecisionNode; +class InputNode; // We don't want this interface to be opinionated about what type of rng we're using. // So we create this class to do type erasure on RNGs. @@ -139,6 +140,9 @@ class Graph { // The number of constraints in the model. ssize_t num_constraints() const noexcept { return constraints_.size(); } + // The number of input nodes in the model. + ssize_t num_inputs() const noexcept { return inputs_.size(); } + // Specify the objective node. Must be an array with a single element. // To unset the objective provide nullptr. void set_objective(ArrayNode* objective_ptr); @@ -159,6 +163,9 @@ class Graph { std::span decisions() noexcept { return decisions_; } std::span decisions() const noexcept { return decisions_; } + std::span inputs() noexcept { return inputs_; } + std::span inputs() const noexcept { return inputs_; } + // Remove unused nodes from the graph. // // This method will reset the topological sort if there is one. @@ -182,6 +189,7 @@ class Graph { ArrayNode* objective_ptr_ = nullptr; std::vector constraints_; std::vector decisions_; + std::vector inputs_; // Track whether the model is currently topologically sorted bool topologically_sorted_ = false; @@ -332,6 +340,8 @@ NodeType* Graph::emplace_node(Args&&... args) { static_assert(std::is_base_of_v); ptr->topological_index_ = decisions_.size(); decisions_.emplace_back(ptr); + } else if constexpr (std::is_base_of_v) { + inputs_.emplace_back(ptr); } return ptr; // return the observing pointer diff --git a/dwave/optimization/libcpp/graph.pxd b/dwave/optimization/libcpp/graph.pxd index 6c9d561f..0e6639a1 100644 --- a/dwave/optimization/libcpp/graph.pxd +++ b/dwave/optimization/libcpp/graph.pxd @@ -23,6 +23,7 @@ from dwave.optimization.libcpp cimport span from dwave.optimization.libcpp.array cimport Array from dwave.optimization.libcpp.state cimport State + cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil: cdef cppclass Node: struct SuccessorView: @@ -38,12 +39,20 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" cdef cppclass DecisionNode(Node): pass + cdef cppclass InputNode(Node, Array): + pass + # Sometimes Cython isn't able to reason about pointers as template inputs, so # we make a few aliases for convenience ctypedef Node* NodePtr ctypedef ArrayNode* ArrayNodePtr ctypedef DecisionNode* DecisionNodePtr +# This seems to be necessary to allow Cython to iterate over the returned +# span from `inputs()` directly. Otherwise it tries to cast it to a non-const +# version of span before iterating, which the C++ compiler will complain about. +ctypedef InputNode* const constInputNodePtr + cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil: cdef cppclass Graph: T* emplace_node[T](...) except+ @@ -51,9 +60,10 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" span[const unique_ptr[Node]] nodes() const span[const ArrayNodePtr] constraints() span[const DecisionNodePtr] decisions() + Py_ssize_t num_constraints() Py_ssize_t num_nodes() Py_ssize_t num_decisions() - Py_ssize_t num_constraints() + Py_ssize_t num_inputs() @staticmethod void recursive_initialize(State&, Node*) except+ @staticmethod @@ -64,3 +74,4 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" void topological_sort() bool topologically_sorted() const Py_ssize_t remove_unused_nodes() + span[constInputNodePtr] inputs() diff --git a/dwave/optimization/model.pyx b/dwave/optimization/model.pyx index c2e1eb69..0f48f2a2 100644 --- a/dwave/optimization/model.pyx +++ b/dwave/optimization/model.pyx @@ -47,11 +47,17 @@ from libcpp.vector cimport vector from dwave.optimization.libcpp.array cimport Array as cppArray from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode from dwave.optimization.symbols cimport symbol_from_ptr +from dwave.optimization.expression cimport Expression __all__ = ["Model"] +ctypedef fused ExpressionOrModel: + Model + Expression + + @contextlib.contextmanager def locked(model): """Context manager that hold a locked model and unlocks it when the context is exited.""" @@ -1035,6 +1041,10 @@ cdef class Model(_Model): return SetVariable(self, n, min_size, n if max_size is None else max_size) +def _States_init(States self, ExpressionOrModel model): + self._model_ref = weakref.ref(model) + + cdef class States: r"""States of a symbol in a model. @@ -1087,8 +1097,12 @@ cdef class States: >>> model.states.size() 0 """ - def __init__(self, Model model): - self._model_ref = weakref.ref(model) + + # Cython doesn't seem to properly handle fused type arguments on __init__, + # so we have to use this awkward workaround + # See https://github.com/cython/cython/issues/3758 + def __init__(self, model): + _States_init(self, model) def __len__(self): """The number of model states.""" diff --git a/dwave/optimization/symbols.pyx b/dwave/optimization/symbols.pyx index 08b53b5c..c07a95ec 100644 --- a/dwave/optimization/symbols.pyx +++ b/dwave/optimization/symbols.pyx @@ -102,6 +102,11 @@ from dwave.optimization.libcpp.nodes cimport ( from dwave.optimization.model cimport ArraySymbol, _Model, Model, Symbol from dwave.optimization.expression cimport Expression +ctypedef fused ExpressionOrModel: + Model + Expression + + __all__ = [ "Absolute", "Add", @@ -179,6 +184,7 @@ cdef void _register(object cls, const type_info& typeinfo): _cpp_type_to_python[type_index(typeinfo)] = (cls) +# TODO: should this use ExpressionOrModel? cdef object symbol_from_ptr(_Model model, cppNode* node_ptr): """Create a Python/Cython symbol from a C++ Node*.""" @@ -2271,39 +2277,36 @@ cdef class NaryReduce(ArraySymbol): def __init__( self, - input_symbols: Collection[Input], - ArraySymbol output_symbol, + # input_symbols: Collection[Input], + # ArraySymbol output_symbol, + expression: Expression, operands: Collection[ArraySymbol], initial_values: Optional[tuple[float]] = None, ): if len(operands) == 0: raise ValueError("must have at least one operand") - if len(input_symbols) != len(operands) + 1: + if expression.num_inputs() != len(operands) + 1: raise ValueError("must have exactly one more input than number of operands") if initial_values is None: - initial_values = (0,) * len(input_symbols) + initial_values = (0,) * expression.num_inputs() - if len(initial_values) != len(input_symbols): + if len(initial_values) != expression.num_inputs(): raise ValueError("must have same number of initial values as inputs") - cdef _Model expression = input_symbols[0].model - cdef _Model model = operands[0].model - cdef cppArrayNode* output = output_symbol.array_ptr + cdef Model model = operands[0].model + cdef cppArrayNode* output = expression.output.array_ptr cdef vector[double] cppinitial_values + cdef cppInputNode* cppinput cdef vector[cppInputNode*] cppinputs cdef vector[cppArrayNode*] cppoperands for val in initial_values: cppinitial_values.push_back(val) - cdef Input inp - for node in input_symbols: - if node.model != expression: - raise ValueError("all inputs must belong to the expression model") - inp = node - cppinputs.push_back(inp.ptr) + for cppinput in expression._graph.inputs(): + cppinputs.push_back(cppinput) cdef ArraySymbol array for node in operands: @@ -2318,11 +2321,12 @@ cdef class NaryReduce(ArraySymbol): move(expression._graph), cppinputs, output, cppinitial_values, cppoperands ) except ValueError as e: + expression.unlock() raise self._handle_unsupported_expression_exception(expression, e) self.initialize_arraynode(model, self.ptr) - def _handle_unsupported_expression_exception(self, _Model expression, exception: ValueError): + def _handle_unsupported_expression_exception(self, Expression expression, exception): try: info = json.loads(str(exception)) except json.JSONDecodeError: diff --git a/tests/test_expression.py b/tests/test_expression.py new file mode 100644 index 00000000..8cb219ff --- /dev/null +++ b/tests/test_expression.py @@ -0,0 +1,89 @@ +# Copyright 2024 D-Wave Systems Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import dwave.optimization.symbols +from dwave.optimization.expression import Expression + + +class TestExpression(unittest.TestCase): + def test(self): + Expression() + + def test_initial_inputs(self): + exp = Expression(num_inputs=10, lower_bound=-5, upper_bound=3.7, integral=False) + self.assertEqual(exp.num_inputs(), 10) + + # Test that all arguments must be provided if starting with initial inputs + with self.assertRaises(ValueError): + Expression(num_inputs=10, upper_bound=3.7, integral=False) + with self.assertRaises(ValueError): + Expression(num_inputs=10, lower_bound=-5, integral=False) + with self.assertRaises(ValueError): + Expression(num_inputs=10, lower_bound=-5, upper_bound=3.7) + + def test_unsupported_symbols(self): + # Can't add decisions to an Expression, even manually + exp = Expression() + with self.assertRaises(TypeError): + dwave.optimization.symbols.IntegerVariable(exp) + + # Can't add other symbols, e.g. Reshape + exp = Expression() + inp = exp.input(0, 1, False) + with self.assertRaises(TypeError): + dwave.optimization.symbols.Reshape(inp, (1, 1, 1)) + + def test_num_inputs(self): + exp = Expression() + self.assertEqual(exp.num_inputs(), 0) + + inp0 = exp.input(-1, 1, True) + self.assertEqual(exp.num_inputs(), 1) + + inp1 = exp.input(-1, 1, True) + self.assertEqual(exp.num_inputs(), 2) + + inp0 + inp1 + self.assertEqual(exp.num_inputs(), 2) + self.assertEqual(exp.num_nodes(), 3) + + exp.input(-1, 1, True) + self.assertEqual(exp.num_inputs(), 3) + self.assertEqual(exp.num_nodes(), 4) + + def test_iter_inputs(self): + exp = Expression() + self.assertListEqual(list(exp.iter_inputs()), []) + + inp0 = exp.input(-1, 1, True) + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 1) + self.assertTrue(inp0.equals(symbols[0])) + + inp1 = exp.input(-1, 1, True) + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 2) + self.assertTrue(inp0.equals(symbols[0])) + self.assertTrue(inp1.equals(symbols[1])) + + inp0 + inp1 + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 2) + + inp2 = exp.input(-1, 1, True) + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 3) + self.assertTrue(inp2.equals(symbols[2])) diff --git a/tests/test_symbols.py b/tests/test_symbols.py index 8b92154f..5d07369d 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -1771,9 +1771,9 @@ def generate_symbols(self): exp = Expression() inputs = [exp.input(-10, 10, False) for _ in range(3)] - sum_ = inputs[0] + inputs[1] + inputs[2] + exp.set_output(inputs[0] + inputs[1] + inputs[2]) - acc = dwave.optimization.symbols.NaryReduce(inputs, sum_, (c0, c1)) + acc = dwave.optimization.symbols.NaryReduce(exp, (c0, c1)) model.lock() yield acc @@ -1785,51 +1785,25 @@ def test_mismatched_inputs(self): exp = Expression() inputs = [exp.input(-10, 10, False) for _ in range(3)] - sum_ = inputs[0] + inputs[1] + inputs[2] + exp.set_output(inputs[0] + inputs[1] + inputs[2]) with self.assertRaises(ValueError): - dwave.optimization.symbols.NaryReduce(inputs, sum_, (c0,)) + dwave.optimization.symbols.NaryReduce(exp, (c0,)) with self.assertRaises(ValueError): - dwave.optimization.symbols.NaryReduce(inputs[:1], sum_, (c0, c1)) - - with self.assertRaises(ValueError): - dwave.optimization.symbols.NaryReduce(inputs, sum_, (c0, c1), initial_values=(0,)) + dwave.optimization.symbols.NaryReduce(exp, (c0, c1), initial_values=(0,)) def test_invalid_expressions(self): model = Model() c0 = model.constant([0, 0]) - # exp = Expression() - # inputs = [exp.input(-10, 10, False) for _ in range(2)] - # i = exp.integer() - # sum_ = inputs[0] + inputs[1] - # - # try: - # dwave.optimization.symbols.NaryReduce(inputs, sum_, (c0,)) - # self.assertTrue(False, "should raise exception") - # except Exception as e: - # self.assertIsInstance(e, dwave.optimization.symbols.UnsupportedNaryReduceExpression) - # self.assertRegex(str(e), "decision") - # self.assertTrue(i.equals(e.symbol)) - # - # exp = Expression() - # inputs = [exp.input(-10, 10, False) for _ in range(2)] - # reshape = inputs[0].reshape((1, 1, 1)) - # - # try: - # dwave.optimization.symbols.NaryReduce(inputs, reshape, (c0,)) - # self.assertTrue(False, "should raise exception") - # except Exception as e: - # self.assertIsInstance(e, dwave.optimization.symbols.UnsupportedNaryReduceExpression) - # self.assertRegex(str(e), "unsupported node") - # self.assertTrue(reshape.equals(e.symbol)) - + # Can't use an Expression that uses a non-scalar input exp = Expression() inp1 = exp.input(-10, 10, False) - inp5 = exp.input(-10, 10, False, (5,)) + inp5 = dwave.optimization.symbols.Input(exp, -10, 10, False, (5,)) + exp.set_output(inp1) try: - dwave.optimization.symbols.NaryReduce((inp1, inp5), inp1, (c0,)) + dwave.optimization.symbols.NaryReduce(exp, (c0,)) self.assertTrue(False, "should raise exception") except Exception as e: self.assertIsInstance(e, dwave.optimization.symbols.UnsupportedNaryReduceExpression)