From 824a7fecba3db57fc8c10932369511a3c8506709 Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Wed, 17 Apr 2024 15:55:00 -0400 Subject: [PATCH] feat: rm partial / bounds checks in Array.qsort --- src/Init/Data/Array/QSort.lean | 76 ++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/src/Init/Data/Array/QSort.lean b/src/Init/Data/Array/QSort.lean index d0974c96c804..55e67ef1a353 100644 --- a/src/Init/Data/Array/QSort.lean +++ b/src/Init/Data/Array/QSort.lean @@ -5,42 +5,58 @@ Authors: Leonardo de Moura -/ prelude import Init.Data.Array.Basic +import Init.Omega namespace Array --- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget -def qpartition (as : Array α) (lt : α → α → Bool) (lo hi : Nat) : Nat × Array α := - if h : as.size = 0 then (0, as) else have : Inhabited α := ⟨as[0]'(by revert h; cases as.size <;> simp)⟩ -- TODO: remove - let mid := (lo + hi) / 2 - let as := if lt (as.get! mid) (as.get! lo) then as.swap! lo mid else as - let as := if lt (as.get! hi) (as.get! lo) then as.swap! lo hi else as - let as := if lt (as.get! mid) (as.get! hi) then as.swap! mid hi else as - let pivot := as.get! hi - let rec loop (as : Array α) (i j : Nat) := - if h : j < hi then - if lt (as.get! j) pivot then - let as := as.swap! i j - loop as (i+1) (j+1) +@[inline] def qpartition (as : {as : Array α // as.size = n}) + (lt : α → α → Bool) (lo hi : Fin n) (hle : lo ≤ hi) : + {as : Array α // as.size = n} × {pivot : Fin n // lo ≤ pivot ∧ pivot ≤ hi} := + let mid : Fin n := ⟨(lo.1 + hi) / 2, by omega⟩ + let rec @[inline] maybeSwap (as : {as : Array α // as.size = n}) (lo hi : Fin n) : {as : Array α // as.size = n} := + let hi := hi.cast as.2.symm + let lo := lo.cast as.2.symm + if lt (as.1.get hi) (as.1.get lo) then ⟨as.1.swap lo hi, (Array.size_swap ..).trans as.2⟩ else as + let as := maybeSwap as lo mid + let as := maybeSwap as lo hi + let as := maybeSwap as hi mid + let_fun pivot := as.1.get (hi.cast as.2.symm) + let rec loop + (as : {as : Array α // as.size = n}) (i j : Fin n) (H : lo ≤ i ∧ i ≤ j ∧ j ≤ hi) : + {as : Array α // as.size = n} × {pivot : Fin n // lo ≤ pivot ∧ pivot ≤ hi} := + have ⟨loi, ij, jhi⟩ := H + if h : j < hi then by + -- FIXME: if we don't clear these variables, `omega` will revert/intro them + -- and as a result `loop` will spuriously depend on the extra `as` variables, breaking linearity + rename_i as₁ as₂ as₃ as₄; clear as₁ mid as₂ as₃ as₄ + exact if lt (as.1.get (j.cast as.2.symm)) pivot then + let as := ⟨as.1.swap (i.cast as.2.symm) (j.cast as.2.symm), (Array.size_swap ..).trans as.2⟩ + loop as ⟨i.1+1, by omega⟩ ⟨j.1+1, by omega⟩ + ⟨Nat.le_succ_of_le H.1, Nat.succ_le_succ ij, Nat.succ_le_of_lt h⟩ else - loop as i (j+1) + loop as i ⟨j.1+1, by omega⟩ ⟨loi, Nat.le_succ_of_le ij, Nat.succ_le_of_lt h⟩ else - let as := as.swap! i hi - (i, as) - termination_by hi - j - loop as lo lo + let as := ⟨as.1.swap (i.cast as.2.symm) (hi.cast as.2.symm), (Array.size_swap ..).trans as.2⟩ + ⟨as, i, loi, Nat.le_trans ij jhi⟩ + termination_by hi.1 - j + loop as lo lo ⟨Nat.le_refl _, Nat.le_refl _, hle⟩ -@[inline] partial def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α := - let rec @[specialize] sort (as : Array α) (low high : Nat) := - if low < high then - let p := qpartition as lt low high; - -- TODO: fix `partial` support in the equation compiler, it breaks if we use `let (mid, as) := partition as lt low high` - let mid := p.1 - let as := p.2 - if mid >= high then as - else - let as := sort as low mid - sort as (mid+1) high +@[inline] def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α := + let rec @[specialize] sort {n} (as : {as : Array α // as.size = n}) + (lo : Nat) (hi : Fin n) : {as : Array α // as.size = n} := + if h : lo < hi.1 then + let ⟨as, mid, (_ : lo ≤ mid), _⟩ := + qpartition as lt ⟨lo, Nat.lt_trans h hi.2⟩ hi (Nat.le_of_lt h) + let as := sort as lo ⟨mid - 1, by omega⟩ + sort as (mid + 1) hi else as - sort as low high + termination_by hi - lo + if low < high then + if h : high < as.size then + (sort ⟨as, rfl⟩ low ⟨high, h⟩).1 + else + have := Inhabited.mk as + panic! "index out of bounds" + else as end Array