Skip to content

Commit

Permalink
feat: faster, linear HashMap.alter and modify (#6573)
Browse files Browse the repository at this point in the history
This PR replaces the existing implementations of `(D)HashMap.alter` and
`(D)HashMap.modify` with primitive, more efficient ones and in
particular provides proofs that they yield well-formed hash maps (`WF`
typeclass).

---------

Co-authored-by: Paul Reichert <[email protected]>
  • Loading branch information
datokrat and datokrat authored Jan 14, 2025
1 parent 05aa256 commit 3243dc5
Show file tree
Hide file tree
Showing 9 changed files with 889 additions and 45 deletions.
31 changes: 12 additions & 19 deletions src/Std/Data/DHashMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -251,34 +251,27 @@ instance [BEq α] [Hashable α] : ForIn m (DHashMap α β) ((a : α) × β a) wh
Modifies in place the value associated with a given key.
This function ensures that the value is used linearly.
It is currently implemented in terms of `get?`, `erase`, and `insert`,
but will later become a primitive operation.
(It is provided already to help avoid non-linear code.)
-/
@[inline] def modify [LawfulBEq α] (m : DHashMap α β) (a : α) (f : β a → β a) : DHashMap α β :=
match m.get? a with
| none => m
| some b => m.erase a |>.insert a (f b)
⟨Raw₀.modify ⟨m.1, m.2.size_buckets_pos⟩ a f, Raw.WF.modify₀ m.2

@[inline, inherit_doc DHashMap.modify] def Const.modify {β : Type v} (m : DHashMap α (fun _ => β))
(a : α) (f : β → β) : DHashMap α (fun _ => β) :=
⟨Raw₀.Const.modify ⟨m.1, m.2.size_buckets_pos⟩ a f, Raw.WF.constModify₀ m.2

/--
Modifies in place the value associated with a given key,
allowing creating new values and deleting values via an `Option` valued replacement function.
This function ensures that the value is used linearly.
It is currently implemented in terms of `get?`, `erase`, and `insert`,
but will later become a primitive operation.
(It is provided already to help avoid non-linear code.)
-/
@[inline] def alter [LawfulBEq α] (m : DHashMap α β) (a : α) (f : Option (β a) → Option (β a)) : DHashMap α β :=
match m.get? a with
| none =>
match f none with
| none => m
| some b => m.insert a b
| some b =>
match f (some b) with
| none => m.erase a
| some b => m.erase a |>.insert a b
@[inline] def alter [LawfulBEq α] (m : DHashMap α β)
(a : α) (f : Option (β a) → Option (β a)) : DHashMap α β :=
⟨Raw₀.alter ⟨m.1, m.2.size_buckets_pos⟩ a f, Raw.WF.alter₀ m.2

@[inline, inherit_doc DHashMap.alter] def Const.alter {β : Type v}
(m : DHashMap α (fun _ => β)) (a : α) (f : Option β → Option β) : DHashMap α (fun _ => β) :=
⟨Raw₀.Const.alter ⟨m.1, m.2.size_buckets_pos⟩ a f, Raw.WF.constAlter₀ m.2

@[inline, inherit_doc Raw.insertMany] def insertMany {ρ : Type w}
[ForIn Id ρ ((a : α) × β a)] (m : DHashMap α β) (l : ρ) : DHashMap α β :=
Expand Down
57 changes: 57 additions & 0 deletions src/Std/Data/DHashMap/Internal/AssocList/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,63 @@ def erase [BEq α] (a : α) : AssocList α β → AssocList α β
| nil => nil
| cons k v l => bif k == a then l else cons k v (l.erase a)

/-- Internal implementation detail of the hash map -/
def modify [BEq α] [LawfulBEq α] (a : α) (f : β a → β a) :
AssocList α β → AssocList α β
| nil => nil
| cons k v l =>
if h : k == a then
have h' : k = a := eq_of_beq h
let b := f (cast (congrArg β h') v)
cons a b l
else
cons k v (modify a f l)

/-- Internal implementation detail of the hash map -/
def alter [BEq α] [LawfulBEq α] (a : α) (f : Option (β a) → Option (β a)) :
AssocList α β → AssocList α β
| nil => match f none with
| none => nil
| some b => cons a b nil
| cons k v l =>
if h : k == a then
have h' : k = a := eq_of_beq h
match f (some (cast (congrArg β h') v)) with
| none => l
| some b => cons a b l
else
let tail := alter a f l
cons k v tail

namespace Const

/-- Internal implementation detail of the hash map -/
def modify [BEq α] {β : Type v} (a : α) (f : β → β) :
AssocList α (fun _ => β) → AssocList α (fun _ => β)
| nil => nil
| cons k v l =>
if k == a then
cons a (f v) l
else
cons k v (modify a f l)

/-- Internal implementation detail of the hash map -/
def alter [BEq α] {β : Type v} (a : α) (f : Option β → Option β) :
AssocList α (fun _ => β) → AssocList α (fun _ => β)
| nil => match f none with
| none => nil
| some b => AssocList.cons a b nil
| cons k v l =>
if k == a then
match f v with
| none => l
| some b => cons a b l
else
let tail := alter a f l
cons k v tail

end Const

/-- Internal implementation detail of the hash map -/
@[inline] def filterMap (f : (a : α) → β a → Option (γ a)) :
AssocList α β → AssocList α γ :=
Expand Down
39 changes: 39 additions & 0 deletions src/Std/Data/DHashMap/Internal/AssocList/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,45 @@ theorem toList_filter {f : (a : α) → β a → Bool} {l : AssocList α β} :
· exact (ih _).trans (by simpa using perm_middle.symm)
· exact ih _

theorem toList_alter [BEq α] [LawfulBEq α] {a : α} {f : Option (β a) → Option (β a)}
{l : AssocList α β} :
Perm (l.alter a f).toList (alterKey a f l.toList) := by
induction l
· simp only [alter, toList_nil, alterKey_nil]
split <;> simp_all
· rw [toList]
refine Perm.trans ?_ alterKey_cons_perm.symm
rw [alter]
split <;> (try split) <;> simp_all

theorem modify_eq_alter [BEq α] [LawfulBEq α] {a : α} {f : β a → β a} {l : AssocList α β} :
modify a f l = alter a (·.map f) l := by
induction l
· rfl
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]

namespace Const

variable {β : Type v}

theorem toList_alter [BEq α] [EquivBEq α] {a : α} {f : Option β → Option β}
{l : AssocList α (fun _ => β)} : Perm (alter a f l).toList (Const.alterKey a f l.toList) := by
induction l
· simp only [alter, toList_nil, alterKey_nil]
split <;> simp_all
· rw [toList]
refine Perm.trans ?_ Const.alterKey_cons_perm.symm
rw [alter]
split <;> (try split) <;> simp_all

theorem modify_eq_alter [BEq α] [EquivBEq α] {a : α} {f : β → β} {l : AssocList α (fun _ => β)} :
modify a f l = alter a (·.map f) l := by
induction l
· rfl
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]

end Const

theorem foldl_apply {l : AssocList α β} {acc : List δ} (f : (a : α) → β a → δ) :
l.foldl (fun acc k v => f k v :: acc) acc =
(l.toList.map (fun p => f p.1 p.2)).reverse ++ acc := by
Expand Down
66 changes: 66 additions & 0 deletions src/Std/Data/DHashMap/Internal/Defs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,72 @@ where
let buckets' := buckets.uset i (AssocList.cons a b bkt) h
expandIfNecessary ⟨⟨size', buckets'⟩, by simpa [buckets']⟩

/-- Internal implementation detail of the hash map -/
@[inline] def modify [BEq α] [Hashable α] [LawfulBEq α] (m : Raw₀ α β) (a : α) (f : β a → β a) :
Raw₀ α β :=
let ⟨⟨size, buckets⟩, hm⟩ := m
let size' := size
let ⟨i, hi⟩ := mkIdx buckets.size hm (hash a)
let bucket := buckets[i]
if bucket.contains a then
let buckets := buckets.uset i .nil hi
let bucket := bucket.modify a f
⟨⟨size, buckets.uset i bucket (by simpa [buckets])⟩, (by simpa [buckets])⟩
else
m

/-- Internal implementation detail of the hash map -/
@[inline] def Const.modify [BEq α] {β : Type v} [Hashable α] (m : Raw₀ α (fun _ => β)) (a : α)
(f : β → β) : Raw₀ α (fun _ => β) :=
let ⟨⟨size, buckets⟩, hm⟩ := m
let size' := size
let ⟨i, hi⟩ := mkIdx buckets.size hm (hash a)
let bucket := buckets[i]
if bucket.contains a then
let buckets := buckets.uset i .nil hi
let bucket := AssocList.Const.modify a f bucket
⟨⟨size, buckets.uset i bucket (by simpa [buckets])⟩, (by simpa [buckets])⟩
else
m

/-- Internal implementation detail of the hash map -/
@[inline] def alter [BEq α] [Hashable α] [LawfulBEq α] (m : Raw₀ α β) (a : α)
(f : Option (β a) → Option (β a)) : Raw₀ α β :=
let ⟨⟨size, buckets⟩, hm⟩ := m
let ⟨i, h⟩ := mkIdx buckets.size hm (hash a)
let bkt := buckets[i]
if bkt.contains a then
let buckets' := buckets.uset i .nil h
let bkt' := bkt.alter a f
let size' := if bkt'.contains a then size else size - 1
⟨⟨size', buckets'.uset i bkt' (by simpa [buckets'])⟩, by simpa [buckets']⟩
else
match f none with
| none => m
| some b =>
let size' := size + 1
let buckets' := buckets.uset i (.cons a b bkt) h
expandIfNecessary ⟨⟨size', buckets'⟩, by simpa [buckets']⟩

/-- Internal implementation detail of the hash map -/
@[inline] def Const.alter [BEq α] [Hashable α] {β : Type v} (m : Raw₀ α (fun _ => β)) (a : α)
(f : Option β → Option β) : Raw₀ α (fun _ => β) :=
let ⟨⟨size, buckets⟩, hm⟩ := m
let ⟨i, h⟩ := mkIdx buckets.size hm (hash a)
let bkt := buckets[i]
if bkt.contains a then
let buckets' := buckets.uset i .nil h
let bkt' := AssocList.Const.alter a f bkt
let size' := if bkt'.contains a then size else size - 1
⟨⟨size', buckets'.uset i bkt' (by simpa [buckets'])⟩, by simpa [buckets']⟩
else
match f none with
| none => m
| some b =>
let size' := size + 1
let buckets' := buckets.uset i (.cons a b bkt) h
expandIfNecessary ⟨⟨size', buckets'⟩, by simpa [buckets']⟩

/-- Internal implementation detail of the hash map -/
@[inline] def containsThenInsert [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (b : β a) :
Bool × Raw₀ α β :=
Expand Down
Loading

0 comments on commit 3243dc5

Please sign in to comment.