diff --git a/src/translate/constraints.py b/src/translate/constraints.py index c5a80b59dc..b433078536 100644 --- a/src/translate/constraints.py +++ b/src/translate/constraints.py @@ -16,7 +16,8 @@ class EqualityConjunction: def __init__(self, equalities): self.equalities = tuple(equalities) # represents a conjunction of expressions x = y, where x,y are strings - # representing objects, variables or invariant parameters. + # representing objects or variables or ints, representing invariant + # parameters. self._consistent = None self._representative = None # dictionary @@ -47,13 +48,15 @@ def _compute_representatives(self): # the equivalence class contains a single object, the representative is # this object. (If it contains more than one object, the conjunction is # inconsistent and we don't store representatives.) - # (with objects being smaller than variables) + # (with objects being smaller than variables or invariant parameters) representative = {} for eq_class in self._eq_classes.values(): if next(iter(eq_class)) in representative: continue # we already processed this equivalence class - variables = [item for item in eq_class if item.startswith("?")] - constants = [item for item in eq_class if not item.startswith("?")] + variables = [item for item in eq_class if isinstance(item, int) or + item.startswith("?")] + constants = [item for item in eq_class if not isinstance(item, int) + and not item.startswith("?")] if len(constants) >= 2: self._consistent = False @@ -81,8 +84,8 @@ def get_representative(self): class ConstraintSystem: """A ConstraintSystem stores two parts, both talking about the equality or - inequality of strings (representing objects, variables or invariant - parameters): + inequality of strings and ints (strings representing objects or + variables, ints representing invariant parameters): - equality_DNFs is a list containing lists of EqualityConjunctions. Each EqualityConjunction represents an expression of the form (x1 = y1 and ... and xn = yn). A list of EqualityConjunctions can be @@ -163,7 +166,8 @@ def inequality_disjunction_ok(ineq_disj, representative): # inequality disjunction there is an inequality where the two terms # are in different equivalence classes. representative = combined.get_representative() - if any(representative.get(s, s)[0] != "?" + if any(not isinstance(representative.get(s, s), int) and + representative.get(s, s)[0] != "?" for s in self.not_constant): continue for ineq_disjunction in self.ineq_disjunctions: diff --git a/src/translate/invariant_finder.py b/src/translate/invariant_finder.py index d771233ac0..d016d2ec8a 100755 --- a/src/translate/invariant_finder.py +++ b/src/translate/invariant_finder.py @@ -78,14 +78,12 @@ def get_fluents(task): def get_initial_invariants(task): for predicate in get_fluents(task): - all_args = list(f"?@v{i}" for i in range(len(predicate.arguments))) - atom = pddl.Atom(predicate.name, all_args) - part = invariants.InvariantPart(atom, -1) + all_args = list(range(len(predicate.arguments))) + part = invariants.InvariantPart(predicate.name, all_args, -1) yield invariants.Invariant((part,)) for omitted in range(len(predicate.arguments)): - inv_args = all_args[0:omitted] + ["_"] + all_args[omitted:-1] - atom = pddl.Atom(predicate.name, inv_args) - part = invariants.InvariantPart(atom, omitted) + inv_args = all_args[0:omitted] + [-1] + all_args[omitted:-1] + part = invariants.InvariantPart(predicate.name, inv_args, omitted) yield invariants.Invariant((part,)) def find_invariants(task, reachable_action_params): @@ -124,13 +122,10 @@ def useful_groups(invariants, initial_facts): for invariant in predicate_to_invariants.get(atom.predicate, ()): parameters = invariant.get_parameters(atom) # we need to make the parameters dictionary hashable, so - # we store the values as a tuple in the order of the numbering of - # the invariant parameters. - inv_vars = [f"?@v{i}" for i in range(invariant.arity())] - parameters_tuple = tuple((var, parameters[var]) - for var in inv_vars) + # we store the values as a tuple + parameters_tuple = tuple(parameters[var] + for var in range(invariant.arity())) - parameters_tuple = tuple(sorted(x for x in parameters.items())) group_key = (invariant, parameters_tuple) if group_key not in nonempty_groups: nonempty_groups.add(group_key) diff --git a/src/translate/invariants.py b/src/translate/invariants.py index 47e53a0173..a7b605db4b 100644 --- a/src/translate/invariants.py +++ b/src/translate/invariants.py @@ -109,51 +109,52 @@ def ensure_inequality(system, literal1, literal2): class InvariantPart: - def __init__(self, atom, omitted_pos=-1): + def __init__(self, predicate, args, omitted_pos=-1): """There is one InvariantPart for every predicate mentioned in the - invariant. The atom of the invariant part has arguments of the form - "?@vi" for the invariant parameters and "_" at the omitted position. - If no position is omitted, omitted_pos is -1.""" - self.atom = atom + invariant. The arguments args contain numbers 0,1,... for the + invariant parameters and -1 at the omitted position. + If no position is omitted, omitted_pos is -1, otherwise it is the + index of -1 in args.""" + self.predicate = predicate + self.args = tuple(args) self.omitted_pos = omitted_pos def __eq__(self, other): # This implies equality of the omitted_pos component. - return self.atom == other.atom + return self.predicate == other.predicate and self.args == other.args def __ne__(self, other): - return self.atom != other.atom + return self.predicate != other.predicate or self.args != other.args def __le__(self, other): - return self.atom <= other.atom + return (self.predicate, self.args) <= (other.predicate, other.args) def __lt__(self, other): - return self.atom < other.atom + return (self.predicate, self.args) < (other.predicate, other.args) def __hash__(self): - return hash(self.atom) + return hash((self.predicate, self.args)) def __str__(self): - return f"{self.atom} [omitted_pos = {self.omitted_pos}]" + return f"{self.predicate}({self.args}) [omitted_pos = {self.omitted_pos}]" def arity(self): if self.omitted_pos == -1: - return len(self.atom.args) + return len(self.args) else: - return len(self.atom.args) - 1 + return len(self.args) - 1 def get_parameters(self, literal): """Returns a dictionary, mapping the invariant parameters to the corresponding values in the literal.""" return dict((arg, literal.args[pos]) - for pos, arg in enumerate(self.atom.args) + for pos, arg in enumerate(self.args) if pos != self.omitted_pos) def instantiate(self, parameters_tuple): - parameters = dict(parameters_tuple) - args = [parameters[arg] if arg != "_" else "?X" - for arg in self.atom.args] - return pddl.Atom(self.atom.predicate, args) + args = [parameters_tuple[arg] if arg != -1 else "?X" + for arg in self.args] + return pddl.Atom(self.predicate, args) def possible_mappings(self, own_literal, other_literal): """This method is used when an action had an unbalanced add effect @@ -244,18 +245,17 @@ def possible_matches(self, own_literal, other_literal): create a new Invariant Part Q(?@v1, ?@v2, _. ?@v0) with the third argument being counted. """ - assert self.atom.predicate == own_literal.predicate + assert self.predicate == own_literal.predicate result = [] for mapping in self.possible_mappings(own_literal, other_literal): - args = ["_"] * len(other_literal.args) + args = [-1] * len(other_literal.args) omitted = -1 for (other_pos, inv_var) in mapping: if inv_var == -1: omitted = other_pos else: args[other_pos] = inv_var - atom = pddl.Atom(other_literal.predicate, args) - result.append(InvariantPart(atom, omitted)) + result.append(InvariantPart(other_literal.predicate, args, omitted)) return result @@ -263,13 +263,14 @@ class Invariant: # An invariant is a logical expression of the type # forall ?@v1...?@vk: sum_(part in parts) weight(part, ?@v1, ..., ?@vk) <= 1. # k is called the arity of the invariant. - # A "part" is an atom that only contains arguments from {?@v1, ..., ?@vk, _}; - # the symbol _ may occur at most once. + # A "part" is an atom that only contains arguments from {?@v1, ..., ?@vk, -1} + # but instead of ?@vi, we store it as int i; + # the symbol -1 may occur at most once. def __init__(self, parts): self.parts = frozenset(parts) - self.predicates = {part.atom.predicate for part in parts} - self.predicate_to_part = {part.atom.predicate: part for part in parts} + self.predicate_to_part = {part.predicate: part for part in parts} + self.predicates = set(self.predicate_to_part.keys()) assert len(self.parts) == len(self.predicates) def __eq__(self, other): @@ -315,8 +316,8 @@ def _get_cover_equivalence_conjunction(self, literal): """ part = self.predicate_to_part[literal.predicate] equalities = [(inv_param, literal.args[pos]) - for pos, inv_param in enumerate(part.atom.args) - if pos != part.omitted_pos] # alternatively inv_param != "_" + for pos, inv_param in enumerate(part.args) + if inv_param != -1] # -1 if ommited return constraints.EqualityConjunction(equalities) # If there were more parts for the same predicate, we would have to # consider more than one assignment (disjunctively). @@ -326,7 +327,7 @@ def check_balance(self, balance_checker, enqueue_func): # Check balance for this hypothesis. actions_to_check = set() for part in self.parts: - actions_to_check |= balance_checker.get_threats(part.atom.predicate) + actions_to_check |= balance_checker.get_threats(part.predicate) for action in actions_to_check: heavy_action = balance_checker.get_heavy_action(action) if self._operator_too_heavy(heavy_action): @@ -410,7 +411,8 @@ def _add_effect_unbalanced(self, action, add_effect, del_effects, # add_cover. If the equivalence class contains an object, the # representative is an object. for param in params: - if representative.get(param, param)[0] == "?": + r = representative.get(param, param) + if isinstance(r, int) or r[0] == "?": # for the add effect being a threat to the invariant, param # does not need to be a specific constant. So we may not bind # it to a constant when balancing the add effect. We store this