Skip to content

Commit

Permalink
ensure unused bits are zero when setting bitarray
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Dec 24, 2024
1 parent 1f69b43 commit d50aacf
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions core/trie/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,14 @@ func (b *BitArray) Len() uint8 {
}

// Returns the bytes representation of the bit array in big endian format
//
//nolint:mnd
func (b *BitArray) Bytes() []byte {
var res [32]byte

switch {
case b.len == 0:
// all zeros
return res[:]
case b.len >= 192:
// Create mask for top word: keeps only valid bits above 192
// e.g., if len=200, keeps lowest 8 bits (200-192)
mask := maxUint64 >> (256 - uint16(b.len))
binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask)
binary.BigEndian.PutUint64(res[8:16], b.words[2])
binary.BigEndian.PutUint64(res[16:24], b.words[1])
binary.BigEndian.PutUint64(res[24:32], b.words[0])
case b.len >= 128:
// Mask for bits 128-191: keeps only valid bits above 128
// e.g., if len=150, keeps lowest 22 bits (150-128)
mask := maxUint64 >> (192 - b.len)
binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask)
binary.BigEndian.PutUint64(res[16:24], b.words[1])
binary.BigEndian.PutUint64(res[24:32], b.words[0])
case b.len >= 64:
// You get the idea
mask := maxUint64 >> (128 - b.len)
binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask)
binary.BigEndian.PutUint64(res[24:32], b.words[0])
default:
mask := maxUint64 >> (64 - b.len)
binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask)
}
b.truncateToLength()
binary.BigEndian.PutUint64(res[0:8], b.words[3])
binary.BigEndian.PutUint64(res[8:16], b.words[2])
binary.BigEndian.PutUint64(res[16:24], b.words[1])
binary.BigEndian.PutUint64(res[24:32], b.words[0])

return res[:]
}
Expand Down Expand Up @@ -335,15 +310,17 @@ func (b *BitArray) Set(x *BitArray) *BitArray {

// Sets the bit array to the bytes representation of a felt.
func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray {
b.setFelt(f)
b.len = length
b.setFelt(f)
b.truncateToLength()
return b
}

// Sets the bit array to the bytes representation of a felt with length 251.
func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray {
b.setFelt(f)
b.len = 251
b.setFelt(f)
b.truncateToLength()
return b
}

Expand All @@ -352,13 +329,15 @@ func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray {
func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray {
b.setBytes32(data)
b.len = length
b.truncateToLength()
return b
}

// Sets the bit array to the uint64 representation of a bit array.
func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray {
b.words[0] = data
b.len = length
b.truncateToLength()
return b
}

Expand Down Expand Up @@ -432,6 +411,27 @@ func (b *BitArray) clear() *BitArray {
return b
}

// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros.
//
//nolint:mnd
func (b *BitArray) truncateToLength() {
switch {
case b.len == 0:
b.words = [4]uint64{0, 0, 0, 0}
case b.len <= 64:
b.words[0] &= maxUint64 >> (64 - b.len)
b.words[1], b.words[2], b.words[3] = 0, 0, 0
case b.len <= 128:
b.words[1] &= maxUint64 >> (128 - b.len)
b.words[2], b.words[3] = 0, 0
case b.len <= 192:
b.words[2] &= maxUint64 >> (192 - b.len)
b.words[3] = 0
default:
b.words[3] &= maxUint64 >> (256 - uint16(b.len))
}
}

// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit.
// The bit position is counted from the least significant bit, starting at 0.
// For example:
Expand Down

0 comments on commit d50aacf

Please sign in to comment.