Skip to content

Commit

Permalink
make bitarray functional and memory effecient
Browse files Browse the repository at this point in the history
  • Loading branch information
barrust committed Jan 1, 2024
1 parent edabbe5 commit 92ad74d
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 49 deletions.
65 changes: 23 additions & 42 deletions probables/quotientfilter/quotientfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,23 @@
from array import array

from probables.hashes import KeyT, fnv_1a_32
from probables.utilities import Bitarray


def get_hash(x, m):
def get_hash(x: KeyT, m: int):
return fnv_1a_32(x, 0) & ((1 << m) - 1)


class bitarray:
# NOTE: NOT SPACE EFFICIENT FOR NOW
def __init__(self, size: int):
self.bitarray = array("B", [0]) * size
self.size = size

def __getitem__(self, idx: int):
return self.bitarray[idx]

def __setitem__(self, idx: int, val: int):
if val < 0 or val > 1:
raise ValueError("Invalid bit setting; must be 0 or 1")
self.bitarray[idx] = val

def set_bit(self, idx: int):
self.bitarray[idx] = 1

def clear_bit(self, idx: int):
self.bitarray[idx] = 0


class QuotientFilter:
def __init__(self): # needs to be parameterized
self.q = 24
self.r = 8
self.m = 1 << self.q # the size of the array (2**q)
self._q = 24
self._r = 8
self._size = 1 << self._q # same as 2**q

self.is_occupied_arr = bitarray(self.m)
self.is_continuation_arr = bitarray(self.m)
self.is_shifted_arr = bitarray(self.m)
self._filter = array("I", [0]) * self.m
self.is_occupied_arr = Bitarray(self._size)
self.is_continuation_arr = Bitarray(self._size)
self.is_shifted_arr = Bitarray(self._size)
self._filter = array("I", [0]) * self._size

def shift_insert(self, k, v, start, j, flag):
if self.is_occupied_arr[j] == 0 and self.is_continuation_arr[j] == 0 and self.is_shifted_arr[j] == 0:
Expand All @@ -49,7 +29,8 @@ def shift_insert(self, k, v, start, j, flag):
self.is_shifted_arr[j] = 1 if j != k else 0

else:
i = (j + 1) & (self.m - 1)
# print("using shift insert")
i = (j + 1) & (self._size - 1)

while True:
f = self.is_occupied_arr[i] + self.is_continuation_arr[i] + self.is_shifted_arr[i]
Expand All @@ -67,15 +48,15 @@ def shift_insert(self, k, v, start, j, flag):
if f == 0:
break

i = (i + 1) & (self.m - 1)
i = (i + 1) & (self._size - 1)

self._filter[j] = v
self.is_occupied_arr[k] = 1
self.is_continuation_arr[j] = 1 if j != start else 0
self.is_shifted_arr[j] = 1 if j != k else 0

if flag == 1:
self.is_continuation_arr[(j + 1) & (self.m - 1)] = 1
self.is_continuation_arr[(j + 1) & (self._size - 1)] = 1

def get_start_index(self, k):
j = k
Expand All @@ -86,7 +67,7 @@ def get_start_index(self, k):
cnts += 1

if self.is_shifted_arr[j] == 1:
j = (j - 1) & (self.m - 1)
j = (j - 1) & (self._size - 1)
else:
break

Expand All @@ -96,15 +77,15 @@ def get_start_index(self, k):
break
cnts -= 1

j = (j + 1) & (self.m - 1)
j = (j + 1) & (self._size - 1)

return j

def add(self, key: KeyT):
if self.contains(key) is False:
_hash = get_hash(key, self.q + self.r)
key_quotient = _hash >> self.r
key_remainder = _hash & ((1 << self.r) - 1)
_hash = get_hash(key, self._q + self._r)
key_quotient = _hash >> self._r
key_remainder = _hash & ((1 << self._r) - 1)

