From f8313f7856edaf533226c0f95d75b3b869af6fc3 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 20:48:08 +0800 Subject: [PATCH 01/45] this works --- core/trie/bitarray.go | 134 +++++++++++++++++++++++++++++++++++++ core/trie/bitarray_test.go | 98 +++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 core/trie/bitarray.go create mode 100644 core/trie/bitarray_test.go diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go new file mode 100644 index 000000000..286e38d7a --- /dev/null +++ b/core/trie/bitarray.go @@ -0,0 +1,134 @@ +package trie + +import ( + "encoding/binary" + + "github.com/NethermindEth/juno/core/felt" +) + +const ( + mask64 = uint64(1 << 63) +) + +type bitArray struct { + len uint8 + words [4]uint64 // Little endian (i.e. words[0] is the least significant) +} + +func (b *bitArray) Len() uint8 { + return b.len +} + +func (b *bitArray) Bytes() [32]byte { + var res [32]byte + + switch { + case b.len == 0: + return res + case b.len == 255: + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + case b.len >= 192: + rem := 256 - uint(b.len) + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + case b.len >= 128: + rem := 192 - b.len + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + case b.len >= 64: + rem := 128 - b.len + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + default: + rem := 64 - b.len + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) + } + + return res +} + +func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) + b.len = felt.Bits - 1 + return b +} + +func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { + if n >= b.len { + return b.clear() + } + + switch { + case n == 0: + return b.set(x) + case n >= 192: + b.rsh192(x) + n -= 192 + b.words[0] >>= n + b.len -= n + case n >= 128: + b.rsh128(x) + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + b.len -= n + case n >= 64: + b.rsh64(x) + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + b.len -= n + default: + b.set(x) + b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) + b.words[0] >>= n + b.len -= n + } + + return b +} + +func (b *bitArray) set(x *bitArray) *bitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +func (b *bitArray) rsh64(x *bitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *bitArray) rsh128(x *bitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *bitArray) rsh192(x *bitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *bitArray) clear() *bitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go new file mode 100644 index 000000000..6a5e974df --- /dev/null +++ b/core/trie/bitarray_test.go @@ -0,0 +1,98 @@ +package trie + +import ( + "bytes" + "encoding/binary" + "testing" +) + +var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} + +func TestBytes(t *testing.T) { + tests := []struct { + name string + bitArray bitArray + want [32]byte + }{ + // { + // name: "length == 0", + // bitArray: bitArray{len: 0, words: maxBitArray}, + // want: [32]byte{}, + // }, + // { + // name: "length < 64", + // bitArray: bitArray{len: 38, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + // return b + // }(), + // }, + // { + // name: "64 <= length < 128", + // bitArray: bitArray{len: 100, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + // { + // name: "128 <= length < 192", + // bitArray: bitArray{len: 130, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[8:16], 0x3) + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + { + name: "192 <= length < 255", + bitArray: bitArray{len: 201, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + // { + // name: "length == 254", + // bitArray: bitArray{len: 254, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + // { + // name: "length == 255", + // bitArray: bitArray{len: 255, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.bitArray.Bytes() + if !bytes.Equal(got[:], tt.want[:]) { + t.Errorf("bitArray.Bytes() = %v, want %v", got, tt.want) + } + }) + } +} From 47e95903b51a5f1ee0684805e9a5c4860b038d25 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:40:09 +0800 Subject: [PATCH 02/45] one failed but im getting closer --- core/trie/bitarray.go | 55 +++++++++-------- core/trie/bitarray_test.go | 122 ++++++++++++++++++------------------- 2 files changed, 89 insertions(+), 88 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 286e38d7a..6887140d4 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,46 +11,42 @@ const ( ) type bitArray struct { - len uint8 - words [4]uint64 // Little endian (i.e. words[0] is the least significant) -} - -func (b *bitArray) Len() uint8 { - return b.len + pos uint8 // position of the most significant bit + words [4]uint64 // little endian (i.e. words[0] is the least significant) } func (b *bitArray) Bytes() [32]byte { var res [32]byte switch { - case b.len == 0: + case b.pos == 0: return res - case b.len == 255: + case b.pos == 255: binary.BigEndian.PutUint64(res[0:8], b.words[3]) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 192: - rem := 256 - uint(b.len) - mask := uint64(1<<(64-rem)) - 1 + case b.pos >= 192: + rem := 255 - uint(b.pos) + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 128: - rem := 192 - b.len - mask := uint64(1<<(64-rem)) - 1 + case b.pos >= 128: + rem := 191 - b.pos + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 64: - rem := 128 - b.len - mask := uint64(1<<(64-rem)) - 1 + case b.pos >= 64: + rem := 127 - b.pos + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - rem := 64 - b.len - mask := uint64(1<<(64-rem)) - 1 + rem := 63 - b.pos + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } @@ -63,12 +59,17 @@ func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { b.words[2] = binary.BigEndian.Uint64(res[8:16]) b.words[1] = binary.BigEndian.Uint64(res[16:24]) b.words[0] = binary.BigEndian.Uint64(res[24:32]) - b.len = felt.Bits - 1 + b.pos = felt.Bits - 1 return b } +// Rsh shifts the bit array to the right by n bits. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { - if n >= b.len { + if b.pos == 0 { + return b + } + + if n >= b.pos { return b.clear() } @@ -79,13 +80,13 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.rsh192(x) n -= 192 b.words[0] >>= n - b.len -= n + b.pos -= n case n >= 128: b.rsh128(x) n -= 128 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] >>= n - b.len -= n + b.pos -= n case n >= 64: b.rsh64(x) n -= 64 @@ -93,21 +94,21 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) b.words[3] >>= n - b.len -= n + b.pos -= n default: b.set(x) b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) b.words[0] >>= n - b.len -= n + b.pos -= n } return b } func (b *bitArray) set(x *bitArray) *bitArray { - b.len = x.len + b.pos = x.pos b.words[0] = x.words[0] b.words[1] = x.words[1] b.words[2] = x.words[2] @@ -128,7 +129,7 @@ func (b *bitArray) rsh192(x *bitArray) { } func (b *bitArray) clear() *bitArray { - b.len = 0 + b.pos = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 6a5e974df..695acc881 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -14,77 +14,77 @@ func TestBytes(t *testing.T) { bitArray bitArray want [32]byte }{ - // { - // name: "length == 0", - // bitArray: bitArray{len: 0, words: maxBitArray}, - // want: [32]byte{}, - // }, - // { - // name: "length < 64", - // bitArray: bitArray{len: 38, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) - // return b - // }(), - // }, - // { - // name: "64 <= length < 128", - // bitArray: bitArray{len: 100, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, - // { - // name: "128 <= length < 192", - // bitArray: bitArray{len: 130, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[8:16], 0x3) - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, + { + name: "length == 0", + bitArray: bitArray{pos: 0, words: maxBitArray}, + want: [32]byte{}, + }, + { + name: "length < 64", + bitArray: bitArray{pos: 38, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x7FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + bitArray: bitArray{pos: 100, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0x7FFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + { + name: "128 <= length < 192", + bitArray: bitArray{pos: 130, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x7) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, { name: "192 <= length < 255", - bitArray: bitArray{len: 201, words: maxBitArray}, + bitArray: bitArray{pos: 201, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FF) + binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + { + name: "length == 254", + bitArray: bitArray{pos: 254, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + { + name: "length == 255", + bitArray: bitArray{pos: 255, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b }(), }, - // { - // name: "length == 254", - // bitArray: bitArray{len: 254, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, - // { - // name: "length == 255", - // bitArray: bitArray{len: 255, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, } for _, tt := range tests { From cae88a9489e0535ab564cc42726c335e7a9b6e41 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:52:42 +0800 Subject: [PATCH 03/45] this works --- core/trie/bitarray.go | 26 +++++++++++++++++--------- core/trie/bitarray_test.go | 4 +--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 6887140d4..da18c6cd7 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -10,8 +10,10 @@ const ( mask64 = uint64(1 << 63) ) +var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} + type bitArray struct { - pos uint8 // position of the most significant bit + pos uint8 // position of the current most significant bit (0-255) words [4]uint64 // little endian (i.e. words[0] is the least significant) } @@ -27,26 +29,32 @@ func (b *bitArray) Bytes() [32]byte { binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.pos >= 192: - rem := 255 - uint(b.pos) - mask := ^mask64 >> (rem - 1) + // For positions >= 192, we need to mask the most significant word (words[3]) + // to zero out bits beyond the current position. + // Example: if pos = 201, then rem = 255 - 201 = 54 + // mask = ^mask64 >> (54 - 1) = ^(1<<63) >> 53 + // This creates a mask like: 0000000000000000000000000000000000000000000000000000001111111111 + // When applied to words[3], it preserves only the 10 least significant bits + shift := 255 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.pos >= 128: - rem := 191 - b.pos - mask := ^mask64 >> (rem - 1) + shift := 191 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.pos >= 64: - rem := 127 - b.pos - mask := ^mask64 >> (rem - 1) + shift := 127 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - rem := 63 - b.pos - mask := ^mask64 >> (rem - 1) + shift := 63 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 695acc881..6ef2fa109 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -6,8 +6,6 @@ import ( "testing" ) -var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} - func TestBytes(t *testing.T) { tests := []struct { name string @@ -33,7 +31,7 @@ func TestBytes(t *testing.T) { bitArray: bitArray{pos: 100, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[16:24], 0x7FFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0x1FFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b }(), From 8018d7159103d73fba763d9d3679e4a1d8cdf2fa Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:58:24 +0800 Subject: [PATCH 04/45] add bytes benchmark --- core/trie/bitarray_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 6ef2fa109..fd747dd19 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -94,3 +94,40 @@ func TestBytes(t *testing.T) { }) } } + +func BenchmarkBitArrayBytes(b *testing.B) { + testCases := []struct { + name string + ba bitArray + }{ + { + name: "empty", + ba: bitArray{pos: 0, words: maxBitArray}, + }, + { + name: "pos_38", + ba: bitArray{pos: 38, words: maxBitArray}, + }, + { + name: "pos_100", + ba: bitArray{pos: 100, words: maxBitArray}, + }, + { + name: "pos_201", + ba: bitArray{pos: 201, words: maxBitArray}, + }, + { + name: "pos_255", + ba: bitArray{pos: 255, words: maxBitArray}, + }, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = tc.ba.Bytes() + } + }) + } +} From 989b2b10513840fe6aced22490bcd51518514cff Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 23:51:24 +0800 Subject: [PATCH 05/45] looks gud --- core/trie/bitarray.go | 65 ++++++++++++++++++----------------- core/trie/bitarray_test.go | 69 ++++++++++++-------------------------- 2 files changed, 53 insertions(+), 81 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index da18c6cd7..bdbd15898 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -12,49 +12,48 @@ const ( var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} +// bitArray is a structure that represents a bit array with a max length of 255 bits. +// The reason why 255 bits is the max length is because we only need up to 252 bits for the felt. +// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// Unlike normal bit arrays, it has a `len` field that represents the number of used bits. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. type bitArray struct { - pos uint8 // position of the current most significant bit (0-255) + len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) } +// Bytes returns the bytes representation of the bit array in big endian format. func (b *bitArray) Bytes() [32]byte { var res [32]byte switch { - case b.pos == 0: + case b.len == 0: return res - case b.pos == 255: - binary.BigEndian.PutUint64(res[0:8], b.words[3]) - binary.BigEndian.PutUint64(res[8:16], b.words[2]) - binary.BigEndian.PutUint64(res[16:24], b.words[1]) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.pos >= 192: - // For positions >= 192, we need to mask the most significant word (words[3]) - // to zero out bits beyond the current position. - // Example: if pos = 201, then rem = 255 - 201 = 54 - // mask = ^mask64 >> (54 - 1) = ^(1<<63) >> 53 - // This creates a mask like: 0000000000000000000000000000000000000000000000000000001111111111 - // When applied to words[3], it preserves only the 10 least significant bits - shift := 255 - b.pos - mask := ^mask64 >> (shift - 1) + case b.len >= 192: + // len is 0-based, so 255 (not 256) represents all bits used + // subtracting from 255 ensures correct mask when len=255 + // For example, when len is 255, it means all bits from index 0 + // to 254 are used (total of 255 bits). + // So when we create the mask, we shift 255 - 255 = 0 bits to the right. + // This creates a mask that covers all bits from index 0 to 254. + mask := ^mask64 >> (255 - b.len) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.pos >= 128: - shift := 191 - b.pos - mask := ^mask64 >> (shift - 1) + case b.len >= 128: + // Similar pattern for 191 boundary (3 words × 64 bits - 1) + mask := ^mask64 >> (191 - b.len) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.pos >= 64: - shift := 127 - b.pos - mask := ^mask64 >> (shift - 1) + case b.len >= 64: + mask := ^mask64 >> (127 - b.len) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - shift := 63 - b.pos - mask := ^mask64 >> (shift - 1) + mask := ^mask64 >> (63 - b.len) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } @@ -67,17 +66,17 @@ func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { b.words[2] = binary.BigEndian.Uint64(res[8:16]) b.words[1] = binary.BigEndian.Uint64(res[16:24]) b.words[0] = binary.BigEndian.Uint64(res[24:32]) - b.pos = felt.Bits - 1 + b.len = felt.Bits - 1 return b } // Rsh shifts the bit array to the right by n bits. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { - if b.pos == 0 { + if b.len == 0 { return b } - if n >= b.pos { + if n >= b.len { return b.clear() } @@ -88,13 +87,13 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.rsh192(x) n -= 192 b.words[0] >>= n - b.pos -= n + b.len -= n case n >= 128: b.rsh128(x) n -= 128 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] >>= n - b.pos -= n + b.len -= n case n >= 64: b.rsh64(x) n -= 64 @@ -102,21 +101,21 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) b.words[3] >>= n - b.pos -= n + b.len -= n default: b.set(x) b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) b.words[0] >>= n - b.pos -= n + b.len -= n } return b } func (b *bitArray) set(x *bitArray) *bitArray { - b.pos = x.pos + b.len = x.len b.words[0] = x.words[0] b.words[1] = x.words[1] b.words[2] = x.words[2] @@ -137,7 +136,7 @@ func (b *bitArray) rsh192(x *bitArray) { } func (b *bitArray) clear() *bitArray { - b.pos = 0 + b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index fd747dd19..cdd23b948 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "encoding/binary" + "math/bits" "testing" ) @@ -14,34 +15,34 @@ func TestBytes(t *testing.T) { }{ { name: "length == 0", - bitArray: bitArray{pos: 0, words: maxBitArray}, + bitArray: bitArray{len: 0, words: maxBitArray}, want: [32]byte{}, }, { name: "length < 64", - bitArray: bitArray{pos: 38, words: maxBitArray}, + bitArray: bitArray{len: 38, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[24:32], 0x7FFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) return b }(), }, { name: "64 <= length < 128", - bitArray: bitArray{pos: 100, words: maxBitArray}, + bitArray: bitArray{len: 100, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[16:24], 0x1FFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b }(), }, { name: "128 <= length < 192", - bitArray: bitArray{pos: 130, words: maxBitArray}, + bitArray: bitArray{len: 130, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[8:16], 0x7) + binary.BigEndian.PutUint64(b[8:16], 0x3) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b @@ -49,10 +50,10 @@ func TestBytes(t *testing.T) { }, { name: "192 <= length < 255", - bitArray: bitArray{pos: 201, words: maxBitArray}, + bitArray: bitArray{len: 201, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x3FF) + binary.BigEndian.PutUint64(b[0:8], 0x1FF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) @@ -61,10 +62,10 @@ func TestBytes(t *testing.T) { }, { name: "length == 254", - bitArray: bitArray{pos: 254, words: maxBitArray}, + bitArray: bitArray{len: 254, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) @@ -73,10 +74,10 @@ func TestBytes(t *testing.T) { }, { name: "length == 255", - bitArray: bitArray{pos: 255, words: maxBitArray}, + bitArray: bitArray{len: 255, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) @@ -91,42 +92,14 @@ func TestBytes(t *testing.T) { if !bytes.Equal(got[:], tt.want[:]) { t.Errorf("bitArray.Bytes() = %v, want %v", got, tt.want) } - }) - } -} -func BenchmarkBitArrayBytes(b *testing.B) { - testCases := []struct { - name string - ba bitArray - }{ - { - name: "empty", - ba: bitArray{pos: 0, words: maxBitArray}, - }, - { - name: "pos_38", - ba: bitArray{pos: 38, words: maxBitArray}, - }, - { - name: "pos_100", - ba: bitArray{pos: 100, words: maxBitArray}, - }, - { - name: "pos_201", - ba: bitArray{pos: 201, words: maxBitArray}, - }, - { - name: "pos_255", - ba: bitArray{pos: 255, words: maxBitArray}, - }, - } - - for _, tc := range testCases { - b.Run(tc.name, func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = tc.ba.Bytes() + // check if the received bytes has the same bit count as the bitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.bitArray.len) { + t.Errorf("bitArray.Bytes() bit count = %v, want %v", count, tt.bitArray.len) } }) } From 905f003782d7ae8170840837ce2cf01d97239104 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 14 Dec 2024 18:40:04 +0800 Subject: [PATCH 06/45] add Rsh test --- core/trie/bitarray.go | 100 +++++++++++++++++++++++------------ core/trie/bitarray_test.go | 103 +++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 33 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bdbd15898..bcb4c6949 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -2,82 +2,108 @@ package trie import ( "encoding/binary" + "math" "github.com/NethermindEth/juno/core/felt" ) const ( - mask64 = uint64(1 << 63) + maxUint64 = uint64(math.MaxUint64) ) -var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} +var maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} // bitArray is a structure that represents a bit array with a max length of 255 bits. -// The reason why 255 bits is the max length is because we only need up to 252 bits for the felt. -// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. // It uses a little endian representation to do bitwise operations of the words efficiently. // Unlike normal bit arrays, it has a `len` field that represents the number of used bits. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The reason why 255 bits is the max length is because we only need up to 251 bits for a given trie key. +// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. type bitArray struct { len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) } -// Bytes returns the bytes representation of the bit array in big endian format. +// Bytes returns the bytes representation of the bit array in big endian format func (b *bitArray) Bytes() [32]byte { var res [32]byte switch { case b.len == 0: + // all zeros return res case b.len >= 192: - // len is 0-based, so 255 (not 256) represents all bits used - // subtracting from 255 ensures correct mask when len=255 - // For example, when len is 255, it means all bits from index 0 - // to 254 are used (total of 255 bits). - // So when we create the mask, we shift 255 - 255 = 0 bits to the right. - // This creates a mask that covers all bits from index 0 to 254. - mask := ^mask64 >> (255 - b.len) + // Create mask for top word: keeps only valid bits above 192 + // e.g., if len=200, keeps lowest 8 bits (200-192) + mask := maxUint64 >> (256 - uint16(b.len)) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.len >= 128: - // Similar pattern for 191 boundary (3 words × 64 bits - 1) - mask := ^mask64 >> (191 - b.len) + // Mask for bits 128-191: keeps only valid bits above 128 + // e.g., if len=150, keeps lowest 22 bits (150-128) + mask := maxUint64 >> (192 - b.len) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.len >= 64: - mask := ^mask64 >> (127 - b.len) + // You get the idea + mask := maxUint64 >> (128 - b.len) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - mask := ^mask64 >> (63 - b.len) + mask := maxUint64 >> (64 - b.len) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } return res } -func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { +func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = length + return b +} + +func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = 251 + return b +} + +func (b *bitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) b.words[2] = binary.BigEndian.Uint64(res[8:16]) b.words[1] = binary.BigEndian.Uint64(res[16:24]) b.words[0] = binary.BigEndian.Uint64(res[24:32]) - b.len = felt.Bits - 1 - return b } -// Rsh shifts the bit array to the right by n bits. +func (b *bitArray) PrefixEqual(x *bitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + var long, short *bitArray + long, short = b, x + + if b.len < x.len { + long, short = x, b + } + + return long.Rsh(long, long.len-short.len).Equal(short) +} + +// Rsh sets b = x >> n and returns b. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { - if b.len == 0 { - return b + if x.len == 0 { + return b.set(x) } - if n >= b.len { - return b.clear() + if n >= x.len { + x.clear() + return b.set(x) } switch { @@ -85,35 +111,43 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { return b.set(x) case n >= 192: b.rsh192(x) + b.len = x.len - n n -= 192 b.words[0] >>= n - b.len -= n case n >= 128: b.rsh128(x) + b.len = x.len - n n -= 128 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] >>= n - b.len -= n case n >= 64: b.rsh64(x) + b.len = x.len - n n -= 64 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) - b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) - b.words[3] >>= n - b.len -= n + b.words[2] >>= n default: b.set(x) - b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) - b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) - b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) - b.words[0] >>= n b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n } return b } +// Eq checks if two bit arrays are equal +func (b *bitArray) Equal(x *bitArray) bool { + return b.len == x.len && + b.words[0] == x.words[0] && + b.words[1] == x.words[1] && + b.words[2] == x.words[2] && + b.words[3] == x.words[3] +} + func (b *bitArray) set(x *bitArray) *bitArray { b.len = x.len b.words[0] = x.words[0] diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index cdd23b948..b740c9314 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -104,3 +104,106 @@ func TestBytes(t *testing.T) { }) } } + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *bitArray + shiftBy uint8 + expected *bitArray + }{ + { + name: "zero length array", + initial: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + shiftBy: 0, + expected: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 65, + expected: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 32, + expected: &bitArray{ + len: 96, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 64, + expected: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + shiftBy: 128, + expected: &bitArray{ + len: 123, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + shiftBy: 192, + expected: &bitArray{ + len: 59, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(bitArray).Rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} From 323370c054f11a92ed70b74ceb63bcab1f8dba61 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 14 Dec 2024 19:20:55 +0800 Subject: [PATCH 07/45] add Truncate --- core/trie/bitarray.go | 98 +++++++++---- core/trie/bitarray_test.go | 281 +++++++++++++++++++++++++++++++++++++ 2 files changed, 354 insertions(+), 25 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bcb4c6949..3a7387f32 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -60,31 +60,15 @@ func (b *bitArray) Bytes() [32]byte { return res } -func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { - b.setFelt(f) - b.len = length - return b -} - -func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { - b.setFelt(f) - b.len = 251 - return b -} - -func (b *bitArray) setFelt(f *felt.Felt) { - res := f.Bytes() - b.words[3] = binary.BigEndian.Uint64(res[0:8]) - b.words[2] = binary.BigEndian.Uint64(res[8:16]) - b.words[1] = binary.BigEndian.Uint64(res[16:24]) - b.words[0] = binary.BigEndian.Uint64(res[24:32]) -} - func (b *bitArray) PrefixEqual(x *bitArray) bool { if b.len == x.len { return b.Equal(x) } + if b.len == 0 || x.len == 0 { + return true + } + var long, short *bitArray long, short = b, x @@ -95,20 +79,64 @@ func (b *bitArray) PrefixEqual(x *bitArray) bool { return long.Rsh(long, long.len-short.len).Equal(short) } +// Truncate sets b to the first 'length' bits of x and returns b. +// If length >= x.len, b is a copy of x. +// Any bits beyond the specified length are cleared to zero. +// For example: +// +// x = 11001011 (len=8) +// Truncate(x, 4) = 1011 (len=4) +// Truncate(x, 10) = 11001011 (len=8, original x) +// Truncate(x, 0) = 0 (len=0) +func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { + if length >= x.len { + return b.Set(x) + } + + b.Set(x) + b.len = length + + // Clear all words beyond what's needed + switch { + case length == 0: + b.words = [4]uint64{0, 0, 0, 0} + case length <= 64: + mask := maxUint64 >> (64 - length) + b.words[0] &= mask + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + case length <= 128: + mask := maxUint64 >> (128 - length) + b.words[1] &= mask + b.words[2] = 0 + b.words[3] = 0 + case length <= 192: + mask := maxUint64 >> (192 - length) + b.words[2] &= mask + b.words[3] = 0 + default: + mask := maxUint64 >> (256 - uint16(length)) + b.words[3] &= mask + } + + return b +} + // Rsh sets b = x >> n and returns b. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { if x.len == 0 { - return b.set(x) + return b.Set(x) } if n >= x.len { x.clear() - return b.set(x) + return b.Set(x) } switch { case n == 0: - return b.set(x) + return b.Set(x) case n >= 192: b.rsh192(x) b.len = x.len - n @@ -128,7 +156,7 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) b.words[2] >>= n default: - b.set(x) + b.Set(x) b.len -= n b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) @@ -148,7 +176,27 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } -func (b *bitArray) set(x *bitArray) *bitArray { +func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = length + return b +} + +func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = 251 + return b +} + +func (b *bitArray) setFelt(f *felt.Felt) { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) +} + +func (b *bitArray) Set(x *bitArray) *bitArray { b.len = x.len b.words[0] = x.words[0] b.words[1] = x.words[1] diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index b740c9314..0e031ff57 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -207,3 +207,284 @@ func TestRsh(t *testing.T) { }) } } + +func TestPrefixEqual(t *testing.T) { + tests := []struct { + name string + a *bitArray + b *bitArray + want bool + }{ + { + name: "equal lengths, equal values", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: true, + }, + { + name: "equal lengths, different values", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different lengths, a longer but same prefix", + a: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, b longer but same prefix", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, different prefix", + a: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "zero length arrays", + a: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + b: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "one zero length array", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "max length difference", + a: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + b: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.PrefixEqual(tt.b); got != tt.want { + t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) + } + // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) + if got := tt.b.PrefixEqual(tt.a); got != tt.want { + t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + initial bitArray + length uint8 + expected bitArray + }{ + { + name: "truncate to zero", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 0, + expected: bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "truncate within first word - 32 bits", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 32, + expected: bitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "truncate to single bit", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 1, + expected: bitArray{ + len: 1, + words: [4]uint64{0x0000000000000001, 0, 0, 0}, + }, + }, + { + name: "truncate across words - 100 bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 100, + expected: bitArray{ + len: 100, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "truncate at word boundary - 64 bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 64, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "truncate at word boundary - 128 bits", + initial: bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 128, + expected: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "truncate in third word - 150 bits", + initial: bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 150, + expected: bitArray{ + len: 150, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, + }, + }, + { + name: "truncate in fourth word - 220 bits", + initial: bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + length: 220, + expected: bitArray{ + len: 220, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, + }, + }, + { + name: "truncate max length - 251 bits", + initial: bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + length: 251, + expected: bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "truncate sparse bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: bitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 64, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 128, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(bitArray).Truncate(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} From 37c8c5e9693eacf5998d3446bca031c37aa54bb4 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 17:09:17 +0800 Subject: [PATCH 08/45] add CommonMSBs --- core/trie/bitarray.go | 132 +++++++++++++++- core/trie/bitarray_test.go | 314 ++++++++++++++++++++++++++++++++++++- 2 files changed, 441 insertions(+), 5 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 3a7387f32..c542c1e02 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -1,17 +1,23 @@ package trie import ( + "bytes" "encoding/binary" "math" + "math/bits" "github.com/NethermindEth/juno/core/felt" ) const ( maxUint64 = uint64(math.MaxUint64) + byteBits = 8 ) -var maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} +var ( + maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + emptyBitArray = &bitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} +) // bitArray is a structure that represents a bit array with a max length of 255 bits. // It uses a little endian representation to do bitwise operations of the words efficiently. @@ -60,7 +66,24 @@ func (b *bitArray) Bytes() [32]byte { return res } -func (b *bitArray) PrefixEqual(x *bitArray) bool { +// EqualMSBs checks if two bit arrays share the same most significant bits, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *bitArray) EqualMSBs(x *bitArray) bool { if b.len == x.len { return b.Equal(x) } @@ -70,8 +93,8 @@ func (b *bitArray) PrefixEqual(x *bitArray) bool { } var long, short *bitArray - long, short = b, x + long, short = b, x if b.len < x.len { long, short = x, b } @@ -123,6 +146,66 @@ func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { return b } +// CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. +// For example: +// +// x = 1101 0111 (len=8) +// y = 1101 0000 (len=8) +// CommonMSBs(x,y) = 1101 (len=4) +func (b *bitArray) CommonMSBs(x, y *bitArray) *bitArray { + if x.len == 0 || y.len == 0 { + return emptyBitArray + } + + long, short := x, y + if x.len < y.len { + long, short = y, x + } + + // Align arrays by right-shifting longer array and then XOR to find differences + // Example: + // short = 1101 (len=4) + // long = 1101 0111 (len=8) + // + // Step 1: Right shift longer array by 4 + // short = 1100 + // long = 1101 + // + // Step 2: XOR shows difference at last bit + // 1100 (short) + // 1101 (aligned long) + // ---- XOR + // 0001 (difference at last position) + // We can then use the position of the first set bit and right-shift to get the common MSBs + diff := long.len - short.len + b.Rsh(long, diff).Xor(b, short) + divergentBit := findFirstSetBit(b) + + return b.Rsh(short, divergentBit) +} + +// findFirstSetBit returns the position of the first '1' bit in the array, +// scanning from most significant to least significant bit. +// +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 2 // third bit from right is set +func findFirstSetBit(b *bitArray) uint8 { + if b.len == 0 { + return 0 + } + + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + return 0 +} + // Rsh sets b = x >> n and returns b. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { if x.len == 0 { @@ -167,6 +250,15 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { return b } +// Xor sets b = x ^ y and returns b. +func (b *bitArray) Xor(x, y *bitArray) *bitArray { + b.words[0] = x.words[0] ^ y.words[0] + b.words[1] = x.words[1] ^ y.words[1] + b.words[2] = x.words[2] ^ y.words[2] + b.words[3] = x.words[3] ^ y.words[3] + return b +} + // Eq checks if two bit arrays are equal func (b *bitArray) Equal(x *bitArray) bool { return b.len == x.len && @@ -176,6 +268,21 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } +// Write serializes the bitArray into a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// bitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { + if err := buf.WriteByte(b.len); err != nil { + return 0, err + } + + n, err := buf.Write(b.activeBytes()) + return n + 1, err +} + func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { b.setFelt(f) b.len = length @@ -205,6 +312,25 @@ func (b *bitArray) Set(x *bitArray) *bitArray { return b } +// byteCount returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *bitArray) byteCount() uint8 { + // Cast to uint16 to avoid overflow + return uint8((uint16(b.len) + uint16(byteBits-1)) / uint16(byteBits)) +} + +// activeBytes returns a slice containing only the bytes that are actually used +// by the bit array, excluding leading zero bytes. The returned slice is in +// big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *bitArray) activeBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + func (b *bitArray) rsh64(x *bitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 0e031ff57..825277e6e 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "math/bits" "testing" + + "github.com/stretchr/testify/assert" ) func TestBytes(t *testing.T) { @@ -172,6 +174,18 @@ func TestRsh(t *testing.T) { words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, + { + name: "shift by 127", + initial: &bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + }, + shiftBy: 127, + expected: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, { name: "shift by 128", initial: &bitArray{ @@ -315,11 +329,11 @@ func TestPrefixEqual(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.a.PrefixEqual(tt.b); got != tt.want { + if got := tt.a.EqualMSBs(tt.b); got != tt.want { t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) } // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) - if got := tt.b.PrefixEqual(tt.a); got != tt.want { + if got := tt.b.EqualMSBs(tt.a); got != tt.want { t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) } }) @@ -488,3 +502,299 @@ func TestTruncate(t *testing.T) { }) } } + +func TestWrite(t *testing.T) { + tests := []struct { + name string + bitArray bitArray + want []byte // Expected bytes after writing + }{ + { + name: "empty bit array", + bitArray: bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: []byte{0}, // Just the length byte + }, + { + name: "8 bits", + bitArray: bitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + want: []byte{8, 0xFF}, // length byte + 1 data byte + }, + { + name: "10 bits requiring 2 bytes", + bitArray: bitArray{ + len: 10, + words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary + }, + want: []byte{10, 0x3, 0xFF}, // length byte + 2 data bytes + }, + { + name: "64 bits", + bitArray: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: append( + []byte{64}, // length byte + []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}..., // 8 data bytes + ), + }, + { + name: "251 bits", + bitArray: bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + want: func() []byte { + b := make([]byte, 33) // 1 length byte + 32 data bytes + b[0] = 251 // length byte + // First byte is 0x07 (from the most significant bits) + b[1] = 0x07 + // Rest of the bytes are 0xFF + for i := 2; i < 33; i++ { + b[i] = 0xFF + } + return b + }(), + }, + { + name: "sparse bits", + bitArray: bitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary + }, + want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := new(bytes.Buffer) + gotN, err := tt.bitArray.Write(buf) + assert.NoError(t, err) + + // Check number of bytes written + if gotN != len(tt.want) { + t.Errorf("Write() wrote %d bytes, want %d", gotN, len(tt.want)) + } + + // Check written bytes + if got := buf.Bytes(); !bytes.Equal(got, tt.want) { + t.Errorf("Write() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCommonPrefix(t *testing.T) { + tests := []struct { + name string + x *bitArray + y *bitArray + want *bitArray + }{ + { + name: "empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "one empty array", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "identical arrays - single word", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "identical arrays - multiple words", + x: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + y: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + want: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + }, + { + name: "different lengths with common prefix - first word", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different lengths with common prefix - multiple words", + x: &bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + }, + y: &bitArray{ + len: 127, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + }, + want: &bitArray{ + len: 127, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "different at first bit", + x: &bitArray{ + len: 64, + words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "different in middle of first word", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in second word", + x: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF0FFFFFFF, 0, 0}, + }, + y: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + want: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in third word", + x: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + y: &bitArray{ + len: 192, + words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, + }, + want: &bitArray{ + len: 56, + words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in last word", + x: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + y: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFF0FFFFFFF}, + }, + want: &bitArray{ + len: 27, + words: [4]uint64{0x7FFFFFF}, + }, + }, + { + name: "sparse bits with common prefix", + x: &bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, + }, + y: &bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, + }, + want: &bitArray{ + len: 52, + words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, + }, + }, + { + name: "max length difference", + x: &bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + }, + y: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(bitArray) + gotSymmetric := new(bitArray) + + got.CommonMSBs(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("CommonMSBs() = %v, want %v", got, tt.want) + } + + // Test symmetry: x.CommonMSBs(y) should equal y.CommonMSBs(x) + gotSymmetric.CommonMSBs(tt.y, tt.x) + if !gotSymmetric.Equal(tt.want) { + t.Errorf("CommonMSBs() symmetric test = %v, want %v", gotSymmetric, tt.want) + } + }) + } +} From 31b6884575dc768eadbf3fed3e76e10c5440767f Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 17:12:54 +0800 Subject: [PATCH 09/45] minor comments --- core/trie/bitarray.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index c542c1e02..a229c974b 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -102,8 +102,8 @@ func (b *bitArray) EqualMSBs(x *bitArray) bool { return long.Rsh(long, long.len-short.len).Equal(short) } -// Truncate sets b to the first 'length' bits of x and returns b. -// If length >= x.len, b is a copy of x. +// Truncate sets b to the first 'length' bits of x (starting from the least significant bit). +// If length >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: // From 93ade0cfe05c7e860bed414b88767ee68d82912c Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 19:05:56 +0800 Subject: [PATCH 10/45] add UnmarshalBinary --- core/trie/bitarray.go | 25 ++++++++++++++++++++++++- core/trie/bitarray_test.go | 13 +++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index a229c974b..0cca0ccf7 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -268,7 +268,7 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } -// Write serializes the bitArray into a bytes buffer in the following format: +// Write serialises the bitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -283,6 +283,21 @@ func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } +// UnmarshalBinary deserialises the bitArray from a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// [0x0A, 0x03, 0xFF] -> bitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *bitArray) UnmarshalBinary(data []byte) error { + b.len = data[0] + + var bs [32]byte + copy(bs[32-b.byteCount():], data[1:]) + b.SetBytes32(bs) + return nil +} + func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { b.setFelt(f) b.len = length @@ -295,6 +310,14 @@ func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { return b } +func (b *bitArray) SetBytes32(data [32]byte) *bitArray { + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) + return b +} + func (b *bitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 825277e6e..649002c1e 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -503,7 +503,7 @@ func TestTruncate(t *testing.T) { } } -func TestWrite(t *testing.T) { +func TestWriteAndUnmarshalBinary(t *testing.T) { tests := []struct { name string bitArray bitArray @@ -584,9 +584,18 @@ func TestWrite(t *testing.T) { } // Check written bytes - if got := buf.Bytes(); !bytes.Equal(got, tt.want) { + got := buf.Bytes() + if !bytes.Equal(got, tt.want) { t.Errorf("Write() = %v, want %v", got, tt.want) } + + gotBitArray := new(bitArray) + if err := gotBitArray.UnmarshalBinary(got); err != nil { + t.Errorf("UnmarshalBinary() = %v", err) + } + if !gotBitArray.Equal(&tt.bitArray) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.bitArray) + } }) } } From 987ad47b3e9cf9f4230efafd0af357e7e7dfd653 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 19:13:02 +0800 Subject: [PATCH 11/45] more bitArray public --- core/trie/bitarray.go | 67 +++++---- core/trie/bitarray_test.go | 278 ++++++++++++++++++------------------- 2 files changed, 177 insertions(+), 168 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 0cca0ccf7..57f3d2807 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -16,22 +16,31 @@ const ( var ( maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = &bitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} + emptyBitArray = &BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} ) -// bitArray is a structure that represents a bit array with a max length of 255 bits. +// BitArray is a structure that represents a bit array with a max length of 255 bits. // It uses a little endian representation to do bitwise operations of the words efficiently. // Unlike normal bit arrays, it has a `len` field that represents the number of used bits. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. // The reason why 255 bits is the max length is because we only need up to 251 bits for a given trie key. // Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. -type bitArray struct { +type BitArray struct { len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) } +func (b *BitArray) Felt() *felt.Felt { + bs := b.Bytes() + return new(felt.Felt).SetBytes(bs[:]) +} + +func (b *BitArray) Len() uint8 { + return b.len +} + // Bytes returns the bytes representation of the bit array in big endian format -func (b *bitArray) Bytes() [32]byte { +func (b *BitArray) Bytes() [32]byte { var res [32]byte switch { @@ -83,7 +92,7 @@ func (b *bitArray) Bytes() [32]byte { // a = 1100 (len=4) // b = [] (len=0) // a.EqualMSBs(b) = true // Zero length is always a prefix match -func (b *bitArray) EqualMSBs(x *bitArray) bool { +func (b *BitArray) EqualMSBs(x *BitArray) bool { if b.len == x.len { return b.Equal(x) } @@ -92,7 +101,7 @@ func (b *bitArray) EqualMSBs(x *bitArray) bool { return true } - var long, short *bitArray + var long, short *BitArray long, short = b, x if b.len < x.len { @@ -111,7 +120,7 @@ func (b *bitArray) EqualMSBs(x *bitArray) bool { // Truncate(x, 4) = 1011 (len=4) // Truncate(x, 10) = 11001011 (len=8, original x) // Truncate(x, 0) = 0 (len=0) -func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { +func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { if length >= x.len { return b.Set(x) } @@ -152,7 +161,7 @@ func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { // x = 1101 0111 (len=8) // y = 1101 0000 (len=8) // CommonMSBs(x,y) = 1101 (len=4) -func (b *bitArray) CommonMSBs(x, y *bitArray) *bitArray { +func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { if x.len == 0 || y.len == 0 { return emptyBitArray } @@ -192,7 +201,7 @@ func (b *bitArray) CommonMSBs(x, y *bitArray) *bitArray { // // array = 0000 0000 ... 0100 (len=251) // findFirstSetBit() = 2 // third bit from right is set -func findFirstSetBit(b *bitArray) uint8 { +func findFirstSetBit(b *BitArray) uint8 { if b.len == 0 { return 0 } @@ -207,7 +216,7 @@ func findFirstSetBit(b *bitArray) uint8 { } // Rsh sets b = x >> n and returns b. -func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { +func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { if x.len == 0 { return b.Set(x) } @@ -251,7 +260,7 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { } // Xor sets b = x ^ y and returns b. -func (b *bitArray) Xor(x, y *bitArray) *bitArray { +func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] b.words[2] = x.words[2] ^ y.words[2] @@ -260,7 +269,7 @@ func (b *bitArray) Xor(x, y *bitArray) *bitArray { } // Eq checks if two bit arrays are equal -func (b *bitArray) Equal(x *bitArray) bool { +func (b *BitArray) Equal(x *BitArray) bool { return b.len == x.len && b.words[0] == x.words[0] && b.words[1] == x.words[1] && @@ -268,13 +277,13 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } -// Write serialises the bitArray into a bytes buffer in the following format: +// Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: // -// bitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] -func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { +// BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { if err := buf.WriteByte(b.len); err != nil { return 0, err } @@ -283,13 +292,13 @@ func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } -// UnmarshalBinary deserialises the bitArray from a bytes buffer in the following format: +// UnmarshalBinary deserialises the BitArray from a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: // -// [0x0A, 0x03, 0xFF] -> bitArray{len: 10, words: [4]uint64{0x03FF}} -func (b *bitArray) UnmarshalBinary(data []byte) error { +// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *BitArray) UnmarshalBinary(data []byte) error { b.len = data[0] var bs [32]byte @@ -298,19 +307,19 @@ func (b *bitArray) UnmarshalBinary(data []byte) error { return nil } -func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { +func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } -func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { +func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } -func (b *bitArray) SetBytes32(data [32]byte) *bitArray { +func (b *BitArray) SetBytes32(data [32]byte) *BitArray { b.words[3] = binary.BigEndian.Uint64(data[0:8]) b.words[2] = binary.BigEndian.Uint64(data[8:16]) b.words[1] = binary.BigEndian.Uint64(data[16:24]) @@ -318,7 +327,7 @@ func (b *bitArray) SetBytes32(data [32]byte) *bitArray { return b } -func (b *bitArray) setFelt(f *felt.Felt) { +func (b *BitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) b.words[2] = binary.BigEndian.Uint64(res[8:16]) @@ -326,7 +335,7 @@ func (b *bitArray) setFelt(f *felt.Felt) { b.words[0] = binary.BigEndian.Uint64(res[24:32]) } -func (b *bitArray) Set(x *bitArray) *bitArray { +func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] b.words[1] = x.words[1] @@ -337,7 +346,7 @@ func (b *bitArray) Set(x *bitArray) *bitArray { // byteCount returns the minimum number of bytes needed to represent the bit array. // It rounds up to the nearest byte. -func (b *bitArray) byteCount() uint8 { +func (b *BitArray) byteCount() uint8 { // Cast to uint16 to avoid overflow return uint8((uint16(b.len) + uint16(byteBits-1)) / uint16(byteBits)) } @@ -349,24 +358,24 @@ func (b *bitArray) byteCount() uint8 { // Example: // // len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] -func (b *bitArray) activeBytes() []byte { +func (b *BitArray) activeBytes() []byte { wordsBytes := b.Bytes() return wordsBytes[32-b.byteCount():] } -func (b *bitArray) rsh64(x *bitArray) { +func (b *BitArray) rsh64(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] } -func (b *bitArray) rsh128(x *bitArray) { +func (b *BitArray) rsh128(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] } -func (b *bitArray) rsh192(x *bitArray) { +func (b *BitArray) rsh192(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] } -func (b *bitArray) clear() *bitArray { +func (b *BitArray) clear() *BitArray { b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 649002c1e..2c8265e09 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -11,18 +11,18 @@ import ( func TestBytes(t *testing.T) { tests := []struct { - name string - bitArray bitArray - want [32]byte + name string + ba BitArray + want [32]byte }{ { - name: "length == 0", - bitArray: bitArray{len: 0, words: maxBitArray}, - want: [32]byte{}, + name: "length == 0", + ba: BitArray{len: 0, words: maxBitArray}, + want: [32]byte{}, }, { - name: "length < 64", - bitArray: bitArray{len: 38, words: maxBitArray}, + name: "length < 64", + ba: BitArray{len: 38, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) @@ -30,8 +30,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "64 <= length < 128", - bitArray: bitArray{len: 100, words: maxBitArray}, + name: "64 <= length < 128", + ba: BitArray{len: 100, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) @@ -40,8 +40,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "128 <= length < 192", - bitArray: bitArray{len: 130, words: maxBitArray}, + name: "128 <= length < 192", + ba: BitArray{len: 130, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[8:16], 0x3) @@ -51,8 +51,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "192 <= length < 255", - bitArray: bitArray{len: 201, words: maxBitArray}, + name: "192 <= length < 255", + ba: BitArray{len: 201, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x1FF) @@ -63,8 +63,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "length == 254", - bitArray: bitArray{len: 254, words: maxBitArray}, + name: "length == 254", + ba: BitArray{len: 254, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) @@ -75,8 +75,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "length == 255", - bitArray: bitArray{len: 255, words: maxBitArray}, + name: "length == 255", + ba: BitArray{len: 255, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) @@ -90,18 +90,18 @@ func TestBytes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.bitArray.Bytes() + got := tt.ba.Bytes() if !bytes.Equal(got[:], tt.want[:]) { - t.Errorf("bitArray.Bytes() = %v, want %v", got, tt.want) + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) } - // check if the received bytes has the same bit count as the bitArray.len + // check if the received bytes has the same bit count as the BitArray.len count := 0 for _, b := range got { count += bits.OnesCount8(b) } - if count != int(tt.bitArray.len) { - t.Errorf("bitArray.Bytes() bit count = %v, want %v", count, tt.bitArray.len) + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) } }) } @@ -110,102 +110,102 @@ func TestBytes(t *testing.T) { func TestRsh(t *testing.T) { tests := []struct { name string - initial *bitArray + initial *BitArray shiftBy uint8 - expected *bitArray + expected *BitArray }{ { name: "zero length array", - initial: &bitArray{ + initial: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, shiftBy: 5, - expected: &bitArray{ + expected: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "shift by 0", - initial: &bitArray{ + initial: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, shiftBy: 0, - expected: &bitArray{ + expected: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "shift by more than length", - initial: &bitArray{ + initial: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, shiftBy: 65, - expected: &bitArray{ + expected: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "shift by less than 64", - initial: &bitArray{ + initial: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, shiftBy: 32, - expected: &bitArray{ + expected: &BitArray{ len: 96, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, }, }, { name: "shift by exactly 64", - initial: &bitArray{ + initial: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, shiftBy: 64, - expected: &bitArray{ + expected: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "shift by 127", - initial: &bitArray{ + initial: &BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, }, shiftBy: 127, - expected: &bitArray{ + expected: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "shift by 128", - initial: &bitArray{ + initial: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, shiftBy: 128, - expected: &bitArray{ + expected: &BitArray{ len: 123, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "shift by 192", - initial: &bitArray{ + initial: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, shiftBy: 192, - expected: &bitArray{ + expected: &BitArray{ len: 59, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -214,7 +214,7 @@ func TestRsh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(bitArray).Rsh(tt.initial, tt.shiftBy) + result := new(BitArray).Rsh(tt.initial, tt.shiftBy) if !result.Equal(tt.expected) { t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) } @@ -225,17 +225,17 @@ func TestRsh(t *testing.T) { func TestPrefixEqual(t *testing.T) { tests := []struct { name string - a *bitArray - b *bitArray + a *BitArray + b *BitArray want bool }{ { name: "equal lengths, equal values", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -243,11 +243,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "equal lengths, different values", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, }, @@ -255,11 +255,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "different lengths, a longer but same prefix", - a: &bitArray{ + a: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -267,11 +267,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "different lengths, b longer but same prefix", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, @@ -279,11 +279,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "different lengths, different prefix", - a: &bitArray{ + a: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, }, @@ -291,11 +291,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "zero length arrays", - a: &bitArray{ + a: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, @@ -303,11 +303,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "one zero length array", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, @@ -315,11 +315,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "max length difference", - a: &bitArray{ + a: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, - b: &bitArray{ + b: &BitArray{ len: 1, words: [4]uint64{0x1, 0, 0, 0}, }, @@ -343,150 +343,150 @@ func TestPrefixEqual(t *testing.T) { func TestTruncate(t *testing.T) { tests := []struct { name string - initial bitArray + initial BitArray length uint8 - expected bitArray + expected BitArray }{ { name: "truncate to zero", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 0, - expected: bitArray{ + expected: BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "truncate within first word - 32 bits", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 32, - expected: bitArray{ + expected: BitArray{ len: 32, words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, }, }, { name: "truncate to single bit", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 1, - expected: bitArray{ + expected: BitArray{ len: 1, words: [4]uint64{0x0000000000000001, 0, 0, 0}, }, }, { name: "truncate across words - 100 bits", - initial: bitArray{ + initial: BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, length: 100, - expected: bitArray{ + expected: BitArray{ len: 100, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, }, }, { name: "truncate at word boundary - 64 bits", - initial: bitArray{ + initial: BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, length: 64, - expected: bitArray{ + expected: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "truncate at word boundary - 128 bits", - initial: bitArray{ + initial: BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, length: 128, - expected: bitArray{ + expected: BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "truncate in third word - 150 bits", - initial: bitArray{ + initial: BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, length: 150, - expected: bitArray{ + expected: BitArray{ len: 150, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, }, }, { name: "truncate in fourth word - 220 bits", - initial: bitArray{ + initial: BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, length: 220, - expected: bitArray{ + expected: BitArray{ len: 220, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, }, }, { name: "truncate max length - 251 bits", - initial: bitArray{ + initial: BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, length: 251, - expected: bitArray{ + expected: BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, }, { name: "truncate sparse bits", - initial: bitArray{ + initial: BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, }, length: 100, - expected: bitArray{ + expected: BitArray{ len: 100, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, }, }, { name: "no change when new length equals current length", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 64, - expected: bitArray{ + expected: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "no change when new length greater than current length", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 128, - expected: bitArray{ + expected: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -495,7 +495,7 @@ func TestTruncate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(bitArray).Truncate(&tt.initial, tt.length) + result := new(BitArray).Truncate(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -505,13 +505,13 @@ func TestTruncate(t *testing.T) { func TestWriteAndUnmarshalBinary(t *testing.T) { tests := []struct { - name string - bitArray bitArray - want []byte // Expected bytes after writing + name string + ba BitArray + want []byte // Expected bytes after writing }{ { name: "empty bit array", - bitArray: bitArray{ + ba: BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, @@ -519,7 +519,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "8 bits", - bitArray: bitArray{ + ba: BitArray{ len: 8, words: [4]uint64{0xFF, 0, 0, 0}, }, @@ -527,7 +527,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "10 bits requiring 2 bytes", - bitArray: bitArray{ + ba: BitArray{ len: 10, words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary }, @@ -535,7 +535,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "64 bits", - bitArray: bitArray{ + ba: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -546,7 +546,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "251 bits", - bitArray: bitArray{ + ba: BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, @@ -564,7 +564,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "sparse bits", - bitArray: bitArray{ + ba: BitArray{ len: 16, words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary }, @@ -575,7 +575,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { buf := new(bytes.Buffer) - gotN, err := tt.bitArray.Write(buf) + gotN, err := tt.ba.Write(buf) assert.NoError(t, err) // Check number of bytes written @@ -589,12 +589,12 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { t.Errorf("Write() = %v, want %v", got, tt.want) } - gotBitArray := new(bitArray) + gotBitArray := new(BitArray) if err := gotBitArray.UnmarshalBinary(got); err != nil { t.Errorf("UnmarshalBinary() = %v", err) } - if !gotBitArray.Equal(&tt.bitArray) { - t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.bitArray) + if !gotBitArray.Equal(&tt.ba) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) } }) } @@ -603,9 +603,9 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { func TestCommonPrefix(t *testing.T) { tests := []struct { name string - x *bitArray - y *bitArray - want *bitArray + x *BitArray + y *BitArray + want *BitArray }{ { name: "empty arrays", @@ -615,7 +615,7 @@ func TestCommonPrefix(t *testing.T) { }, { name: "one empty array", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -624,165 +624,165 @@ func TestCommonPrefix(t *testing.T) { }, { name: "identical arrays - single word", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "identical arrays - multiple words", - x: &bitArray{ + x: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, }, { name: "different lengths with common prefix - first word", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, }, { name: "different lengths with common prefix - multiple words", - x: &bitArray{ + x: &BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, }, - y: &bitArray{ + y: &BitArray{ len: 127, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 127, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "different at first bit", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "different in middle of first word", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, }, { name: "different in second word", - x: &bitArray{ + x: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF0FFFFFFF, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, }, { name: "different in third word", - x: &bitArray{ + x: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 192, words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 56, words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "different in last word", - x: &bitArray{ + x: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, - y: &bitArray{ + y: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFF0FFFFFFF}, }, - want: &bitArray{ + want: &BitArray{ len: 27, words: [4]uint64{0x7FFFFFF}, }, }, { name: "sparse bits with common prefix", - x: &bitArray{ + x: &BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 52, words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, }, }, { name: "max length difference", - x: &bitArray{ + x: &BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, }, - y: &bitArray{ + y: &BitArray{ len: 1, words: [4]uint64{0x1, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 1, words: [4]uint64{0x1, 0, 0, 0}, }, @@ -791,8 +791,8 @@ func TestCommonPrefix(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := new(bitArray) - gotSymmetric := new(bitArray) + got := new(BitArray) + gotSymmetric := new(BitArray) got.CommonMSBs(tt.x, tt.y) if !got.Equal(tt.want) { From 031aa35b3804822e4ee37542e6025b6b9d05f1cc Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 22:20:09 +0800 Subject: [PATCH 12/45] add IsBitSet --- core/trie/bitarray.go | 10 ++++ core/trie/bitarray_test.go | 109 +++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 57f3d2807..43e5800d5 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -277,6 +277,16 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } +// IsBitSit returns true if bit n-th is set, where n = 0 is LSB. +// The n must be <= 255. +func (b *BitArray) IsBitSet(n uint8) bool { + if n >= b.len { + return false + } + + return (b.words[n/64] & (1 << (n % 64))) != 0 +} + // Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 2c8265e09..229479bdb 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -807,3 +807,112 @@ func TestCommonPrefix(t *testing.T) { }) } } + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit set", + ba: BitArray{ + len: 64, + words: [4]uint64{1, 0, 0, 0}, + }, + pos: 0, + want: true, + }, + { + name: "last bit in first word", + ba: BitArray{ + len: 64, + words: [4]uint64{1 << 63, 0, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "first bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 64, + want: true, + }, + { + name: "bit beyond length", + ba: BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + pos: 65, + want: false, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 1, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 0, + want: false, + }, + { + name: "bit in last word", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 59}, + }, + pos: 251, + want: false, // position 251 is beyond the highest valid bit (250) + }, + { + name: "highest valid bit (255)", + ba: BitArray{ + len: 255, + words: [4]uint64{0, 0, 0, 1 << 62}, // bit 255 set + }, + pos: 254, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + pos: 100, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSet(tt.pos) + if got != tt.want { + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} From c23817f011a0b9385186f2973e5919e6b5c81e58 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 23:13:54 +0800 Subject: [PATCH 13/45] fix lint and comments --- core/trie/bitarray.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 43e5800d5..db6c702ac 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,20 +11,19 @@ import ( const ( maxUint64 = uint64(math.MaxUint64) - byteBits = 8 + bits8 = 8 ) var ( maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = &BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} + emptyBitArray = new(BitArray) ) -// BitArray is a structure that represents a bit array with a max length of 255 bits. +// BitArray is a structure that represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. -// Unlike normal bit arrays, it has a `len` field that represents the number of used bits. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. -// The reason why 255 bits is the max length is because we only need up to 251 bits for a given trie key. -// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. +// The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. type BitArray struct { len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) @@ -40,6 +39,8 @@ func (b *BitArray) Len() uint8 { } // Bytes returns the bytes representation of the bit array in big endian format +// +//nolint:mnd func (b *BitArray) Bytes() [32]byte { var res [32]byte @@ -120,6 +121,8 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { // Truncate(x, 4) = 1011 (len=4) // Truncate(x, 10) = 11001011 (len=8, original x) // Truncate(x, 0) = 0 (len=0) +// +//nolint:mnd func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { if length >= x.len { return b.Set(x) @@ -173,7 +176,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { // Align arrays by right-shifting longer array and then XOR to find differences // Example: - // short = 1101 (len=4) + // short = 1100 (len=4) // long = 1101 0111 (len=8) // // Step 1: Right shift longer array by 4 @@ -216,6 +219,8 @@ func findFirstSetBit(b *BitArray) uint8 { } // Rsh sets b = x >> n and returns b. +// +//nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { if x.len == 0 { return b.Set(x) @@ -358,7 +363,7 @@ func (b *BitArray) Set(x *BitArray) *BitArray { // It rounds up to the nearest byte. func (b *BitArray) byteCount() uint8 { // Cast to uint16 to avoid overflow - return uint8((uint16(b.len) + uint16(byteBits-1)) / uint16(byteBits)) + return uint8((uint16(b.len) + (bits8 - 1)) / uint16(bits8)) } // activeBytes returns a slice containing only the bytes that are actually used From df6a513fcca1dc8d1bafd2112ea96078f63b68e4 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 16 Dec 2024 22:09:50 +0800 Subject: [PATCH 14/45] Felt return value --- core/trie/bitarray.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index db6c702ac..c371bee99 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -29,9 +29,12 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func (b *BitArray) Felt() *felt.Felt { +func (b *BitArray) Felt() felt.Felt { bs := b.Bytes() - return new(felt.Felt).SetBytes(bs[:]) + + var f felt.Felt + f.SetBytes(bs[:]) + return f } func (b *BitArray) Len() uint8 { From ec1703c7a8853db7827d695690ed39fbaa0cf4ee Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 14:53:42 +0800 Subject: [PATCH 15/45] add felt tests --- core/trie/bitarray.go | 99 ++++++++++++++----------- core/trie/bitarray_test.go | 144 +++++++++++++++++++++++++++++++++++-- 2 files changed, 195 insertions(+), 48 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index c371bee99..03aa84f70 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -29,11 +29,13 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func (b *BitArray) Felt() felt.Felt { - bs := b.Bytes() +func NewBitArray(val uint64) *BitArray { + return new(BitArray).SetUint64(val) +} +func (b *BitArray) Felt() felt.Felt { var f felt.Felt - f.SetBytes(bs[:]) + f.SetBytes(b.Bytes()) return f } @@ -44,13 +46,13 @@ func (b *BitArray) Len() uint8 { // Bytes returns the bytes representation of the bit array in big endian format // //nolint:mnd -func (b *BitArray) Bytes() [32]byte { +func (b *BitArray) Bytes() []byte { var res [32]byte switch { case b.len == 0: // all zeros - return res + return res[:] case b.len >= 192: // Create mask for top word: keeps only valid bits above 192 // e.g., if len=200, keeps lowest 8 bits (200-192) @@ -76,7 +78,7 @@ func (b *BitArray) Bytes() [32]byte { binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } - return res + return res[:] } // EqualMSBs checks if two bit arrays share the same most significant bits, where the length of @@ -199,28 +201,6 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// findFirstSetBit returns the position of the first '1' bit in the array, -// scanning from most significant to least significant bit. -// -// The bit position is counted from the least significant bit, starting at 0. -// For example: -// -// array = 0000 0000 ... 0100 (len=251) -// findFirstSetBit() = 2 // third bit from right is set -func findFirstSetBit(b *BitArray) uint8 { - if b.len == 0 { - return 0 - } - - for i := 3; i >= 0; i-- { - if word := b.words[i]; word != 0 { - return uint8((i+1)*64 - bits.LeadingZeros64(word)) - } - } - - return 0 -} - // Rsh sets b = x >> n and returns b. // //nolint:mnd @@ -316,13 +296,21 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { // Example: // // [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} -func (b *BitArray) UnmarshalBinary(data []byte) error { +func (b *BitArray) UnmarshalBinary(data []byte) { b.len = data[0] var bs [32]byte copy(bs[32-b.byteCount():], data[1:]) - b.SetBytes32(bs) - return nil + b.setBytes32(bs[:]) +} + +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b } func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { @@ -337,11 +325,15 @@ func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { return b } -func (b *BitArray) SetBytes32(data [32]byte) *BitArray { - b.words[3] = binary.BigEndian.Uint64(data[0:8]) - b.words[2] = binary.BigEndian.Uint64(data[8:16]) - b.words[1] = binary.BigEndian.Uint64(data[16:24]) - b.words[0] = binary.BigEndian.Uint64(data[24:32]) +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + b.setBytes32(data) + b.len = length + return b +} + +func (b *BitArray) SetUint64(data uint64) *BitArray { + b.words[0] = data + b.len = 64 return b } @@ -353,13 +345,12 @@ func (b *BitArray) setFelt(f *felt.Felt) { b.words[0] = binary.BigEndian.Uint64(res[24:32]) } -func (b *BitArray) Set(x *BitArray) *BitArray { - b.len = x.len - b.words[0] = x.words[0] - b.words[1] = x.words[1] - b.words[2] = x.words[2] - b.words[3] = x.words[3] - return b +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) } // byteCount returns the minimum number of bytes needed to represent the bit array. @@ -398,3 +389,25 @@ func (b *BitArray) clear() *BitArray { b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b } + +// findFirstSetBit returns the position of the first '1' bit in the array, +// scanning from most significant to least significant bit. +// +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 2 // third bit from right is set +func findFirstSetBit(b *BitArray) uint8 { + if b.len == 0 { + return 0 + } + + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + return 0 +} diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 229479bdb..96af163f3 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -6,7 +6,9 @@ import ( "math/bits" "testing" + "github.com/NethermindEth/juno/core/felt" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBytes(t *testing.T) { @@ -91,7 +93,7 @@ func TestBytes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.ba.Bytes() - if !bytes.Equal(got[:], tt.want[:]) { + if !bytes.Equal(got, tt.want[:]) { t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) } @@ -589,10 +591,8 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { t.Errorf("Write() = %v, want %v", got, tt.want) } - gotBitArray := new(BitArray) - if err := gotBitArray.UnmarshalBinary(got); err != nil { - t.Errorf("UnmarshalBinary() = %v", err) - } + var gotBitArray BitArray + gotBitArray.UnmarshalBinary(got) if !gotBitArray.Equal(&tt.ba) { t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) } @@ -916,3 +916,137 @@ func TestIsBitSet(t *testing.T) { }) } } + +func TestFeltConversion(t *testing.T) { + tests := []struct { + name string + ba BitArray + length uint8 + want string // hex representation of felt + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + length: 0, + want: "0x0", + }, + { + name: "single word", + ba: BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 64, + want: "0xffffffffffffffff", + }, + { + name: "two words", + ba: BitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 128, + want: "0xffffffffffffffffffffffffffffffff", + }, + { + name: "three words", + ba: BitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 192, + want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "max length (251 bits)", + ba: BitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + length: 255, + want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "sparse bits", + ba: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 128, + want: "0x5555555555555555aaaaaaaaaaaaaaaa", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Felt() conversion + gotFelt := tt.ba.Felt() + assert.Equal(t, tt.want, gotFelt.String()) + + // Test SetFelt() conversion (round trip) + var newBA BitArray + newBA.SetFelt(tt.length, &gotFelt) + assert.Equal(t, tt.ba.len, newBA.len) + assert.Equal(t, tt.ba.words, newBA.words) + }) + } +} + +func TestSetFeltValidation(t *testing.T) { + tests := []struct { + name string + feltStr string + length uint8 + shouldMatch bool + }{ + { + name: "valid felt with matching length", + feltStr: "0xf", + length: 4, + shouldMatch: true, + }, + { + name: "felt larger than specified length", + feltStr: "0xff", + length: 4, + shouldMatch: false, + }, + { + name: "zero felt with non-zero length", + feltStr: "0x0", + length: 8, + shouldMatch: true, + }, + { + name: "max felt with max length", + feltStr: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + length: 251, + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f felt.Felt + _, err := f.SetString(tt.feltStr) + require.NoError(t, err) + + var ba BitArray + ba.SetFelt(tt.length, &f) + + // Convert back to felt and compare + roundTrip := ba.Felt() + if tt.shouldMatch { + assert.True(t, roundTrip.Equal(&f), + "expected %s, got %s", f.String(), roundTrip.String()) + } else { + assert.False(t, roundTrip.Equal(&f), + "values should not match: original %s, roundtrip %s", + f.String(), roundTrip.String()) + } + }) + } +} From c00326c26baccdf64df01f368b301e09ab70f7f5 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 15:11:01 +0800 Subject: [PATCH 16/45] use const for 0xFF...FF --- core/trie/bitarray.go | 2 +- core/trie/bitarray_test.go | 162 ++++++++++++++++++------------------- 2 files changed, 82 insertions(+), 82 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 03aa84f70..7b377ddf2 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -10,7 +10,7 @@ import ( ) const ( - maxUint64 = uint64(math.MaxUint64) + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF bits8 = 8 ) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 96af163f3..abfccc7a1 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -37,7 +37,7 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -47,8 +47,8 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[8:16], 0x3) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -58,9 +58,9 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x1FF) - binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -70,9 +70,9 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -82,9 +82,9 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -132,19 +132,19 @@ func TestRsh(t *testing.T) { name: "shift by 0", initial: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, shiftBy: 0, expected: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "shift by more than length", initial: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, shiftBy: 65, expected: &BitArray{ @@ -156,60 +156,60 @@ func TestRsh(t *testing.T) { name: "shift by less than 64", initial: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, shiftBy: 32, expected: &BitArray{ len: 96, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, }, }, { name: "shift by exactly 64", initial: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, shiftBy: 64, expected: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "shift by 127", initial: &BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, }, shiftBy: 127, expected: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, }, { name: "shift by 128", initial: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, shiftBy: 128, expected: &BitArray{ len: 123, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, }, { name: "shift by 192", initial: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, shiftBy: 192, expected: &BitArray{ len: 59, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, } @@ -235,11 +235,11 @@ func TestPrefixEqual(t *testing.T) { name: "equal lengths, equal values", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: true, }, @@ -247,7 +247,7 @@ func TestPrefixEqual(t *testing.T) { name: "equal lengths, different values", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 64, @@ -259,11 +259,11 @@ func TestPrefixEqual(t *testing.T) { name: "different lengths, a longer but same prefix", a: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, b: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: true, }, @@ -271,11 +271,11 @@ func TestPrefixEqual(t *testing.T) { name: "different lengths, b longer but same prefix", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, want: true, }, @@ -283,7 +283,7 @@ func TestPrefixEqual(t *testing.T) { name: "different lengths, different prefix", a: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, b: &BitArray{ len: 64, @@ -307,7 +307,7 @@ func TestPrefixEqual(t *testing.T) { name: "one zero length array", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 0, @@ -319,7 +319,7 @@ func TestPrefixEqual(t *testing.T) { name: "max length difference", a: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, b: &BitArray{ len: 1, @@ -353,7 +353,7 @@ func TestTruncate(t *testing.T) { name: "truncate to zero", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 0, expected: BitArray{ @@ -365,7 +365,7 @@ func TestTruncate(t *testing.T) { name: "truncate within first word - 32 bits", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 32, expected: BitArray{ @@ -377,7 +377,7 @@ func TestTruncate(t *testing.T) { name: "truncate to single bit", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 1, expected: BitArray{ @@ -389,72 +389,72 @@ func TestTruncate(t *testing.T) { name: "truncate across words - 100 bits", initial: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, length: 100, expected: BitArray{ len: 100, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, }, }, { name: "truncate at word boundary - 64 bits", initial: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, length: 64, expected: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "truncate at word boundary - 128 bits", initial: BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, length: 128, expected: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, }, { name: "truncate in third word - 150 bits", initial: BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, length: 150, expected: BitArray{ len: 150, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, }, }, { name: "truncate in fourth word - 220 bits", initial: BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, length: 220, expected: BitArray{ len: 220, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, }, }, { name: "truncate max length - 251 bits", initial: BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, length: 251, expected: BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, }, { @@ -473,24 +473,24 @@ func TestTruncate(t *testing.T) { name: "no change when new length equals current length", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 64, expected: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "no change when new length greater than current length", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 128, expected: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, } @@ -539,7 +539,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { name: "64 bits", ba: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: append( []byte{64}, // length byte @@ -550,7 +550,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { name: "251 bits", ba: BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, want: func() []byte { b := make([]byte, 33) // 1 length byte + 32 data bytes @@ -617,7 +617,7 @@ func TestCommonPrefix(t *testing.T) { name: "one empty array", x: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, y: emptyBitArray, want: emptyBitArray, @@ -626,37 +626,37 @@ func TestCommonPrefix(t *testing.T) { name: "identical arrays - single word", x: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, y: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "identical arrays - multiple words", x: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, y: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, want: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, }, { name: "different lengths with common prefix - first word", x: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, y: &BitArray{ len: 32, @@ -671,15 +671,15 @@ func TestCommonPrefix(t *testing.T) { name: "different lengths with common prefix - multiple words", x: &BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, }, y: &BitArray{ len: 127, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, want: &BitArray{ len: 127, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, }, { @@ -690,7 +690,7 @@ func TestCommonPrefix(t *testing.T) { }, y: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: &BitArray{ len: 0, @@ -705,7 +705,7 @@ func TestCommonPrefix(t *testing.T) { }, y: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: &BitArray{ len: 32, @@ -716,11 +716,11 @@ func TestCommonPrefix(t *testing.T) { name: "different in second word", x: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF0FFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0xFFFFFFFF0FFFFFFF, 0, 0}, }, y: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, want: &BitArray{ len: 32, @@ -731,7 +731,7 @@ func TestCommonPrefix(t *testing.T) { name: "different in third word", x: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, y: &BitArray{ len: 192, @@ -746,11 +746,11 @@ func TestCommonPrefix(t *testing.T) { name: "different in last word", x: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, y: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFF0FFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFF0FFFFFFF}, }, want: &BitArray{ len: 27, @@ -776,7 +776,7 @@ func TestCommonPrefix(t *testing.T) { name: "max length difference", x: &BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, }, y: &BitArray{ len: 1, @@ -855,7 +855,7 @@ func TestIsBitSet(t *testing.T) { name: "bit beyond length", ba: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, pos: 65, want: false, @@ -900,7 +900,7 @@ func TestIsBitSet(t *testing.T) { name: "position at length boundary", ba: BitArray{ len: 100, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, pos: 100, want: false, @@ -937,7 +937,7 @@ func TestFeltConversion(t *testing.T) { name: "single word", ba: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 64, want: "0xffffffffffffffff", @@ -946,7 +946,7 @@ func TestFeltConversion(t *testing.T) { name: "two words", ba: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, length: 128, want: "0xffffffffffffffffffffffffffffffff", @@ -955,7 +955,7 @@ func TestFeltConversion(t *testing.T) { name: "three words", ba: BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, length: 192, want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", @@ -964,7 +964,7 @@ func TestFeltConversion(t *testing.T) { name: "max length (251 bits)", ba: BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, length: 255, want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", From 1a23f0dead401c53948ea201853ae4972fcc7f7d Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 16:19:27 +0800 Subject: [PATCH 17/45] add MSBs --- core/trie/bitarray.go | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 7b377ddf2..bbcf15f41 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -3,6 +3,8 @@ package trie import ( "bytes" "encoding/binary" + "encoding/hex" + "fmt" "math" "math/bits" @@ -29,8 +31,8 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func NewBitArray(val uint64) *BitArray { - return new(BitArray).SetUint64(val) +func NewBitArray(length uint8, val uint64) *BitArray { + return new(BitArray).SetUint64(length, val) } func (b *BitArray) Felt() felt.Felt { @@ -275,6 +277,14 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + // Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order @@ -331,12 +341,26 @@ func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { return b } -func (b *BitArray) SetUint64(data uint64) *BitArray { +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data - b.len = 64 + b.len = length return b } +func (b *BitArray) EncodedLen() uint { + return b.byteCount() + 1 +} + +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +func (b *BitArray) String() string { + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) +} + func (b *BitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) @@ -355,9 +379,9 @@ func (b *BitArray) setBytes32(data []byte) { // byteCount returns the minimum number of bytes needed to represent the bit array. // It rounds up to the nearest byte. -func (b *BitArray) byteCount() uint8 { +func (b *BitArray) byteCount() uint { // Cast to uint16 to avoid overflow - return uint8((uint16(b.len) + (bits8 - 1)) / uint16(bits8)) + return (uint(b.len) + (bits8 - 1)) / uint(bits8) } // activeBytes returns a slice containing only the bytes that are actually used From fad615d79ef0f17235511a5235ad2870d45020e7 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 17:06:58 +0800 Subject: [PATCH 18/45] add MSBs() and rename Truncate to LSBs --- core/trie/bitarray.go | 39 ++++++---- core/trie/bitarray_test.go | 155 ++++++++++++++++++++++++++++++++----- 2 files changed, 160 insertions(+), 34 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bbcf15f41..7f5aa11f2 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -17,7 +17,7 @@ const ( ) var ( - maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} emptyBitArray = new(BitArray) ) @@ -119,18 +119,18 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return long.Rsh(long, long.len-short.len).Equal(short) } -// Truncate sets b to the first 'length' bits of x (starting from the least significant bit). -// If length >= x.len, b is an exact copy of x. +// LSBs sets b to the least significant 'n' bits of x. +// If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: // // x = 11001011 (len=8) -// Truncate(x, 4) = 1011 (len=4) -// Truncate(x, 10) = 11001011 (len=8, original x) -// Truncate(x, 0) = 0 (len=0) +// LSBs(x, 4) = 1011 (len=4) +// LSBs(x, 10) = 11001011 (len=8, original x) +// LSBs(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { +func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { if length >= x.len { return b.Set(x) } @@ -165,6 +165,23 @@ func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { return b } +// MSBs sets b to the most significant 'n' bits of x. +// If n >= x.len, b is an exact copy of x. +// Any bits beyond the specified length are cleared to zero. +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + // CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. // For example: // @@ -277,14 +294,6 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } -func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { - if n >= x.len { - return b.Set(x) - } - - return b.Rsh(x, x.len-n) -} - // Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index abfccc7a1..c90223ab6 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -11,6 +11,10 @@ import ( "github.com/stretchr/testify/require" ) +const ( + ones63 = 0x7FFFFFFFFFFFFFFF +) + func TestBytes(t *testing.T) { tests := []struct { name string @@ -19,12 +23,12 @@ func TestBytes(t *testing.T) { }{ { name: "length == 0", - ba: BitArray{len: 0, words: maxBitArray}, + ba: BitArray{len: 0, words: maxBits}, want: [32]byte{}, }, { name: "length < 64", - ba: BitArray{len: 38, words: maxBitArray}, + ba: BitArray{len: 38, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) @@ -33,7 +37,7 @@ func TestBytes(t *testing.T) { }, { name: "64 <= length < 128", - ba: BitArray{len: 100, words: maxBitArray}, + ba: BitArray{len: 100, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) @@ -43,7 +47,7 @@ func TestBytes(t *testing.T) { }, { name: "128 <= length < 192", - ba: BitArray{len: 130, words: maxBitArray}, + ba: BitArray{len: 130, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[8:16], 0x3) @@ -54,7 +58,7 @@ func TestBytes(t *testing.T) { }, { name: "192 <= length < 255", - ba: BitArray{len: 201, words: maxBitArray}, + ba: BitArray{len: 201, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x1FF) @@ -66,7 +70,7 @@ func TestBytes(t *testing.T) { }, { name: "length == 254", - ba: BitArray{len: 254, words: maxBitArray}, + ba: BitArray{len: 254, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) @@ -78,10 +82,10 @@ func TestBytes(t *testing.T) { }, { name: "length == 255", - ba: BitArray{len: 255, words: maxBitArray}, + ba: BitArray{len: 255, words: maxBits}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[0:8], ones63) binary.BigEndian.PutUint64(b[8:16], maxUint64) binary.BigEndian.PutUint64(b[16:24], maxUint64) binary.BigEndian.PutUint64(b[24:32], maxUint64) @@ -180,7 +184,7 @@ func TestRsh(t *testing.T) { name: "shift by 127", initial: &BitArray{ len: 255, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, }, shiftBy: 127, expected: &BitArray{ @@ -342,7 +346,7 @@ func TestPrefixEqual(t *testing.T) { } } -func TestTruncate(t *testing.T) { +func TestLSBs(t *testing.T) { tests := []struct { name string initial BitArray @@ -497,7 +501,7 @@ func TestTruncate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(BitArray).Truncate(&tt.initial, tt.length) + result := new(BitArray).LSBs(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -505,6 +509,119 @@ func TestTruncate(t *testing.T) { } } +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + func TestWriteAndUnmarshalBinary(t *testing.T) { tests := []struct { name string @@ -671,22 +788,22 @@ func TestCommonPrefix(t *testing.T) { name: "different lengths with common prefix - multiple words", x: &BitArray{ len: 255, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, }, y: &BitArray{ len: 127, - words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, ones63, 0, 0}, }, want: &BitArray{ len: 127, - words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, ones63, 0, 0}, }, }, { name: "different at first bit", x: &BitArray{ len: 64, - words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{ones63, 0, 0, 0}, }, y: &BitArray{ len: 64, @@ -776,7 +893,7 @@ func TestCommonPrefix(t *testing.T) { name: "max length difference", x: &BitArray{ len: 255, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, }, y: &BitArray{ len: 1, @@ -961,12 +1078,12 @@ func TestFeltConversion(t *testing.T) { want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", }, { - name: "max length (251 bits)", + name: "251 bits", ba: BitArray{ - len: 255, + len: 251, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, - length: 255, + length: 251, want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", }, { From 38da2f17cb4a2847804092a4bba19f2dc2ebec45 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 23:48:36 +0800 Subject: [PATCH 19/45] all tests passed --- core/state.go | 9 +- core/trie/bitarray.go | 21 +-- core/trie/key.go | 187 -------------------------- core/trie/key_test.go | 229 -------------------------------- core/trie/node.go | 38 +++--- core/trie/node_test.go | 4 +- core/trie/proof.go | 101 ++++++-------- core/trie/proof_test.go | 26 +++- core/trie/storage.go | 25 ++-- core/trie/storage_test.go | 20 +-- core/trie/trie.go | 91 +++++++------ core/trie/trie_pkg_test.go | 41 +++--- core/trie/trie_test.go | 76 ++++++++++- migration/migration.go | 16 +-- migration/migration_pkg_test.go | 15 ++- 15 files changed, 284 insertions(+), 615 deletions(-) delete mode 100644 core/trie/key.go delete mode 100644 core/trie/key_test.go diff --git a/core/state.go b/core/state.go index 378ba65be..c17ff13f3 100644 --- a/core/state.go +++ b/core/state.go @@ -139,10 +139,11 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr // fetch root key rootKeyDBKey := dbPrefix - var rootKey *trie.Key + var rootKey *trie.BitArray err := s.txn.Get(rootKeyDBKey, func(val []byte) error { - rootKey = new(trie.Key) - return rootKey.UnmarshalBinary(val) + rootKey = new(trie.BitArray) + rootKey.UnmarshalBinary(val) + return nil }) // if some error other than "not found" @@ -169,7 +170,7 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr if resultingRootKey != nil { var rootKeyBytes bytes.Buffer - _, marshalErr := resultingRootKey.WriteTo(&rootKeyBytes) + _, marshalErr := resultingRootKey.Write(&rootKeyBytes) if marshalErr != nil { return marshalErr } diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 7f5aa11f2..11a55edae 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -109,14 +109,13 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return true } - var long, short *BitArray - - long, short = b, x - if b.len < x.len { - long, short = x, b + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len } - return long.Rsh(long, long.len-short.len).Equal(short) + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) } // LSBs sets b to the least significant 'n' bits of x. @@ -229,8 +228,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { } if n >= x.len { - x.clear() - return b.Set(x) + return b.clear() } switch { @@ -277,6 +275,13 @@ func (b *BitArray) Xor(x, y *BitArray) *BitArray { // Eq checks if two bit arrays are equal func (b *BitArray) Equal(x *BitArray) bool { + // TODO(weiihann): this is really not a good thing to do... + if b == nil && x == nil { + return true + } else if b == nil || x == nil { + return false + } + return b.len == x.len && b.words[0] == x.words[0] && b.words[1] == x.words[1] && diff --git a/core/trie/key.go b/core/trie/key.go deleted file mode 100644 index 0d0ca7aa8..000000000 --- a/core/trie/key.go +++ /dev/null @@ -1,187 +0,0 @@ -package trie - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "math/big" - - "github.com/NethermindEth/juno/core/felt" -) - -var NilKey = &Key{len: 0, bitset: [32]byte{}} - -type Key struct { - len uint8 - bitset [32]byte -} - -func NewKey(length uint8, keyBytes []byte) Key { - k := Key{len: length} - if len(keyBytes) > len(k.bitset) { - panic("bytes does not fit in bitset") - } - copy(k.bitset[len(k.bitset)-len(keyBytes):], keyBytes) - return k -} - -func (k *Key) bytesNeeded() uint { - const byteBits = 8 - return (uint(k.len) + (byteBits - 1)) / byteBits -} - -func (k *Key) inUseBytes() []byte { - return k.bitset[len(k.bitset)-int(k.bytesNeeded()):] -} - -func (k *Key) unusedBytes() []byte { - return k.bitset[:len(k.bitset)-int(k.bytesNeeded())] -} - -func (k *Key) WriteTo(buf *bytes.Buffer) (int64, error) { - if err := buf.WriteByte(k.len); err != nil { - return 0, err - } - - n, err := buf.Write(k.inUseBytes()) - return int64(1 + n), err -} - -func (k *Key) UnmarshalBinary(data []byte) error { - k.len = data[0] - k.bitset = [32]byte{} - copy(k.inUseBytes(), data[1:1+k.bytesNeeded()]) - return nil -} - -func (k *Key) EncodedLen() uint { - return k.bytesNeeded() + 1 -} - -func (k *Key) Len() uint8 { - return k.len -} - -func (k *Key) Felt() felt.Felt { - var f felt.Felt - f.SetBytes(k.bitset[:]) - return f -} - -func (k *Key) Equal(other *Key) bool { - if k == nil && other == nil { - return true - } else if k == nil || other == nil { - return false - } - return k.len == other.len && k.bitset == other.bitset -} - -// IsBitSet returns whether the bit at the given position is 1. -// Position 0 represents the least significant (rightmost) bit. -func (k *Key) IsBitSet(position uint8) bool { - const LSB = uint8(0x1) - byteIdx := position / 8 - byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1] - bitIdx := position % 8 - return ((byteAtIdx >> bitIdx) & LSB) != 0 -} - -// shiftRight removes n least significant bits from the key by performing a right shift -// operation and reducing the key length. For example, if the key contains bits -// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4). -// -// The operation is destructive - it modifies the key in place. -func (k *Key) shiftRight(n uint8) { - if k.len < n { - panic("deleting more bits than there are") - } - - if n == 0 { - return - } - - var bigInt big.Int - bigInt.SetBytes(k.bitset[:]) - bigInt.Rsh(&bigInt, uint(n)) - bigInt.FillBytes(k.bitset[:]) - k.len -= n -} - -// MostSignificantBits returns a new key with the most significant n bits of the current key. -func (k *Key) MostSignificantBits(n uint8) (*Key, error) { - if n > k.len { - return nil, errors.New("cannot get more bits than the key length") - } - - keyCopy := k.Copy() - keyCopy.shiftRight(k.len - n) - return &keyCopy, nil -} - -// Truncate truncates key to `length` bits by clearing the remaining upper bits -func (k *Key) Truncate(length uint8) { - k.len = length - - unusedBytes := k.unusedBytes() - clear(unusedBytes) - - // clear upper bits on the last used byte - inUseBytes := k.inUseBytes() - unusedBitsCount := 8 - (k.len % 8) - if unusedBitsCount != 8 && len(inUseBytes) > 0 { - inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount - } -} - -func (k *Key) String() string { - return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) -} - -// Copy returns a deep copy of the key -func (k *Key) Copy() Key { - newKey := Key{len: k.len} - copy(newKey.bitset[:], k.bitset[:]) - return newKey -} - -func (k *Key) Bytes() [32]byte { - var result [32]byte - copy(result[:], k.bitset[:]) - return result -} - -// findCommonKey finds the set of common MSB bits in two key bitsets. -func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { - divergentBit := findDivergentBit(longerKey, shorterKey) - - if divergentBit == 0 { - return *NilKey, false - } - - commonKey := *shorterKey - commonKey.shiftRight(shorterKey.Len() - divergentBit + 1) - return commonKey, divergentBit == shorterKey.Len()+1 -} - -// findDivergentBit finds the first bit that is different between two keys, -// starting from the most significant bit of both keys. -func findDivergentBit(longerKey, shorterKey *Key) uint8 { - divergentBit := uint8(0) - for divergentBit <= shorterKey.Len() && - longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) { - divergentBit++ - } - return divergentBit -} - -func isSubset(longerKey, shorterKey *Key) bool { - divergentBit := findDivergentBit(longerKey, shorterKey) - return divergentBit == shorterKey.Len()+1 -} - -func FeltToKey(length uint8, key *felt.Felt) Key { - keyBytes := key.Bytes() - return NewKey(length, keyBytes[:]) -} diff --git a/core/trie/key_test.go b/core/trie/key_test.go deleted file mode 100644 index 3867678e6..000000000 --- a/core/trie/key_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package trie_test - -import ( - "bytes" - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestKeyEncoding(t *testing.T) { - tests := map[string]struct { - Len uint8 - Bytes []byte - }{ - "multiple of 8": { - Len: 4 * 8, - Bytes: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - }, - "0 len": { - Len: 0, - Bytes: []byte{}, - }, - "odd len": { - Len: 3, - Bytes: []byte{0x03}, - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - key := trie.NewKey(test.Len, test.Bytes) - - var keyBuffer bytes.Buffer - n, err := key.WriteTo(&keyBuffer) - require.NoError(t, err) - assert.Equal(t, len(test.Bytes)+1, int(n)) - - keyBytes := keyBuffer.Bytes() - require.Len(t, keyBytes, int(n)) - assert.Equal(t, test.Len, keyBytes[0]) - assert.Equal(t, test.Bytes, keyBytes[1:]) - - var decodedKey trie.Key - require.NoError(t, decodedKey.UnmarshalBinary(keyBytes)) - assert.Equal(t, key, decodedKey) - }) - } -} - -func BenchmarkKeyEncoding(b *testing.B) { - val, err := new(felt.Felt).SetRandom() - require.NoError(b, err) - valBytes := val.Bytes() - - key := trie.NewKey(felt.Bits, valBytes[:]) - buffer := bytes.Buffer{} - buffer.Grow(felt.Bytes + 1) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := key.WriteTo(&buffer) - require.NoError(b, err) - require.NoError(b, key.UnmarshalBinary(buffer.Bytes())) - buffer.Reset() - } -} - -func TestTruncate(t *testing.T) { - tests := map[string]struct { - key trie.Key - newLen uint8 - expectedKey trie.Key - }{ - "truncate to 12 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 12, - expectedKey: trie.NewKey(12, []byte{0x03, 0x14}), - }, - "truncate to 9 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 9, - expectedKey: trie.NewKey(9, []byte{0x01, 0x14}), - }, - "truncate to 3 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 3, - expectedKey: trie.NewKey(3, []byte{0x04}), - }, - "truncate to multiple of 8": { - key: trie.NewKey(251, []uint8{ - 0x7, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab, - 0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0, - 0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f, - }), - newLen: 248, - expectedKey: trie.NewKey(248, []uint8{ - 0x0, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab, - 0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0, - 0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f, - }), - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - copyKey := test.key - copyKey.Truncate(test.newLen) - assert.Equal(t, test.expectedKey, copyKey) - }) - } -} - -func TestKeyTest(t *testing.T) { - key := trie.NewKey(44, []byte{0x10, 0x02}) - for i := 0; i < int(key.Len()); i++ { - assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i) - } -} - -func TestIsBitSet(t *testing.T) { - tests := map[string]struct { - key trie.Key - position uint8 - expected bool - }{ - "single byte, LSB set": { - key: trie.NewKey(8, []byte{0x01}), - position: 0, - expected: true, - }, - "single byte, MSB set": { - key: trie.NewKey(8, []byte{0x80}), - position: 7, - expected: true, - }, - "single byte, middle bit set": { - key: trie.NewKey(8, []byte{0x10}), - position: 4, - expected: true, - }, - "single byte, bit not set": { - key: trie.NewKey(8, []byte{0xFE}), - position: 0, - expected: false, - }, - "multiple bytes, LSB set": { - key: trie.NewKey(16, []byte{0x00, 0x02}), - position: 1, - expected: true, - }, - "multiple bytes, MSB set": { - key: trie.NewKey(16, []byte{0x01, 0x00}), - position: 8, - expected: true, - }, - "multiple bytes, no bits set": { - key: trie.NewKey(16, []byte{0x00, 0x00}), - position: 7, - expected: false, - }, - "check all bits in pattern": { - key: trie.NewKey(8, []byte{0xA5}), // 10100101 - position: 0, - expected: true, - }, - } - - // Additional test for 0xA5 pattern - key := trie.NewKey(8, []byte{0xA5}) // 10100101 - expectedBits := []bool{true, false, true, false, false, true, false, true} - for i, expected := range expectedBits { - assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i) - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - result := tc.key.IsBitSet(tc.position) - assert.Equal(t, tc.expected, result) - }) - } -} - -func TestMostSignificantBits(t *testing.T) { - tests := []struct { - name string - key trie.Key - n uint8 - want trie.Key - expectErr bool - }{ - { - name: "Valid case", - key: trie.NewKey(8, []byte{0b11110000}), - n: 4, - want: trie.NewKey(4, []byte{0b00001111}), - expectErr: false, - }, - { - name: "Request more bits than available", - key: trie.NewKey(8, []byte{0b11110000}), - n: 10, - want: trie.Key{}, - expectErr: true, - }, - { - name: "Zero bits requested", - key: trie.NewKey(8, []byte{0b11110000}), - n: 0, - want: trie.NewKey(0, []byte{}), - expectErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.key.MostSignificantBits(tt.n) - if (err != nil) != tt.expectErr { - t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr) - return - } - if !tt.expectErr && !got.Equal(&tt.want) { - t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/core/trie/node.go b/core/trie/node.go index 172869cb1..c51a0130d 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -13,8 +13,8 @@ import ( // https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#trie_construction type Node struct { Value *felt.Felt - Left *Key - Right *Key + Left *BitArray + Right *BitArray LeftHash *felt.Felt RightHash *felt.Felt } @@ -40,27 +40,27 @@ func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFn crypto.HashFn) *fe return n.Hash(&path, hashFn) } -func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { +func (n *Node) WriteTo(buf *bytes.Buffer) (int, error) { if n.Value == nil { return 0, errors.New("cannot marshal node with nil value") } - totalBytes := int64(0) + var totalBytes int valueB := n.Value.Bytes() wrote, err := buf.Write(valueB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } if n.Left != nil { - wrote, errInner := n.Left.WriteTo(buf) + wrote, errInner := n.Left.Write(buf) totalBytes += wrote if errInner != nil { return totalBytes, errInner } - wrote, errInner = n.Right.WriteTo(buf) // n.Right is non-nil by design + wrote, errInner = n.Right.Write(buf) // n.Right is non-nil by design totalBytes += wrote if errInner != nil { return totalBytes, errInner @@ -75,14 +75,14 @@ func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { leftHashB := n.LeftHash.Bytes() wrote, err = buf.Write(leftHashB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } rightHashB := n.RightHash.Bytes() wrote, err = buf.Write(rightHashB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } @@ -110,17 +110,13 @@ func (n *Node) UnmarshalBinary(data []byte) error { } if n.Left == nil { - n.Left = new(Key) - n.Right = new(Key) + n.Left = new(BitArray) + n.Right = new(BitArray) } - if err := n.Left.UnmarshalBinary(data); err != nil { - return err - } + n.Left.UnmarshalBinary(data) data = data[n.Left.EncodedLen():] - if err := n.Right.UnmarshalBinary(data); err != nil { - return err - } + n.Right.UnmarshalBinary(data) data = data[n.Right.EncodedLen():] if n.LeftHash == nil { @@ -157,11 +153,11 @@ func (n *Node) Update(other *Node) error { return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value) } - if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) { + if n.Left != nil && other.Left != nil && !n.Left.Equal(emptyBitArray) && !other.Left.Equal(emptyBitArray) && !n.Left.Equal(other.Left) { return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) } - if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) { + if n.Right != nil && other.Right != nil && !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && !n.Right.Equal(other.Right) { return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) } @@ -177,10 +173,10 @@ func (n *Node) Update(other *Node) error { if other.Value != nil { n.Value = other.Value } - if other.Left != nil && !other.Left.Equal(NilKey) { + if other.Left != nil && !other.Left.Equal(emptyBitArray) { n.Left = other.Left } - if other.Right != nil && !other.Right.Equal(NilKey) { + if other.Right != nil && !other.Right.Equal(emptyBitArray) { n.Right = other.Right } if other.LeftHash != nil { diff --git a/core/trie/node_test.go b/core/trie/node_test.go index ccb52b3ea..b222732f4 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -22,7 +22,7 @@ func TestNodeHash(t *testing.T) { node := trie.Node{ Value: new(felt.Felt).SetBytes(valueBytes), } - path := trie.NewKey(6, []byte{42}) + path := trie.NewBitArray(6, 42) - assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") + assert.Equal(t, expected, node.Hash(path, crypto.Pedersen), "TestTrieNode_Hash failed") } diff --git a/core/trie/proof.go b/core/trie/proof.go index 9f1fd3ab1..f4c624705 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -40,14 +40,14 @@ func (b *Binary) String() string { type Edge struct { Child *felt.Felt // child hash - Path *Key // path from parent to child + Path *BitArray // path from parent to child } func (e *Edge) Hash(hash crypto.HashFn) *felt.Felt { length := make([]byte, len(e.Path.bitset)) length[len(e.Path.bitset)-1] = e.Path.len pathFelt := e.Path.Felt() - lengthFelt := new(felt.Felt).SetBytes(length) + lengthFelt := new(felt.Felt).SetBytes(length[:]) return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt) } @@ -71,7 +71,7 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { return err } - var parentKey *Key + var parentKey *BitArray for i, sNode := range nodesFromRoot { sNodeEdge, sNodeBinary, err := storageNodeToProofNode(t, parentKey, sNode) @@ -140,7 +140,6 @@ func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSe func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) (*felt.Felt, error) { key := FeltToKey(globalTrieHeight, keyFelt) expectedHash := root - keyLen := key.Len() var curPos uint8 for { @@ -156,17 +155,17 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash switch node := proofNode.(type) { case *Binary: // Binary nodes represent left/right choices - if key.Len() <= curPos { - return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", key.Len(), curPos) + if keyBits.Len() <= curPos { + return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", keyBits.Len(), curPos) } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if key.IsBitSet(keyLen - curPos - 1) { + if keyBits.IsBitSet(keyBits.Len() - curPos - 1) { expectedHash = node.RightHash } curPos++ case *Edge: // Edge nodes represent paths between binary nodes - if !verifyEdgePath(&key, node.Path, curPos) { + if !verifyEdgePath(keyBits, node.Path, curPos) { return &felt.Zero, nil } @@ -176,7 +175,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // We've consumed all bits in our path - if curPos >= keyLen { + if curPos >= keyBits.Len() { return expectedHash, nil } } @@ -235,18 +234,18 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * } nodes := NewStorageNodeSet() - firstKey := FeltToKey(globalTrieHeight, first) + firstKey := new(BitArray).SetFelt(globalTrieHeight, first) // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values // Empty range proof with more elements on the right is not accepted in this function. // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. if len(keys) == 0 { - rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + rootKey, val, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - if val != nil || hasRightElement(rootKey, &firstKey, nodes) { + if val != nil || hasRightElement(rootKey, firstKey, nodes) { return false, errors.New("more entries available") } @@ -254,17 +253,17 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * } last := keys[len(keys)-1] - lastKey := FeltToKey(globalTrieHeight, last) + lastKey := new(BitArray).SetFelt(globalTrieHeight, last) // Special case: there is only one element and two edge keys are the same - if len(keys) == 1 && firstKey.Equal(&lastKey) { - rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + if len(keys) == 1 && firstKey.Equal(lastKey) { + rootKey, val, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - elementKey := FeltToKey(globalTrieHeight, keys[0]) - if !firstKey.Equal(&elementKey) { + elementKey := new(BitArray).SetFelt(globalTrieHeight, keys[0]) + if !firstKey.Equal(elementKey) { return false, errors.New("correct proof but invalid key") } @@ -272,7 +271,7 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, errors.New("correct proof but invalid value") } - return hasRightElement(rootKey, &firstKey, nodes), nil + return hasRightElement(rootKey, firstKey, nodes), nil } // In all other cases, we require two edge paths available. @@ -281,12 +280,12 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, errors.New("last key is less than first key") } - rootKey, _, err := proofToPath(root, &firstKey, proof, nodes) + rootKey, _, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - lastRootKey, _, err := proofToPath(root, &lastKey, proof, nodes) + lastRootKey, _, err := proofToPath(root, lastKey, proof, nodes) if err != nil { return false, err } @@ -311,11 +310,11 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) } - return hasRightElement(rootKey, &lastKey, nodes), nil + return hasRightElement(rootKey, lastKey, nodes), nil } // isEdge checks if the storage node is an edge node. -func isEdge(parentKey *Key, sNode StorageNode) bool { +func isEdge(parentKey *BitArray, sNode StorageNode) bool { sNodeLen := sNode.key.len if parentKey == nil { // Root return sNodeLen != 0 @@ -326,7 +325,7 @@ func isEdge(parentKey *Key, sNode StorageNode) bool { // storageNodeToProofNode converts a StorageNode to the ProofNode(s). // Juno's Trie has nodes that are Binary AND Edge, whereas the protocol requires nodes that are Binary XOR Edge. // We need to convert the former to the latter for proof generation. -func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { +func storageNodeToProofNode(tri *Trie, parentKey *BitArray, sNode StorageNode) (*Edge, *Binary, error) { var edge *Edge if isEdge(parentKey, sNode) { edgePath := path(sNode.key, parentKey) @@ -375,8 +374,8 @@ func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge // proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining // as hashes. The given edge proof can be existent or non-existent. -func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageNodeSet) (*Key, *felt.Felt, error) { - rootKey, val, err := buildPath(root, key, 0, nil, proof, nodes) +func proofToPath(root *felt.Felt, keyBits *BitArray, proof *ProofNodeSet, nodes *StorageNodeSet) (*BitArray, *felt.Felt, error) { + rootKey, val, err := buildPath(root, keyBits, 0, nil, proof, nodes) if err != nil { return nil, nil, err } @@ -400,7 +399,7 @@ func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageN sn := NewPartialStorageNode(edge.Path, edge.Child) // Handle leaf edge case (single key trie) - if edge.Path.Len() == key.Len() { + if edge.Path.Len() == keyBits.Len() { if err := nodes.Put(*sn.key, sn); err != nil { return nil, nil, fmt.Errorf("failed to store leaf edge: %w", err) } @@ -433,12 +432,12 @@ func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageN // It returns the current node's key and any leaf value found along this path. func buildPath( nodeHash *felt.Felt, - key *Key, + key *BitArray, curPos uint8, curNode *StorageNode, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // We reached the leaf if curPos == key.Len() { leafKey := key.Copy() @@ -451,7 +450,7 @@ func buildPath( proofNode, ok := proof.Get(*nodeHash) if !ok { // non-existent proof node - return NilKey, nil, nil + return emptyBitArray, nil, nil } switch pn := proofNode.(type) { @@ -470,23 +469,19 @@ func buildPath( func handleBinaryNode( binary *Binary, nodeHash *felt.Felt, - key *Key, + key *BitArray, curPos uint8, curNode *StorageNode, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // If curNode is nil, it means that this current binary node is the root node. // Or, it's an internal binary node and the parent is also a binary node. // A standalone binary proof node always corresponds to a single storage node. // If curNode is not nil, it means that the parent node is an edge node. // In this case, the key of the storage node is based on the parent edge node. if curNode == nil { - nodeKey, err := key.MostSignificantBits(curPos) - if err != nil { - return nil, nil, err - } - curNode = NewPartialStorageNode(nodeKey, nodeHash) + curNode = NewPartialStorageNode(new(BitArray).MSBs(key, curPos), nodeHash) } curNode.node.LeftHash = binary.LeftHash curNode.node.RightHash = binary.RightHash @@ -523,23 +518,19 @@ func handleBinaryNode( // the current node's key and any leaf value found along this path. func handleEdgeNode( edge *Edge, - key *Key, + key *BitArray, curPos uint8, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // Verify the edge path matches the key path if !verifyEdgePath(key, edge.Path, curPos) { - return NilKey, nil, nil + return emptyBitArray, nil, nil } // The next node position is the end of the edge path nextPos := curPos + edge.Path.Len() - nodeKey, err := key.MostSignificantBits(nextPos) - if err != nil { - return nil, nil, fmt.Errorf("failed to get MSB for internal edge: %w", err) - } - curNode := NewPartialStorageNode(nodeKey, edge.Child) + curNode := NewPartialStorageNode(new(BitArray).MSBs(key, nextPos), edge.Child) // This is an edge leaf, stop traversing the trie if nextPos == key.Len() { @@ -562,24 +553,12 @@ func handleEdgeNode( } // verifyEdgePath checks if the edge path matches the key path at the current position. -func verifyEdgePath(key, edgePath *Key, curPos uint8) bool { - if key.Len() < curPos+edgePath.Len() { - return false - } - - // Ensure the bits between segment of the key and the node path match - start := key.Len() - curPos - edgePath.Len() - end := key.Len() - curPos - for i := start; i < end; i++ { - if key.IsBitSet(i) != edgePath.IsBitSet(i-start) { - return false // paths diverge - this proves non-membership - } - } - return true +func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { + return new(BitArray).LSBs(key, key.Len()-curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. -func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { +func buildTrie(height uint8, rootKey *BitArray, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { tr, err := NewTriePedersen(newMemStorage(), height) if err != nil { return nil, err @@ -607,9 +586,9 @@ func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values [] // hasRightElement checks if there is a right sibling for the given key in the trie. // This function assumes that the entire path has been resolved. -func hasRightElement(rootKey, key *Key, nodes *StorageNodeSet) bool { +func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { cur := rootKey - for cur != nil && !cur.Equal(NilKey) { + for cur != nil && !cur.Equal(emptyBitArray) { sn, ok := nodes.Get(*cur) if !ok { return false diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 94eaabc54..0f9c54543 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -13,6 +13,30 @@ import ( "github.com/stretchr/testify/require" ) +func TestFix(t *testing.T) { + numKeys := 1000 + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + records := make([]*keyValue, numKeys) + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) + require.NoError(t, err) + } + + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 + }) + + require.NoError(t, tempTrie.Commit()) +} + func TestProve(t *testing.T) { t.Parallel() @@ -360,7 +384,7 @@ func TestOneElementRangeProof(t *testing.T) { }) } -// TestAllElementsProof tests the range proof with all elements and nil proof. +// TestAllElementsRangeProof tests the range proof with all elements and nil proof. func TestAllElementsRangeProof(t *testing.T) { t.Parallel() diff --git a/core/trie/storage.go b/core/trie/storage.go index c4e5ae091..6fe994fe3 100644 --- a/core/trie/storage.go +++ b/core/trie/storage.go @@ -42,17 +42,17 @@ func NewStorage(txn db.Transaction, prefix []byte) *Storage { // dbKey creates a byte array to be used as a key to our KV store // it simply appends the given key to the configured prefix -func (t *Storage) dbKey(key *Key, buffer *bytes.Buffer) (int64, error) { +func (t *Storage) dbKey(key *BitArray, buffer *bytes.Buffer) (int, error) { _, err := buffer.Write(t.prefix) if err != nil { return 0, err } - keyLen, err := key.WriteTo(buffer) - return int64(len(t.prefix)) + keyLen, err + keyLen, err := key.Write(buffer) + return len(t.prefix) + keyLen, err } -func (t *Storage) Put(key *Key, value *Node) error { +func (t *Storage) Put(key *BitArray, value *Node) error { buffer := getBuffer() defer bufferPool.Put(buffer) keyLen, err := t.dbKey(key, buffer) @@ -69,7 +69,7 @@ func (t *Storage) Put(key *Key, value *Node) error { return t.txn.Set(encodedBytes[:keyLen], encodedBytes[keyLen:]) } -func (t *Storage) Get(key *Key) (*Node, error) { +func (t *Storage) Get(key *BitArray) (*Node, error) { buffer := getBuffer() defer bufferPool.Put(buffer) _, err := t.dbKey(key, buffer) @@ -87,7 +87,7 @@ func (t *Storage) Get(key *Key) (*Node, error) { return node, err } -func (t *Storage) Delete(key *Key) error { +func (t *Storage) Delete(key *BitArray) error { buffer := getBuffer() defer bufferPool.Put(buffer) _, err := t.dbKey(key, buffer) @@ -97,21 +97,22 @@ func (t *Storage) Delete(key *Key) error { return t.txn.Delete(buffer.Bytes()) } -func (t *Storage) RootKey() (*Key, error) { - var rootKey *Key +func (t *Storage) RootKey() (*BitArray, error) { + var rootKey *BitArray if err := t.txn.Get(t.prefix, func(val []byte) error { - rootKey = new(Key) - return rootKey.UnmarshalBinary(val) + rootKey = new(BitArray) + rootKey.UnmarshalBinary(val) + return nil }); err != nil { return nil, err } return rootKey, nil } -func (t *Storage) PutRootKey(newRootKey *Key) error { +func (t *Storage) PutRootKey(newRootKey *BitArray) error { buffer := getBuffer() defer bufferPool.Put(buffer) - _, err := newRootKey.WriteTo(buffer) + _, err := newRootKey.Write(buffer) if err != nil { return err } diff --git a/core/trie/storage_test.go b/core/trie/storage_test.go index 809ded479..37a4e8e44 100644 --- a/core/trie/storage_test.go +++ b/core/trie/storage_test.go @@ -15,7 +15,7 @@ import ( func TestStorage(t *testing.T) { testDB := pebble.NewMemTest(t) prefix := []byte{37, 44} - key := trie.NewKey(44, nil) + key := trie.NewBitArray(44, 0) value, err := new(felt.Felt).SetRandom() require.NoError(t, err) @@ -27,7 +27,7 @@ func TestStorage(t *testing.T) { t.Run("put a node", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Put(&key, node) + return tTxn.Put(key, node) })) }) @@ -35,7 +35,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(&key) + got, err = tTxn.Get(key) require.NoError(t, err) assert.Equal(t, node, got) return err @@ -46,7 +46,7 @@ func TestStorage(t *testing.T) { // Successfully delete a node and return an error to force a roll back. require.Error(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - err = tTxn.Delete(&key) + err = tTxn.Delete(key) require.NoError(t, err) return errors.New("should rollback") })) @@ -56,7 +56,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(&key) + got, err = tTxn.Get(key) assert.Equal(t, node, got) return err })) @@ -66,23 +66,23 @@ func TestStorage(t *testing.T) { // Delete a node. require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Delete(&key) + return tTxn.Delete(key) })) // Node should no longer exist in the database. require.EqualError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - _, err = tTxn.Get(&key) + _, err = tTxn.Get(key) return err }), db.ErrKeyNotFound.Error()) }) - rootKey := trie.NewKey(8, []byte{0x2}) + rootKey := trie.NewBitArray(8, 2) t.Run("put root key", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.PutRootKey(&rootKey) + return tTxn.PutRootKey(rootKey) })) }) @@ -91,7 +91,7 @@ func TestStorage(t *testing.T) { tTxn := trie.NewStorage(txn, prefix) gotRootKey, err := tTxn.RootKey() require.NoError(t, err) - assert.Equal(t, rootKey, *gotRootKey) + assert.Equal(t, rootKey, gotRootKey) return nil })) }) diff --git a/core/trie/trie.go b/core/trie/trie.go index 5f8a51d9c..c8f00d2e8 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -35,12 +35,12 @@ const globalTrieHeight = 251 // TODO(weiihann): this is declared in core also, s // [specification]: https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#merkle_patricia_trie type Trie struct { height uint8 - rootKey *Key + rootKey *BitArray maxKey *felt.Felt storage *Storage hash crypto.HashFn - dirtyNodes []*Key + dirtyNodes []*BitArray rootKeyIsDirty bool } @@ -94,32 +94,35 @@ func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error { return do(trie) } -// feltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], +// FeltToKey converts a key, given in felt, to a trie.Key which when followed on a [Trie], // leads to the corresponding [Node] -func (t *Trie) FeltToKey(k *felt.Felt) Key { - return FeltToKey(t.height, k) +func (t *Trie) FeltToKey(k *felt.Felt) BitArray { + var ba BitArray + ba.SetFelt(t.height, k) + return ba } // path returns the path as mentioned in the [specification] for commitment calculations. // path is suffix of key that diverges from parentKey. For example, // for a key 0b1011 and parentKey 0b10, this function would return the path object of 0b0. -func path(key, parentKey *Key) Key { - path := *key - // drop parent key, and one more MSB since left/right relation already encodes that information - if parentKey != nil { - path.Truncate(path.Len() - parentKey.Len() - 1) +func path(key, parentKey *BitArray) BitArray { + if parentKey == nil { + return key.Copy() } - return path + + var pathKey BitArray + pathKey.LSBs(key, key.Len()-parentKey.Len()-1) + return pathKey } // storageNode is the on-disk representation of a [Node], // where key is the storage key and node is the value. type StorageNode struct { - key *Key + key *BitArray node *Node } -func (sn *StorageNode) Key() *Key { +func (sn *StorageNode) Key() *BitArray { return sn.key } @@ -133,7 +136,7 @@ func (sn *StorageNode) String() string { func (sn *StorageNode) Update(other *StorageNode) error { // First validate all fields for conflicts - if sn.key != nil && other.key != nil && !sn.key.Equal(NilKey) && !other.key.Equal(NilKey) { + if sn.key != nil && other.key != nil && !sn.key.Equal(emptyBitArray) && !other.key.Equal(emptyBitArray) { if !sn.key.Equal(other.key) { return fmt.Errorf("keys do not match: %s != %s", sn.key, other.key) } @@ -147,47 +150,47 @@ func (sn *StorageNode) Update(other *StorageNode) error { } // After validation, perform update - if other.key != nil && !other.key.Equal(NilKey) { + if other.key != nil && !other.key.Equal(emptyBitArray) { sn.key = other.key } return nil } -func NewStorageNode(key *Key, node *Node) *StorageNode { +func NewStorageNode(key *BitArray, node *Node) *StorageNode { return &StorageNode{key: key, node: node} } // NewPartialStorageNode creates a new StorageNode with a given key and value, // where the right and left children are nil. -func NewPartialStorageNode(key *Key, value *felt.Felt) *StorageNode { +func NewPartialStorageNode(key *BitArray, value *felt.Felt) *StorageNode { return &StorageNode{ key: key, node: &Node{ Value: value, - Left: NilKey, - Right: NilKey, + Left: emptyBitArray, + Right: emptyBitArray, }, } } // StorageNodeSet wraps OrderedSet to provide specific functionality for StorageNodes type StorageNodeSet struct { - set *utils.OrderedSet[Key, *StorageNode] + set *utils.OrderedSet[BitArray, *StorageNode] } func NewStorageNodeSet() *StorageNodeSet { return &StorageNodeSet{ - set: utils.NewOrderedSet[Key, *StorageNode](), + set: utils.NewOrderedSet[BitArray, *StorageNode](), } } -func (s *StorageNodeSet) Get(key Key) (*StorageNode, bool) { +func (s *StorageNodeSet) Get(key BitArray) (*StorageNode, bool) { return s.set.Get(key) } // Put adds a new StorageNode or updates an existing one. -func (s *StorageNodeSet) Put(key Key, node *StorageNode) error { +func (s *StorageNodeSet) Put(key BitArray, node *StorageNode) error { if node == nil { return errors.New("cannot put nil node") } @@ -217,7 +220,7 @@ func (s *StorageNodeSet) Size() int { // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. -func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { +func (t *Trie) nodesFromRoot(key *BitArray) ([]StorageNode, error) { var nodes []StorageNode cur := t.rootKey for cur != nil { @@ -236,8 +239,7 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { node: node, }) - subset := isSubset(key, cur) - if cur.Len() >= key.Len() || !subset { + if cur.Len() >= key.Len() || !key.EqualMSBs(cur) { return nodes, nil } @@ -267,12 +269,12 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { } // GetNodeFromKey returns the node for a given key. -func (t *Trie) GetNodeFromKey(key *Key) (*Node, error) { +func (t *Trie) GetNodeFromKey(key *BitArray) (*Node, error) { return t.storage.Get(key) } // check if we are updating an existing leaf, if yes avoid traversing the trie -func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { +func (t *Trie) updateLeaf(nodeKey BitArray, node *Node, value *felt.Felt) (*felt.Felt, error) { // Check if we are updating an existing leaf if !value.IsZero() { if existingLeaf, err := t.storage.Get(&nodeKey); err == nil { @@ -289,7 +291,7 @@ func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt return nil, nil } -func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { +func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey BitArray, node *Node, value *felt.Felt) (*felt.Felt, error) { if value.IsZero() { return nil, nil // no-op } @@ -301,7 +303,7 @@ func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *fe return &old, nil } -func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []StorageNode) (*felt.Felt, error) { +func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey BitArray, nodes []StorageNode) (*felt.Felt, error) { if nodeKey.Equal(sibling.key) { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -314,7 +316,7 @@ func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []Stora return nil, nil } -func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent StorageNode) { +func (t *Trie) replaceLinkWithNewParent(key *BitArray, commonKey BitArray, siblingParent StorageNode) { if siblingParent.node.Left.Equal(key) { *siblingParent.node.Left = commonKey } else { @@ -323,8 +325,9 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent S } // TODO(weiihann): not a good idea to couple proof verification logic with trie logic -func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { - commonKey, _ := findCommonKey(nodeKey, sibling.key) +func (t *Trie) insertOrUpdateValue(nodeKey *BitArray, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { + var commonKey BitArray + commonKey.CommonMSBs(nodeKey, sibling.key) newParent := &Node{} var leftChild, rightChild *Node @@ -497,19 +500,19 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, proof []*StorageNode) (*felt. } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutInner(key *Key, node *Node) error { +func (t *Trie) PutInner(key *BitArray, node *Node) error { if err := t.storage.Put(key, node); err != nil { return err } return nil } -func (t *Trie) setRootKey(newRootKey *Key) { +func (t *Trie) setRootKey(newRootKey *BitArray) { t.rootKey = newRootKey t.rootKeyIsDirty = true } -func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo +func (t *Trie) updateValueIfDirty(key *BitArray) (*Node, error) { //nolint:gocyclo node, err := t.storage.Get(key) if err != nil { return nil, err @@ -523,7 +526,7 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo shouldUpdate := false for _, dirtyNode := range t.dirtyNodes { if key.Len() < dirtyNode.Len() { - shouldUpdate = isSubset(dirtyNode, key) + shouldUpdate = key.EqualMSBs(dirtyNode) if shouldUpdate { break } @@ -531,9 +534,9 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo } // Update inner proof nodes - if node.Left.Equal(NilKey) && node.Right.Equal(NilKey) { // leaf + if node.Left.Equal(emptyBitArray) && node.Right.Equal(emptyBitArray) { // leaf shouldUpdate = false - } else if node.Left.Equal(NilKey) || node.Right.Equal(NilKey) { // inner + } else if node.Left.Equal(emptyBitArray) || node.Right.Equal(emptyBitArray) { // inner shouldUpdate = true } if !shouldUpdate { @@ -542,11 +545,11 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo var leftIsProof, rightIsProof bool var leftHash, rightHash *felt.Felt - if node.Left.Equal(NilKey) { // key could be nil but hash cannot be + if node.Left.Equal(emptyBitArray) { // key could be nil but hash cannot be leftIsProof = true leftHash = node.LeftHash } - if node.Right.Equal(NilKey) { + if node.Right.Equal(emptyBitArray) { rightIsProof = true rightHash = node.RightHash } @@ -643,7 +646,7 @@ func (t *Trie) deleteLast(nodes []StorageNode) error { return err } - var siblingKey Key + var siblingKey BitArray if parent.node.Left.Equal(last.key) { siblingKey = *parent.node.Right } else { @@ -710,7 +713,7 @@ func (t *Trie) Commit() error { } // RootKey returns db key of the [Trie] root node -func (t *Trie) RootKey() *Key { +func (t *Trie) RootKey() *BitArray { return t.rootKey } @@ -732,7 +735,7 @@ The following can be printed: The spacing to represent the levels of the trie can remain the same. */ -func (t *Trie) dump(level int, parentP *Key) { +func (t *Trie) dump(level int, parentP *BitArray) { if t.rootKey == nil { fmt.Printf("%sEMPTY\n", strings.Repeat("\t", level)) return diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 5426cbcaf..533450a68 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -55,16 +55,17 @@ func TestTrieKeys(t *testing.T) { // Check parent and its left right children l := tempTrie.FeltToKey(leftKey) r := tempTrie.FeltToKey(rightKey) - commonKey, isSame := findCommonKey(&l, &r) - require.False(t, isSame) + var commonKey BitArray + commonKey.CommonMSBs(&l, &r) // Common key should be 0b100, length 251-2; - expectKey := NewKey(251-2, []byte{0x4}) + // expectKey := NewKey(251-2, []byte{0x4}) + expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, commonKey) + assert.Equal(t, expectKey, &commonKey) // Current rootKey should be the common key - assert.Equal(t, expectKey, *tempTrie.rootKey) + assert.Equal(t, expectKey, tempTrie.rootKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -98,12 +99,12 @@ func TestTrieKeys(t *testing.T) { // Check parent and its left right children l := tempTrie.FeltToKey(leftKey) r := tempTrie.FeltToKey(rightKey) - commonKey, isSame := findCommonKey(&l, &r) - require.False(t, isSame) + var commonKey BitArray + commonKey.CommonMSBs(&l, &r) - expectKey := NewKey(251-2, []byte{0x4}) + expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, commonKey) + assert.Equal(t, expectKey, &commonKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -139,8 +140,8 @@ func TestTrieKeys(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b101) _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(250, []byte{0x2}) - parentNode, pErr := tempTrie.storage.Get(&commonKey) + commonKey := NewBitArray(250, 2) + parentNode, pErr := tempTrie.storage.Get(commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) @@ -150,8 +151,8 @@ func TestTrieKeys(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b110) _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(250, []byte{0x3}) - parentNode, pErr := tempTrie.storage.Get(&commonKey) + commonKey := NewBitArray(250, 3) + parentNode, pErr := tempTrie.storage.Get(commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) @@ -166,15 +167,15 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(248, []byte{}) - parentNode, err := tempTrie.storage.Get(&commonKey) + commonKey := NewBitArray(248, 0) + parentNode, err := tempTrie.storage.Get(commonKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) - expectRightKey := NewKey(249, []byte{0x1}) + expectRightKey := NewBitArray(249, 1) - assert.Equal(t, expectRightKey, *parentNode.Right) + assert.Equal(t, expectRightKey, parentNode.Right) }) }) } @@ -239,11 +240,11 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { _, err = tempTrie.Put(test.deleteKey, zeroVal) require.NoError(t, err) - newRootKey := NewKey(251-2, []byte{0x1}) + newRootKey := NewBitArray(249, 1) - assert.Equal(t, newRootKey, *tempTrie.rootKey) + assert.Equal(t, newRootKey, tempTrie.rootKey) - rootNode, err := tempTrie.storage.Get(&newRootKey) + rootNode, err := tempTrie.storage.Get(newRootKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(rightKey), *rootNode.Right) diff --git a/core/trie/trie_test.go b/core/trie/trie_test.go index 51d589ab6..7384bf558 100644 --- a/core/trie/trie_test.go +++ b/core/trie/trie_test.go @@ -164,7 +164,81 @@ func TestPutZero(t *testing.T) { var keys []*felt.Felt // put random 64 keys and record roots - for range 64 { + for i := 0; i < 64; i++ { + key, value := new(felt.Felt), new(felt.Felt) + + _, err = key.SetRandom() + require.NoError(t, err) + + t.Logf("key: %s", key.String()) + + _, err = value.SetRandom() + require.NoError(t, err) + + t.Logf("value: %s", value.String()) + + _, err = tempTrie.Put(key, value) + require.NoError(t, err) + + keys = append(keys, key) + + var root *felt.Felt + root, err = tempTrie.Root() + require.NoError(t, err) + + roots = append(roots, root) + } + + t.Run("adding a zero value to a non-existent key should not change Trie", func(t *testing.T) { + var key, root *felt.Felt + key, err = new(felt.Felt).SetRandom() + require.NoError(t, err) + + _, err = tempTrie.Put(key, new(felt.Felt)) + require.NoError(t, err) + + root, err = tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, true, root.Equal(roots[len(roots)-1])) + }) + + t.Run("remove keys one by one, check roots", func(t *testing.T) { + var gotRoot *felt.Felt + // put zero in reverse order and check roots still match + for i := range 64 { + root := roots[len(roots)-1-i] + + gotRoot, err = tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, root, gotRoot) + + key := keys[len(keys)-1-i] + _, err = tempTrie.Put(key, new(felt.Felt)) + require.NoError(t, err) + } + }) + + t.Run("empty roots should match", func(t *testing.T) { + actualEmptyRoot, err := tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, true, actualEmptyRoot.Equal(emptyRoot)) + }) + return nil + })) +} + +func TestTrie(t *testing.T) { + require.NoError(t, trie.RunOnTempTriePedersen(251, func(tempTrie *trie.Trie) error { + emptyRoot, err := tempTrie.Root() + require.NoError(t, err) + var roots []*felt.Felt + var keys []*felt.Felt + + // put random 64 keys and record roots + for i := 0; i < 64; i++ { key, value := new(felt.Felt), new(felt.Felt) _, err = key.SetRandom() diff --git a/migration/migration.go b/migration/migration.go index 107bd40f1..97ce613f5 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -511,7 +511,7 @@ func calculateL1MsgHashes(txn db.Transaction, n *utils.Network) error { return processBlocks(txn, processBlockFunc) } -func bitset2Key(bs *bitset.BitSet) *trie.Key { +func bitset2BitArray(bs *bitset.BitSet) *trie.BitArray { bsWords := bs.Words() if len(bsWords) > felt.Limbs { panic("key too long to fit in Felt") @@ -524,9 +524,7 @@ func bitset2Key(bs *bitset.BitSet) *trie.Key { } f := new(felt.Felt).SetBytes(bsBytes[:]) - fBytes := f.Bytes() - k := trie.NewKey(uint8(bs.Len()), fBytes[:]) - return &k + return new(trie.BitArray).SetFelt(uint8(bs.Len()), f) } func migrateTrieRootKeysFromBitsetToTrieKeys(txn db.Transaction, key, value []byte, _ *utils.Network) error { @@ -535,8 +533,8 @@ func migrateTrieRootKeysFromBitsetToTrieKeys(txn db.Transaction, key, value []by if err := bs.UnmarshalBinary(value); err != nil { return err } - trieKey := bitset2Key(&bs) - _, err := trieKey.WriteTo(&tempBuf) + trieKey := bitset2BitArray(&bs) + _, err := trieKey.Write(&tempBuf) if err != nil { return err } @@ -574,8 +572,8 @@ func migrateTrieNodesFromBitsetToTrieKey(target db.Bucket) BucketMigratorDoFunc Value: n.Value, } if n.Left != nil { - trieNode.Left = bitset2Key(n.Left) - trieNode.Right = bitset2Key(n.Right) + trieNode.Left = bitset2BitArray(n.Left) + trieNode.Right = bitset2BitArray(n.Right) } if _, err := trieNode.WriteTo(&tempBuf); err != nil { @@ -594,7 +592,7 @@ func migrateTrieNodesFromBitsetToTrieKey(target db.Bucket) BucketMigratorDoFunc } var keyBuffer bytes.Buffer - if _, err := bitset2Key(&bs).WriteTo(&keyBuffer); err != nil { + if _, err := bitset2BitArray(&bs).Write(&keyBuffer); err != nil { return err } diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index e2d5613c4..688643386 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -260,8 +260,11 @@ func TestMigrateTrieRootKeysFromBitsetToTrieKeys(t *testing.T) { require.NoError(t, migrateTrieRootKeysFromBitsetToTrieKeys(memTxn, key, bsBytes, &utils.Mainnet)) - var trieKey trie.Key - err = memTxn.Get(key, trieKey.UnmarshalBinary) + var trieKey trie.BitArray + err = memTxn.Get(key, func(data []byte) error { + trieKey.UnmarshalBinary(data) + return nil + }) require.NoError(t, err) require.Equal(t, bs.Len(), uint(trieKey.Len())) require.Equal(t, felt.Zero, trieKey.Felt()) @@ -357,7 +360,7 @@ func TestMigrateCairo1CompiledClass(t *testing.T) { } } -func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { +func TestMigrateTrieNodesFromBitsetToBitArray(t *testing.T) { migrator := migrateTrieNodesFromBitsetToTrieKey(db.ClassesTrie) memTxn := db.NewMemTransaction() @@ -388,9 +391,9 @@ func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { require.ErrorIs(t, err, db.ErrKeyNotFound) var nodeKeyBuf bytes.Buffer - newNodeKey := bitset2Key(bs) - wrote, err = newNodeKey.WriteTo(&nodeKeyBuf) - require.True(t, wrote > 0) + newNodeKey := bitset2BitArray(bs) + bWrite, err := newNodeKey.Write(&nodeKeyBuf) + require.True(t, bWrite > 0) require.NoError(t, err) var trieNode trie.Node From 9f5f5819b31651b743a4c3f95d1b06eee2e9ca66 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 10:35:09 +0800 Subject: [PATCH 20/45] improve comments --- core/trie/bitarray.go | 107 ++++++++++++++++++++----------------- core/trie/bitarray_test.go | 3 ++ 2 files changed, 60 insertions(+), 50 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 11a55edae..a2ee47b05 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -16,12 +16,9 @@ const ( bits8 = 8 ) -var ( - maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = new(BitArray) -) +var emptyBitArray = new(BitArray) -// BitArray is a structure that represents a bit array with length representing the number of used bits. +// Represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. // The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. @@ -35,6 +32,7 @@ func NewBitArray(length uint8, val uint64) *BitArray { return new(BitArray).SetUint64(length, val) } +// Returns the felt representation of the bit array. func (b *BitArray) Felt() felt.Felt { var f felt.Felt f.SetBytes(b.Bytes()) @@ -45,7 +43,7 @@ func (b *BitArray) Len() uint8 { return b.len } -// Bytes returns the bytes representation of the bit array in big endian format +// Returns the bytes representation of the bit array in big endian format // //nolint:mnd func (b *BitArray) Bytes() []byte { @@ -83,42 +81,7 @@ func (b *BitArray) Bytes() []byte { return res[:] } -// EqualMSBs checks if two bit arrays share the same most significant bits, where the length of -// the check is determined by the shorter array. Returns true if either array has -// length 0, or if the first min(b.len, x.len) MSBs are identical. -// -// For example: -// -// a = 1101 (len=4) -// b = 11010111 (len=8) -// a.EqualMSBs(b) = true // First 4 MSBs match -// -// a = 1100 (len=4) -// b = 1101 (len=4) -// a.EqualMSBs(b) = false // All bits compared, not equal -// -// a = 1100 (len=4) -// b = [] (len=0) -// a.EqualMSBs(b) = true // Zero length is always a prefix match -func (b *BitArray) EqualMSBs(x *BitArray) bool { - if b.len == x.len { - return b.Equal(x) - } - - if b.len == 0 || x.len == 0 { - return true - } - - // Compare only the first min(b.len, x.len) bits - minLen := b.len - if x.len < minLen { - minLen = x.len - } - - return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) -} - -// LSBs sets b to the least significant 'n' bits of x. +// Sets b to the least significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: @@ -164,7 +127,42 @@ func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { return b } -// MSBs sets b to the most significant 'n' bits of x. +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets b to the most significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: @@ -181,7 +179,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { return b.Rsh(x, x.len-n) } -// CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. +// Sets b to the longest sequence of matching most significant bits between two bit arrays. // For example: // // x = 1101 0111 (len=8) @@ -219,7 +217,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// Rsh sets b = x >> n and returns b. +// Sets b = x >> n and returns b. // //nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { @@ -264,7 +262,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { return b } -// Xor sets b = x ^ y and returns b. +// Sets b = x ^ y and returns b. func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] @@ -273,7 +271,7 @@ func (b *BitArray) Xor(x, y *BitArray) *BitArray { return b } -// Eq checks if two bit arrays are equal +// Checks if two bit arrays are equal func (b *BitArray) Equal(x *BitArray) bool { // TODO(weiihann): this is really not a good thing to do... if b == nil && x == nil { @@ -289,7 +287,7 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } -// IsBitSit returns true if bit n-th is set, where n = 0 is LSB. +// Returns true if bit n-th is set, where n = 0 is LSB. // The n must be <= 255. func (b *BitArray) IsBitSet(n uint8) bool { if n >= b.len { @@ -299,7 +297,7 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } -// Write serialises the BitArray into a bytes buffer in the following format: +// Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -314,7 +312,7 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } -// UnmarshalBinary deserialises the BitArray from a bytes buffer in the following format: +// Deserialises the BitArray from a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -328,6 +326,7 @@ func (b *BitArray) UnmarshalBinary(data []byte) { b.setBytes32(bs[:]) } +// Sets b to the same value as x. func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] @@ -337,40 +336,48 @@ func (b *BitArray) Set(x *BitArray) *BitArray { return b } +// Sets b to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } +// Sets b to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } +// Interprets the data as the big-endian bytes, sets b to that value and returns b. +// If the data is larger than 32 bytes, only the first 32 bytes are used. func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) b.len = length return b } +// Sets b to the uint64 representation of a bit array. func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length return b } +// Returns the length of the encoded bit array in bytes. func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 } +// Returns a deep copy of b. func (b *BitArray) Copy() BitArray { var res BitArray res.Set(b) return res } +// Returns a string representation of the bit array. func (b *BitArray) String() string { return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index c90223ab6..479df49fd 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "encoding/binary" + "math" "math/bits" "testing" @@ -11,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + const ( ones63 = 0x7FFFFFFFFFFFFFFF ) From dd8b290aa4e552511c248704ba419671fcf71727 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 10:49:33 +0800 Subject: [PATCH 21/45] fix lint --- core/trie/node.go | 4 +++- core/trie/trie.go | 8 +++++++- core/trie/trie_pkg_test.go | 2 -- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/core/trie/node.go b/core/trie/node.go index c51a0130d..ec20bd489 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -157,7 +157,9 @@ func (n *Node) Update(other *Node) error { return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) } - if n.Right != nil && other.Right != nil && !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && !n.Right.Equal(other.Right) { + if n.Right != nil && other.Right != nil && + !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && + !n.Right.Equal(other.Right) { return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) } diff --git a/core/trie/trie.go b/core/trie/trie.go index c8f00d2e8..d0b68fa77 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -325,7 +325,13 @@ func (t *Trie) replaceLinkWithNewParent(key *BitArray, commonKey BitArray, sibli } // TODO(weiihann): not a good idea to couple proof verification logic with trie logic -func (t *Trie) insertOrUpdateValue(nodeKey *BitArray, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { +func (t *Trie) insertOrUpdateValue( + nodeKey *BitArray, + node *Node, + nodes []StorageNode, + sibling StorageNode, + siblingIsParentProof bool, +) error { var commonKey BitArray commonKey.CommonMSBs(nodeKey, sibling.key) diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 533450a68..04037c4d7 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -135,7 +135,6 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) newVal := new(felt.Felt).SetUint64(12) - //nolint: dupl t.Run("Add to left branch", func(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b101) _, err = tempTrie.Put(newKey, newVal) @@ -146,7 +145,6 @@ func TestTrieKeys(t *testing.T) { assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) }) - //nolint: dupl t.Run("Add to right branch", func(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b110) _, err = tempTrie.Put(newKey, newVal) From f3226532f7da01793d4198dced516c4837082db1 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 11:04:28 +0800 Subject: [PATCH 22/45] minor chore --- core/trie/bitarray.go | 6 +++--- core/trie/bitarray_test.go | 26 +++++++++++++------------- core/trie/proof_test.go | 24 ------------------------ core/trie/trie.go | 1 + 4 files changed, 17 insertions(+), 40 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index a2ee47b05..3762d844c 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -83,7 +83,6 @@ func (b *BitArray) Bytes() []byte { // Sets b to the least significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. -// Any bits beyond the specified length are cleared to zero. // For example: // // x = 11001011 (len=8) @@ -164,7 +163,6 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { // Sets b to the most significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. -// Any bits beyond the specified length are cleared to zero. // For example: // // x = 11001011 (len=8) @@ -406,7 +404,7 @@ func (b *BitArray) byteCount() uint { } // activeBytes returns a slice containing only the bytes that are actually used -// by the bit array, excluding leading zero bytes. The returned slice is in +// by the bit array, as specified by the length. The returned slice is in // big-endian order. // // Example: @@ -448,11 +446,13 @@ func findFirstSetBit(b *BitArray) uint8 { return 0 } + // Start from the most significant and move towards the least significant for i := 3; i >= 0; i-- { if word := b.words[i]; word != 0 { return uint8((i+1)*64 - bits.LeadingZeros64(word)) } } + // All bits are zero, no set bit found return 0 } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 479df49fd..4c57794b0 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -15,7 +15,7 @@ import ( var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} const ( - ones63 = 0x7FFFFFFFFFFFFFFF + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 ) func TestBytes(t *testing.T) { @@ -231,7 +231,7 @@ func TestRsh(t *testing.T) { } } -func TestPrefixEqual(t *testing.T) { +func TestEqualMSBs(t *testing.T) { tests := []struct { name string a *BitArray @@ -357,7 +357,7 @@ func TestLSBs(t *testing.T) { expected BitArray }{ { - name: "truncate to zero", + name: "zero", initial: BitArray{ len: 64, words: [4]uint64{maxUint64, 0, 0, 0}, @@ -369,7 +369,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate within first word - 32 bits", + name: "get 32 LSBs", initial: BitArray{ len: 64, words: [4]uint64{maxUint64, 0, 0, 0}, @@ -381,7 +381,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate to single bit", + name: "get 1 LSB", initial: BitArray{ len: 64, words: [4]uint64{maxUint64, 0, 0, 0}, @@ -389,11 +389,11 @@ func TestLSBs(t *testing.T) { length: 1, expected: BitArray{ len: 1, - words: [4]uint64{0x0000000000000001, 0, 0, 0}, + words: [4]uint64{0x1, 0, 0, 0}, }, }, { - name: "truncate across words - 100 bits", + name: "get 100 LSBs across words", initial: BitArray{ len: 128, words: [4]uint64{maxUint64, maxUint64, 0, 0}, @@ -405,7 +405,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate at word boundary - 64 bits", + name: "get 64 LSBs at word boundary", initial: BitArray{ len: 128, words: [4]uint64{maxUint64, maxUint64, 0, 0}, @@ -417,7 +417,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate at word boundary - 128 bits", + name: "get 128 LSBs at word boundary", initial: BitArray{ len: 192, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, @@ -429,7 +429,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate in third word - 150 bits", + name: "get 150 LSBs in third word", initial: BitArray{ len: 192, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, @@ -441,7 +441,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate in fourth word - 220 bits", + name: "get 220 LSBs in fourth word", initial: BitArray{ len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, @@ -453,7 +453,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate max length - 251 bits", + name: "get 251 LSBs", initial: BitArray{ len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, @@ -465,7 +465,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate sparse bits", + name: "get 100 LSBs from sparse bits", initial: BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 0f9c54543..046b1b1bc 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -13,30 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestFix(t *testing.T) { - numKeys := 1000 - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) - - records := make([]*keyValue, numKeys) - for i := 1; i < numKeys+1; i++ { - key := new(felt.Felt).SetUint64(uint64(i)) - records[i-1] = &keyValue{key: key, value: key} - _, err := tempTrie.Put(key, key) - require.NoError(t, err) - } - - sort.Slice(records, func(i, j int) bool { - return records[i].key.Cmp(records[j].key) < 0 - }) - - require.NoError(t, tempTrie.Commit()) -} - func TestProve(t *testing.T) { t.Parallel() diff --git a/core/trie/trie.go b/core/trie/trie.go index d0b68fa77..0d809debc 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -106,6 +106,7 @@ func (t *Trie) FeltToKey(k *felt.Felt) BitArray { // path is suffix of key that diverges from parentKey. For example, // for a key 0b1011 and parentKey 0b10, this function would return the path object of 0b0. func path(key, parentKey *BitArray) BitArray { + // drop parent key, and one more MSB since left/right relation already encodes that information if parentKey == nil { return key.Copy() } From 979f949ab3a103b4b7950bf36bb13d27cc60c1ee Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 19 Dec 2024 12:53:03 +0800 Subject: [PATCH 23/45] improvements --- core/state.go | 2 +- core/trie/bitarray.go | 73 ++++++++++++++++++-------------------- core/trie/node_test.go | 2 +- core/trie/proof.go | 1 + core/trie/storage_test.go | 16 ++++----- core/trie/trie_pkg_test.go | 18 +++++----- 6 files changed, 55 insertions(+), 57 deletions(-) diff --git a/core/state.go b/core/state.go index c17ff13f3..27c20f057 100644 --- a/core/state.go +++ b/core/state.go @@ -139,7 +139,7 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr // fetch root key rootKeyDBKey := dbPrefix - var rootKey *trie.BitArray + var rootKey *trie.BitArray // TODO: use value instead of pointer err := s.txn.Get(rootKeyDBKey, func(val []byte) error { rootKey = new(trie.BitArray) rootKey.UnmarshalBinary(val) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 3762d844c..7f8d5481a 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,10 +11,7 @@ import ( "github.com/NethermindEth/juno/core/felt" ) -const ( - maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF - bits8 = 8 -) +const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF var emptyBitArray = new(BitArray) @@ -28,8 +25,10 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func NewBitArray(length uint8, val uint64) *BitArray { - return new(BitArray).SetUint64(length, val) +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b } // Returns the felt representation of the bit array. @@ -81,8 +80,8 @@ func (b *BitArray) Bytes() []byte { return res[:] } -// Sets b to the least significant 'n' bits of x. -// If n >= x.len, b is an exact copy of x. +// Sets the bit array to the least significant 'n' bits of x. +// If length >= x.len, the bit array is an exact copy of x. // For example: // // x = 11001011 (len=8) @@ -91,35 +90,35 @@ func (b *BitArray) Bytes() []byte { // LSBs(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { - if length >= x.len { +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { return b.Set(x) } b.Set(x) - b.len = length + b.len = n // Clear all words beyond what's needed switch { - case length == 0: + case n == 0: b.words = [4]uint64{0, 0, 0, 0} - case length <= 64: - mask := maxUint64 >> (64 - length) + case n <= 64: + mask := maxUint64 >> (64 - n) b.words[0] &= mask b.words[1] = 0 b.words[2] = 0 b.words[3] = 0 - case length <= 128: - mask := maxUint64 >> (128 - length) + case n <= 128: + mask := maxUint64 >> (128 - n) b.words[1] &= mask b.words[2] = 0 b.words[3] = 0 - case length <= 192: - mask := maxUint64 >> (192 - length) + case n <= 192: + mask := maxUint64 >> (192 - n) b.words[2] &= mask b.words[3] = 0 default: - mask := maxUint64 >> (256 - uint16(length)) + mask := maxUint64 >> (256 - uint16(n)) b.words[3] &= mask } @@ -161,8 +160,8 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) } -// Sets b to the most significant 'n' bits of x. -// If n >= x.len, b is an exact copy of x. +// Sets the bit array to the most significant 'n' bits of x. +// If n >= x.len, the bit array is an exact copy of x. // For example: // // x = 11001011 (len=8) @@ -177,7 +176,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { return b.Rsh(x, x.len-n) } -// Sets b to the longest sequence of matching most significant bits between two bit arrays. +// Sets the bit array to the longest sequence of matching most significant bits between two bit arrays. // For example: // // x = 1101 0111 (len=8) @@ -185,7 +184,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { // CommonMSBs(x,y) = 1101 (len=4) func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { if x.len == 0 || y.len == 0 { - return emptyBitArray + return b.clear() } long, short := x, y @@ -215,7 +214,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// Sets b = x >> n and returns b. +// Sets the bit array to x >> n and returns the bit array. // //nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { @@ -260,7 +259,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { return b } -// Sets b = x ^ y and returns b. +// Sets the bit array to x ^ y and returns the bit array. func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] @@ -324,7 +323,7 @@ func (b *BitArray) UnmarshalBinary(data []byte) { b.setBytes32(bs[:]) } -// Sets b to the same value as x. +// Sets the bit array to the same value as x. func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] @@ -334,21 +333,21 @@ func (b *BitArray) Set(x *BitArray) *BitArray { return b } -// Sets b to the bytes representation of a felt. +// Sets the bit array to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } -// Sets b to the bytes representation of a felt with length 251. +// Sets the bit array to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } -// Interprets the data as the big-endian bytes, sets b to that value and returns b. +// Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. // If the data is larger than 32 bytes, only the first 32 bytes are used. func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) @@ -356,7 +355,7 @@ func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { return b } -// Sets b to the uint64 representation of a bit array. +// Sets the bit array to the uint64 representation of a bit array. func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length @@ -368,7 +367,7 @@ func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 } -// Returns a deep copy of b. +// Returns a deep copy of the bit array. func (b *BitArray) Copy() BitArray { var res BitArray res.Set(b) @@ -396,16 +395,16 @@ func (b *BitArray) setBytes32(data []byte) { b.words[0] = binary.BigEndian.Uint64(data[24:32]) } -// byteCount returns the minimum number of bytes needed to represent the bit array. +// Returns the minimum number of bytes needed to represent the bit array. // It rounds up to the nearest byte. func (b *BitArray) byteCount() uint { + const bits8 = 8 // Cast to uint16 to avoid overflow return (uint(b.len) + (bits8 - 1)) / uint(bits8) } -// activeBytes returns a slice containing only the bytes that are actually used -// by the bit array, as specified by the length. The returned slice is in -// big-endian order. +// Returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. // // Example: // @@ -433,9 +432,7 @@ func (b *BitArray) clear() *BitArray { return b } -// findFirstSetBit returns the position of the first '1' bit in the array, -// scanning from most significant to least significant bit. -// +// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. // The bit position is counted from the least significant bit, starting at 0. // For example: // diff --git a/core/trie/node_test.go b/core/trie/node_test.go index b222732f4..cc1bb06ed 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -24,5 +24,5 @@ func TestNodeHash(t *testing.T) { } path := trie.NewBitArray(6, 42) - assert.Equal(t, expected, node.Hash(path, crypto.Pedersen), "TestTrieNode_Hash failed") + assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") } diff --git a/core/trie/proof.go b/core/trie/proof.go index f4c624705..d844717ae 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -48,6 +48,7 @@ func (e *Edge) Hash(hash crypto.HashFn) *felt.Felt { length[len(e.Path.bitset)-1] = e.Path.len pathFelt := e.Path.Felt() lengthFelt := new(felt.Felt).SetBytes(length[:]) + // TODO: no need to return reference, just return value to avoid heap allocation return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt) } diff --git a/core/trie/storage_test.go b/core/trie/storage_test.go index 37a4e8e44..21302f130 100644 --- a/core/trie/storage_test.go +++ b/core/trie/storage_test.go @@ -27,7 +27,7 @@ func TestStorage(t *testing.T) { t.Run("put a node", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Put(key, node) + return tTxn.Put(&key, node) })) }) @@ -35,7 +35,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(key) + got, err = tTxn.Get(&key) require.NoError(t, err) assert.Equal(t, node, got) return err @@ -46,7 +46,7 @@ func TestStorage(t *testing.T) { // Successfully delete a node and return an error to force a roll back. require.Error(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - err = tTxn.Delete(key) + err = tTxn.Delete(&key) require.NoError(t, err) return errors.New("should rollback") })) @@ -56,7 +56,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(key) + got, err = tTxn.Get(&key) assert.Equal(t, node, got) return err })) @@ -66,13 +66,13 @@ func TestStorage(t *testing.T) { // Delete a node. require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Delete(key) + return tTxn.Delete(&key) })) // Node should no longer exist in the database. require.EqualError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - _, err = tTxn.Get(key) + _, err = tTxn.Get(&key) return err }), db.ErrKeyNotFound.Error()) }) @@ -82,7 +82,7 @@ func TestStorage(t *testing.T) { t.Run("put root key", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.PutRootKey(rootKey) + return tTxn.PutRootKey(&rootKey) })) }) @@ -91,7 +91,7 @@ func TestStorage(t *testing.T) { tTxn := trie.NewStorage(txn, prefix) gotRootKey, err := tTxn.RootKey() require.NoError(t, err) - assert.Equal(t, rootKey, gotRootKey) + assert.Equal(t, &rootKey, gotRootKey) return nil })) }) diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 04037c4d7..d9d13b1e4 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -62,10 +62,10 @@ func TestTrieKeys(t *testing.T) { // expectKey := NewKey(251-2, []byte{0x4}) expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, &commonKey) + assert.Equal(t, expectKey, commonKey) // Current rootKey should be the common key - assert.Equal(t, expectKey, tempTrie.rootKey) + assert.Equal(t, &expectKey, tempTrie.rootKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -104,7 +104,7 @@ func TestTrieKeys(t *testing.T) { expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, &commonKey) + assert.Equal(t, &expectKey, &commonKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -140,7 +140,7 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) commonKey := NewBitArray(250, 2) - parentNode, pErr := tempTrie.storage.Get(commonKey) + parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) @@ -150,7 +150,7 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) commonKey := NewBitArray(250, 3) - parentNode, pErr := tempTrie.storage.Get(commonKey) + parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) @@ -166,14 +166,14 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) commonKey := NewBitArray(248, 0) - parentNode, err := tempTrie.storage.Get(commonKey) + parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) expectRightKey := NewBitArray(249, 1) - assert.Equal(t, expectRightKey, parentNode.Right) + assert.Equal(t, &expectRightKey, parentNode.Right) }) }) } @@ -240,9 +240,9 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { newRootKey := NewBitArray(249, 1) - assert.Equal(t, newRootKey, tempTrie.rootKey) + assert.Equal(t, &newRootKey, tempTrie.rootKey) - rootNode, err := tempTrie.storage.Get(newRootKey) + rootNode, err := tempTrie.storage.Get(&newRootKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(rightKey), *rootNode.Right) From ed8b0eaa9a9caa2da8e6b6d91f2bb2a1ad1ba5f7 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 20 Dec 2024 11:56:48 +0800 Subject: [PATCH 24/45] ensure unused bits are zero when setting bitarray --- core/trie/bitarray.go | 64 +++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 7f8d5481a..2abfc9be2 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -43,39 +43,14 @@ func (b *BitArray) Len() uint8 { } // Returns the bytes representation of the bit array in big endian format -// -//nolint:mnd func (b *BitArray) Bytes() []byte { var res [32]byte - switch { - case b.len == 0: - // all zeros - return res[:] - case b.len >= 192: - // Create mask for top word: keeps only valid bits above 192 - // e.g., if len=200, keeps lowest 8 bits (200-192) - mask := maxUint64 >> (256 - uint16(b.len)) - binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) - binary.BigEndian.PutUint64(res[8:16], b.words[2]) - binary.BigEndian.PutUint64(res[16:24], b.words[1]) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 128: - // Mask for bits 128-191: keeps only valid bits above 128 - // e.g., if len=150, keeps lowest 22 bits (150-128) - mask := maxUint64 >> (192 - b.len) - binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) - binary.BigEndian.PutUint64(res[16:24], b.words[1]) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 64: - // You get the idea - mask := maxUint64 >> (128 - b.len) - binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - default: - mask := maxUint64 >> (64 - b.len) - binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) - } + b.truncateToLength() + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) return res[:] } @@ -335,15 +310,17 @@ func (b *BitArray) Set(x *BitArray) *BitArray { // Sets the bit array to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { - b.setFelt(f) b.len = length + b.setFelt(f) + b.truncateToLength() return b } // Sets the bit array to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { - b.setFelt(f) b.len = 251 + b.setFelt(f) + b.truncateToLength() return b } @@ -352,6 +329,7 @@ func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) b.len = length + b.truncateToLength() return b } @@ -359,6 +337,7 @@ func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length + b.truncateToLength() return b } @@ -432,6 +411,27 @@ func (b *BitArray) clear() *BitArray { return b } +// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +//nolint:mnd +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + // Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. // The bit position is counted from the least significant bit, starting at 0. // For example: From 131f640206778fbafabfdb85e2b914b990b64bba Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 20 Dec 2024 12:01:49 +0800 Subject: [PATCH 25/45] update comment --- core/trie/bitarray.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 2abfc9be2..85aa4fcca 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -437,7 +437,7 @@ func (b *BitArray) truncateToLength() { // For example: // // array = 0000 0000 ... 0100 (len=251) -// findFirstSetBit() = 2 // third bit from right is set +// findFirstSetBit() = 3 // third bit from right is set func findFirstSetBit(b *BitArray) uint8 { if b.len == 0 { return 0 From bdca4e1526cbc674f4e9c1277e8c5227bd3586e3 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 26 Dec 2024 18:58:31 +0800 Subject: [PATCH 26/45] Add LSBsAtPos() --- core/trie/bitarray.go | 20 +++++++ core/trie/bitarray_test.go | 112 +++++++++++++++++++++++++++++++++++++ core/trie/proof.go | 2 +- core/trie/trie.go | 2 +- 4 files changed, 134 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 85aa4fcca..d56e7cb47 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -100,6 +100,26 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { return b } +// Returns the least significant bits of `x` with `pos` as the most significant bit. +// `pos` is counted from the most significant bit, starting at 0. +// For example: +// +// x = 11001011 (len=8) +// LSBsAtPos(x, 1) = 1001011 (len=7) +// LSBsAtPos(x, 10) = 0 (len=0) +// LSBsAtPos(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBsAtPos(x *BitArray, pos uint8) *BitArray { + if pos == 0 { + return b.Set(x) + } + + if pos > x.Len() { + return b.clear() + } + + return b.LSBs(x, x.Len()-pos) +} + // Checks if the current bit array share the same most significant bits with another, where the length of // the check is determined by the shorter array. Returns true if either array has // length 0, or if the first min(b.len, x.len) MSBs are identical. diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 4c57794b0..5584c1174 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1170,3 +1170,115 @@ func TestSetFeltValidation(t *testing.T) { }) } } + +func TestLSBsAtPos(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBsAtPos(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBsAtPos() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/trie/proof.go b/core/trie/proof.go index d844717ae..24fc577bb 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBs(key, key.Len()-curPos).EqualMSBs(edgePath) + return new(BitArray).LSBsAtPos(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. diff --git a/core/trie/trie.go b/core/trie/trie.go index 0d809debc..3fd4c8b57 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -112,7 +112,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBs(key, key.Len()-parentKey.Len()-1) + pathKey.LSBsAtPos(key, parentKey.Len()+1) return pathKey } From b53c94966ae08f2bc11b2b1cb2709d7774710ea1 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 26 Dec 2024 19:36:13 +0800 Subject: [PATCH 27/45] Add BitSet() --- core/trie/bitarray.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index d56e7cb47..0d820bf88 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -280,13 +280,22 @@ func (b *BitArray) Equal(x *BitArray) bool { } // Returns true if bit n-th is set, where n = 0 is LSB. -// The n must be <= 255. func (b *BitArray) IsBitSet(n uint8) bool { + return b.BitSet(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSet(n uint8) uint8 { if n >= b.len { - return false + return 0 } - return (b.words[n/64] & (1 << (n % 64))) != 0 + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 } // Serialises the BitArray into a bytes buffer in the following format: From 144914dd5e249e9815879688320f2d632ddec1da Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 27 Dec 2024 12:29:40 +0800 Subject: [PATCH 28/45] Add BitSetAtMSB() --- core/trie/bitarray.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 0d820bf88..b85de665a 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -298,6 +298,11 @@ func (b *BitArray) BitSet(n uint8) uint8 { return 0 } +// Returns the bit value at the most significant bit +func (b *BitArray) BitSetAtMSB() uint8 { + return b.BitSet(b.Len() - 1) +} + // Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order From 05dc052946f3ccb7defa77c44f60329e0e16dc78 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 27 Dec 2024 13:55:17 +0800 Subject: [PATCH 29/45] add IsEmpty() --- core/trie/bitarray.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index b85de665a..41afa097c 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -303,6 +303,10 @@ func (b *BitArray) BitSetAtMSB() uint8 { return b.BitSet(b.Len() - 1) } +func (b *BitArray) IsEmpty() bool { + return b.len == 0 +} + // Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order From 83d88f1f1d06aed64fcaa54fa38a349319b6e425 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 27 Dec 2024 19:22:59 +0800 Subject: [PATCH 30/45] Add Lsh, Or, Append --- core/trie/bitarray.go | 96 +++++++++++- core/trie/bitarray_test.go | 294 +++++++++++++++++++++++++++++++++++++ 2 files changed, 389 insertions(+), 1 deletion(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 41afa097c..035c89dce 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,7 +11,10 @@ import ( "github.com/NethermindEth/juno/core/felt" ) -const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) var emptyBitArray = new(BitArray) @@ -251,6 +254,85 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { b.words[3] >>= n } + b.truncateToLength() + return b +} + +// Lsh sets the bit array to x << n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { + b.Set(x) + + if x.len == 0 || n == 0 { + return b + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n == 0: + return b + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +// Sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + + // First copy x + b.Set(x) + + // Then shift left by y's length and OR with y + return b.Lsh(b, y.len).Or(b, y) +} + +// Sets the bit array to x | y and returns the bit array. +func (b *BitArray) Or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len return b } @@ -443,6 +525,18 @@ func (b *BitArray) rsh192(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] } +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + func (b *BitArray) clear() *BitArray { b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 5584c1174..973eb3ea7 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -231,6 +231,300 @@ func TestRsh(t *testing.T) { } } +func TestLsh(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 5, + want: emptyBitArray, + }, + { + name: "shift by 0", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 4, + want: &BitArray{ + len: 8, + words: [4]uint64{0xF0, 0, 0, 0}, // 11110000 + }, + }, + { + name: "shift across word boundary", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 62, + want: &BitArray{ + len: 66, + words: [4]uint64{0xC000000000000000, 0x3, 0, 0}, + }, + }, + { + name: "shift by 64 (full word)", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 64, + want: &BitArray{ + len: 72, + words: [4]uint64{0, 0xFF, 0, 0}, + }, + }, + { + name: "shift by 128", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 128, + want: &BitArray{ + len: 136, + words: [4]uint64{0, 0, 0xFF, 0}, + }, + }, + { + name: "shift by 192", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 192, + want: &BitArray{ + len: 200, + words: [4]uint64{0, 0, 0, 0xFF}, + }, + }, + { + name: "shift causing length overflow", + x: &BitArray{ + len: 200, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{ + 0xF000000000000000, + 0xF, + 0, + 0, + }, + }, + }, + { + name: "shift sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + n: 4, + want: &BitArray{ + len: 12, + words: [4]uint64{0xAA0, 0, 0, 0}, // 101010100000 + }, + }, + { + name: "shift partial word across boundary", + x: &BitArray{ + len: 100, + words: [4]uint64{0xFF, 0xFF, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 160, + words: [4]uint64{ + 0xF000000000000000, + 0xF00000000000000F, + 0xF, + 0, + }, + }, + }, + { + name: "near maximum length shift", + x: &BitArray{ + len: 251, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 4, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{0xFF0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Lsh(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("Lsh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "both empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "first array empty", + x: emptyBitArray, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "second array empty", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: emptyBitArray, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "different lengths within word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 2, + words: [4]uint64{0x3, 0, 0, 0}, // 11 + }, + want: &BitArray{ + len: 6, + words: [4]uint64{0x3F, 0, 0, 0}, // 111111 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 62, + words: [4]uint64{0x3FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 66, + words: [4]uint64{maxUint64, 0x3, 0, 0}, + }, + }, + { + name: "across multiple words", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + y: &BitArray{ + len: 8, + words: [4]uint64{0x55, 0, 0, 0}, // 01010101 + }, + want: &BitArray{ + len: 16, + words: [4]uint64{0xAA55, 0, 0, 0}, // 1010101001010101 + }, + }, + { + name: "result exactly at length limit", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + want: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Append(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("Append() = %v, want %v", got, tt.want) + } + }) + } +} + func TestEqualMSBs(t *testing.T) { tests := []struct { name string From 28905b3e5f6085fe5c45d424e06a31015beee508 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 27 Dec 2024 20:08:56 +0800 Subject: [PATCH 31/45] fix rebase --- core/trie/node.go | 4 ++-- core/trie/proof.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/trie/node.go b/core/trie/node.go index ec20bd489..51d5b7678 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -20,7 +20,7 @@ type Node struct { } // Hash calculates the hash of a [Node] -func (n *Node) Hash(path *Key, hashFn crypto.HashFn) *felt.Felt { +func (n *Node) Hash(path *BitArray, hashFn crypto.HashFn) *felt.Felt { if path.Len() == 0 { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -35,7 +35,7 @@ func (n *Node) Hash(path *Key, hashFn crypto.HashFn) *felt.Felt { } // Hash calculates the hash of a [Node] -func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFn crypto.HashFn) *felt.Felt { +func (n *Node) HashFromParent(parentKey, nodeKey *BitArray, hashFn crypto.HashFn) *felt.Felt { path := path(nodeKey, parentKey) return n.Hash(&path, hashFn) } diff --git a/core/trie/proof.go b/core/trie/proof.go index 24fc577bb..73f32b852 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -44,8 +44,8 @@ type Edge struct { } func (e *Edge) Hash(hash crypto.HashFn) *felt.Felt { - length := make([]byte, len(e.Path.bitset)) - length[len(e.Path.bitset)-1] = e.Path.len + var length [32]byte + length[31] = e.Path.len pathFelt := e.Path.Felt() lengthFelt := new(felt.Felt).SetBytes(length[:]) // TODO: no need to return reference, just return value to avoid heap allocation @@ -139,7 +139,7 @@ func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSe // - The path bits don't match the key bits // - The proof ends before processing all key bits func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) (*felt.Felt, error) { - key := FeltToKey(globalTrieHeight, keyFelt) + keyBits := new(BitArray).SetFelt(globalTrieHeight, keyFelt) expectedHash := root var curPos uint8 From f05d2d1008020c1f1147591fef1f36b84ee50019 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 17:24:24 +0800 Subject: [PATCH 32/45] Add BitSetFromMSB() --- core/trie/bitarray.go | 24 +++++++++++++++++------- core/trie/bitarray_test.go | 6 +++--- core/trie/proof.go | 2 +- core/trie/trie.go | 2 +- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 035c89dce..d3f99d322 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -59,6 +59,7 @@ func (b *BitArray) Bytes() []byte { } // Sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. // If length >= x.len, the bit array is an exact copy of x. // For example: // @@ -103,15 +104,14 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { return b } -// Returns the least significant bits of `x` with `pos` as the most significant bit. -// `pos` is counted from the most significant bit, starting at 0. +// Returns the least significant bits of `x` with `pos` counted from the most significant bit, starting at 0. // For example: // // x = 11001011 (len=8) -// LSBsAtPos(x, 1) = 1001011 (len=7) -// LSBsAtPos(x, 10) = 0 (len=0) -// LSBsAtPos(x, 0) = 11001011 (len=8, original x) -func (b *BitArray) LSBsAtPos(x *BitArray, pos uint8) *BitArray { +// LSBsFromMSB(x, 1) = 1001011 (len=7) +// LSBsFromMSB(x, 10) = 0 (len=0) +// LSBsFromMSB(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { if pos == 0 { return b.Set(x) } @@ -380,8 +380,18 @@ func (b *BitArray) BitSet(n uint8) uint8 { return 0 } +// Returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSetFromMSB(n uint8) uint8 { + if n >= b.Len() { + return 0 + } + + return b.BitSet(b.Len() - n - 1) +} + // Returns the bit value at the most significant bit -func (b *BitArray) BitSetAtMSB() uint8 { +func (b *BitArray) MSB() uint8 { return b.BitSet(b.Len() - 1) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 973eb3ea7..5fbec5f6e 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1465,7 +1465,7 @@ func TestSetFeltValidation(t *testing.T) { } } -func TestLSBsAtPos(t *testing.T) { +func TestLSBsFromMSB(t *testing.T) { tests := []struct { name string x *BitArray @@ -1569,9 +1569,9 @@ func TestLSBsAtPos(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := new(BitArray).LSBsAtPos(tt.x, tt.pos) + got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) if !got.Equal(tt.want) { - t.Errorf("LSBsAtPos() = %v, want %v", got, tt.want) + t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) } }) } diff --git a/core/trie/proof.go b/core/trie/proof.go index 73f32b852..2656615f9 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBsAtPos(key, curPos).EqualMSBs(edgePath) + return new(BitArray).LSBsFromMSB(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. diff --git a/core/trie/trie.go b/core/trie/trie.go index 3fd4c8b57..8bbbc2cd6 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -112,7 +112,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBsAtPos(key, parentKey.Len()+1) + pathKey.LSBsFromMSB(key, parentKey.Len()+1) return pathKey } From 8393471318bc67556a6e74d823fddf7a1402a662 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 17:47:16 +0800 Subject: [PATCH 33/45] Reverse methods between LSBs and MSBs --- core/trie/bitarray.go | 54 +++++---- core/trie/bitarray_test.go | 232 ++++++++++++++++++------------------- core/trie/proof.go | 9 +- core/trie/trie.go | 2 +- 4 files changed, 153 insertions(+), 144 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index d3f99d322..1dad274e1 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -64,12 +64,12 @@ func (b *BitArray) Bytes() []byte { // For example: // // x = 11001011 (len=8) -// LSBs(x, 4) = 1011 (len=4) -// LSBs(x, 10) = 11001011 (len=8, original x) -// LSBs(x, 0) = 0 (len=0) +// LSBsFromLSB(x, 4) = 1011 (len=4) +// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) +// LSBsFromLSB(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { +func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { if n >= x.len { return b.Set(x) } @@ -108,10 +108,10 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { // For example: // // x = 11001011 (len=8) -// LSBsFromMSB(x, 1) = 1001011 (len=7) -// LSBsFromMSB(x, 10) = 0 (len=0) -// LSBsFromMSB(x, 0) = 11001011 (len=8, original x) -func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { +// LSBs(x, 1) = 1001011 (len=7) +// LSBs(x, 10) = 0 (len=0) +// LSBs(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBs(x *BitArray, pos uint8) *BitArray { if pos == 0 { return b.Set(x) } @@ -120,7 +120,7 @@ func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { return b.clear() } - return b.LSBs(x, x.Len()-pos) + return b.LSBsFromLSB(x, x.Len()-pos) } // Checks if the current bit array share the same most significant bits with another, where the length of @@ -361,38 +361,48 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } -// Returns true if bit n-th is set, where n = 0 is LSB. +// Returns true if bit n-th is set, where n = 0 is MSB. func (b *BitArray) IsBitSet(n uint8) bool { return b.BitSet(n) == 1 } -// Returns the bit value at position n, where n = 0 is LSB. +// Returns true if bit n-th is set, where n = 0 is LSB. +func (b *BitArray) IsBitSetFromLSB(n uint8) bool { + return b.BitSetFromLSB(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is MSB. // If n is out of bounds, returns 0. func (b *BitArray) BitSet(n uint8) uint8 { - if n >= b.len { + if n >= b.Len() { return 0 } - if (b.words[n/64] & (1 << (n % 64))) != 0 { - return 1 - } - - return 0 + return b.BitSetFromLSB(b.Len() - n - 1) } -// Returns the bit value at position n, where n = 0 is MSB. +// Returns the bit value at position n, where n = 0 is LSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSetFromMSB(n uint8) uint8 { - if n >= b.Len() { +func (b *BitArray) BitSetFromLSB(n uint8) uint8 { + if n >= b.len { return 0 } - return b.BitSet(b.Len() - n - 1) + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 } // Returns the bit value at the most significant bit func (b *BitArray) MSB() uint8 { - return b.BitSet(b.Len() - 1) + return b.BitSet(0) +} + +// Returns the bit value at the least significant bit +func (b *BitArray) LSB() uint8 { + return b.BitSetFromLSB(0) } func (b *BitArray) IsEmpty() bool { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 5fbec5f6e..cf1a055fe 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -644,6 +644,118 @@ func TestEqualMSBs(t *testing.T) { } func TestLSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBs(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBsFromLSB(t *testing.T) { tests := []struct { name string initial BitArray @@ -798,7 +910,7 @@ func TestLSBs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(BitArray).LSBs(&tt.initial, tt.length) + result := new(BitArray).LSBsFromLSB(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -1222,7 +1334,7 @@ func TestCommonPrefix(t *testing.T) { } } -func TestIsBitSet(t *testing.T) { +func TestIsBitSetFromLSB(t *testing.T) { tests := []struct { name string ba BitArray @@ -1323,9 +1435,9 @@ func TestIsBitSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.ba.IsBitSet(tt.pos) + got := tt.ba.IsBitSetFromLSB(tt.pos) if got != tt.want { - t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + t.Errorf("IsBitSetFromLSB(%d) = %v, want %v", tt.pos, got, tt.want) } }) } @@ -1464,115 +1576,3 @@ func TestSetFeltValidation(t *testing.T) { }) } } - -func TestLSBsFromMSB(t *testing.T) { - tests := []struct { - name string - x *BitArray - pos uint8 - want *BitArray - }{ - { - name: "zero position", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 0, - want: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - }, - { - name: "position beyond length", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 65, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "get last 4 bits", - x: &BitArray{ - len: 8, - words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 - }, - pos: 4, - want: &BitArray{ - len: 4, - words: [4]uint64{0x0F, 0, 0, 0}, // 1111 - }, - }, - { - name: "get bits across word boundary", - x: &BitArray{ - len: 128, - words: [4]uint64{maxUint64, maxUint64, 0, 0}, - }, - pos: 64, - want: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - }, - { - name: "get bits from max length array", - x: &BitArray{ - len: 251, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, - }, - pos: 200, - want: &BitArray{ - len: 51, - words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "empty array", - x: emptyBitArray, - pos: 1, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "sparse bits", - x: &BitArray{ - len: 16, - words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 - }, - pos: 8, - want: &BitArray{ - len: 8, - words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 - }, - }, - { - name: "position equals length", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 64, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) - if !got.Equal(tt.want) { - t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/core/trie/proof.go b/core/trie/proof.go index 2656615f9..d6f5a2751 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -161,7 +161,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if keyBits.IsBitSet(keyBits.Len() - curPos - 1) { + if keyBits.IsBitSet(curPos + 1) { expectedHash = node.RightHash } curPos++ @@ -489,7 +489,7 @@ func handleBinaryNode( // Calculate next position and determine to take left or right path nextPos := curPos + 1 - isRightPath := key.IsBitSet(key.Len() - nextPos) + isRightPath := key.IsBitSet(nextPos) nextHash := binary.LeftHash if isRightPath { nextHash = binary.RightHash @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBsFromMSB(key, curPos).EqualMSBs(edgePath) + return new(BitArray).LSBs(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. @@ -602,8 +602,7 @@ func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { // If we're taking a left path and there's a right sibling, // then there are elements with larger values - bitPos := key.Len() - cur.Len() - 1 - isLeft := !key.IsBitSet(bitPos) + isLeft := !key.IsBitSet(cur.Len() + 1) if isLeft && sn.node.RightHash != nil { return true } diff --git a/core/trie/trie.go b/core/trie/trie.go index 8bbbc2cd6..a52353841 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -112,7 +112,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBsFromMSB(key, parentKey.Len()+1) + pathKey.LSBs(key, parentKey.Len()+1) return pathKey } From ef68aeb179dc5ced5fdedbd3f23d6415b324a251 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 17:50:33 +0800 Subject: [PATCH 34/45] Revert "Reverse methods between LSBs and MSBs" This reverts commit a06a8cea3269860cec14ac0a4a1f2afef3afdcfd. --- core/trie/bitarray.go | 54 ++++----- core/trie/bitarray_test.go | 232 ++++++++++++++++++------------------- core/trie/proof.go | 9 +- core/trie/trie.go | 2 +- 4 files changed, 144 insertions(+), 153 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 1dad274e1..d3f99d322 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -64,12 +64,12 @@ func (b *BitArray) Bytes() []byte { // For example: // // x = 11001011 (len=8) -// LSBsFromLSB(x, 4) = 1011 (len=4) -// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) -// LSBsFromLSB(x, 0) = 0 (len=0) +// LSBs(x, 4) = 1011 (len=4) +// LSBs(x, 10) = 11001011 (len=8, original x) +// LSBs(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { if n >= x.len { return b.Set(x) } @@ -108,10 +108,10 @@ func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { // For example: // // x = 11001011 (len=8) -// LSBs(x, 1) = 1001011 (len=7) -// LSBs(x, 10) = 0 (len=0) -// LSBs(x, 0) = 11001011 (len=8, original x) -func (b *BitArray) LSBs(x *BitArray, pos uint8) *BitArray { +// LSBsFromMSB(x, 1) = 1001011 (len=7) +// LSBsFromMSB(x, 10) = 0 (len=0) +// LSBsFromMSB(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { if pos == 0 { return b.Set(x) } @@ -120,7 +120,7 @@ func (b *BitArray) LSBs(x *BitArray, pos uint8) *BitArray { return b.clear() } - return b.LSBsFromLSB(x, x.Len()-pos) + return b.LSBs(x, x.Len()-pos) } // Checks if the current bit array share the same most significant bits with another, where the length of @@ -361,29 +361,14 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } -// Returns true if bit n-th is set, where n = 0 is MSB. +// Returns true if bit n-th is set, where n = 0 is LSB. func (b *BitArray) IsBitSet(n uint8) bool { return b.BitSet(n) == 1 } -// Returns true if bit n-th is set, where n = 0 is LSB. -func (b *BitArray) IsBitSetFromLSB(n uint8) bool { - return b.BitSetFromLSB(n) == 1 -} - -// Returns the bit value at position n, where n = 0 is MSB. -// If n is out of bounds, returns 0. -func (b *BitArray) BitSet(n uint8) uint8 { - if n >= b.Len() { - return 0 - } - - return b.BitSetFromLSB(b.Len() - n - 1) -} - // Returns the bit value at position n, where n = 0 is LSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSetFromLSB(n uint8) uint8 { +func (b *BitArray) BitSet(n uint8) uint8 { if n >= b.len { return 0 } @@ -395,14 +380,19 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 { return 0 } -// Returns the bit value at the most significant bit -func (b *BitArray) MSB() uint8 { - return b.BitSet(0) +// Returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSetFromMSB(n uint8) uint8 { + if n >= b.Len() { + return 0 + } + + return b.BitSet(b.Len() - n - 1) } -// Returns the bit value at the least significant bit -func (b *BitArray) LSB() uint8 { - return b.BitSetFromLSB(0) +// Returns the bit value at the most significant bit +func (b *BitArray) MSB() uint8 { + return b.BitSet(b.Len() - 1) } func (b *BitArray) IsEmpty() bool { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index cf1a055fe..5fbec5f6e 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -644,118 +644,6 @@ func TestEqualMSBs(t *testing.T) { } func TestLSBs(t *testing.T) { - tests := []struct { - name string - x *BitArray - pos uint8 - want *BitArray - }{ - { - name: "zero position", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 0, - want: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - }, - { - name: "position beyond length", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 65, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "get last 4 bits", - x: &BitArray{ - len: 8, - words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 - }, - pos: 4, - want: &BitArray{ - len: 4, - words: [4]uint64{0x0F, 0, 0, 0}, // 1111 - }, - }, - { - name: "get bits across word boundary", - x: &BitArray{ - len: 128, - words: [4]uint64{maxUint64, maxUint64, 0, 0}, - }, - pos: 64, - want: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - }, - { - name: "get bits from max length array", - x: &BitArray{ - len: 251, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, - }, - pos: 200, - want: &BitArray{ - len: 51, - words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "empty array", - x: emptyBitArray, - pos: 1, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "sparse bits", - x: &BitArray{ - len: 16, - words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 - }, - pos: 8, - want: &BitArray{ - len: 8, - words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 - }, - }, - { - name: "position equals length", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 64, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := new(BitArray).LSBs(tt.x, tt.pos) - if !got.Equal(tt.want) { - t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestLSBsFromLSB(t *testing.T) { tests := []struct { name string initial BitArray @@ -910,7 +798,7 @@ func TestLSBsFromLSB(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(BitArray).LSBsFromLSB(&tt.initial, tt.length) + result := new(BitArray).LSBs(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -1334,7 +1222,7 @@ func TestCommonPrefix(t *testing.T) { } } -func TestIsBitSetFromLSB(t *testing.T) { +func TestIsBitSet(t *testing.T) { tests := []struct { name string ba BitArray @@ -1435,9 +1323,9 @@ func TestIsBitSetFromLSB(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.ba.IsBitSetFromLSB(tt.pos) + got := tt.ba.IsBitSet(tt.pos) if got != tt.want { - t.Errorf("IsBitSetFromLSB(%d) = %v, want %v", tt.pos, got, tt.want) + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) } }) } @@ -1576,3 +1464,115 @@ func TestSetFeltValidation(t *testing.T) { }) } } + +func TestLSBsFromMSB(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/trie/proof.go b/core/trie/proof.go index d6f5a2751..2656615f9 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -161,7 +161,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if keyBits.IsBitSet(curPos + 1) { + if keyBits.IsBitSet(keyBits.Len() - curPos - 1) { expectedHash = node.RightHash } curPos++ @@ -489,7 +489,7 @@ func handleBinaryNode( // Calculate next position and determine to take left or right path nextPos := curPos + 1 - isRightPath := key.IsBitSet(nextPos) + isRightPath := key.IsBitSet(key.Len() - nextPos) nextHash := binary.LeftHash if isRightPath { nextHash = binary.RightHash @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBs(key, curPos).EqualMSBs(edgePath) + return new(BitArray).LSBsFromMSB(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. @@ -602,7 +602,8 @@ func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { // If we're taking a left path and there's a right sibling, // then there are elements with larger values - isLeft := !key.IsBitSet(cur.Len() + 1) + bitPos := key.Len() - cur.Len() - 1 + isLeft := !key.IsBitSet(bitPos) if isLeft && sn.node.RightHash != nil { return true } diff --git a/core/trie/trie.go b/core/trie/trie.go index a52353841..8bbbc2cd6 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -112,7 +112,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBs(key, parentKey.Len()+1) + pathKey.LSBsFromMSB(key, parentKey.Len()+1) return pathKey } From 07bef4b3def48333dacd0eb586620eeda06f8843 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 17:54:50 +0800 Subject: [PATCH 35/45] add direction --- core/trie/bitarray.go | 18 +++++++++++------- core/trie/bitarray_test.go | 10 +++++----- core/trie/proof.go | 6 +++--- core/trie/trie.go | 6 +++--- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index d3f99d322..4d41fa012 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -69,7 +69,7 @@ func (b *BitArray) Bytes() []byte { // LSBs(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { +func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { if n >= x.len { return b.Set(x) } @@ -120,7 +120,7 @@ func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { return b.clear() } - return b.LSBs(x, x.Len()-pos) + return b.LSBsFromLSB(x, x.Len()-pos) } // Checks if the current bit array share the same most significant bits with another, where the length of @@ -362,13 +362,13 @@ func (b *BitArray) Equal(x *BitArray) bool { } // Returns true if bit n-th is set, where n = 0 is LSB. -func (b *BitArray) IsBitSet(n uint8) bool { - return b.BitSet(n) == 1 +func (b *BitArray) IsBitSetFromLSB(n uint8) bool { + return b.BitSetFromLSB(n) == 1 } // Returns the bit value at position n, where n = 0 is LSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSet(n uint8) uint8 { +func (b *BitArray) BitSetFromLSB(n uint8) uint8 { if n >= b.len { return 0 } @@ -387,12 +387,16 @@ func (b *BitArray) BitSetFromMSB(n uint8) uint8 { return 0 } - return b.BitSet(b.Len() - n - 1) + return b.BitSetFromLSB(b.Len() - n - 1) } // Returns the bit value at the most significant bit func (b *BitArray) MSB() uint8 { - return b.BitSet(b.Len() - 1) + return b.BitSetFromMSB(0) +} + +func (b *BitArray) LSB() uint8 { + return b.BitSetFromLSB(0) } func (b *BitArray) IsEmpty() bool { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 5fbec5f6e..8a3d35f34 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -643,7 +643,7 @@ func TestEqualMSBs(t *testing.T) { } } -func TestLSBs(t *testing.T) { +func TestLSBsFromLSB(t *testing.T) { tests := []struct { name string initial BitArray @@ -798,7 +798,7 @@ func TestLSBs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(BitArray).LSBs(&tt.initial, tt.length) + result := new(BitArray).LSBsFromLSB(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -1222,7 +1222,7 @@ func TestCommonPrefix(t *testing.T) { } } -func TestIsBitSet(t *testing.T) { +func TestIsBitSetFromLSB(t *testing.T) { tests := []struct { name string ba BitArray @@ -1323,9 +1323,9 @@ func TestIsBitSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.ba.IsBitSet(tt.pos) + got := tt.ba.IsBitSetFromLSB(tt.pos) if got != tt.want { - t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + t.Errorf("IsBitSetFromLSB(%d) = %v, want %v", tt.pos, got, tt.want) } }) } diff --git a/core/trie/proof.go b/core/trie/proof.go index 2656615f9..ad615cd6d 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -161,7 +161,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if keyBits.IsBitSet(keyBits.Len() - curPos - 1) { + if keyBits.IsBitSetFromLSB(keyBits.Len() - curPos - 1) { expectedHash = node.RightHash } curPos++ @@ -489,7 +489,7 @@ func handleBinaryNode( // Calculate next position and determine to take left or right path nextPos := curPos + 1 - isRightPath := key.IsBitSet(key.Len() - nextPos) + isRightPath := key.IsBitSetFromLSB(key.Len() - nextPos) nextHash := binary.LeftHash if isRightPath { nextHash = binary.RightHash @@ -603,7 +603,7 @@ func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { // If we're taking a left path and there's a right sibling, // then there are elements with larger values bitPos := key.Len() - cur.Len() - 1 - isLeft := !key.IsBitSet(bitPos) + isLeft := !key.IsBitSetFromLSB(bitPos) if isLeft && sn.node.RightHash != nil { return true } diff --git a/core/trie/trie.go b/core/trie/trie.go index 8bbbc2cd6..d7ea67271 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -244,7 +244,7 @@ func (t *Trie) nodesFromRoot(key *BitArray) ([]StorageNode, error) { return nodes, nil } - if key.IsBitSet(key.Len() - cur.Len() - 1) { + if key.IsBitSetFromLSB(key.Len() - cur.Len() - 1) { cur = node.Right } else { cur = node.Left @@ -346,7 +346,7 @@ func (t *Trie) insertOrUpdateValue( if err != nil { return err } - if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSetFromLSB(nodeKey.Len() - commonKey.Len() - 1) { newParent.Right = nodeKey newParent.RightHash = node.Hash(nodeKey, t.hash) } else { @@ -358,7 +358,7 @@ func (t *Trie) insertOrUpdateValue( } t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSetFromLSB(nodeKey.Len() - commonKey.Len() - 1) { newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { From 743c5813f116140990ec0d5e08bd065f24b7d308 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:01:36 +0800 Subject: [PATCH 36/45] fix tests --- core/trie/bitarray_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 8a3d35f34..86560d392 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -204,7 +204,7 @@ func TestRsh(t *testing.T) { shiftBy: 128, expected: &BitArray{ len: 123, - words: [4]uint64{maxUint64, maxUint64, 0, 0}, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFF, 0, 0}, }, }, { @@ -216,7 +216,7 @@ func TestRsh(t *testing.T) { shiftBy: 192, expected: &BitArray{ len: 59, - words: [4]uint64{maxUint64, 0, 0, 0}, + words: [4]uint64{0x7FFFFFFFFFFFFFF, 0, 0, 0}, }, }, } From c5de9dfde99056ac12351713cf7232c51d744b9f Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:26:22 +0800 Subject: [PATCH 37/45] add more tests --- core/trie/bitarray.go | 4 + core/trie/bitarray_test.go | 361 +++++++++++++++++++++++++------------ 2 files changed, 253 insertions(+), 112 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 4d41fa012..647333107 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -380,6 +380,10 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 { return 0 } +func (b *BitArray) IsBitSetFromMSB(n uint8) bool { + return b.BitSetFromMSB(n) == 1 +} + // Returns the bit value at position n, where n = 0 is MSB. // If n is out of bounds, returns 0. func (b *BitArray) BitSetFromMSB(n uint8) uint8 { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 86560d392..a910b4a31 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -643,6 +643,118 @@ func TestEqualMSBs(t *testing.T) { } } +func TestLSBsFromMSB(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) + } + }) + } +} + func TestLSBsFromLSB(t *testing.T) { tests := []struct { name string @@ -1301,6 +1413,15 @@ func TestIsBitSetFromLSB(t *testing.T) { pos: 251, want: false, // position 251 is beyond the highest valid bit (250) }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 250, + want: true, + }, { name: "highest valid bit (255)", ba: BitArray{ @@ -1331,6 +1452,134 @@ func TestIsBitSetFromLSB(t *testing.T) { } } +func TestIsBitSetFromMSB(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit (MSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x80, 0, 0, 0}, // 10000000 + }, + pos: 0, + want: true, + }, + { + name: "last bit (LSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x01, 0, 0, 0}, // 00000001 + }, + pos: 7, + want: true, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 0, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 1, + want: false, + }, + { + name: "position beyond length", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + pos: 8, + want: false, + }, + { + name: "bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 0, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 99, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSetFromMSB(tt.pos) + if got != tt.want { + t.Errorf("IsBitSetFromMSB(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestDebug(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 63, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSetFromMSB(tt.pos) + if got != tt.want { + t.Errorf("IsBitSetFromMSB(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + func TestFeltConversion(t *testing.T) { tests := []struct { name string @@ -1464,115 +1713,3 @@ func TestSetFeltValidation(t *testing.T) { }) } } - -func TestLSBsFromMSB(t *testing.T) { - tests := []struct { - name string - x *BitArray - pos uint8 - want *BitArray - }{ - { - name: "zero position", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 0, - want: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - }, - { - name: "position beyond length", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 65, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "get last 4 bits", - x: &BitArray{ - len: 8, - words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 - }, - pos: 4, - want: &BitArray{ - len: 4, - words: [4]uint64{0x0F, 0, 0, 0}, // 1111 - }, - }, - { - name: "get bits across word boundary", - x: &BitArray{ - len: 128, - words: [4]uint64{maxUint64, maxUint64, 0, 0}, - }, - pos: 64, - want: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - }, - { - name: "get bits from max length array", - x: &BitArray{ - len: 251, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, - }, - pos: 200, - want: &BitArray{ - len: 51, - words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "empty array", - x: emptyBitArray, - pos: 1, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "sparse bits", - x: &BitArray{ - len: 16, - words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 - }, - pos: 8, - want: &BitArray{ - len: 8, - words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 - }, - }, - { - name: "position equals length", - x: &BitArray{ - len: 64, - words: [4]uint64{maxUint64, 0, 0, 0}, - }, - pos: 64, - want: &BitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) - if !got.Equal(tt.want) { - t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) - } - }) - } -} From fdd8ff3f4da92cf160c4c7944bcb4d9afaf2f64d Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:31:48 +0800 Subject: [PATCH 38/45] use IsBitSetFromMSB --- core/trie/proof.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/core/trie/proof.go b/core/trie/proof.go index ad615cd6d..efe2f21f3 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -161,7 +161,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if keyBits.IsBitSetFromLSB(keyBits.Len() - curPos - 1) { + if keyBits.IsBitSetFromMSB(curPos) { expectedHash = node.RightHash } curPos++ @@ -489,7 +489,7 @@ func handleBinaryNode( // Calculate next position and determine to take left or right path nextPos := curPos + 1 - isRightPath := key.IsBitSetFromLSB(key.Len() - nextPos) + isRightPath := key.IsBitSetFromMSB(curPos) nextHash := binary.LeftHash if isRightPath { nextHash = binary.RightHash @@ -602,8 +602,7 @@ func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { // If we're taking a left path and there's a right sibling, // then there are elements with larger values - bitPos := key.Len() - cur.Len() - 1 - isLeft := !key.IsBitSetFromLSB(bitPos) + isLeft := !key.IsBitSetFromMSB(cur.Len()) if isLeft && sn.node.RightHash != nil { return true } From da85dd288f8c2dbb2f845a626e41ea8854d6f772 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:34:54 +0800 Subject: [PATCH 39/45] use IsBitSetFromMSB --- core/trie/trie.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/trie/trie.go b/core/trie/trie.go index d7ea67271..da30fe01a 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -244,7 +244,7 @@ func (t *Trie) nodesFromRoot(key *BitArray) ([]StorageNode, error) { return nodes, nil } - if key.IsBitSetFromLSB(key.Len() - cur.Len() - 1) { + if key.IsBitSetFromMSB(cur.Len()) { cur = node.Right } else { cur = node.Left @@ -346,7 +346,7 @@ func (t *Trie) insertOrUpdateValue( if err != nil { return err } - if nodeKey.IsBitSetFromLSB(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSetFromMSB(commonKey.Len()) { newParent.Right = nodeKey newParent.RightHash = node.Hash(nodeKey, t.hash) } else { @@ -358,7 +358,7 @@ func (t *Trie) insertOrUpdateValue( } t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - if nodeKey.IsBitSetFromLSB(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSetFromMSB(commonKey.Len()) { newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { From d6d456d8bb2ebb97aedb8a3d3009cb111bc6458e Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:43:52 +0800 Subject: [PATCH 40/45] Replace LSBsFromMSB to LSBs --- core/trie/bitarray.go | 8 ++++---- core/trie/bitarray_test.go | 6 +++--- core/trie/proof.go | 2 +- core/trie/trie.go | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 647333107..bfe057a0f 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -108,10 +108,10 @@ func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { // For example: // // x = 11001011 (len=8) -// LSBsFromMSB(x, 1) = 1001011 (len=7) -// LSBsFromMSB(x, 10) = 0 (len=0) -// LSBsFromMSB(x, 0) = 11001011 (len=8, original x) -func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { +// LSBs(x, 1) = 1001011 (len=7) +// LSBs(x, 10) = 0 (len=0) +// LSBs(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBs(x *BitArray, pos uint8) *BitArray { if pos == 0 { return b.Set(x) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index a910b4a31..adafdc028 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -643,7 +643,7 @@ func TestEqualMSBs(t *testing.T) { } } -func TestLSBsFromMSB(t *testing.T) { +func TestLSBs(t *testing.T) { tests := []struct { name string x *BitArray @@ -747,9 +747,9 @@ func TestLSBsFromMSB(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) + got := new(BitArray).LSBs(tt.x, tt.pos) if !got.Equal(tt.want) { - t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) + t.Errorf("LSBs() = %v, want %v", got, tt.want) } }) } diff --git a/core/trie/proof.go b/core/trie/proof.go index efe2f21f3..b816d8a2a 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBsFromMSB(key, curPos).EqualMSBs(edgePath) + return new(BitArray).LSBs(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. diff --git a/core/trie/trie.go b/core/trie/trie.go index da30fe01a..83cf9ec65 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -112,7 +112,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBsFromMSB(key, parentKey.Len()+1) + pathKey.LSBs(key, parentKey.Len()+1) return pathKey } From e39ef5724f657c882f25905a232bdbfd73af3284 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:46:36 +0800 Subject: [PATCH 41/45] change IsBitSetFromMSB to IsBitSet --- core/trie/bitarray.go | 8 ++++---- core/trie/bitarray_test.go | 10 +++++----- core/trie/proof.go | 6 +++--- core/trie/trie.go | 6 +++--- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bfe057a0f..9ec33bbab 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -380,13 +380,13 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 { return 0 } -func (b *BitArray) IsBitSetFromMSB(n uint8) bool { - return b.BitSetFromMSB(n) == 1 +func (b *BitArray) IsBitSet(n uint8) bool { + return b.BitSet(n) == 1 } // Returns the bit value at position n, where n = 0 is MSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSetFromMSB(n uint8) uint8 { +func (b *BitArray) BitSet(n uint8) uint8 { if n >= b.Len() { return 0 } @@ -396,7 +396,7 @@ func (b *BitArray) BitSetFromMSB(n uint8) uint8 { // Returns the bit value at the most significant bit func (b *BitArray) MSB() uint8 { - return b.BitSetFromMSB(0) + return b.BitSet(0) } func (b *BitArray) LSB() uint8 { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index adafdc028..7cb59493a 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1452,7 +1452,7 @@ func TestIsBitSetFromLSB(t *testing.T) { } } -func TestIsBitSetFromMSB(t *testing.T) { +func TestIsBitSet(t *testing.T) { tests := []struct { name string ba BitArray @@ -1544,9 +1544,9 @@ func TestIsBitSetFromMSB(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.ba.IsBitSetFromMSB(tt.pos) + got := tt.ba.IsBitSet(tt.pos) if got != tt.want { - t.Errorf("IsBitSetFromMSB(%d) = %v, want %v", tt.pos, got, tt.want) + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) } }) } @@ -1572,9 +1572,9 @@ func TestDebug(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.ba.IsBitSetFromMSB(tt.pos) + got := tt.ba.IsBitSet(tt.pos) if got != tt.want { - t.Errorf("IsBitSetFromMSB(%d) = %v, want %v", tt.pos, got, tt.want) + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) } }) } diff --git a/core/trie/proof.go b/core/trie/proof.go index b816d8a2a..4afed36df 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -161,7 +161,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if keyBits.IsBitSetFromMSB(curPos) { + if keyBits.IsBitSet(curPos) { expectedHash = node.RightHash } curPos++ @@ -489,7 +489,7 @@ func handleBinaryNode( // Calculate next position and determine to take left or right path nextPos := curPos + 1 - isRightPath := key.IsBitSetFromMSB(curPos) + isRightPath := key.IsBitSet(curPos) nextHash := binary.LeftHash if isRightPath { nextHash = binary.RightHash @@ -602,7 +602,7 @@ func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { // If we're taking a left path and there's a right sibling, // then there are elements with larger values - isLeft := !key.IsBitSetFromMSB(cur.Len()) + isLeft := !key.IsBitSet(cur.Len()) if isLeft && sn.node.RightHash != nil { return true } diff --git a/core/trie/trie.go b/core/trie/trie.go index 83cf9ec65..bb2320d80 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -244,7 +244,7 @@ func (t *Trie) nodesFromRoot(key *BitArray) ([]StorageNode, error) { return nodes, nil } - if key.IsBitSetFromMSB(cur.Len()) { + if key.IsBitSet(cur.Len()) { cur = node.Right } else { cur = node.Left @@ -346,7 +346,7 @@ func (t *Trie) insertOrUpdateValue( if err != nil { return err } - if nodeKey.IsBitSetFromMSB(commonKey.Len()) { + if nodeKey.IsBitSet(commonKey.Len()) { newParent.Right = nodeKey newParent.RightHash = node.Hash(nodeKey, t.hash) } else { @@ -358,7 +358,7 @@ func (t *Trie) insertOrUpdateValue( } t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - if nodeKey.IsBitSetFromMSB(commonKey.Len()) { + if nodeKey.IsBitSet(commonKey.Len()) { newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { From 7f479c16ed915ff31f68ac3f6ed619f39040edbe Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 18:49:08 +0800 Subject: [PATCH 42/45] minor chore --- core/trie/bitarray.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 9ec33bbab..1e5de6408 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -64,9 +64,9 @@ func (b *BitArray) Bytes() []byte { // For example: // // x = 11001011 (len=8) -// LSBs(x, 4) = 1011 (len=4) -// LSBs(x, 10) = 11001011 (len=8, original x) -// LSBs(x, 0) = 0 (len=0) +// LSBsFromLSB(x, 4) = 1011 (len=4) +// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) +// LSBsFromLSB(x, 0) = 0 (len=0) // //nolint:mnd func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { @@ -104,23 +104,23 @@ func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { return b } -// Returns the least significant bits of `x` with `pos` counted from the most significant bit, starting at 0. +// Returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. // For example: // // x = 11001011 (len=8) // LSBs(x, 1) = 1001011 (len=7) // LSBs(x, 10) = 0 (len=0) // LSBs(x, 0) = 11001011 (len=8, original x) -func (b *BitArray) LSBs(x *BitArray, pos uint8) *BitArray { - if pos == 0 { +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n == 0 { return b.Set(x) } - if pos > x.Len() { + if n > x.Len() { return b.clear() } - return b.LSBsFromLSB(x, x.Len()-pos) + return b.LSBsFromLSB(x, x.Len()-n) } // Checks if the current bit array share the same most significant bits with another, where the length of From c6c0e85ea4f233579be8b5397d77e8b20f5d3ff8 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 31 Dec 2024 10:27:16 +0800 Subject: [PATCH 43/45] add SetBit() --- core/trie/bitarray.go | 12 ++++++++++++ core/trie/bitarray_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 1e5de6408..e466a4089 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -479,6 +479,18 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { return b } +// Sets the bit array to a single bit. +func (b *BitArray) SetBit(bit bool) *BitArray { + b.len = 1 + if bit { + b.words[0] = 1 + } else { + b.words[0] = 0 + } + b.truncateToLength() + return b +} + // Returns the length of the encoded bit array in bytes. func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 7cb59493a..933b6fdd4 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1713,3 +1713,37 @@ func TestSetFeltValidation(t *testing.T) { }) } } + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + bit bool + want BitArray + }{ + { + name: "set bit false", + bit: false, + want: BitArray{ + len: 1, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "set bit true", + bit: true, + want: BitArray{ + len: 1, + words: [4]uint64{1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBit(tt.bit) + if !got.Equal(&tt.want) { + t.Errorf("SetBit(%v) = %v, want %v", tt.bit, got, tt.want) + } + }) + } +} From c6c8183c2ff138e1be4c577e51331b5e44cf5f18 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 31 Dec 2024 12:55:48 +0800 Subject: [PATCH 44/45] rename methods --- core/trie/bitarray.go | 14 +++++++------- core/trie/bitarray_test.go | 28 ---------------------------- 2 files changed, 7 insertions(+), 35 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index e466a4089..2508d8d90 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -363,12 +363,12 @@ func (b *BitArray) Equal(x *BitArray) bool { // Returns true if bit n-th is set, where n = 0 is LSB. func (b *BitArray) IsBitSetFromLSB(n uint8) bool { - return b.BitSetFromLSB(n) == 1 + return b.BitFromLSB(n) == 1 } // Returns the bit value at position n, where n = 0 is LSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSetFromLSB(n uint8) uint8 { +func (b *BitArray) BitFromLSB(n uint8) uint8 { if n >= b.len { return 0 } @@ -381,26 +381,26 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 { } func (b *BitArray) IsBitSet(n uint8) bool { - return b.BitSet(n) == 1 + return b.Bit(n) == 1 } // Returns the bit value at position n, where n = 0 is MSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSet(n uint8) uint8 { +func (b *BitArray) Bit(n uint8) uint8 { if n >= b.Len() { return 0 } - return b.BitSetFromLSB(b.Len() - n - 1) + return b.BitFromLSB(b.Len() - n - 1) } // Returns the bit value at the most significant bit func (b *BitArray) MSB() uint8 { - return b.BitSet(0) + return b.Bit(0) } func (b *BitArray) LSB() uint8 { - return b.BitSetFromLSB(0) + return b.BitFromLSB(0) } func (b *BitArray) IsEmpty() bool { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 933b6fdd4..8eb084422 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1552,34 +1552,6 @@ func TestIsBitSet(t *testing.T) { } } -func TestDebug(t *testing.T) { - tests := []struct { - name string - ba BitArray - pos uint8 - want bool - }{ - { - name: "bit in second word", - ba: BitArray{ - len: 128, - words: [4]uint64{0, 1, 0, 0}, - }, - pos: 63, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.ba.IsBitSet(tt.pos) - if got != tt.want { - t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) - } - }) - } -} - func TestFeltConversion(t *testing.T) { tests := []struct { name string From 38d930b642344ac4be4fc33cd6d4d1a07a77257d Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 7 Jan 2025 14:35:45 +0800 Subject: [PATCH 45/45] Add Cmp() --- core/trie/bitarray.go | 54 ++++++++++++++++++---- core/trie/bitarray_test.go | 94 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 135 insertions(+), 13 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 2508d8d90..75d1ebb3e 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -158,7 +158,7 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) } -// Sets the bit array to the most significant 'n' bits of x. +// Sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). // If n >= x.len, the bit array is an exact copy of x. // For example: // @@ -480,14 +480,10 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { } // Sets the bit array to a single bit. -func (b *BitArray) SetBit(bit bool) *BitArray { +func (b *BitArray) SetBit(bit uint8) *BitArray { b.len = 1 - if bit { - b.words[0] = 1 - } else { - b.words[0] = 0 - } - b.truncateToLength() + b.words[0] = uint64(bit & 1) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 return b } @@ -503,7 +499,16 @@ func (b *BitArray) Copy() BitArray { return res } +// Returns the encoded string representation of the bit array. +func (b *BitArray) EncodedString() string { + var res []byte + res = append(res, b.len) + res = append(res, b.Bytes()...) + return string(res) +} + // Returns a string representation of the bit array. +// This is typically used for logging or debugging. func (b *BitArray) String() string { return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) } @@ -615,3 +620,36 @@ func findFirstSetBit(b *BitArray) uint8 { // All bits are zero, no set bit found return 0 } + +// Cmp compares two bit arrays lexicographically. +// The comparison is first done by length, then by content if lengths are equal. +// Returns: +// +// -1 if b < x +// 0 if b == x +// 1 if b > x +func (b *BitArray) Cmp(x *BitArray) int { + // First compare lengths + if b.len < x.len { + return -1 + } + if b.len > x.len { + return 1 + } + + // Lengths are equal, compare the actual bits + d0, carry := bits.Sub64(b.words[0], x.words[0], 0) + d1, carry := bits.Sub64(b.words[1], x.words[1], carry) + d2, carry := bits.Sub64(b.words[2], x.words[2], carry) + d3, carry := bits.Sub64(b.words[3], x.words[3], carry) + + if carry == 1 { + return -1 + } + + if d0|d1|d2|d3 == 0 { + return 0 + } + + return 1 +} diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 8eb084422..e3d7c795a 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1689,20 +1689,20 @@ func TestSetFeltValidation(t *testing.T) { func TestSetBit(t *testing.T) { tests := []struct { name string - bit bool + bit uint8 want BitArray }{ { - name: "set bit false", - bit: false, + name: "set bit 0", + bit: 0, want: BitArray{ len: 1, words: [4]uint64{0, 0, 0, 0}, }, }, { - name: "set bit true", - bit: true, + name: "set bit 1", + bit: 1, want: BitArray{ len: 1, words: [4]uint64{1, 0, 0, 0}, @@ -1719,3 +1719,87 @@ func TestSetBit(t *testing.T) { }) } } + +func TestCmp(t *testing.T) { + tests := []struct { + name string + x BitArray + y BitArray + want int + }{ + { + name: "equal empty arrays", + x: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + y: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: 0, + }, + { + name: "equal non-empty arrays", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 0, + }, + { + name: "different lengths - x shorter", + x: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "different lengths - x longer", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, x < y in first word", + x: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "same length, x > y in first word", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, difference in last word", + x: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFFF}}, + y: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFF0}}, + want: 1, + }, + { + name: "same length, sparse bits", + x: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}}, + y: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}}, + want: 1, + }, + { + name: "max length difference", + x: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + y: BitArray{len: 1, words: [4]uint64{0x1, 0, 0, 0}}, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.x.Cmp(&tt.y) + if got != tt.want { + t.Errorf("Cmp() = %v, want %v", got, tt.want) + } + + // Test anti-symmetry: if x.Cmp(y) = z then y.Cmp(x) = -z + gotReverse := tt.y.Cmp(&tt.x) + if gotReverse != -tt.want { + t.Errorf("Reverse Cmp() = %v, want %v", gotReverse, -tt.want) + } + + // Test transitivity with self: x.Cmp(x) should always be 0 + if tt.x.Cmp(&tt.x) != 0 { + t.Error("Self Cmp() != 0") + } + }) + } +}