]> git.sesse.net Git - plocate/commitdiff
Proof-of-concept of using ICU for strength-zero searches. icu-case-insensitive
authorSteinar H. Gunderson <steinar+git@gunderson.no>
Sun, 29 May 2022 09:10:57 +0000 (11:10 +0200)
committerSteinar H. Gunderson <steinar+git@gunderson.no>
Sun, 29 May 2022 09:10:57 +0000 (11:10 +0200)
database-builder.cpp
meson.build
needle.cpp
parse_trigrams.cpp
plocate.cpp

index 419e012ca94c2625471680de17b67f07716ecec8..e9c67a7fd4c315182248768b02bb2f2d512d98d3 100644 (file)
@@ -17,6 +17,8 @@
 #include <unistd.h>
 #include <zdict.h>
 #include <zstd.h>
+#include <unicode/coll.h>
+#include <unicode/locid.h>
 
 #define P4NENC_BOUND(n) ((n + 127) / 128 + (n + 32) * sizeof(uint32_t))
 
@@ -302,6 +304,52 @@ void EncodingCorpus::flush_block()
 
        uint32_t docid = num_blocks;
 
+       // Oh, ICU...
+       vector<uint8_t> sort_key;
+       sort_key.resize(32);
+        int32_t num_locales;
+        const icu::Locale* locales = icu::Collator::getAvailableLocales(num_locales);
+        for (int i = 0; i < num_locales; ++i) {
+                const icu::Locale &loc = locales[i];
+               if (strcmp(loc.getName(), "en_US_POSIX") == 0) {
+                       continue;  // Too weird.
+               }
+                UErrorCode status = U_ZERO_ERROR;
+               icu::Collator *coll = icu::Collator::createInstance(loc, status);
+               if (U_FAILURE(status)) {
+                       fprintf(stderr, "ERROR: Failed to create collator\n");
+                       exit(1);
+               }
+               coll->setStrength(icu::Collator::PRIMARY);
+               const char *ptr = current_block.c_str();
+               const char *end = ptr + current_block.size();
+               while (ptr < end) {
+                       size_t len = strlen(ptr);
+                       int32_t sortkey_len;
+                       for ( ;; ) {
+                               sortkey_len = coll->getSortKey(icu::UnicodeString::fromUTF8(icu::StringPiece(ptr, len)), sort_key.data(), sort_key.size());
+                               if (sortkey_len < sort_key.size()) {  // Note <, not <=; we need to keep a slop byte.
+                                       break;
+                               }
+                               sort_key.resize(sortkey_len * 3 / 2);
+                       }
+
+                       const uint8_t *keyptr = &sort_key[0];
+                       const uint8_t *keyend = keyptr + sortkey_len;
+                       while (keyptr < keyend - 3) {
+                               // NOTE: Will read one byte past the end of the trigram, but it's OK,
+                               // since we always call it from contexts where there's a terminating zero byte.
+                               uint32_t trgm;
+                               memcpy(&trgm, keyptr, sizeof(trgm));
+                               ++keyptr;
+                               trgm = le32toh(trgm);
+                               add_docid(trgm & 0xffffff, docid);
+                       }
+
+                       ptr += len + 1;
+               }
+        }
+#if 0
        // Create trigrams.
        const char *ptr = current_block.c_str();
        const char *end = ptr + current_block.size();
@@ -335,6 +383,7 @@ void EncodingCorpus::flush_block()
                        }
                }
        }
+#endif
 
        // Compress and add the filename block.
        filename_blocks.push_back(outfp_pos);
index 0cda2398f7b4007cb6154b8012f9d5213e9dba75..c3551054cde32f62dc37ced5abc66062b1e69149 100644 (file)
@@ -13,6 +13,7 @@ uringdep = dependency('liburing', required: false)
 zstddep = dependency('libzstd')
 threaddep = dependency('threads')
 atomicdep = cxx.find_library('atomic', required: false)
+icudep = dependency('icu-i18n')
 
 if not uringdep.found()
        add_project_arguments('-DWITHOUT_URING', language: 'cpp')
@@ -31,16 +32,16 @@ if cxx.compiles(code, name: 'function multiversioning')
 endif
 
 executable('plocate', ['plocate.cpp', 'io_uring_engine.cpp', 'turbopfor.cpp', 'parse_trigrams.cpp', 'serializer.cpp', 'access_rx_cache.cpp', 'needle.cpp', 'complete_pread.cpp'],
-       dependencies: [uringdep, zstddep, threaddep, atomicdep],
+       dependencies: [uringdep, zstddep, threaddep, atomicdep, icudep],
        install: true,
        install_mode: ['rwxr-sr-x', 'root', get_option('locategroup')])
 executable('plocate-build', ['plocate-build.cpp', 'database-builder.cpp'],
-       dependencies: [zstddep],
+       dependencies: [zstddep, icudep],
        install: true,
        install_dir: get_option('sbindir'))
 updatedb_progname = get_option('updatedb_progname')
 executable(updatedb_progname, ['updatedb.cpp', 'database-builder.cpp', 'conf.cpp', 'lib.cpp', 'bind-mount.cpp', 'complete_pread.cpp'],
-       dependencies: [zstddep, threaddep],
+       dependencies: [zstddep, threaddep, icudep],
        install: true,
        install_dir: get_option('sbindir'))
 
