Skip to content

Commit

Permalink
feat: rm partial / bounds checks in Array.qsort
Browse files Browse the repository at this point in the history
  • Loading branch information
digama0 committed Apr 17, 2024
1 parent 88ee503 commit 4b95a5b
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions src/Init/Data/Array/QSort.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,55 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Array.Basic
import Init.Omega

namespace Array
-- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget

def qpartition (as : Array α) (lt : α → α → Bool) (lo hi : Nat) : Nat × Array α :=
if h : as.size = 0 then (0, as) else have : Inhabited α := ⟨as[0]'(by revert h; cases as.size <;> simp)⟩ -- TODO: remove
let mid := (lo + hi) / 2
let as := if lt (as.get! mid) (as.get! lo) then as.swap! lo mid else as
let as := if lt (as.get! hi) (as.get! lo) then as.swap! lo hi else as
let as := if lt (as.get! mid) (as.get! hi) then as.swap! mid hi else as
let pivot := as.get! hi
let rec loop (as : Array α) (i j : Nat) :=
@[inline] def qpartition (as : {as : Array α // as.size = n})
(lt : α → α → Bool) (lo hi : Fin n) (hle : lo ≤ hi) :
{as : Array α // as.size = n} × {pivot : Fin n // lo ≤ pivot ∧ pivot ≤ hi} :=
let mid : Fin n := ⟨(lo.1 + hi) / 2, by omega⟩
let rec @[inline] maybeSwap (as : {as : Array α // as.size = n}) (lo hi : Fin n) : {as : Array α // as.size = n} :=
let hi := hi.cast as.2.symm
let lo := lo.cast as.2.symm
if lt (as.1.get hi) (as.1.get lo) then ⟨as.1.swap lo hi, (Array.size_swap ..).trans as.2else as
let as := maybeSwap as lo mid
let as := maybeSwap as lo hi
let as := maybeSwap as mid hi
let pivot := as.1.get (hi.cast as.2.symm)
let rec loop
(as : {as : Array α // as.size = n}) (i j : Fin n) (H : lo ≤ i ∧ i ≤ j ∧ j ≤ hi) :
{as : Array α // as.size = n} × {pivot : Fin n // lo ≤ pivot ∧ pivot ≤ hi} :=
have ⟨loi, ij, jhi⟩ := H
if h : j < hi then
if lt (as.get! j) pivot then
let as := as.swap! i j
loop as (i+1) (j+1)
if lt (as.1.get (j.cast as.2.symm)) pivot then
let as := ⟨as.1.swap (i.cast as.2.symm) (j.cast as.2.symm), (Array.size_swap ..).trans as.2
loop as ⟨i.1+1, by omega⟩ ⟨j.1+1, by omega⟩
⟨Nat.le_succ_of_le H.1, Nat.succ_le_succ ij, Nat.succ_le_of_lt h⟩
else
loop as i (j+1)
loop as i ⟨j.1+1, by omega⟩ ⟨loi, Nat.le_succ_of_le ij, Nat.succ_le_of_lt h⟩
else
let as := as.swap! i hi
(i, as)
termination_by hi - j
loop as lo lo
let as := as.1.swap (i.cast as.2.symm) (hi.cast as.2.symm), (Array.size_swap ..).trans as.2
⟨as, i, loi, Nat.le_trans ij jhi⟩
termination_by hi.1 - j
loop as lo lo ⟨Nat.le_refl _, Nat.le_refl _, hle⟩

@[inline] partial def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort (as : Array α) (low high : Nat) :=
if low < high then
let p := qpartition as lt low high;
-- TODO: fix `partial` support in the equation compiler, it breaks if we use `let (mid, as) := partition as lt low high`
let mid := p.1
let as := p.2
if mid >= high then as
else
let as := sort as low mid
sort as (mid+1) high
@[inline] def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort {n} (as : {as : Array α // as.size = n})
(lo : Nat) (hi : Fin n) : {as : Array α // as.size = n} :=
if h : lo < hi.1 then
let ⟨as, mid, (_ : lo ≤ mid), _⟩ :=
qpartition as lt ⟨lo, Nat.lt_trans h hi.2⟩ hi (Nat.le_of_lt h)
if mid < hi.1 then
let as := sort as lo mid
sort as (mid+1) hi
else as
else as
sort as low high
termination_by hi - lo
if low < high then
if h : high < as.size then
(sort ⟨as, rfl⟩ low ⟨high, h⟩).1
else panic! "index out of bounds"
else as

end Array

0 comments on commit 4b95a5b

Please sign in to comment.