Skip to content

Commit

Permalink
Merge branch 'compute-spans'
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Jun 20, 2024
2 parents feb7ea9 + 464480a commit eb9ce33
Show file tree
Hide file tree
Showing 9 changed files with 536 additions and 81 deletions.
92 changes: 71 additions & 21 deletions searcharray/phrase/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# MSB
# term A |10| | 12| 13| | 18 * Intersect A+B -> 10,12
# term B |10| | 12| | 14 15 17 | * RHS intersect -> 13 -> 14
# term C | 11 14
# term C | 11 | 14
# term D 10 | 15
# term E | 11 12
# ------
Expand All @@ -37,39 +37,87 @@
#
# ----
#
# Intersecting
# term A |10| | 12| 13 18
# term B |10| | 12| 14 15 17
# Adjacent
# term A 10 12| 13| 18
# term B 10 12| | 14 15 17 |
# Giving 9.5, 10.5 11.5, 12.5, 13.5 , 17.5
# ------
# Next pair, intersecting
# term B 10 12 | 14 | 15 17 * RHS intersect -> 13 -> 14
# term C 11 | 14 |
# Adjacent
# term B 10 12 14 15 17
# term C |11 14
# Giving 10.5, 13.5, 14.5
# w/ above 10.5, 13.5
# ------
#
#
# Finally: X X X X <- 4 words selected from each to find max span
#
# Does it work to compute any intersection, and then attach any adjecent that exist
#
#

def hdr(arr):
return encoder.header(arr) >> (_64 - encoder.header_bits)


def _intersect_all(posns_encoded: List[np.ndarray]):
"""Intersect all encoded positions at roaringish MSBs."""
if len(posns_encoded) < 2:
raise ValueError("Need at least two positions to intersect")
last_lhs_headers = None
last_rhs_headers = None
curr = posns_encoded[0]
for term_idx, posns_next in enumerate(posns_encoded[1:]):
lhs_int_idx, rhs_int = intersect(curr, posns_next, mask=encoder.header_mask)
int_headers = encoder.header(curr[lhs_int_idx])

lhs_int = posns_encoded[0]
lhs_to_left = posns_encoded[0]
lhs_to_right = posns_encoded[0]
# What is adjacent on LHS / RHS of interserction
# Next to left
# 0 -> 2
# <- 1 (curr to right)
# 1 -> (keep
# 0 <- 2
curr_to_right, next_to_left = adjacent(curr, posns_next, mask=encoder.header_mask)
lhs_headers = merge(int_headers, posns_next[next_to_left])
rhs_headers = merge(int_headers, curr[curr_to_right])
next_to_right, curr_to_left = adjacent(posns_next, curr, mask=encoder.header_mask)
lhs_headers = merge(lhs_headers, curr[curr_to_left])
rhs_headers = merge(rhs_headers, posns_next[next_to_right])

for term_idx, posns_next in enumerate(posns_encoded[1:]):
_, rhs_int = intersect(lhs_int, posns_next, mask=encoder.header_mask)
_, rhs_ls = adjacent(lhs_to_left, posns_next, mask=encoder.header_mask)
rhs_rs, _ = adjacent(lhs_to_right, posns_next, mask=encoder.header_mask)
if last_lhs_headers is not None:
lhs, _ = intersect(last_lhs_headers, lhs_headers, mask=encoder.header_mask)
rhs, _ = intersect(last_rhs_headers, rhs_headers, mask=encoder.header_mask)
last_lhs_headers = last_lhs_headers[lhs]
last_rhs_headers = last_rhs_headers[rhs]
else:
last_lhs_headers = lhs_headers
last_rhs_headers = rhs_headers

# Update LHS to rhs_int + rhs_ls + rhs_rs indices
lhs_int = posns_next[rhs_int]
lhs_to_left = posns_next[rhs_ls]
lhs_to_right = posns_next[rhs_rs]
# Update by intersecting with current working lhs / rhs headers