index 60f169872ac6123c8b06b44477fbe82015b6bb0c..401dbda62005b95564a19deebb71cbc2bb084baa 100644 (file)
 #include <string.h>
 #include <utility>
 
+#include <unicode/coll.h>
+#include <unicode/stsearch.h>
+
 using namespace std;
 
 bool matches(const Needle &needle, const char *haystack)
 {
-       if (needle.type == Needle::STRSTR) {
-               return strstr(haystack, needle.str.c_str()) != nullptr;
-       } else if (needle.type == Needle::GLOB) {
-               int flags = ignore_case ? FNM_CASEFOLD : 0;
-               return fnmatch(needle.str.c_str(), haystack, flags) == 0;
-       } else {
-               assert(needle.type == Needle::REGEX);
-               return regexec(&needle.re, haystack, /*nmatch=*/0, /*pmatch=*/nullptr, /*flags=*/0) == 0;
+       UErrorCode status = U_ZERO_ERROR;
+       icu::UnicodeString target(haystack);  // fromUTF8?
+       icu::UnicodeString pattern(needle.str.c_str());
+       icu::Locale locale = icu::Locale::createCanonical(setlocale(LC_CTYPE, NULL));
+       icu::StringSearch search(pattern, target, locale, nullptr, status);
+       search.getCollator()->setStrength(icu::Collator::PRIMARY);
+       //search.setStrength(icu::Collator::PRIMARY);
+
+       int pos = search.first(status);
+       if (U_FAILURE(status)) {
+               fprintf(stderr, "Could not create a StringSearch object.\n");
+               exit(1);
        }
+       return pos != USEARCH_DONE;
+
+//     if (needle.type == Needle::STRSTR) {
+//             return strstr(haystack, needle.str.c_str()) != nullptr;
+//     } else if (needle.type == Needle::GLOB) {
+//             int flags = ignore_case ? FNM_CASEFOLD : 0;
+//             return fnmatch(needle.str.c_str(), haystack, flags) == 0;
+//     } else {
+//             assert(needle.type == Needle::REGEX);
+//             return regexec(&needle.re, haystack, /*nmatch=*/0, /*pmatch=*/nullptr, /*flags=*/0) == 0;
+//     }
 }
 
 string unescape_glob_to_plain_string(const string &needle)
index 9e72aac9bcfbb71f085b3553d81ef80278c058d3..97734e18cbd81656eb2c439741b29d390e521ed7 100644 (file)
@@ -1,5 +1,6 @@
 #include "parse_trigrams.h"
 
+#include "dprintf.h"
 #include "unique_sort.h"
 
 #include <assert.h>
@@ -7,6 +8,8 @@
 #include <string.h>
 #include <wctype.h>
 
+#include <unicode/coll.h>
+
 using namespace std;
 
 string print_td(const TrigramDisjunction &td)
@@ -284,6 +287,38 @@ void parse_trigrams_ignore_case(const string &needle, vector<TrigramDisjunction>
 
 void parse_trigrams(const string &needle, bool ignore_case, vector<TrigramDisjunction> *trigram_groups)
 {
+       // ICU...
+       string needle2;
+       for (char ch : needle) {
+               if (ch != '*') needle2.push_back(ch);
+       }
+
+       dprintf("posix locale = %s\n", setlocale(LC_CTYPE, NULL));
+       icu::Locale locale = icu::Locale::createCanonical(setlocale(LC_CTYPE, NULL));
+       dprintf("icu locale = %s\n", locale.getName());
+       UErrorCode status = U_ZERO_ERROR;
+       icu::Collator *coll = icu::Collator::createInstance(locale, status);
+       // FIXME check for failure
+       uint8_t needlebuf[1024];  // FIXME
+       coll->setStrength(icu::Collator::PRIMARY);
+       int len = coll->getSortKey(icu::UnicodeString::fromUTF8(needle2), needlebuf, sizeof(needlebuf));
+       dprintf("needlelen = %d (from ascii %zu, needle '%s')\n", len, needle2.size(), needle2.c_str());
+       for (size_t i = 0; i < len; ++i) {
+               dprintf(" %02x", needlebuf[i]);
+       }
+       dprintf("\n");
+       for (size_t i = 0; i < len - 3; ++i) {
+               uint32_t trgm = needlebuf[i] | (needlebuf[i + 1] << 8) | (needlebuf[i + 2] << 16);
+               dprintf("trgm = %06x\n", trgm);
+
+               TrigramDisjunction new_pt;
+               new_pt.remaining_trigrams_to_read = 1;
+               new_pt.trigram_alternatives.push_back(trgm);
+               new_pt.max_num_docids = 0;
+               trigram_groups->push_back(move(new_pt));
+       }
+       return;
+
        if (ignore_case) {
                parse_trigrams_ignore_case(needle, trigram_groups);
                return;
index a1cd97a4a4645c08880b014d4799f9d4c33182f6..64d356e28a551a04e982c0ec38e6083e37091f08 100644 (file)
@@ -958,7 +958,7 @@ int main(int argc, char **argv)
                        needle.re = compile_regex(needle.str);
                } else if (any_wildcard) {
                        needle.type = Needle::GLOB;
-               } else if (ignore_case) {
+               } else if (ignore_case && false) {
                        // strcasestr() doesn't handle locales correctly (even though LSB
                        // claims it should), but somehow, fnmatch() does, and it's about
                        // the same speed as using a regex.