if (
self.is_occupied_arr[key_quotient] == 0
Expand All @@ -126,7 +107,7 @@ def add(self, key: KeyT):
f = self.is_occupied_arr[j] + self.is_continuation_arr[j] + self.is_shifted_arr[j]

while starts == 0 and f != 0 and key_remainder > self._filter[j]:
j = (j + 1) & (self.m - 1)
j = (j + 1) & (self._size - 1)

if self.is_continuation_arr[j] == 0:
starts += 1
Expand All @@ -139,9 +120,9 @@ def add(self, key: KeyT):
self.shift_insert(key_quotient, key_remainder, u, j, 1)

def contains(self, key: KeyT):
_hash = get_hash(key, self.q + self.r)
key_quotient = _hash >> self.r
key_remainder = _hash & ((1 << self.r) - 1)
_hash = get_hash(key, self._q + self._r)
key_quotient = _hash >> self._r
key_remainder = _hash & ((1 << self._r) - 1)

if self.is_occupied_arr[key_quotient] == 0:
return False
Expand All @@ -162,7 +143,7 @@ def contains(self, key: KeyT):
if self._filter[j] == key_remainder:
return True

j = (j + 1) & (self.m - 1)
j = (j + 1) & (self._size - 1)
f = self.is_occupied_arr[j] + self.is_continuation_arr[j] + self.is_shifted_arr[j]

return False
94 changes: 94 additions & 0 deletions probables/utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
""" Utility Functions """

import math
import mmap
import string
from array import array
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -83,3 +85,95 @@ def seek(self, pos: int, whence: int) -> None:
def read(self, n: int = -1) -> bytes:
"""Implement a method to read from the file on top of the MMap class"""
return self.__m.read(n)


class Bitarray:
"""Simplified, pure python bitarray implementation using as little memory as possible"""

def __init__(self, size: int):
if size <= 0:
raise ValueError(f"Bitarray size must be larger than 1; {size} was provided")
self._size_bytes = math.ceil(size / 8)
self._bitarray = array("B", [0]) * self._size_bytes
self._size = size

@property
def size_bytes(self) -> int:
"""The size of the bitarray in bytes"""
return self._size_bytes

@property
def size(self) -> int:
"""The number of bits in the bitarray"""
return self._size

@property
def bitarray(self) -> array:
"""The bitarray"""
return self._bitarray

def __getitem__(self, idx: int) -> int:
return self.check_bit(idx)

def __setitem__(self, idx: int, val: int):
if val < 0 or val > 1:
raise ValueError("Invalid bit setting; must be 0 or 1")
if idx < 0 or idx >= self._size:
raise IndexError(f"Bitarray index outside of range; index {idx} was provided")
b = idx // 8
if val == 1:
self._bitarray[b] = self._bitarray[b] | (1 << (idx % 8))
else:
self._bitarray[b] = self._bitarray[b] & ~(1 << (idx % 8))

def check_bit(self, idx: int) -> int:
"""Check if the bit idx is set
Args:
idx (int): The index to check
Returns:
int: The status of the bit, either 0 or 1"""
if idx < 0 or idx >= self._size:
raise IndexError(f"Bitarray index outside of range; index {idx} was provided")
return 0 if (self._bitarray[idx // 8] & (1 << (idx % 8))) == 0 else 1

def is_bit_set(self, idx: int) -> bool:
"""Check if the bit idx is set
Args:
idx (int): The index to check
Returns:
int: The status of the bit, either 0 or 1"""
return bool(self.check_bit(idx))

def set_bit(self, idx: int) -> None:
"""Set the bit at idx to 1
Args:
idx (int): The index to set"""
if idx < 0 or idx >= self._size:
raise IndexError(f"Bitarray index outside of range; index {idx} was provided")
b = idx // 8
self._bitarray[b] = self._bitarray[b] | (1 << (idx % 8))

def clear_bit(self, idx: int) -> None:
"""Set the bit at idx to 0
Args:
idx (int): The index to clear"""
if idx < 0 or idx >= self._size:
raise IndexError(f"Bitarray index outside of range; index {idx} was provided")
b = idx // 8
self._bitarray[b] = self._bitarray[b] & ~(1 << (idx % 8))

def clear(self):
"""Clear all bits in the bitarray"""
for i in range(self._size_bytes):
self._bitarray[i] = 0

def as_string(self):
"""String representation of the bitarray
Returns:
str: Bitarray representation as a string"""
return "".join([str(self.check_bit(x)) for x in range(self._size)])
72 changes: 65 additions & 7 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@

from utilities import different_hash

from probables.utilities import (
MMap,
get_x_bits,
is_hex_string,
is_valid_file,
resolve_path,
)
from probables.utilities import Bitarray, MMap, get_x_bits, is_hex_string, is_valid_file, resolve_path

DELETE_TEMP_FILES = True

Expand Down Expand Up @@ -115,6 +109,70 @@ def test_resolve_path(self):
p2 = resolve_path("./{}".format(fobj.name))
self.assertTrue(p2.is_absolute())

def test_bitarray(self):
"""test bit array basic operations"""
ba = Bitarray(100)
for i in range(33):
ba.set_bit(i * 3)

self.assertEqual(
ba.as_string(),
"1001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001000",
)
self.assertTrue(ba.is_bit_set(3))
self.assertFalse(ba.is_bit_set(4))
self.assertEqual(ba[0], 1)
self.assertEqual(ba[1], 0)

for i in range(33):
ba.clear_bit(i * 3)

self.assertEqual(
ba.as_string(),
"0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
)

for i in range(33):
ba.set_bit(i * 3)
self.assertEqual(
ba.as_string(),
"1001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001000",
)

self.assertEqual(ba[2], 0)
ba[2] = 1
self.assertEqual(ba[2], 1)
ba[2] = 0
self.assertEqual(ba[2], 0)

ba.clear()
self.assertEqual(
ba.as_string(),
"0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
)

def test_bitarray_invalid_idx(self):
"""use an invalid type in a jaccard index"""
ba = Bitarray(10)
self.assertRaises(IndexError, lambda: ba.set_bit(12))
self.assertRaises(IndexError, lambda: ba.set_bit(-1))
self.assertRaises(IndexError, lambda: ba.check_bit(-1))
self.assertRaises(IndexError, lambda: ba.check_bit(12))
self.assertRaises(IndexError, lambda: ba.clear_bit(-1))
self.assertRaises(IndexError, lambda: ba.clear_bit(12))

self.assertRaises(IndexError, lambda: ba[-1])
self.assertRaises(IndexError, lambda: ba[12])

def test_set(idx, val):
ba[idx] = val

self.assertRaises(IndexError, lambda: test_set(-1, 0))
self.assertRaises(IndexError, lambda: test_set(12, 0))
# set as non-valid bit value
self.assertRaises(ValueError, lambda: test_set(1, 5))
self.assertRaises(ValueError, lambda: test_set(12, -1))


if __name__ == "__main__":
unittest.main()

0 comments on commit 92ad74d

Please sign in to comment.