Skip to content

Commit

Permalink
feat: align List/Array/Vector.flatMap (#6660)
Browse files Browse the repository at this point in the history
This PR defines `Vector.flatMap`, changes the order of arguments in
`List.flatMap` for consistency, and aligns the lemmas for
`List`/`Array`/`Vector` `flatMap`.
  • Loading branch information
kim-em authored Jan 16, 2025
1 parent 3a6c5cf commit 80ddbf4
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 15 deletions.
114 changes: 110 additions & 4 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,11 @@ theorem filterMap_eq_push_iff {f : α → Option β} {l : Array α} {l' : Array
cases bs
simp

theorem toArray_append {xs : List α} {ys : Array α} :
xs.toArray ++ ys = (xs ++ ys.toList).toArray := by
rcases ys with ⟨ys⟩
simp

@[simp] theorem toArray_eq_append_iff {xs : List α} {as bs : Array α} :
xs.toArray = as ++ bs ↔ xs = as.toList ++ bs.toList := by
cases as
Expand Down Expand Up @@ -1871,6 +1876,11 @@ theorem append_eq_map_iff {f : α → β} :
rw [← flatten_map_toArray]
simp

theorem flatten_toArray (l : List (Array α)) :
l.toArray.flatten = (l.map Array.toList).flatten.toArray := by
apply ext'
simp

@[simp] theorem size_flatten (L : Array (Array α)) : L.flatten.size = (L.map size).sum := by
cases L using array₂_induction
simp [Function.comp_def]
Expand All @@ -1886,14 +1896,14 @@ theorem mem_flatten : ∀ {L : Array (Array α)}, a ∈ L.flatten ↔ ∃ l, l
· rintro ⟨s, h₁, h₂⟩
refine ⟨s.toList, ⟨⟨s, h₁, rfl⟩, h₂⟩⟩

@[simp] theorem flatten_eq_nil_iff {L : Array (Array α)} : L.flatten = #[] ↔ ∀ l ∈ L, l = #[] := by
@[simp] theorem flatten_eq_empty_iff {L : Array (Array α)} : L.flatten = #[] ↔ ∀ l ∈ L, l = #[] := by
induction L using array₂_induction
simp

@[simp] theorem nil_eq_flatten_iff {L : Array (Array α)} : #[] = L.flatten ↔ ∀ l ∈ L, l = #[] := by
rw [eq_comm, flatten_eq_nil_iff]
@[simp] theorem empty_eq_flatten_iff {L : Array (Array α)} : #[] = L.flatten ↔ ∀ l ∈ L, l = #[] := by
rw [eq_comm, flatten_eq_empty_iff]

theorem flatten_ne_nil_iff {xs : Array (Array α)} : xs.flatten ≠ #[] ↔ ∃ x, x ∈ xs ∧ x ≠ #[] := by
theorem flatten_ne_empty_iff {xs : Array (Array α)} : xs.flatten ≠ #[] ↔ ∃ x, x ∈ xs ∧ x ≠ #[] := by
simp

theorem exists_of_mem_flatten : a ∈ flatten L → ∃ l, l ∈ L ∧ a ∈ l := mem_flatten.1
Expand Down Expand Up @@ -2029,6 +2039,102 @@ theorem eq_iff_flatten_eq {L L' : Array (Array α)} :
rw [List.map_inj_right]
simp +contextual

/-! ### flatMap -/

theorem flatMap_def (l : Array α) (f : α → Array β) : l.flatMap f = flatten (map f l) := by
rcases l with ⟨l⟩
simp [flatten_toArray, Function.comp_def, List.flatMap_def]

theorem flatMap_toList (l : Array α) (f : α → List β) :
l.toList.flatMap f = (l.flatMap (fun a => (f a).toArray)).toList := by
rcases l with ⟨l⟩
simp

@[simp] theorem flatMap_id (l : Array (Array α)) : l.flatMap id = l.flatten := by simp [flatMap_def]

@[simp] theorem flatMap_id' (l : Array (Array α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]

@[simp]
theorem size_flatMap (l : Array α) (f : α → Array β) :
(l.flatMap f).size = sum (map (fun a => (f a).size) l) := by
rcases l with ⟨l⟩
simp [Function.comp_def]

@[simp] theorem mem_flatMap {f : α → Array β} {b} {l : Array α} : b ∈ l.flatMap f ↔ ∃ a, a ∈ l ∧ b ∈ f a := by
simp [flatMap_def, mem_flatten]
exact ⟨fun ⟨_, ⟨a, h₁, rfl⟩, h₂⟩ => ⟨a, h₁, h₂⟩, fun ⟨a, h₁, h₂⟩ => ⟨_, ⟨a, h₁, rfl⟩, h₂⟩⟩

theorem exists_of_mem_flatMap {b : β} {l : Array α} {f : α → Array β} :
b ∈ l.flatMap f → ∃ a, a ∈ l ∧ b ∈ f a := mem_flatMap.1

theorem mem_flatMap_of_mem {b : β} {l : Array α} {f : α → Array β} {a} (al : a ∈ l) (h : b ∈ f a) :
b ∈ l.flatMap f := mem_flatMap.2 ⟨a, al, h⟩

@[simp]
theorem flatMap_eq_empty_iff {l : Array α} {f : α → Array β} : l.flatMap f = #[] ↔ ∀ x ∈ l, f x = #[] := by
rw [flatMap_def, flatten_eq_empty_iff]
simp

theorem forall_mem_flatMap {p : β → Prop} {l : Array α} {f : α → Array β} :
(∀ (x) (_ : x ∈ l.flatMap f), p x) ↔ ∀ (a) (_ : a ∈ l) (b) (_ : b ∈ f a), p b := by
simp only [mem_flatMap, forall_exists_index, and_imp]
constructor <;> (intros; solve_by_elim)

theorem flatMap_singleton (f : α → Array β) (x : α) : #[x].flatMap f = f x := by
simp

@[simp] theorem flatMap_singleton' (l : Array α) : (l.flatMap fun x => #[x]) = l := by
rcases l with ⟨l⟩
simp

@[simp] theorem flatMap_append (xs ys : Array α) (f : α → Array β) :
(xs ++ ys).flatMap f = xs.flatMap f ++ ys.flatMap f := by
rcases xs with ⟨xs⟩
rcases ys with ⟨ys⟩
simp

theorem flatMap_assoc {α β} (l : Array α) (f : α → Array β) (g : β → Array γ) :
(l.flatMap f).flatMap g = l.flatMap fun x => (f x).flatMap g := by
rcases l with ⟨l⟩
simp [List.flatMap_assoc, flatMap_toList]

theorem map_flatMap (f : β → γ) (g : α → Array β) (l : Array α) :
(l.flatMap g).map f = l.flatMap fun a => (g a).map f := by
rcases l with ⟨l⟩
simp [List.map_flatMap]

theorem flatMap_map (f : α → β) (g : β → Array γ) (l : Array α) :
(map f l).flatMap g = l.flatMap (fun a => g (f a)) := by
rcases l with ⟨l⟩
simp [List.flatMap_map]

theorem map_eq_flatMap {α β} (f : α → β) (l : Array α) : map f l = l.flatMap fun x => #[f x] := by
simp only [← map_singleton]
rw [← flatMap_singleton' l, map_flatMap, flatMap_singleton']

theorem filterMap_flatMap {β γ} (l : Array α) (g : α → Array β) (f : β → Option γ) :
(l.flatMap g).filterMap f = l.flatMap fun a => (g a).filterMap f := by
rcases l with ⟨l⟩
simp [List.filterMap_flatMap]

theorem filter_flatMap (l : Array α) (g : α → Array β) (f : β → Bool) :
(l.flatMap g).filter f = l.flatMap fun a => (g a).filter f := by
rcases l with ⟨l⟩
simp [List.filter_flatMap]

theorem flatMap_eq_foldl (f : α → Array β) (l : Array α) :
l.flatMap f = l.foldl (fun acc a => acc ++ f a) #[] := by
rcases l with ⟨l⟩
simp only [List.flatMap_toArray, List.flatMap_eq_foldl, size_toArray, List.foldl_toArray']
suffices ∀ l', (List.foldl (fun acc a => acc ++ (f a).toList) l' l).toArray =
List.foldl (fun acc a => acc ++ f a) l'.toArray l by
simpa using this []
induction l with
| nil => simp
| cons a l ih =>
intro l'
simp [ih ((l' ++ (f a).toList)), toArray_append]

/-! Content below this point has not yet been aligned with `List`. -/

-- This is a duplicate of `List.toArray_toList`.
Expand Down
6 changes: 3 additions & 3 deletions src/Init/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -606,11 +606,11 @@ set_option linter.missingDocs false in
to get a list of lists, and then concatenates them all together.
* `[2, 3, 2].bind range = [0, 1, 0, 1, 2, 0, 1]`
-/
@[inline] def flatMap {α : Type u} {β : Type v} (a : List α) (b : α → List β) : List β := flatten (map b a)
@[inline] def flatMap {α : Type u} {β : Type v} (b : α → List β) (a : List α) : List β := flatten (map b a)

@[simp] theorem flatMap_nil (f : α → List β) : List.flatMap [] f = [] := by simp [flatten, List.flatMap]
@[simp] theorem flatMap_nil (f : α → List β) : List.flatMap f [] = [] := by simp [flatten, List.flatMap]
@[simp] theorem flatMap_cons x xs (f : α → List β) :
List.flatMap (x :: xs) f = f x ++ List.flatMap xs f := by simp [flatten, List.flatMap]
List.flatMap f (x :: xs) = f x ++ List.flatMap f xs := by simp [flatten, List.flatMap]

set_option linter.missingDocs false in
@[deprecated flatMap (since := "2024-10-16")] abbrev bind := @flatMap
Expand Down
6 changes: 3 additions & 3 deletions src/Init/Data/List/Impl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ The following operations are given `@[csimp]` replacements below:
/-! ### flatMap -/

/-- Tail recursive version of `List.flatMap`. -/
@[inline] def flatMapTR (as : List α) (f : α → List β) : List β := go as #[] where
@[inline] def flatMapTR (f : α → List β) (as : List α) : List β := go as #[] where
/-- Auxiliary for `flatMap`: `flatMap.go f as = acc.toList ++ bind f as` -/
@[specialize] go : List α → Array β → List β
| [], acc => acc.toList
| x::xs, acc => go xs (acc ++ f x)

@[csimp] theorem flatMap_eq_flatMapTR : @List.flatMap = @flatMapTR := by
funext α β as f
funext α β f as
let rec go : ∀ as acc, flatMapTR.go f as acc = acc.toList ++ as.flatMap f
| [], acc => by simp [flatMapTR.go, flatMap]
| x::xs, acc => by simp [flatMapTR.go, flatMap, go xs]
Expand All @@ -112,7 +112,7 @@ The following operations are given `@[csimp]` replacements below:
/-! ### flatten -/

/-- Tail recursive version of `List.flatten`. -/
@[inline] def flattenTR (l : List (List α)) : List α := flatMapTR l id
@[inline] def flattenTR (l : List (List α)) : List α := l.flatMapTR id

@[csimp] theorem flatten_eq_flattenTR : @flatten = @flattenTR := by
funext α l; rw [← List.flatMap_id, List.flatMap_eq_flatMapTR]; rfl
Expand Down
10 changes: 5 additions & 5 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2070,14 +2070,14 @@ theorem eq_iff_flatten_eq : ∀ {L L' : List (List α)},

theorem flatMap_def (l : List α) (f : α → List β) : l.flatMap f = flatten (map f l) := by rfl

@[simp] theorem flatMap_id (l : List (List α)) : List.flatMap l id = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id (l : List (List α)) : l.flatMap id = l.flatten := by simp [flatMap_def]

@[simp] theorem flatMap_id' (l : List (List α)) : List.flatMap l (fun a => a) = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id' (l : List (List α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]

@[simp]
theorem length_flatMap (l : List α) (f : α → List β) :
length (l.flatMap f) = sum (map (length ∘ f) l) := by
rw [List.flatMap, length_flatten, map_map]
length (l.flatMap f) = sum (map (fun a => (f a).length) l) := by
rw [List.flatMap, length_flatten, map_map, Function.comp_def]

@[simp] theorem mem_flatMap {f : α → List β} {b} {l : List α} : b ∈ l.flatMap f ↔ ∃ a, a ∈ l ∧ b ∈ f a := by
simp [flatMap_def, mem_flatten]
Expand All @@ -2090,7 +2090,7 @@ theorem mem_flatMap_of_mem {b : β} {l : List α} {f : α → List β} {a} (al :
b ∈ l.flatMap f := mem_flatMap.2 ⟨a, al, h⟩

@[simp]
theorem flatMap_eq_nil_iff {l : List α} {f : α → List β} : List.flatMap l f = [] ↔ ∀ x ∈ l, f x = [] :=
theorem flatMap_eq_nil_iff {l : List α} {f : α → List β} : l.flatMap f = [] ↔ ∀ x ∈ l, f x = [] :=
flatten_eq_nil_iff.trans <| by
simp only [mem_map, forall_exists_index, and_imp, forall_apply_eq_imp_iff₂]

Expand Down
3 changes: 3 additions & 0 deletions src/Init/Data/Vector/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ result is empty. If `stop` is greater than the size of the vector, the size is u
⟨(v.toArray.map Vector.toArray).flatten,
by rcases v; simp_all [Function.comp_def, Array.map_const']⟩

@[inline] def flatMap (v : Vector α n) (f : α → Vector β m) : Vector β (n * m) :=
⟨v.toArray.flatMap fun a => (f a).toArray, by simp [Array.map_const']⟩

/-- Maps corresponding elements of two vectors of equal size using the function `f`. -/
@[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α → β → φ) : Vector φ n :=
⟨Array.zipWith a.toArray b.toArray f, by simp⟩
Expand Down
69 changes: 69 additions & 0 deletions src/Init/Data/Vector/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,75 @@ theorem eq_iff_flatten_eq {L L' : Vector (Vector α n) m} :
subst this
rfl


/-! ### flatMap -/

@[simp] theorem flatMap_mk (l : Array α) (h : l.size = m) (f : α → Vector β n) :
(mk l h).flatMap f =
mk (l.flatMap (fun a => (f a).toArray)) (by simp [Array.map_const', h]) := by
simp [flatMap]

@[simp] theorem flatMap_toArray (l : Vector α n) (f : α → Vector β m) :
l.toArray.flatMap (fun a => (f a).toArray) = (l.flatMap f).toArray := by
rcases l with ⟨l, rfl⟩
simp

theorem flatMap_def (l : Vector α n) (f : α → Vector β m) : l.flatMap f = flatten (map f l) := by
rcases l with ⟨l, rfl⟩
simp [Array.flatMap_def, Function.comp_def]

@[simp] theorem flatMap_id (l : Vector (Vector α m) n) : l.flatMap id = l.flatten := by simp [flatMap_def]

@[simp] theorem flatMap_id' (l : Vector (Vector α m) n) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]

@[simp] theorem mem_flatMap {f : α → Vector β m} {b} {l : Vector α n} : b ∈ l.flatMap f ↔ ∃ a, a ∈ l ∧ b ∈ f a := by
simp [flatMap_def, mem_flatten]
exact ⟨fun ⟨_, ⟨a, h₁, rfl⟩, h₂⟩ => ⟨a, h₁, h₂⟩, fun ⟨a, h₁, h₂⟩ => ⟨_, ⟨a, h₁, rfl⟩, h₂⟩⟩

theorem exists_of_mem_flatMap {b : β} {l : Vector α n} {f : α → Vector β m} :
b ∈ l.flatMap f → ∃ a, a ∈ l ∧ b ∈ f a := mem_flatMap.1

theorem mem_flatMap_of_mem {b : β} {l : Vector α n} {f : α → Vector β m} {a} (al : a ∈ l) (h : b ∈ f a) :
b ∈ l.flatMap f := mem_flatMap.2 ⟨a, al, h⟩

theorem forall_mem_flatMap {p : β → Prop} {l : Vector α n} {f : α → Vector β m} :
(∀ (x) (_ : x ∈ l.flatMap f), p x) ↔ ∀ (a) (_ : a ∈ l) (b) (_ : b ∈ f a), p b := by
simp only [mem_flatMap, forall_exists_index, and_imp]
constructor <;> (intros; solve_by_elim)

theorem flatMap_singleton (f : α → Vector β m) (x : α) : #v[x].flatMap f = (f x).cast (by simp) := by
simp [flatMap_def]

@[simp] theorem flatMap_singleton' (l : Vector α n) : (l.flatMap fun x => #v[x]) = l.cast (by simp) := by
rcases l with ⟨l, rfl⟩
simp

@[simp] theorem flatMap_append (xs ys : Vector α n) (f : α → Vector β m) :
(xs ++ ys).flatMap f = (xs.flatMap f ++ ys.flatMap f).cast (by simp [Nat.add_mul]) := by
rcases xs with ⟨xs⟩
rcases ys with ⟨ys⟩
simp [flatMap_def, flatten_append]

theorem flatMap_assoc {α β} (l : Vector α n) (f : α → Vector β m) (g : β → Vector γ k) :
(l.flatMap f).flatMap g = (l.flatMap fun x => (f x).flatMap g).cast (by simp [Nat.mul_assoc]) := by
rcases l with ⟨l, rfl⟩
simp [Array.flatMap_assoc]

theorem map_flatMap (f : β → γ) (g : α → Vector β m) (l : Vector α n) :
(l.flatMap g).map f = l.flatMap fun a => (g a).map f := by
rcases l with ⟨l, rfl⟩
simp [Array.map_flatMap]

theorem flatMap_map (f : α → β) (g : β → Vector γ k) (l : Vector α n) :
(map f l).flatMap g = l.flatMap (fun a => g (f a)) := by
rcases l with ⟨l, rfl⟩
simp [Array.flatMap_map]

theorem map_eq_flatMap {α β} (f : α → β) (l : Vector α n) :
map f l = (l.flatMap fun x => #v[f x]).cast (by simp) := by
rcases l with ⟨l, rfl⟩
simp [Array.map_eq_flatMap]

/-! Content below this point has not yet been aligned with `List` and `Array`. -/

@[simp] theorem getElem_ofFn {α n} (f : Fin n → α) (i : Nat) (h : i < n) :
Expand Down

0 comments on commit 80ddbf4

Please sign in to comment.