Skip to content

Commit

Permalink
Sketch out spans algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Jun 17, 2024
1 parent 7fd75e6 commit eb1f08a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 55 deletions.
15 changes: 0 additions & 15 deletions searcharray/roaringish/popcount.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,6 @@ def popcount64(np.ndarray[DTYPE_t, ndim=1] arr):
return np.array(popcount64_arr(arr))


def ctz64(np.ndarray[DTYPE_t, ndim=1] arr):
"""Count trailing zeros of a 64-bit integer."""
return np.array(ctz_arr(arr))


def clz64(np.ndarray[DTYPE_t, ndim=1] arr):
"""Count leading zeros of a 64-bit integer."""
return np.array(clz_arr(arr))


def msb_mask64(DTYPE_t value):
"""Get the mask of the most significant bit of a 64-bit integer."""
return 1 << (63 - __builtin_clzll(value))


cdef _popcount_reduce_at(DTYPE_t[:] ids, DTYPE_t[:] payload, double[:] output):
cdef DTYPE_t idx = 1
cdef DTYPE_t popcount_sum = __builtin_popcountll(payload[0])
Expand Down
152 changes: 116 additions & 36 deletions searcharray/roaringish/spans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import numpy as np

cimport searcharray.roaringish.snp_ops
from searcharray.roaringish.snp_ops cimport DTYPE_t
from searcharray.roaringish.snp_ops cimport ctz

cdef extern from "stddef.h":
# Trailing and leading zeros to trim the span mask
int __builtin_popcountll(unsigned long long x)
int __builtin_ctzll(unsigned long long x)
int __builtin_clzll(unsigned long long x)


cdef _count_spans_of_slop(DTYPE_t[:] posns, DTYPE_t slop):
Expand All @@ -20,52 +27,125 @@ cdef _count_spans_of_slop(DTYPE_t[:] posns, DTYPE_t slop):
return 1


cdef _span_search(DTYPE_t[:, :] posns_arr,
DTYPE_t[:] phrase_freqs,
DTYPE_t slop,
DTYPE_t key_mask,
DTYPE_t header_mask,
DTYPE_t lsb_bits):
cdef _get_adj_spans(DTYPE_t[:, :] posns_arr,
DTYPE_t[:] phrase_freqs,
DTYPE_t slop,
DTYPE_t key_mask,
DTYPE_t header_mask,
DTYPE_t lsb_bits):
"""Get unscored spans."""
pass


cdef _span_freqs(DTYPE_t[:, :] posns_arr,
DTYPE_t[:] phrase_freqs,
DTYPE_t slop,
DTYPE_t key_mask,
DTYPE_t header_mask,
DTYPE_t lsb_bits):
"""Get unscored spans, within 64 bits."""

cdef DTYPE_t i = 0
cdef DTYPE_t j = 0
cdef DTYPE_t k = 0
cdef DTYPE_t adj = 0
cdef DTYPE_t set_idx = 0
cdef DTYPE_t curr_msb = 0
cdef DTYPE_t adj_msb = 0
cdef DTYPE_t posn = 0
cdef DTYPE_t payload_mask = ~header_mask
cdef DTYPE_t popcount_arr = np.empty(posns_arr.shape[0], dtype=np.uint64)
cdef DTYPE_t[:] which_terms = -np.ones(lsb_bits * 2, dtype=np.uint8)
# Assuming no overlaps.
#
# Collect the term where each position is set
#
# term1: 010011010000 term1 & (term2 + 1)
# term2: 000000000001
# term3: 000000000010
#
# which_terms= F0FF00F0FF21
# (really which_terms [last_term] [this_term])
#
# Scan the which_terms to find spans within slop
# Remove them, increment the phrase_freqs for the doc, then continue

# It may seem we can scan which_terms, but we can just get the minimum spans
#
# which_terms= F0FF00F0FF21
# (really which_terms [last_term] [this_term])
#
# Then diffs:
# which_terms= F0FF00F0FF21
# dist 001201201245
# coll? * <- when to collect, every prev seen unique num
#
# which_terms = F12F00F0FF21
# dist 011230101223
# coll? * *
# This one is tricky because we should NOT collect the first time we encounter all
# terms, but rather the min span in between
#
# which_terms = F2110120FF21
# dist 0234500
# coll? * *
#
# which_terms = F2110120FF21
# spans ---- 1 (posn_first_term, posn_last_term, terms_enc, span_score)
# ----- 2 (posn_first_term, posn_last_term, terms_enc, span_score)
# ----
# ---
# -----
#
# We have to track all active spans
# when popcount terms_enc = num_terms
# ... we collect the span
# if overlaps and size smaller than existing collected span
# remove the existing span
#
# span score is the current span slop
#
# Now we have spans
#
# which_terms = F2110120FF21
# ---
# -----
#
#
#
# Now we score the span to see if its < slop
#
# curr_posns are current bits analyzed for slop
cdef np.uint64_t[:] curr_posns = np.empty(posns_arr.shape[0], dtype=np.uint64)
for i in range(posns_arr.shape[1]):