assert last_rhs_headers is not None
assert last_lhs_headers is not None
to_rhs = last_rhs_headers + (_1 << (_64 - encoder.header_bits))
to_lhs = last_lhs_headers - (_1 << (_64 - encoder.header_bits))
all_headers = merge(to_rhs, to_lhs, drop_duplicates=True)
all_headers = merge(last_lhs_headers, all_headers, drop_duplicates=True)
all_headers = merge(last_rhs_headers, all_headers, drop_duplicates=True)
# Get active MSBs now
# Merge all the rest to grab them
int_header = encoder.header(lhs_int)
to_left_header = encoder.header(lhs_to_left)
to_right_header = encoder.header(lhs_to_right)
merged = merge(int_header, to_left_header, drop_duplicates=True)
merged = merge(merged, to_right_header, drop_duplicates=True)

# Slice only matches at MSBs
# Slice only matches at header MSBs
new_posns_encoded = posns_encoded.copy()
for i in range(len(posns_encoded)):
posns_encoded[i] = encoder.slice(posns_encoded[i], merged)
new_posns_encoded[i] = encoder.slice(posns_encoded[i], header=all_headers)
lengths = np.cumsum([0] + [len(posns) for posns in new_posns_encoded], dtype=np.uint64)
concatted = np.concatenate(new_posns_encoded, dtype=np.uint64)
return concatted, lengths


# Picking up from intersections:
Expand Down Expand Up @@ -122,11 +170,13 @@ def span_search(posns_encoded: List[np.ndarray],
slop: int) -> np.ndarray:
"""Find span matches up to PAYLOAD_LSB bits span distance."""
# Find inner span candidates
_intersect_all(posns_encoded)
posns, lengths = _intersect_all(posns_encoded)

