Skip to content

Commit

Permalink
feat: add model implementation for UTF8 enc/dec (#3961)
Browse files Browse the repository at this point in the history
- [x] Depends on: #3958 
- [x] Depends on: #3960

This makes the UTF-8 encode and decode functions have lean definitions,
so that we can prove properties about them downstream.
  • Loading branch information
digama0 authored Apr 22, 2024
1 parent 7c34b73 commit 70a2394
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 5 deletions.
93 changes: 89 additions & 4 deletions src/Init/Data/String/Extra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,69 @@ def toNat! (s : String) : Nat :=
else
panic! "Nat expected"

def utf8DecodeChar? (a : ByteArray) (i : Nat) : Option Char := do
let c ← a[i]?
if c &&& 0x80 == 0 then
some ⟨c.toUInt32, .inl (Nat.lt_trans c.1.2 (by decide))⟩
else if c &&& 0xe0 == 0xc0 then
let c1 ← a[i+1]?
guard (c1 &&& 0xc0 == 0x80)
let r := ((c &&& 0x1f).toUInt32 <<< 6) ||| (c1 &&& 0x3f).toUInt32
guard (0x80 ≤ r)
-- TODO: Prove h from the definition of r once we have the necessary lemmas
if h : r < 0xd800 then some ⟨r, .inl h⟩ else none
else if c &&& 0xf0 == 0xe0 then
let c1 ← a[i+1]?
let c2 ← a[i+2]?
guard (c1 &&& 0xc0 == 0x80 && c2 &&& 0xc0 == 0x80)
let r :=
((c &&& 0x0f).toUInt32 <<< 12) |||
((c1 &&& 0x3f).toUInt32 <<< 6) |||
(c2 &&& 0x3f).toUInt32
guard (0x800 ≤ r)
-- TODO: Prove `r < 0x110000` from the definition of r once we have the necessary lemmas
if h : r < 0xd8000xdfff < r ∧ r < 0x110000 then some ⟨r, h⟩ else none
else if c &&& 0xf8 == 0xf0 then
let c1 ← a[i+1]?
let c2 ← a[i+2]?
let c3 ← a[i+3]?
guard (c1 &&& 0xc0 == 0x80 && c2 &&& 0xc0 == 0x80 && c3 &&& 0xc0 == 0x80)
let r :=
((c &&& 0x07).toUInt32 <<< 18) |||
((c1 &&& 0x3f).toUInt32 <<< 12) |||
((c2 &&& 0x3f).toUInt32 <<< 6) |||
(c3 &&& 0x3f).toUInt32
if h : 0x10000 ≤ r ∧ r < 0x110000 then
some ⟨r, .inr ⟨Nat.lt_of_lt_of_le (by decide) h.1, h.2⟩⟩
else none
else
none

/-- Returns true if the given byte array consists of valid UTF-8. -/
@[extern "lean_string_validate_utf8"]
opaque validateUTF8 (a : @& ByteArray) : Bool
def validateUTF8 (a : @& ByteArray) : Bool :=
(loop 0).isSome
where
loop (i : Nat) : Option Unit := do
if i < a.size then
let c ← utf8DecodeChar? a i
loop (i + csize c)
else pure ()
termination_by a.size - i
decreasing_by exact Nat.sub_lt_sub_left ‹_› (Nat.lt_add_of_pos_right (one_le_csize c))

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`. -/
@[extern "lean_string_from_utf8"]
opaque fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String
def fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String :=
loop 0 ""
where
loop (i : Nat) (acc : String) : String :=
if i < a.size then
let c := (utf8DecodeChar? a i).getD default
loop (i + csize c) (acc.push c)
else acc
termination_by a.size - i
decreasing_by exact Nat.sub_lt_sub_left ‹_› (Nat.lt_add_of_pos_right (one_le_csize c))

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`,
or returns `none` if `a` is not properly UTF-8 encoded. -/
Expand All @@ -35,13 +91,42 @@ or panics if `a` is not properly UTF-8 encoded. -/
@[inline] def fromUTF8! (a : ByteArray) : String :=
if h : validateUTF8 a then fromUTF8 a h else panic! "invalid UTF-8 string"

def utf8EncodeChar (c : Char) : List UInt8 :=
let v := c.val
if v ≤ 0x7f then
[v.toUInt8]
else if v ≤ 0x7ff then
[(v >>> 6).toUInt8 &&& 0x1f ||| 0xc0,
v.toUInt8 &&& 0x3f ||| 0x80]
else if v ≤ 0xffff then
[(v >>> 12).toUInt8 &&& 0x0f ||| 0xe0,
(v >>> 6).toUInt8 &&& 0x3f ||| 0x80,
v.toUInt8 &&& 0x3f ||| 0x80]
else
[(v >>> 18).toUInt8 &&& 0x07 ||| 0xf0,
(v >>> 12).toUInt8 &&& 0x3f ||| 0x80,
(v >>> 6).toUInt8 &&& 0x3f ||| 0x80,
v.toUInt8 &&& 0x3f ||| 0x80]

@[simp] theorem length_utf8EncodeChar (c : Char) : (utf8EncodeChar c).length = csize c := by
simp [csize, utf8EncodeChar, Char.utf8Size]
cases Decidable.em (c.val ≤ 0x7f) <;> simp [*]
cases Decidable.em (c.val ≤ 0x7ff) <;> simp [*]
cases Decidable.em (c.val ≤ 0xffff) <;> simp [*]

/-- Converts the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/
@[extern "lean_string_to_utf8"]
opaque toUTF8 (a : @& String) : ByteArray
def toUTF8 (a : @& String) : ByteArray :=
⟨⟨a.data.bind utf8EncodeChar⟩⟩

@[simp] theorem size_toUTF8 (s : String) : s.toUTF8.size = s.utf8ByteSize := by
simp [toUTF8, ByteArray.size, Array.size, utf8ByteSize, List.bind]
induction s.data <;> simp [List.map, List.join, utf8ByteSize.go, Nat.add_comm, *]

/-- Accesses a byte in the UTF-8 encoding of the `String`. O(1) -/
@[extern "lean_string_get_byte_fast"]
opaque getUtf8Byte (s : @& String) (n : Nat) (h : n < s.utf8ByteSize) : UInt8
def getUtf8Byte (s : @& String) (n : Nat) (h : n < s.utf8ByteSize) : UInt8 :=
(toUTF8 s).get ⟨n, size_toUTF8 _ ▸ h⟩

theorem Iterator.sizeOf_next_lt_of_hasNext (i : String.Iterator) (h : i.hasNext) : sizeOf i.next < sizeOf i := by
cases i; rename_i s pos; simp [Iterator.next, Iterator.sizeOf_eq]; simp [Iterator.hasNext] at h
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/utf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ bool validate_utf8(uint8_t const * str, size_t size) {
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r < 0x800 || (r >= 0xD800 && r < 0xDFFF)) return false;
if (r < 0x800 || (r >= 0xD800 && r <= 0xDFFF)) return false;

i += 3;
} else if ((c & 0xf8) == 0xf0) {
Expand Down
44 changes: 44 additions & 0 deletions tests/lean/run/utf8英語.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,47 @@
import Lean.Util.TestExtern

instance : BEq ByteArray where
beq x y := x.data == y.data

test_extern String.toUTF8 ""
test_extern String.toUTF8 "\x00"
test_extern String.toUTF8 "$£€𐍈"

macro "test_extern'" t:term " => " v:term : command =>
`(test_extern $t
#guard $t == $v)

def checkGet (s : String) (arr : Array UInt8) :=
(List.range s.utf8ByteSize).all fun i =>
let c := if h : _ then s.getUtf8Byte i h else unreachable!
c == arr.get! i

macro "validate" arr:term " => ↯" : command =>
`(test_extern' String.validateUTF8 $arr => false)
macro "validate" arr:term " => " str:term : command =>
`(test_extern' String.validateUTF8 $arr => true
test_extern' String.fromUTF8 $arr (with_decl_name% _validate by native_decide) => $str
test_extern' String.toUTF8 $str => $arr
#guard checkGet $str ($arr : ByteArray).data)

validate ⟨#[]⟩ => ""
validate ⟨#[0]⟩ => "\x00"
validate ⟨#[0x80]⟩ => ↯
validate ⟨#[0x80, 0x1]⟩ => ↯
validate ⟨#[0xc0, 0x81]⟩ => ↯
validate ⟨#[0xc8, 0x81]⟩ => "ȁ"
validate ⟨#[0xc8, 0x81, 0xc8, 0x81]⟩ => "ȁȁ"
validate ⟨#[0xe0, 0x81]⟩ => ↯
validate ⟨#[0xe0, 0x81, 0x81]⟩ => ↯
validate ⟨#[0xe1, 0x81, 0x81]⟩ => "\u1041"
validate ⟨#[0xed, 0x9f, 0xbf]⟩ => "\ud7ff"
validate ⟨#[0xed, 0xa0, 0xb0]⟩ => ↯
validate ⟨#[0xed, 0xbf, 0xbf]⟩ => ↯
validate ⟨#[0xee, 0x80, 0x80]⟩ => "\ue000"
validate ⟨#[0xf1, 0x81, 0x81, 0x81]⟩ => "񁁁"
validate ⟨#[0xf8, 0x81, 0x81, 0x81, 0x81]⟩ => ↯
validate ⟨#[0x24, 0xc2, 0xa3, 0xe2, 0x82, 0xac, 0xf0, 0x90, 0x8d, 0x88]⟩ => "$£€𐍈"

def check_eq {α} [BEq α] [Repr α] (tag : String) (expected actual : α) : IO Unit :=
unless (expected == actual) do
throw $ IO.userError $
Expand Down

0 comments on commit 70a2394

Please sign in to comment.