Skip to content

Commit

Permalink
Track inputs explicitly in Graph
Browse files Browse the repository at this point in the history
Allows NaryReduce to just take an Expression and operands
  • Loading branch information
wbernoudy committed Nov 20, 2024
1 parent fed7b7e commit 57aae02
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 59 deletions.
2 changes: 1 addition & 1 deletion dwave/optimization/expression/__init__.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 D-Wave Systems Inc.
# Copyright 2024 D-Wave Systems Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 2 additions & 0 deletions dwave/optimization/expression/expression.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ __all__ = ["Expression"]

cdef class Expression(_Model):
cdef readonly ArraySymbol output

cpdef Py_ssize_t num_inputs(self) noexcept
8 changes: 7 additions & 1 deletion dwave/optimization/expression/expression.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ _ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]]


class Expression(_Model):
def __init__(self): ...
def __init__(
self,
num_inputs: int = 0,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
integral: Optional[bool] = None,
): ...

def input(self, lower_bound: float, upper_bound: float, integral: bool, shape: Optional[tuple] = None):

Expand Down
42 changes: 38 additions & 4 deletions dwave/optimization/expression/expression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,60 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers
from typing import Optional

from libcpp cimport bool
from libcpp.cast cimport dynamic_cast

from dwave.optimization.libcpp.array cimport Array as cppArray
from dwave.optimization.libcpp.graph cimport Node as cppNode
from dwave.optimization.libcpp.nodes cimport InputNode as cppInputNode
from dwave.optimization.model cimport ArraySymbol, _Model, States
from dwave.optimization.symbols cimport symbol_from_ptr

ctypedef cppNode* cppNodePtr


__all__ = ["Expression"]


cdef class Expression(_Model):
def __init__(self):
pass
def __init__(
self,
num_inputs: int = 0,
# necessary to prevent Cython from rejecting an int
lower_bound: Optional[numbers.Real] = None,
upper_bound: Optional[numbers.Real] = None,
integral: Optional[bool] = None,
):
self.states = States(self)

self._data_sources = []

if num_inputs > 0:
if any(arg is None for arg in (lower_bound, upper_bound, integral)):
raise ValueError(
"`lower_bound`, `upper_bound` and `integral` must be provided "
"explicitly when initializing inputs"
)
for _ in range(num_inputs):
self.input(lower_bound, upper_bound, integral)

def input(self, lower_bound: float, upper_bound: float, bool integral, shape: Optional[tuple] = None):
def input(self, lower_bound: float, upper_bound: float, bool integral):
"""TODO"""
# avoid circular import
from dwave.optimization.symbols import Input
return Input(self, lower_bound, upper_bound, integral, shape=shape)
# Shape is always scalar for now
return Input(self, lower_bound, upper_bound, integral, shape=tuple())

def set_output(self, value: ArraySymbol):
self.output = value

cpdef Py_ssize_t num_inputs(self) noexcept:
return self._graph.num_inputs()

