Skip to content

Commit

Permalink
feat: UTF-8 string validation (#3958)
Browse files Browse the repository at this point in the history
Previously, there was a function `opaque fromUTF8Unchecked : ByteArray
-> String` which would convert a list of bytes into a string, but as the
name implies it does not validate that the string is UTF-8 before doing
so and as a result it produces unsound results in the compiler (because
the lean model of `String` indirectly asserts UTF-8 validity). This PR
replaces that function by
```lean
opaque validateUTF8 (a : @& ByteArray) : Bool

opaque fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String
```
so that while the function is still "unchecked", we have a proof witness
that the string is valid. To recover the original, actually unchecked
version, use `lcProof` or other unsafe methods to produce the proof
witness.

Because this was the only `ByteArray -> String` conversion function, it
was used in several places in an unsound way (e.g. reading untrusted
input from IO and treating it as UTF-8). These have been replaced by
`fromUTF8?` or `fromUTF8!` as appropriate.
  • Loading branch information
digama0 authored Apr 20, 2024
1 parent 5eb274d commit 62cdb51
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 66 deletions.
25 changes: 18 additions & 7 deletions src/Init/Data/String/Extra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@ def toNat! (s : String) : Nat :=
else
panic! "Nat expected"

/--
Convert a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`.
The result is unspecified if `a` is not properly UTF-8 encoded.
-/
@[extern "lean_string_from_utf8_unchecked"]
opaque fromUTF8Unchecked (a : @& ByteArray) : String
/-- Returns true if the given byte array consists of valid UTF-8. -/
@[extern "lean_string_validate_utf8"]
opaque validateUTF8 (a : @& ByteArray) : Bool

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`. -/
@[extern "lean_string_from_utf8"]
opaque fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`,
or returns `none` if `a` is not properly UTF-8 encoded. -/
@[inline] def fromUTF8? (a : ByteArray) : Option String :=
if h : validateUTF8 a then fromUTF8 a h else none

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`,
or panics if `a` is not properly UTF-8 encoded. -/
@[inline] def fromUTF8! (a : ByteArray) : String :=
if h : validateUTF8 a then fromUTF8 a h else panic! "invalid UTF-8 string"

/-- Convert the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/
/-- Converts the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/
@[extern "lean_string_to_utf8"]
opaque toUTF8 (a : @& String) : ByteArray

Expand Down
18 changes: 11 additions & 7 deletions src/Init/System/IO.lean
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,16 @@ def ofBuffer (r : Ref Buffer) : Stream where
write := fun data => r.modify fun b =>
-- set `exact` to `false` so that repeatedly writing to the stream does not impose quadratic run time
{ b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size }
getLine := r.modifyGet fun b =>
let pos := match b.data.findIdx? (start := b.pos) fun u => u == 0 || u = '\n'.toNat.toUInt8 with
-- include '\n', but not '\0'
| some pos => if b.data.get! pos == 0 then pos else pos + 1
| none => b.data.size
(String.fromUTF8Unchecked <| b.data.extract b.pos pos, { b with pos := pos })
getLine := do
let buf ← r.modifyGet fun b =>
let pos := match b.data.findIdx? (start := b.pos) fun u => u == 0 || u = '\n'.toNat.toUInt8 with
-- include '\n', but not '\0'
| some pos => if b.data.get! pos == 0 then pos else pos + 1
| none => b.data.size
(b.data.extract b.pos pos, { b with pos := pos })
match String.fromUTF8? buf with
| some str => pure str
| none => throw (.userError "invalid UTF-8")
putStr := fun s => r.modify fun b =>
let data := s.toUTF8
{ b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size }
Expand All @@ -791,7 +795,7 @@ def withIsolatedStreams [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (x : m
(if isolateStderr then withStderr (Stream.ofBuffer bOut) else id) <|
x
let bOut ← liftM (m := BaseIO) bOut.get
let out := String.fromUTF8Unchecked bOut.data
let out := String.fromUTF8! bOut.data
pure (out, r)

end FS
Expand Down
2 changes: 1 addition & 1 deletion src/Init/System/Uri.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def decodeUri (uri : String) : String := Id.run do
((decoded.push c).push h1, i + 2)
else
(decoded.push c, i + 1)
return String.fromUTF8Unchecked decoded
return String.fromUTF8! decoded
where hexDigitToUInt8? (c : UInt8) : Option UInt8 :=
if zero ≤ c ∧ c ≤ nine then some (c - zero)
else if lettera ≤ c ∧ c ≤ letterf then some (c - lettera + 10)
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Data/Json/Stream.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ open IO
/-- Consumes `nBytes` bytes from the stream, interprets the bytes as a utf-8 string and the string as a valid JSON object. -/
def readJson (h : FS.Stream) (nBytes : Nat) : IO Json := do
let bytes ← h.read (USize.ofNat nBytes)
let s := String.fromUTF8Unchecked bytes
let some s := String.fromUTF8? bytes | throw (IO.userError "invalid UTF-8")
ofExcept (Json.parse s)

def writeJson (h : FS.Stream) (j : Json) : IO Unit := do
Expand Down
26 changes: 0 additions & 26 deletions src/lake/tests/toml/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,6 @@ inductive TomlOutcome where
| fail (log : MessageLog)
| error (e : IO.Error)

@[inline] def Fin.allM [Monad m] (n) (f : Fin n → m Bool) : m Bool :=
loop 0
where
loop (i : Nat) : m Bool := do
if h : i < n then
if (← f ⟨i, h⟩) then loop (i+1) else pure false
else
pure true
termination_by n - i

@[inline] def Fin.all (n) (f : Fin n → Bool) : Bool :=
Id.run <| allM n f

def bytesBEq (a b : ByteArray) : Bool :=
if h_size : a.size = b.size then
Fin.all a.size fun i => a[i] = b[i]'(h_size ▸ i.isLt)
else
false

def String.fromUTF8 (bytes : ByteArray) : String :=
String.fromUTF8Unchecked bytes |>.map id

@[inline] def String.fromUTF8? (bytes : ByteArray) : Option String :=
let s := String.fromUTF8 bytes
if bytesBEq s.toUTF8 bytes then some s else none

nonrec def loadToml (tomlFile : FilePath) : BaseIO TomlOutcome := do
let fileName := tomlFile.fileName.getD tomlFile.toString
let input ←
Expand Down
32 changes: 18 additions & 14 deletions src/runtime/object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,10 +1614,14 @@ extern "C" LEAN_EXPORT object * lean_mk_string(char const * s) {
return lean_mk_string_from_bytes(s, strlen(s));
}

extern "C" LEAN_EXPORT obj_res lean_string_from_utf8_unchecked(b_obj_arg a) {
extern "C" LEAN_EXPORT obj_res lean_string_from_utf8(b_obj_arg a) {
return lean_mk_string_from_bytes(reinterpret_cast<char *>(lean_sarray_cptr(a)), lean_sarray_size(a));
}

extern "C" LEAN_EXPORT uint8 lean_string_validate_utf8(b_obj_arg a) {
return validate_utf8(lean_sarray_cptr(a), lean_sarray_size(a));
}

extern "C" LEAN_EXPORT obj_res lean_string_to_utf8(b_obj_arg s) {
size_t sz = lean_string_size(s) - 1;
obj_res r = lean_alloc_sarray(1, sz, sz);
Expand Down Expand Up @@ -1741,38 +1745,38 @@ extern "C" LEAN_EXPORT obj_res lean_string_data(obj_arg s) {

static bool lean_string_utf8_get_core(char const * str, usize size, usize i, uint32 & result) {
unsigned c = static_cast<unsigned char>(str[i]);
/* zero continuation (0 to 127) */
/* zero continuation (0 to 0x7F) */
if ((c & 0x80) == 0) {
result = c;
return true;
}

/* one continuation (128 to 2047) */
/* one continuation (0x80 to 0x7FF) */
if ((c & 0xe0) == 0xc0 && i + 1 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
result = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (result >= 128) {
if (result >= 0x80) {
return true;
}
}

/* two continuations (2048 to 55295 and 57344 to 65535) */
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if ((c & 0xf0) == 0xe0 && i + 2 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
result = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (result >= 2048 && (result < 55296 || result > 57343)) {
if (result >= 0x800 && (result < 0xD800 || result > 0xDFFF)) {
return true;
}
}

/* three continuations (65536 to 1114111) */
/* three continuations (0x10000 to 0x10FFFF) */
if ((c & 0xf8) == 0xf0 && i + 3 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned c3 = static_cast<unsigned char>(str[i+3]);
result = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (result >= 65536 && result <= 1114111) {
if (result >= 0x10000 && result <= 0x10FFFF) {
return true;
}
}
Expand Down Expand Up @@ -1810,32 +1814,32 @@ extern "C" LEAN_EXPORT uint32 lean_string_utf8_get(b_obj_arg s, b_obj_arg i0) {
}

extern "C" LEAN_EXPORT uint32_t lean_string_utf8_get_fast_cold(char const * str, size_t i, size_t size, unsigned char c) {
/* one continuation (128 to 2047) */
/* one continuation (0x80 to 0x7FF) */
if ((c & 0xe0) == 0xc0 && i + 1 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
uint32_t result = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (result >= 128) {
if (result >= 0x80) {
return result;
}
}

/* two continuations (2048 to 55295 and 57344 to 65535) */
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if ((c & 0xf0) == 0xe0 && i + 2 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
uint32_t result = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (result >= 2048 && (result < 55296 || result > 57343)) {
if (result >= 0x800 && (result < 0xD800 || result > 0xDFFF)) {
return result;
}
}

/* three continuations (65536 to 1114111) */
/* three continuations (0x10000 to 0x10FFFF) */
if ((c & 0xf8) == 0xf0 && i + 3 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned c3 = static_cast<unsigned char>(str[i+3]);
uint32_t result = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (result >= 65536 && result <= 1114111) {
if (result >= 0x10000 && result <= 0x10FFFF) {
return result;
}
}
Expand Down
66 changes: 58 additions & 8 deletions src/runtime/utf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ unsigned utf8_to_unicode(uchar const * begin, uchar const * end) {
auto it = begin;
unsigned c = *it;
++it;
if (c < 128)
if (c < 0x80)
return c;
unsigned mask = (1u << 6) -1;
unsigned hmask = mask;
Expand Down Expand Up @@ -164,40 +164,40 @@ optional<unsigned> get_utf8_first_byte_opt(unsigned char c) {

unsigned next_utf8(char const * str, size_t size, size_t & i) {
unsigned c = static_cast<unsigned char>(str[i]);
/* zero continuation (0 to 127) */
/* zero continuation (0 to 0x7F) */
if ((c & 0x80) == 0) {
i++;
return c;
}

/* one continuation (128 to 2047) */
/* one continuation (0x80 to 0x7FF) */
if ((c & 0xe0) == 0xc0 && i + 1 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (r >= 128) {
if (r >= 0x80) {
i += 2;
return r;
}
}

/* two continuations (2048 to 55295 and 57344 to 65535) */
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if ((c & 0xf0) == 0xe0 && i + 2 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r >= 2048 && (r < 55296 || r > 57343)) {
if (r >= 0x800 && (r < 0xD800 || r > 0xDFFF)) {
i += 3;
return r;
}
}

/* three continuations (65536 to 1114111) */
/* three continuations (0x10000 to 0x10FFFF) */
if ((c & 0xf8) == 0xf0 && i + 3 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned c3 = static_cast<unsigned char>(str[i+3]);
unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (r >= 65536 && r <= 1114111) {
if (r >= 0x10000 && r <= 0x10FFFF) {
i += 4;
return r;
}
Expand All @@ -220,6 +220,56 @@ void utf8_decode(std::string const & str, std::vector<unsigned> & out) {
}
}

bool validate_utf8(uint8_t const * str, size_t size) {
size_t i = 0;
while (i < size) {
unsigned c = str[i];
if ((c & 0x80) == 0) {
/* zero continuation (0 to 0x7F) */
i++;
} else if ((c & 0xe0) == 0xc0) {
/* one continuation (0x80 to 0x7FF) */
if (i + 1 >= size) return false;

unsigned c1 = str[i+1];
if ((c1 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (r < 0x80) return false;

i += 2;
} else if ((c & 0xf0) == 0xe0) {
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if (i + 2 >= size) return false;

unsigned c1 = str[i+1];
unsigned c2 = str[i+2];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r < 0x800 || (r >= 0xD800 && r < 0xDFFF)) return false;

i += 3;
} else if ((c & 0xf8) == 0xf0) {
/* three continuations (0x10000 to 0x10FFFF) */
if (i + 3 >= size) return false;

unsigned c1 = str[i+1];
unsigned c2 = str[i+2];
unsigned c3 = str[i+3];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80 || (c3 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (r < 0x10000 || r > 0x10FFFF) return false;

i += 4;
} else {
return false;
}
}
return true;
}

#define TAG_CONT static_cast<unsigned char>(0b10000000)
#define TAG_TWO_B static_cast<unsigned char>(0b11000000)
#define TAG_THREE_B static_cast<unsigned char>(0b11100000)
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/utf8.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ LEAN_EXPORT unsigned next_utf8(char const * str, size_t size, size_t & i);
/* Decode a UTF-8 encoded string `str` into unicode scalar values */
LEAN_EXPORT void utf8_decode(std::string const & str, std::vector<unsigned> & out);

/* Returns true if the provided string is valid UTF-8 */
LEAN_EXPORT bool validate_utf8(uint8_t const * str, size_t size);

/* Push a unicode scalar value into a utf-8 encoded string */
LEAN_EXPORT void push_unicode_scalar(std::string & s, unsigned code);

Expand Down
7 changes: 5 additions & 2 deletions tests/lean/run/utf8英語.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ def check_eq {α} [BEq α] [Repr α] (tag : String) (expected actual : α) : IO
s!"assertion failure \"{tag}\":\n expected: {repr expected}\n actual: {repr actual}"

def DecodeUTF8: IO Unit := do
let cs := String.toList "Hello, 英語!"
let str := "Hello, 英語!"
let cs := String.toList str
let ns := cs.map Char.toNat
IO.println cs
IO.println ns
check_eq "utf-8 chars" [72, 101, 108, 108, 111, 44, 32, 33521, 35486, 33] ns
check_eq "utf-8 bytes" #[72, 101, 108, 108, 111, 44, 32, 232, 139, 177, 232, 170, 158, 33] str.toUTF8.data
check_eq "string eq" (some str) (String.fromUTF8? str.toUTF8)

#eval DecodeUTF8
#eval DecodeUTF8

0 comments on commit 62cdb51

Please sign in to comment.