# First get self + adj into a single 64 bit number
# per term
# i: adj:
# 10 11 14
# termA 0011 0001 0010
# termB 0100 1000 0001
#
# curr_posns now:
#
# termA: [00110001,
# termB: 00100000]
#
# Now we can check for minspans in each terms words
#
doc_id = posns_arr[0, i] & key_mask
# Each term
for j in range(posns_arr.shape[0]):
# Each msb
# Later optimization - could we do this without storing which_terms?
term = posns_arr[j, i] & payload_mask
set_idx = __builtin_ctzll(term)
posns_arr[j, i] &= ~(1 << set_idx)
which_terms[set_idx] = j

# Gather and score min spans
for posn in which_terms:
if posn == 0xFF:
continue
dist


# Shift the which_terms up by num_payload_bits
for j in range(64 - lsb_bits):
which_terms[j + lsb_bits] = which_terms[j]



# The min popcount is the upper bound of phrase freq
popcount_xored_min = 128
for j in range(posns_arr.shape[0]):
curr_posns[j] = posns_arr[j, i] << lsb_bits

adj = i + 1
if adj < posns_arr.shape[1]:
adj_msb = posns_arr[i, 0] & header_mask
adj_msb += (1 << lsb_bits)
curr_msb = posns_arr[adj, 0] & header_mask

# If my neighbor is actually a neighbor
if curr_msb == adj_msb:
for j in range(posns_arr.shape[0]):
curr_posns[j] |= posns_arr[j, adj]
# Find a min span
phrase_freqs[doc_id] = _count_spans_of_slop(curr_posns, slop)
popcount_xored = __builtin_popcountll(posns_arr[j, i] ^ max_span_mask)
if popcount_xored < popcount_xored_min:
popcount_xored_min = popcount_xored


def span_search(np.ndarray[DTYPE_t, ndim=2] posns_arr,
Expand Down
47 changes: 43 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from searcharray.roaringish import popcount64
from searcharray.roaringish import popcount64, msb_mask64
#
# streak each some given slop amount
# 001100 self >> 2 | self >> 1 | self | self << 1 | self << 2 | .. | self << n
Expand Down Expand Up @@ -50,6 +50,47 @@
# Expected:
# 111111

#
# How would you adapt this to span across multiple values?
#
# 100 100000000
# 010 000000001
# 001 000000010
# 000 100000000
# 000 010000001
# 000 001000000 (streak phrase size bits + slop)
# Expected:
# 111 111
#
# 1001 00000000
# 0100 00000001
# 0010 00000010
# 0001 00000000
# 0000 10000001
# 0000 01000000 (streak phrase size bits + slop)
# Expected:
# 1111 11
# If there are N
#
#
# One option, just take the candidate mask and concat from each side
#
# mask = (lhs_mask, adj_bits)
#
# Get a mask of adj_bits num bits:
# mask_of_adj_bit_len = (1 << adj_bits) - 1
#
# Shift to upper 64 bits:
# rhs_mask = mask_of_adj_bit_len << 64 - adj_bits
# values = (lhs & lhs_mask << adj_bits) | (rhs & rhs_mask)
#
# Run the normal argorithm on values
# 1. all bits must be set
# 2. shrink span based on leading / trailing zeros
#
#
# lhs &


_1 = np.uint64(1)

Expand All @@ -69,7 +110,6 @@ def i64(v):
# 00001111 <-- this feels closer!

# (Pdb) dump( (u64(~mask >> u64(3)) & mask) & (u64(~mask >> u64(4)) & mask) & (u64(~mask >> u64
(2)) & mask) )

def dump(v):
if isinstance(v, np.ndarray):
Expand All @@ -88,7 +128,7 @@ def lsb(val):


def msb(val):
return np.uint64((~val >> _1) & val)
return msb_mask64(np.asarray([val], dtype=np.uint64))[0]


def spans_within_mask(bit_vals, mask):
Expand All @@ -107,7 +147,6 @@ def naive_max_span(bit_vals):
"""Closest sets of bits for each bit val."""
# Create a mask of len(bit_vals) + slop
mask = np.bitwise_xor.reduce(bit_vals)
import pdb; pdb.set_trace()
# For every such mask, can we find smallest place where each has bits set
# This can be a one step ctz / clz
while True:
Expand Down

0 comments on commit eb1f08a

Please sign in to comment.