def iter_inputs(self):
inputs = self._graph.inputs()
for i in range(self._graph.num_inputs()):
yield symbol_from_ptr(self, inputs[i])
10 changes: 10 additions & 0 deletions dwave/optimization/include/dwave-optimization/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace dwave::optimization {
class ArrayNode;
class Node;
class DecisionNode;
class InputNode;

// We don't want this interface to be opinionated about what type of rng we're using.
// So we create this class to do type erasure on RNGs.
Expand Down Expand Up @@ -139,6 +140,9 @@ class Graph {
// The number of constraints in the model.
ssize_t num_constraints() const noexcept { return constraints_.size(); }

// The number of input nodes in the model.
ssize_t num_inputs() const noexcept { return inputs_.size(); }

// Specify the objective node. Must be an array with a single element.
// To unset the objective provide nullptr.
void set_objective(ArrayNode* objective_ptr);
Expand All @@ -159,6 +163,9 @@ class Graph {
std::span<DecisionNode* const> decisions() noexcept { return decisions_; }
std::span<const DecisionNode* const> decisions() const noexcept { return decisions_; }

std::span<InputNode* const> inputs() noexcept { return inputs_; }
std::span<const InputNode* const> inputs() const noexcept { return inputs_; }

// Remove unused nodes from the graph.
//
// This method will reset the topological sort if there is one.
Expand All @@ -182,6 +189,7 @@ class Graph {
ArrayNode* objective_ptr_ = nullptr;
std::vector<ArrayNode*> constraints_;
std::vector<DecisionNode*> decisions_;
std::vector<InputNode*> inputs_;

// Track whether the model is currently topologically sorted
bool topologically_sorted_ = false;
Expand Down Expand Up @@ -332,6 +340,8 @@ NodeType* Graph::emplace_node(Args&&... args) {
static_assert(std::is_base_of_v<DecisionNode, NodeType>);
ptr->topological_index_ = decisions_.size();
decisions_.emplace_back(ptr);
} else if constexpr (std::is_base_of_v<InputNode, NodeType>) {
inputs_.emplace_back(ptr);
}

return ptr; // return the observing pointer
Expand Down
13 changes: 12 additions & 1 deletion dwave/optimization/libcpp/graph.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from dwave.optimization.libcpp cimport span
from dwave.optimization.libcpp.array cimport Array
from dwave.optimization.libcpp.state cimport State


cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil:
cdef cppclass Node:
struct SuccessorView:
Expand All @@ -38,22 +39,31 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization"
cdef cppclass DecisionNode(Node):
pass

cdef cppclass InputNode(Node, Array):
pass

# Sometimes Cython isn't able to reason about pointers as template inputs, so
# we make a few aliases for convenience
ctypedef Node* NodePtr
ctypedef ArrayNode* ArrayNodePtr
ctypedef DecisionNode* DecisionNodePtr

# This seems to be necessary to allow Cython to iterate over the returned
# span from `inputs()` directly. Otherwise it tries to cast it to a non-const
# version of span before iterating, which the C++ compiler will complain about.
ctypedef InputNode* const constInputNodePtr

cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil:
cdef cppclass Graph:
T* emplace_node[T](...) except+
void initialize_state(State&) except+
span[const unique_ptr[Node]] nodes() const
span[const ArrayNodePtr] constraints()
span[const DecisionNodePtr] decisions()
Py_ssize_t num_constraints()
Py_ssize_t num_nodes()
Py_ssize_t num_decisions()
Py_ssize_t num_constraints()
Py_ssize_t num_inputs()
@staticmethod
void recursive_initialize(State&, Node*) except+
@staticmethod
Expand All @@ -64,3 +74,4 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization"
void topological_sort()
bool topologically_sorted() const
Py_ssize_t remove_unused_nodes()
span[constInputNodePtr] inputs()
18 changes: 16 additions & 2 deletions dwave/optimization/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,17 @@ from libcpp.vector cimport vector
from dwave.optimization.libcpp.array cimport Array as cppArray
from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode
from dwave.optimization.symbols cimport symbol_from_ptr
from dwave.optimization.expression cimport Expression


__all__ = ["Model"]


ctypedef fused ExpressionOrModel:
Model
Expression


@contextlib.contextmanager
def locked(model):
"""Context manager that hold a locked model and unlocks it when the context is exited."""
Expand Down Expand Up @@ -1035,6 +1041,10 @@ cdef class Model(_Model):
return SetVariable(self, n, min_size, n if max_size is None else max_size)


def _States_init(States self, ExpressionOrModel model):
self._model_ref = weakref.ref(model)


cdef class States:
r"""States of a symbol in a model.
Expand Down Expand Up @@ -1087,8 +1097,12 @@ cdef class States:
>>> model.states.size()
0
"""
def __init__(self, Model model):
self._model_ref = weakref.ref(model)

# Cython doesn't seem to properly handle fused type arguments on __init__,
# so we have to use this awkward workaround
# See https://github.com/cython/cython/issues/3758
def __init__(self, model):
_States_init(self, model)

def __len__(self):
"""The number of model states."""
Expand Down
34 changes: 19 additions & 15 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ from dwave.optimization.libcpp.nodes cimport (
from dwave.optimization.model cimport ArraySymbol, _Model, Model, Symbol
from dwave.optimization.expression cimport Expression

ctypedef fused ExpressionOrModel:
Model
Expression


__all__ = [
"Absolute",
"Add",
Expand Down Expand Up @@ -179,6 +184,7 @@ cdef void _register(object cls, const type_info& typeinfo):
_cpp_type_to_python[type_index(typeinfo)] = <PyObject*>(cls)


# TODO: should this use ExpressionOrModel?
cdef object symbol_from_ptr(_Model model, cppNode* node_ptr):
"""Create a Python/Cython symbol from a C++ Node*."""

Expand Down Expand Up @@ -2271,39 +2277,36 @@ cdef class NaryReduce(ArraySymbol):

def __init__(
self,
input_symbols: Collection[Input],
ArraySymbol output_symbol,
# input_symbols: Collection[Input],
# ArraySymbol output_symbol,
expression: Expression,
operands: Collection[ArraySymbol],
initial_values: Optional[tuple[float]] = None,
):
if len(operands) == 0:
raise ValueError("must have at least one operand")

if len(input_symbols) != len(operands) + 1:
if expression.num_inputs() != len(operands) + 1:
raise ValueError("must have exactly one more input than number of operands")

if initial_values is None:
initial_values = (0,) * len(input_symbols)
initial_values = (0,) * expression.num_inputs()

if len(initial_values) != len(input_symbols):
if len(initial_values) != expression.num_inputs():
raise ValueError("must have same number of initial values as inputs")

cdef _Model expression = input_symbols[0].model
cdef _Model model = operands[0].model
cdef cppArrayNode* output = output_symbol.array_ptr
cdef Model model = operands[0].model
cdef cppArrayNode* output = expression.output.array_ptr
cdef vector[double] cppinitial_values
cdef cppInputNode* cppinput
cdef vector[cppInputNode*] cppinputs
cdef vector[cppArrayNode*] cppoperands

for val in initial_values:
cppinitial_values.push_back(val)

cdef Input inp
for node in input_symbols:
if node.model != expression:
raise ValueError("all inputs must belong to the expression model")
inp = <Input?>node
cppinputs.push_back(inp.ptr)
for cppinput in expression._graph.inputs():
cppinputs.push_back(cppinput)

cdef ArraySymbol array
for node in operands:
Expand All @@ -2318,11 +2321,12 @@ cdef class NaryReduce(ArraySymbol):
move(expression._graph), cppinputs, output, cppinitial_values, cppoperands
)
except ValueError as e:
expression.unlock()
raise self._handle_unsupported_expression_exception(expression, e)

self.initialize_arraynode(model, self.ptr)

def _handle_unsupported_expression_exception(self, _Model expression, exception: ValueError):
def _handle_unsupported_expression_exception(self, Expression expression, exception):
try:
info = json.loads(str(exception))
except json.JSONDecodeError:
Expand Down
Loading

0 comments on commit 57aae02

Please sign in to comment.