Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Low overhead case insensitive find #195

Open
3 tasks done
e253 opened this issue Nov 15, 2024 · 2 comments
Open
3 tasks done

Feature: Low overhead case insensitive find #195

e253 opened this issue Nov 15, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@e253
Copy link

e253 commented Nov 15, 2024

Describe what you are looking for

The sz_tolower function requires copying to another buffer. Within find and find_byte routines a lowercasing step can be done quickly in a few extra cycles. I'm happy to add the feature, just gauging interest here.

For avx512

SZ_INTERNAL __m512i sz_lower_avx512(__m512i in)
{
    __m512i A = _mm512_set1_epi8('A');
    __m512i Z = _mm512_set1_epi8('Z');
    __m512i to_lower = _mm512_set1_epi8('a' - 'A');
    __mmask64 ge_A = _mm512_cmpge_epi8_mask(in, A);
    __mmask64 le_Z = _mm512_cmple_epi8_mask(in, Z);
    __mmask64 is_upper = _kand_mask64(ge_A, le_Z);
    return _mm512_mask_add_epi8(in, is_upper, in, to_lower);
}

SZ_PUBLIC sz_cptr_t sz_find_byte_case_insensitive_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n)
{
    __mmask64 mask;
    sz_u512_vec_t h_vec, n_vec;
    /// PATCH!!!
    n_vec.zmm = _mm512_set1_epi8(sz_u8_tolower(n[0]));
    /// PATCH!!!

    while (h_length >= 64) {
        /// PATCH!!!
        h_vec.zmm = sz_lower_avx512(_mm512_loadu_si512(h));
        /// PATCH!!!
        mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm);
        if (mask)
            return h + sz_u64_ctz(mask);
        h += 64, h_length -= 64;
    }

    if (h_length) {
        mask = _sz_u64_mask_until(h_length);
        /// PATCH!!!
        h_vec.zmm = sz_lower_avx512(_mm512_maskz_loadu_epi8(mask, h));
        /// PATCH!!!
        // Reuse the same `mask` variable to find the bit that doesn't match
        mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm);
        if (mask)
            return h + sz_u64_ctz(mask);
    }

    return SZ_NULL_CHAR;
}

Can you contribute to the implementation?

  • I can contribute

Is your feature request specific to a certain interface?

It applies to everything

Contact Details

[email protected]

Is there an existing issue for this?

  • I have searched the existing issues

Code of Conduct

  • I agree to follow this project's Code of Conduct
@e253 e253 added the enhancement New feature or request label Nov 15, 2024
@e253 e253 changed the title Feature: Case Insensitive find Feature: Low overhead case insensitive find Nov 15, 2024
@ashvardanian
Copy link
Owner

I've considered this before, and it raises questions about the character set encodings. The suggested lowering function limits the applicability to ASCII content. There is probably a better way to future-proof the API for UTF8 as well. Feel free to open a PR, and I'll modify/integrate it down the road when we get to it.

@e253
Copy link
Author

e253 commented Nov 16, 2024

The internal function can branch to a fallback if non-ascii characters are present. The public call signature is the same as find / find_byte.

SZ_INTERNAL __m512i sz_lower_avx512(__m512i in)
{
   if (sz_is_ascii_avx512(in)) {
       __m512i A = _mm512_set1_epi8('A');
       __m512i Z = _mm512_set1_epi8('Z');
       __m512i to_lower = _mm512_set1_epi8('a' - 'A');
       __mmask64 ge_A = _mm512_cmpge_epi8_mask(in, A);
       __mmask64 le_Z = _mm512_cmple_epi8_mask(in, Z);
       __mmask64 is_upper = _kand_mask64(ge_A, le_Z);
       return _mm512_mask_add_epi8(in, is_upper, in, to_lower);
   } else {
      // UTF-8 fallback
   }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants