Skip to content

Commit

Permalink
• Correction for output of the global dual bound: for that we moved i…
Browse files Browse the repository at this point in the history
…ts calculation from `TreeSearch.children` function to `TreeSearch.stop` function (#1122)

• At the same time, we simplified the `TreeSearch` interface
  • Loading branch information
rrsadykov authored Jan 19, 2024
1 parent 1268124 commit 3db73e6
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 32 deletions.
6 changes: 3 additions & 3 deletions src/Algorithm/treesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ The input can be built from information contained in a search space and a node.
Methods to perform operations before the tree search algorithm evaluates a node (`current`).
This is useful to restore the state of the formulation for instance.
"""
@mustimplement "ColunaSearchSpace" node_change!(previous::TreeSearch.AbstractNode, current::TreeSearch.AbstractNode, space::AbstractColunaSearchSpace, untreated_nodes) = nothing
@mustimplement "ColunaSearchSpace" node_change!(previous::TreeSearch.AbstractNode, current::TreeSearch.AbstractNode, space::AbstractColunaSearchSpace) = nothing

"""
Methods to perform operations after the conquer algorithms.
Expand Down Expand Up @@ -151,11 +151,11 @@ Performs operations after the divide algorithm when the current node is finally
@mustimplement "ColunaSearchSpace" node_is_pruned(sp::AbstractColunaSearchSpace, current) = nothing

# Implementation of the `children` method for the `AbstractColunaSearchSpace` algorithm.
function TreeSearch.children(space::AbstractColunaSearchSpace, current::TreeSearch.AbstractNode, env, untreated_nodes)
function TreeSearch.children(space::AbstractColunaSearchSpace, current::TreeSearch.AbstractNode, env)
# restore state of the formulation for the current node.
previous = get_previous(space)
if !isnothing(previous)
node_change!(previous, current, space, untreated_nodes)
node_change!(previous, current, space)
end
set_previous!(space, current)
# We should avoid the whole exploration of a node if its local dual bound inherited from its parent is worst than a primal bound found elsewhere on the tree.
Expand Down
17 changes: 9 additions & 8 deletions src/Algorithm/treesearch/branch_and_bound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ mutable struct BaBSearchSpace <: AbstractColunaSearchSpace
optstate::OptimizationState # from TreeSearchRuntimeData

nb_nodes_treated::Int
nb_untreated_nodes::Int
leaves_status::LeavesStatus
inc_primal_manager::GlobalPrimalBoundHandler # stores the global primal bound (shared with all child algorithms).
end
Expand All @@ -152,7 +153,9 @@ set_previous!(sp::BaBSearchSpace, previous::TreeSearch.AbstractNode) = sp.previo
# Tree search implementation
############################################################################################
function TreeSearch.stop(space::BaBSearchSpace, untreated_nodes)
return space.nb_nodes_treated >= space.max_num_nodes || length(untreated_nodes) > space.open_nodes_limit
_update_global_dual_bound!(space, space.reformulation, untreated_nodes) # this method needs to be reimplemented.
space.nb_untreated_nodes = length(untreated_nodes)
return space.nb_nodes_treated >= space.max_num_nodes || space.nb_untreated_nodes > space.open_nodes_limit
end

function TreeSearch.search_space_type(alg::TreeSearchAlgorithm)
Expand Down Expand Up @@ -193,6 +196,7 @@ function TreeSearch.new_space(
nothing,
optstate,
0,
0,
LeavesStatus(reform),
GlobalPrimalBoundHandler(reform; ip_primal_bound = get_ip_primal_bound(input))
)
Expand Down Expand Up @@ -337,7 +341,7 @@ function _update_global_dual_bound!(space, reform::Reformulation, untreated_node
DualBound(getmaster(reform))
end
else
# Otherwise, we use the wost dual bound at the leaves.
# Otherwise, we use the worst dual bound at the leaves.
leaves_worst_dual_bound
end

Expand All @@ -353,24 +357,21 @@ function _update_global_dual_bound!(space, reform::Reformulation, untreated_node
return
end

function node_change!(previous::Node, current::Node, space::BaBSearchSpace, untreated_nodes)
_update_global_dual_bound!(space, space.reformulation, untreated_nodes) # this method needs to be reimplemented.

function node_change!(previous::Node, current::Node, space::BaBSearchSpace)
# We restore the reformulation in the state it was after the creation of the current node (e.g. creation
# of the branching constraint) or its partial evaluation (e.g. strong branching).
# TODO: We don't need to restore if the formulation has been fully evaluated.
restore_from_records!(space.conquer_units_to_restore, current.records)
end

function TreeSearch.tree_search_output(space::BaBSearchSpace, untreated_nodes)
_update_global_dual_bound!(space, space.reformulation, untreated_nodes)
function TreeSearch.tree_search_output(space::BaBSearchSpace)
all_leaves_infeasible = space.leaves_status.infeasible

if !isnothing(get_global_primal_sol(space.inc_primal_manager))
add_ip_primal_sol!(space.optstate, get_global_primal_sol(space.inc_primal_manager))
end

if all_leaves_infeasible && length(untreated_nodes) == 0
if all_leaves_infeasible && space.nb_untreated_nodes == 0
setterminationstatus!(space.optstate, INFEASIBLE)
elseif ip_gap_closed(space.optstate, rtol = space.opt_rtol, atol = space.opt_atol)
setterminationstatus!(space.optstate, OPTIMAL)
Expand Down
14 changes: 8 additions & 6 deletions src/Algorithm/treesearch/printer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ end

TreeSearch.get_priority(explore::TreeSearch.AbstractExploreStrategy, n::PrintedNode) = TreeSearch.get_priority(explore, n.inner)

function TreeSearch.tree_search_output(sp::PrinterSearchSpace, untreated_nodes)
function TreeSearch.tree_search_output(sp::PrinterSearchSpace)
close_tree_search_file!(sp.file_printer)
return TreeSearch.tree_search_output(sp.inner, Iterators.map(n -> n.inner, untreated_nodes))
return TreeSearch.tree_search_output(sp.inner)
end

function TreeSearch.new_space(
Expand Down Expand Up @@ -83,17 +83,19 @@ end
_inner_node(n::PrintedNode) = n.inner # `untreated_node` is a stack.
_inner_node(n::Pair{<:PrintedNode, Float64}) = first(n).inner # `untreated_node` is a priority queue.

function TreeSearch.children(sp::PrinterSearchSpace, current, env, untreated_nodes)
print_log(sp.log_printer, sp, current, env, length(untreated_nodes))
children = TreeSearch.children(sp.inner, current.inner, env, Iterators.map(_inner_node, untreated_nodes))
function TreeSearch.children(sp::PrinterSearchSpace, current, env)
print_log(sp.log_printer, sp, current, env, sp.inner.nb_untreated_nodes)
children = TreeSearch.children(sp.inner, current.inner, env)
# We print node information in the file after the node has been evaluated.
print_node_in_tree_search_file!(sp.file_printer, current, sp, env)
return map(children) do child
return PrintedNode(sp.current_tree_order_id += 1, current.tree_order_id, child)
end
end

TreeSearch.stop(sp::PrinterSearchSpace, untreated_nodes) = TreeSearch.stop(sp.inner, untreated_nodes)
function TreeSearch.stop(sp::PrinterSearchSpace, untreated_nodes)
return TreeSearch.stop(sp.inner, Iterators.map(_inner_node, untreated_nodes))
end

############################################################################################
# Default file printers.
Expand Down
6 changes: 2 additions & 4 deletions src/TreeSearch/TreeSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,14 @@ get_conquer_output(::AbstractNode) = nothing
"Returns `true` is the node is root; `false` otherwise."
@mustimplement "Node" isroot(::AbstractNode) = nothing # BaB implementation

# TODO: remove untreated nodes.
"Evaluate and generate children. This method has a specific implementation for Coluna."
@mustimplement "TreeSearch" children(sp, n, env, untreated_nodes) = nothing
@mustimplement "TreeSearch" children(sp, n, env) = nothing

"Returns true if stopping criteria are met; false otherwise."
@mustimplement "TreeSearch" stop(::AbstractSearchSpace, untreated_nodes) = nothing

# TODO: remove untreated nodes.
"Returns the output of the tree search algorithm."
@mustimplement "TreeSearch" tree_search_output(::AbstractSearchSpace, untreated_nodes) = nothing
@mustimplement "TreeSearch" tree_search_output(::AbstractSearchSpace) = nothing

# Default implementations for some explore strategies.
include("explore.jl")
Expand Down
14 changes: 8 additions & 6 deletions src/TreeSearch/explore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,26 @@ function tree_search(::DepthFirstStrategy, space, env, input)
root_node = new_root(space, input)
stack = Stack{typeof(root_node)}()
push!(stack, root_node)
while !isempty(stack) && !stop(space, stack)
# it is important to call `stop()` function first, as it may update `space`
while !stop(space, stack) && !isempty(stack)
current = pop!(stack)
for child in children(space, current, env, stack)
for child in children(space, current, env)
push!(stack, child)
end
end
return TreeSearch.tree_search_output(space, stack)
return TreeSearch.tree_search_output(space)
end

function tree_search(strategy::AbstractBestFirstSearch, space, env, input)
root_node = new_root(space, input)
pq = PriorityQueue{typeof(root_node), Float64}()
enqueue!(pq, root_node, get_priority(strategy, root_node))
while !isempty(pq) && !stop(space, pq)
# it is important to call `stop()` function first, as it may update `space`
while !stop(space, pq) && !isempty(pq)
current = dequeue!(pq)
for child in children(space, current, env, pq)
for child in children(space, current, env)
enqueue!(pq, child, get_priority(strategy, child))
end
end
return TreeSearch.tree_search_output(space, pq)
return TreeSearch.tree_search_output(space)
end
4 changes: 2 additions & 2 deletions test/unit/Algorithm/explore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct CustomBestFirstSearch <: Coluna.TreeSearch.AbstractBestFirstSearch end

Coluna.TreeSearch.get_priority(::CustomBestFirstSearch, node::NodeAe1) = -node.id

function Coluna.TreeSearch.children(space::CustomSearchSpaceAe1, current, _, _)
function Coluna.TreeSearch.children(space::CustomSearchSpaceAe1, current, _)
children = NodeAe1[]
push!(space.visit_order, current.id)
if current.depth != space.max_depth &&
Expand All @@ -47,7 +47,7 @@ function Coluna.TreeSearch.children(space::CustomSearchSpaceAe1, current, _, _)
return children
end

Coluna.TreeSearch.tree_search_output(space::CustomSearchSpaceAe1, _) = space.visit_order
Coluna.TreeSearch.tree_search_output(space::CustomSearchSpaceAe1) = space.visit_order

function test_dfs()
search_space = CustomSearchSpaceAe1(2, 3, 11)
Expand Down
9 changes: 6 additions & 3 deletions test/unit/TreeSearch/treesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ function Coluna.TreeSearch.new_root(space::TestBaBSearchSpace, input)
return TestBaBNode(inner, 1) ## root id is set to 1 by default
end

Coluna.TreeSearch.stop(space::TestBaBSearchSpace, untreated_nodes) = Coluna.TreeSearch.stop(space.inner, untreated_nodes)
Coluna.TreeSearch.tree_search_output(space::TestBaBSearchSpace, untreated_nodes) = Coluna.TreeSearch.tree_search_output(space.inner, map(n -> n.inner, untreated_nodes))
function Coluna.TreeSearch.stop(space::TestBaBSearchSpace, untreated_nodes)
inner_untreated_nodes = map(node->node.inner, untreated_nodes)
return Coluna.TreeSearch.stop(space.inner, inner_untreated_nodes)
end
Coluna.TreeSearch.tree_search_output(space::TestBaBSearchSpace) = Coluna.TreeSearch.tree_search_output(space.inner)

# methods called by native method children (in branch_and_bound.jl)
Coluna.Algorithm.get_previous(space::TestBaBSearchSpace) = Coluna.Algorithm.get_previous(space.inner)
Expand Down Expand Up @@ -150,7 +153,7 @@ function Coluna.Algorithm.new_children(space::TestBaBSearchSpace, branches::Colu
return children
end

Coluna.Algorithm.node_change!(previous::Coluna.Algorithm.Node, current::TestBaBNode, space::TestBaBSearchSpace, untreated_nodes) = Coluna.Algorithm.node_change!(previous, current.inner, space.inner, map(n -> n.inner, untreated_nodes))
Coluna.Algorithm.node_change!(previous::Coluna.Algorithm.Node, current::TestBaBNode, space::TestBaBSearchSpace) = Coluna.Algorithm.node_change!(previous, current.inner, space.inner)

# end of the interface's redefinition

Expand Down

0 comments on commit 3db73e6

Please sign in to comment.