From 80ddbf45eb92794b1df2cb625e9432753d41601b Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Thu, 16 Jan 2025 16:19:28 +1100 Subject: [PATCH] feat: align List/Array/Vector.flatMap (#6660) This PR defines `Vector.flatMap`, changes the order of arguments in `List.flatMap` for consistency, and aligns the lemmas for `List`/`Array`/`Vector` `flatMap`. --- src/Init/Data/Array/Lemmas.lean | 114 +++++++++++++++++++++++++++++-- src/Init/Data/List/Basic.lean | 6 +- src/Init/Data/List/Impl.lean | 6 +- src/Init/Data/List/Lemmas.lean | 10 +-- src/Init/Data/Vector/Basic.lean | 3 + src/Init/Data/Vector/Lemmas.lean | 69 +++++++++++++++++++ 6 files changed, 193 insertions(+), 15 deletions(-) diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 1acf464ca378..b78fe308f0a5 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -1560,6 +1560,11 @@ theorem filterMap_eq_push_iff {f : α → Option β} {l : Array α} {l' : Array cases bs simp +theorem toArray_append {xs : List α} {ys : Array α} : + xs.toArray ++ ys = (xs ++ ys.toList).toArray := by + rcases ys with ⟨ys⟩ + simp + @[simp] theorem toArray_eq_append_iff {xs : List α} {as bs : Array α} : xs.toArray = as ++ bs ↔ xs = as.toList ++ bs.toList := by cases as @@ -1871,6 +1876,11 @@ theorem append_eq_map_iff {f : α → β} : rw [← flatten_map_toArray] simp +theorem flatten_toArray (l : List (Array α)) : + l.toArray.flatten = (l.map Array.toList).flatten.toArray := by + apply ext' + simp + @[simp] theorem size_flatten (L : Array (Array α)) : L.flatten.size = (L.map size).sum := by cases L using array₂_induction simp [Function.comp_def] @@ -1886,14 +1896,14 @@ theorem mem_flatten : ∀ {L : Array (Array α)}, a ∈ L.flatten ↔ ∃ l, l · rintro ⟨s, h₁, h₂⟩ refine ⟨s.toList, ⟨⟨s, h₁, rfl⟩, h₂⟩⟩ -@[simp] theorem flatten_eq_nil_iff {L : Array (Array α)} : L.flatten = #[] ↔ ∀ l ∈ L, l = #[] := by +@[simp] theorem flatten_eq_empty_iff {L : Array (Array α)} : L.flatten = #[] ↔ ∀ l ∈ L, l = #[] := by induction L using array₂_induction simp -@[simp] theorem nil_eq_flatten_iff {L : Array (Array α)} : #[] = L.flatten ↔ ∀ l ∈ L, l = #[] := by - rw [eq_comm, flatten_eq_nil_iff] +@[simp] theorem empty_eq_flatten_iff {L : Array (Array α)} : #[] = L.flatten ↔ ∀ l ∈ L, l = #[] := by + rw [eq_comm, flatten_eq_empty_iff] -theorem flatten_ne_nil_iff {xs : Array (Array α)} : xs.flatten ≠ #[] ↔ ∃ x, x ∈ xs ∧ x ≠ #[] := by +theorem flatten_ne_empty_iff {xs : Array (Array α)} : xs.flatten ≠ #[] ↔ ∃ x, x ∈ xs ∧ x ≠ #[] := by simp theorem exists_of_mem_flatten : a ∈ flatten L → ∃ l, l ∈ L ∧ a ∈ l := mem_flatten.1 @@ -2029,6 +2039,102 @@ theorem eq_iff_flatten_eq {L L' : Array (Array α)} : rw [List.map_inj_right] simp +contextual +/-! ### flatMap -/ + +theorem flatMap_def (l : Array α) (f : α → Array β) : l.flatMap f = flatten (map f l) := by + rcases l with ⟨l⟩ + simp [flatten_toArray, Function.comp_def, List.flatMap_def] + +theorem flatMap_toList (l : Array α) (f : α → List β) : + l.toList.flatMap f = (l.flatMap (fun a => (f a).toArray)).toList := by + rcases l with ⟨l⟩ + simp + +@[simp] theorem flatMap_id (l : Array (Array α)) : l.flatMap id = l.flatten := by simp [flatMap_def] + +@[simp] theorem flatMap_id' (l : Array (Array α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def] + +@[simp] +theorem size_flatMap (l : Array α) (f : α → Array β) : + (l.flatMap f).size = sum (map (fun a => (f a).size) l) := by + rcases l with ⟨l⟩ + simp [Function.comp_def] + +@[simp] theorem mem_flatMap {f : α → Array β} {b} {l : Array α} : b ∈ l.flatMap f ↔ ∃ a, a ∈ l ∧ b ∈ f a := by + simp [flatMap_def, mem_flatten] + exact ⟨fun ⟨_, ⟨a, h₁, rfl⟩, h₂⟩ => ⟨a, h₁, h₂⟩, fun ⟨a, h₁, h₂⟩ => ⟨_, ⟨a, h₁, rfl⟩, h₂⟩⟩ + +theorem exists_of_mem_flatMap {b : β} {l : Array α} {f : α → Array β} : + b ∈ l.flatMap f → ∃ a, a ∈ l ∧ b ∈ f a := mem_flatMap.1 + +theorem mem_flatMap_of_mem {b : β} {l : Array α} {f : α → Array β} {a} (al : a ∈ l) (h : b ∈ f a) : + b ∈ l.flatMap f := mem_flatMap.2 ⟨a, al, h⟩ + +@[simp] +theorem flatMap_eq_empty_iff {l : Array α} {f : α → Array β} : l.flatMap f = #[] ↔ ∀ x ∈ l, f x = #[] := by + rw [flatMap_def, flatten_eq_empty_iff] + simp + +theorem forall_mem_flatMap {p : β → Prop} {l : Array α} {f : α → Array β} : + (∀ (x) (_ : x ∈ l.flatMap f), p x) ↔ ∀ (a) (_ : a ∈ l) (b) (_ : b ∈ f a), p b := by + simp only [mem_flatMap, forall_exists_index, and_imp] + constructor <;> (intros; solve_by_elim) + +theorem flatMap_singleton (f : α → Array β) (x : α) : #[x].flatMap f = f x := by + simp + +@[simp] theorem flatMap_singleton' (l : Array α) : (l.flatMap fun x => #[x]) = l := by + rcases l with ⟨l⟩ + simp + +@[simp] theorem flatMap_append (xs ys : Array α) (f : α → Array β) : + (xs ++ ys).flatMap f = xs.flatMap f ++ ys.flatMap f := by + rcases xs with ⟨xs⟩ + rcases ys with ⟨ys⟩ + simp + +theorem flatMap_assoc {α β} (l : Array α) (f : α → Array β) (g : β → Array γ) : + (l.flatMap f).flatMap g = l.flatMap fun x => (f x).flatMap g := by + rcases l with ⟨l⟩ + simp [List.flatMap_assoc, flatMap_toList] + +theorem map_flatMap (f : β → γ) (g : α → Array β) (l : Array α) : + (l.flatMap g).map f = l.flatMap fun a => (g a).map f := by + rcases l with ⟨l⟩ + simp [List.map_flatMap] + +theorem flatMap_map (f : α → β) (g : β → Array γ) (l : Array α) : + (map f l).flatMap g = l.flatMap (fun a => g (f a)) := by + rcases l with ⟨l⟩ + simp [List.flatMap_map] + +theorem map_eq_flatMap {α β} (f : α → β) (l : Array α) : map f l = l.flatMap fun x => #[f x] := by + simp only [← map_singleton] + rw [← flatMap_singleton' l, map_flatMap, flatMap_singleton'] + +theorem filterMap_flatMap {β γ} (l : Array α) (g : α → Array β) (f : β → Option γ) : + (l.flatMap g).filterMap f = l.flatMap fun a => (g a).filterMap f := by + rcases l with ⟨l⟩ + simp [List.filterMap_flatMap] + +theorem filter_flatMap (l : Array α) (g : α → Array β) (f : β → Bool) : + (l.flatMap g).filter f = l.flatMap fun a => (g a).filter f := by + rcases l with ⟨l⟩ + simp [List.filter_flatMap] + +theorem flatMap_eq_foldl (f : α → Array β) (l : Array α) : + l.flatMap f = l.foldl (fun acc a => acc ++ f a) #[] := by + rcases l with ⟨l⟩ + simp only [List.flatMap_toArray, List.flatMap_eq_foldl, size_toArray, List.foldl_toArray'] + suffices ∀ l', (List.foldl (fun acc a => acc ++ (f a).toList) l' l).toArray = + List.foldl (fun acc a => acc ++ f a) l'.toArray l by + simpa using this [] + induction l with + | nil => simp + | cons a l ih => + intro l' + simp [ih ((l' ++ (f a).toList)), toArray_append] + /-! Content below this point has not yet been aligned with `List`. -/ -- This is a duplicate of `List.toArray_toList`. diff --git a/src/Init/Data/List/Basic.lean b/src/Init/Data/List/Basic.lean index 3352b90f44d5..3b9b7efb09a0 100644 --- a/src/Init/Data/List/Basic.lean +++ b/src/Init/Data/List/Basic.lean @@ -606,11 +606,11 @@ set_option linter.missingDocs false in to get a list of lists, and then concatenates them all together. * `[2, 3, 2].bind range = [0, 1, 0, 1, 2, 0, 1]` -/ -@[inline] def flatMap {α : Type u} {β : Type v} (a : List α) (b : α → List β) : List β := flatten (map b a) +@[inline] def flatMap {α : Type u} {β : Type v} (b : α → List β) (a : List α) : List β := flatten (map b a) -@[simp] theorem flatMap_nil (f : α → List β) : List.flatMap [] f = [] := by simp [flatten, List.flatMap] +@[simp] theorem flatMap_nil (f : α → List β) : List.flatMap f [] = [] := by simp [flatten, List.flatMap] @[simp] theorem flatMap_cons x xs (f : α → List β) : - List.flatMap (x :: xs) f = f x ++ List.flatMap xs f := by simp [flatten, List.flatMap] + List.flatMap f (x :: xs) = f x ++ List.flatMap f xs := by simp [flatten, List.flatMap] set_option linter.missingDocs false in @[deprecated flatMap (since := "2024-10-16")] abbrev bind := @flatMap diff --git a/src/Init/Data/List/Impl.lean b/src/Init/Data/List/Impl.lean index 5fb32bf0b59c..20f324dba419 100644 --- a/src/Init/Data/List/Impl.lean +++ b/src/Init/Data/List/Impl.lean @@ -96,14 +96,14 @@ The following operations are given `@[csimp]` replacements below: /-! ### flatMap -/ /-- Tail recursive version of `List.flatMap`. -/ -@[inline] def flatMapTR (as : List α) (f : α → List β) : List β := go as #[] where +@[inline] def flatMapTR (f : α → List β) (as : List α) : List β := go as #[] where /-- Auxiliary for `flatMap`: `flatMap.go f as = acc.toList ++ bind f as` -/ @[specialize] go : List α → Array β → List β | [], acc => acc.toList | x::xs, acc => go xs (acc ++ f x) @[csimp] theorem flatMap_eq_flatMapTR : @List.flatMap = @flatMapTR := by - funext α β as f + funext α β f as let rec go : ∀ as acc, flatMapTR.go f as acc = acc.toList ++ as.flatMap f | [], acc => by simp [flatMapTR.go, flatMap] | x::xs, acc => by simp [flatMapTR.go, flatMap, go xs] @@ -112,7 +112,7 @@ The following operations are given `@[csimp]` replacements below: /-! ### flatten -/ /-- Tail recursive version of `List.flatten`. -/ -@[inline] def flattenTR (l : List (List α)) : List α := flatMapTR l id +@[inline] def flattenTR (l : List (List α)) : List α := l.flatMapTR id @[csimp] theorem flatten_eq_flattenTR : @flatten = @flattenTR := by funext α l; rw [← List.flatMap_id, List.flatMap_eq_flatMapTR]; rfl diff --git a/src/Init/Data/List/Lemmas.lean b/src/Init/Data/List/Lemmas.lean index f6f11494a15d..a0f8aa260b64 100644 --- a/src/Init/Data/List/Lemmas.lean +++ b/src/Init/Data/List/Lemmas.lean @@ -2070,14 +2070,14 @@ theorem eq_iff_flatten_eq : ∀ {L L' : List (List α)}, theorem flatMap_def (l : List α) (f : α → List β) : l.flatMap f = flatten (map f l) := by rfl -@[simp] theorem flatMap_id (l : List (List α)) : List.flatMap l id = l.flatten := by simp [flatMap_def] +@[simp] theorem flatMap_id (l : List (List α)) : l.flatMap id = l.flatten := by simp [flatMap_def] -@[simp] theorem flatMap_id' (l : List (List α)) : List.flatMap l (fun a => a) = l.flatten := by simp [flatMap_def] +@[simp] theorem flatMap_id' (l : List (List α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def] @[simp] theorem length_flatMap (l : List α) (f : α → List β) : - length (l.flatMap f) = sum (map (length ∘ f) l) := by - rw [List.flatMap, length_flatten, map_map] + length (l.flatMap f) = sum (map (fun a => (f a).length) l) := by + rw [List.flatMap, length_flatten, map_map, Function.comp_def] @[simp] theorem mem_flatMap {f : α → List β} {b} {l : List α} : b ∈ l.flatMap f ↔ ∃ a, a ∈ l ∧ b ∈ f a := by simp [flatMap_def, mem_flatten] @@ -2090,7 +2090,7 @@ theorem mem_flatMap_of_mem {b : β} {l : List α} {f : α → List β} {a} (al : b ∈ l.flatMap f := mem_flatMap.2 ⟨a, al, h⟩ @[simp] -theorem flatMap_eq_nil_iff {l : List α} {f : α → List β} : List.flatMap l f = [] ↔ ∀ x ∈ l, f x = [] := +theorem flatMap_eq_nil_iff {l : List α} {f : α → List β} : l.flatMap f = [] ↔ ∀ x ∈ l, f x = [] := flatten_eq_nil_iff.trans <| by simp only [mem_map, forall_exists_index, and_imp, forall_apply_eq_imp_iff₂] diff --git a/src/Init/Data/Vector/Basic.lean b/src/Init/Data/Vector/Basic.lean index 269a6577e3d4..32b336c4801d 100644 --- a/src/Init/Data/Vector/Basic.lean +++ b/src/Init/Data/Vector/Basic.lean @@ -174,6 +174,9 @@ result is empty. If `stop` is greater than the size of the vector, the size is u ⟨(v.toArray.map Vector.toArray).flatten, by rcases v; simp_all [Function.comp_def, Array.map_const']⟩ +@[inline] def flatMap (v : Vector α n) (f : α → Vector β m) : Vector β (n * m) := + ⟨v.toArray.flatMap fun a => (f a).toArray, by simp [Array.map_const']⟩ + /-- Maps corresponding elements of two vectors of equal size using the function `f`. -/ @[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α → β → φ) : Vector φ n := ⟨Array.zipWith a.toArray b.toArray f, by simp⟩ diff --git a/src/Init/Data/Vector/Lemmas.lean b/src/Init/Data/Vector/Lemmas.lean index 00ea12d7cee3..17964a63efd9 100644 --- a/src/Init/Data/Vector/Lemmas.lean +++ b/src/Init/Data/Vector/Lemmas.lean @@ -1525,6 +1525,75 @@ theorem eq_iff_flatten_eq {L L' : Vector (Vector α n) m} : subst this rfl + +/-! ### flatMap -/ + +@[simp] theorem flatMap_mk (l : Array α) (h : l.size = m) (f : α → Vector β n) : + (mk l h).flatMap f = + mk (l.flatMap (fun a => (f a).toArray)) (by simp [Array.map_const', h]) := by + simp [flatMap] + +@[simp] theorem flatMap_toArray (l : Vector α n) (f : α → Vector β m) : + l.toArray.flatMap (fun a => (f a).toArray) = (l.flatMap f).toArray := by + rcases l with ⟨l, rfl⟩ + simp + +theorem flatMap_def (l : Vector α n) (f : α → Vector β m) : l.flatMap f = flatten (map f l) := by + rcases l with ⟨l, rfl⟩ + simp [Array.flatMap_def, Function.comp_def] + +@[simp] theorem flatMap_id (l : Vector (Vector α m) n) : l.flatMap id = l.flatten := by simp [flatMap_def] + +@[simp] theorem flatMap_id' (l : Vector (Vector α m) n) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def] + +@[simp] theorem mem_flatMap {f : α → Vector β m} {b} {l : Vector α n} : b ∈ l.flatMap f ↔ ∃ a, a ∈ l ∧ b ∈ f a := by + simp [flatMap_def, mem_flatten] + exact ⟨fun ⟨_, ⟨a, h₁, rfl⟩, h₂⟩ => ⟨a, h₁, h₂⟩, fun ⟨a, h₁, h₂⟩ => ⟨_, ⟨a, h₁, rfl⟩, h₂⟩⟩ + +theorem exists_of_mem_flatMap {b : β} {l : Vector α n} {f : α → Vector β m} : + b ∈ l.flatMap f → ∃ a, a ∈ l ∧ b ∈ f a := mem_flatMap.1 + +theorem mem_flatMap_of_mem {b : β} {l : Vector α n} {f : α → Vector β m} {a} (al : a ∈ l) (h : b ∈ f a) : + b ∈ l.flatMap f := mem_flatMap.2 ⟨a, al, h⟩ + +theorem forall_mem_flatMap {p : β → Prop} {l : Vector α n} {f : α → Vector β m} : + (∀ (x) (_ : x ∈ l.flatMap f), p x) ↔ ∀ (a) (_ : a ∈ l) (b) (_ : b ∈ f a), p b := by + simp only [mem_flatMap, forall_exists_index, and_imp] + constructor <;> (intros; solve_by_elim) + +theorem flatMap_singleton (f : α → Vector β m) (x : α) : #v[x].flatMap f = (f x).cast (by simp) := by + simp [flatMap_def] + +@[simp] theorem flatMap_singleton' (l : Vector α n) : (l.flatMap fun x => #v[x]) = l.cast (by simp) := by + rcases l with ⟨l, rfl⟩ + simp + +@[simp] theorem flatMap_append (xs ys : Vector α n) (f : α → Vector β m) : + (xs ++ ys).flatMap f = (xs.flatMap f ++ ys.flatMap f).cast (by simp [Nat.add_mul]) := by + rcases xs with ⟨xs⟩ + rcases ys with ⟨ys⟩ + simp [flatMap_def, flatten_append] + +theorem flatMap_assoc {α β} (l : Vector α n) (f : α → Vector β m) (g : β → Vector γ k) : + (l.flatMap f).flatMap g = (l.flatMap fun x => (f x).flatMap g).cast (by simp [Nat.mul_assoc]) := by + rcases l with ⟨l, rfl⟩ + simp [Array.flatMap_assoc] + +theorem map_flatMap (f : β → γ) (g : α → Vector β m) (l : Vector α n) : + (l.flatMap g).map f = l.flatMap fun a => (g a).map f := by + rcases l with ⟨l, rfl⟩ + simp [Array.map_flatMap] + +theorem flatMap_map (f : α → β) (g : β → Vector γ k) (l : Vector α n) : + (map f l).flatMap g = l.flatMap (fun a => g (f a)) := by + rcases l with ⟨l, rfl⟩ + simp [Array.flatMap_map] + +theorem map_eq_flatMap {α β} (f : α → β) (l : Vector α n) : + map f l = (l.flatMap fun x => #v[f x]).cast (by simp) := by + rcases l with ⟨l, rfl⟩ + simp [Array.map_eq_flatMap] + /-! Content below this point has not yet been aligned with `List` and `Array`. -/ @[simp] theorem getElem_ofFn {α n} (f : Fin n → α) (i : Nat) (h : i < n) :