Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rm partial / bounds checks in Array.qsort #3933

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions src/Init/Data/Array/QSort.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,58 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Array.Basic
import Init.Omega

namespace Array
-- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget

def qpartition (as : Array α) (lt : α → α → Bool) (lo hi : Nat) : Nat × Array α :=
if h : as.size = 0 then (0, as) else have : Inhabited α := ⟨as[0]'(by revert h; cases as.size <;> simp)⟩ -- TODO: remove
let mid := (lo + hi) / 2
let as := if lt (as.get! mid) (as.get! lo) then as.swap! lo mid else as
let as := if lt (as.get! hi) (as.get! lo) then as.swap! lo hi else as
let as := if lt (as.get! mid) (as.get! hi) then as.swap! mid hi else as
let pivot := as.get! hi
let rec loop (as : Array α) (i j : Nat) :=
if h : j < hi then
if lt (as.get! j) pivot then
let as := as.swap! i j
loop as (i+1) (j+1)
@[inline] def qpartition (as : {as : Array α // as.size = n})
(lt : α → α → Bool) (lo hi : Fin n) (hle : lo ≤ hi) :
{as : Array α // as.size = n} × {pivot : Fin n // lo ≤ pivot ∧ pivot ≤ hi} :=
let mid : Fin n := ⟨(lo.1 + hi) / 2, by omega⟩
let rec @[inline] maybeSwap (as : {as : Array α // as.size = n}) (lo hi : Fin n) : {as : Array α // as.size = n} :=
let hi := hi.cast as.2.symm
let lo := lo.cast as.2.symm
if lt (as.1.get hi) (as.1.get lo) then ⟨as.1.swap lo hi, (Array.size_swap ..).trans as.2⟩ else as
let as := maybeSwap as lo mid
let as := maybeSwap as lo hi
let as := maybeSwap as hi mid
let_fun pivot := as.1.get (hi.cast as.2.symm)
let rec loop
(as : {as : Array α // as.size = n}) (i j : Fin n) (H : lo ≤ i ∧ i ≤ j ∧ j ≤ hi) :
{as : Array α // as.size = n} × {pivot : Fin n // lo ≤ pivot ∧ pivot ≤ hi} :=
have ⟨loi, ij, jhi⟩ := H
if h : j < hi then by
-- FIXME: if we don't clear these variables, `omega` will revert/intro them
-- and as a result `loop` will spuriously depend on the extra `as` variables, breaking linearity
rename_i as₁ as₂ as₃ as₄; clear as₁ mid as₂ as₃ as₄
exact if lt (as.1.get (j.cast as.2.symm)) pivot then
let as := ⟨as.1.swap (i.cast as.2.symm) (j.cast as.2.symm), (Array.size_swap ..).trans as.2⟩
loop as ⟨i.1+1, by omega⟩ ⟨j.1+1, by omega⟩
⟨Nat.le_succ_of_le H.1, Nat.succ_le_succ ij, Nat.succ_le_of_lt h⟩
else
loop as i (j+1)
loop as i ⟨j.1+1, by omega⟩ ⟨loi, Nat.le_succ_of_le ij, Nat.succ_le_of_lt h⟩
else
let as := as.swap! i hi
(i, as)
termination_by hi - j
loop as lo lo
let as := as.1.swap (i.cast as.2.symm) (hi.cast as.2.symm), (Array.size_swap ..).trans as.2⟩
⟨as, i, loi, Nat.le_trans ij jhi⟩
termination_by hi.1 - j
loop as lo lo ⟨Nat.le_refl _, Nat.le_refl _, hle⟩

@[inline] partial def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort (as : Array α) (low high : Nat) :=
if low < high then
let p := qpartition as lt low high;
-- TODO: fix `partial` support in the equation compiler, it breaks if we use `let (mid, as) := partition as lt low high`
let mid := p.1
let as := p.2
if mid >= high then as
else
let as := sort as low mid
sort as (mid+1) high
@[inline] def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort {n} (as : {as : Array α // as.size = n})
(lo : Nat) (hi : Fin n) : {as : Array α // as.size = n} :=
if h : lo < hi.1 then
let ⟨as, mid, (_ : lo ≤ mid), _⟩ :=
qpartition as lt ⟨lo, Nat.lt_trans h hi.2⟩ hi (Nat.le_of_lt h)
let as := sort as lo ⟨mid - 1, by omega⟩
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily relevant for this PR, but one simple improvement to quicksort (as explained in the Sedgewick paper) is to recurse on the smallest subproblem first. This ensures a stack depth of O(log n) because the other call will be in tail position.

sort as (mid + 1) hi
else as
sort as low high
termination_by hi - lo
if low < high then
if h : high < as.size then
(sort ⟨as, rfl⟩ low ⟨high, h⟩).1
else
have := Inhabited.mk as
panic! "index out of bounds"
else as

end Array
Loading