# Populate phrase freqs with matches of slop
r_span_search(posns_encoded, phrase_freqs, slop,
r_span_search(posns, lengths,
phrase_freqs, slop,
encoder.key_mask,
encoder.header_mask,
encoder.key_bits,
encoder.payload_lsb_bits)
return phrase_freqs
6 changes: 5 additions & 1 deletion searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def doclengths(self) -> np.ndarray:

def score(self, token: Union[str, List[str]],
similarity: Similarity = default_bm25,
slop: int = 0,
min_posn: Optional[int] = None,
max_posn: Optional[int] = None) -> np.ndarray:
"""Score each doc using a similarity function.
Expand All @@ -638,7 +639,8 @@ def score(self, token: Union[str, List[str]],
tokens_l = [token] if isinstance(token, str) else token
all_dfs = np.asarray([self.docfreq(token) for token in tokens_l])

tfs = self.termfreqs(token, min_posn=min_posn, max_posn=max_posn)
tfs = self.termfreqs(token, min_posn=min_posn, max_posn=max_posn,
slop=slop)
token = self._check_token_arg(token)
doc_lens = self.doclengths()

Expand All @@ -656,6 +658,8 @@ def _phrase_freq(self, tokens: List[str],
slop=0,
min_posn: Optional[int] = None,
max_posn: Optional[int] = None) -> np.ndarray:
if slop > 0:
logger.warning("!! Slop is experimental and may be slow, crash, or inaccurate etc")
try:
# Decide how/if we need to filter doc ids
term_ids = [self.term_dict.get_term_id(token) for token in tokens]
Expand Down
2 changes: 1 addition & 1 deletion searcharray/roaringish/intersect.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def adjacent(np.ndarray[DTYPE_t, ndim=1] lhs,
lhs.shape[0], rhs.shape[0],
&lhs_out[0], &rhs_out[0],
mask, delta)
return lhs_out[:amt_written], rhs_out[:amt_written]
return np.asarray(lhs_out[:amt_written]), np.asarray(rhs_out[:amt_written])


def intersect_with_adjacents(np.ndarray[DTYPE_t, ndim=1] lhs,
Expand Down
12 changes: 12 additions & 0 deletions searcharray/roaringish/popcount.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ def popcount64(arr: NDArray[np.uint64]) -> NDArray[np.uint64]:
...


def msb_mask64(arr: NDArray[np.uint64]) -> NDArray[np.uint64]:
...


def ctz64(arr: NDArray[np.uint64]) -> NDArray[np.uint64]:
...


def clz64(arr: NDArray[np.uint64]) -> NDArray[np.uint64]:
...


def popcount_reduce_at(ids: NDArray[np.uint64],
payload: NDArray[np.uint64],
out: NDArray[np.float64]) -> NDArray[np.float64]:
Expand Down
30 changes: 30 additions & 0 deletions searcharray/roaringish/popcount.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ cdef extern from "stddef.h":
# and portability, though it's not directly used here.
int __builtin_popcountll(unsigned long long x)

int __builtin_ctzll(unsigned long long x)
int __builtin_clzll(unsigned long long x)


# Include mach performance timer
# cdef extern from "mach/mach_time.h":
Expand All @@ -40,6 +43,32 @@ cdef popcount64_arr(DTYPE_t[:] arr):
return result


cdef ctz_arr(DTYPE_t[:] arr):
cdef np.uint64_t[:] result = np.empty(arr.shape[0], dtype=np.uint64)
# cdef int i = 0
cdef DTYPE_t* result_ptr = &result[0]
cdef DTYPE_t* arr_ptr = &arr[0]

for _ in range(arr.shape[0]):
result_ptr[0] = __builtin_ctzll(arr_ptr[0])
result_ptr += 1
arr_ptr += 1
return result


cdef clz_arr(DTYPE_t[:] arr):
cdef np.uint64_t[:] result = np.empty(arr.shape[0], dtype=np.uint64)
# cdef int i = 0
cdef DTYPE_t* result_ptr = &result[0]
cdef DTYPE_t* arr_ptr = &arr[0]

for _ in range(arr.shape[0]):
result_ptr[0] = __builtin_clzll(arr_ptr[0])
result_ptr += 1
arr_ptr += 1
return result


cdef popcount64_arr_naive(DTYPE_t[:] arr):
cdef np.uint64_t[:] result = np.empty(arr.shape[0], dtype=np.uint64)
cdef int i = 0
Expand All @@ -50,6 +79,7 @@ cdef popcount64_arr_naive(DTYPE_t[:] arr):


def popcount64(np.ndarray[DTYPE_t, ndim=1] arr):
"""Count the number of set bits in a 64-bit integer."""
return np.array(popcount64_arr(arr))


Expand Down
10 changes: 10 additions & 0 deletions searcharray/roaringish/roaringish.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, key_bits: np.uint64 = DEFAULT_KEY_BITS):
assert self.key_bits.dtype == np.uint64
# key bits MSB of 64 bits
self.key_mask = n_msb_mask(key_bits)
self.header_bits = key_bits + self.payload_msb_bits
self.payload_msb_mask = n_msb_mask(np.uint64(self.payload_msb_bits + key_bits)) & ~self.key_mask
assert self.payload_msb_bits.dtype == np.uint64, f"MSB bits dtype was {self.payload_msb_bits.dtype}"
assert self.payload_msb_mask.dtype == np.uint64, f"MSB mask dtype was {self.payload_msb_mask.dtype}"
Expand Down Expand Up @@ -244,10 +245,19 @@ def key_partition(self,
def slice(self,
encoded: np.ndarray,
keys: Optional[np.ndarray] = None,
header: Optional[np.ndarray] = None,
max_payload: Optional[int] = None,
min_payload: Optional[int] = None) -> np.ndarray:
"""Get list of encoded that have values in keys."""
# encoded_keys = encoded.view(np.uint64) >> (_64 - self.key_bits)
if header is not None:
if keys is not None:
raise ValueError("Can't specify both keys and header")
encoded_header = self.header(encoded)
idx_docs, idx_enc = intersect(header.view(np.uint64),
encoded_header.view(np.uint64),
drop_duplicates=False)
encoded = encoded[idx_enc]
if keys is not None:
encoded_keys = self.keys(encoded)
idx_docs, idx_enc = intersect(keys.view(np.uint64),
Expand Down
Loading

0 comments on commit eb9ce33

Please sign in to comment.