]> git.sesse.net Git - plocate/blobdiff - plocate.cpp
Hand-roll zeroing of destination docids for SSE2; takes us seemingly up from ~84...
[plocate] / plocate.cpp
index 5384ff0b6934b5df0951e626713916df272e3a3f..2c5514ee8b16bc5f3c8461a0c97821e6a2216987 100644 (file)
@@ -1,27 +1,79 @@
+#include "db.h"
+#include "io_uring_engine.h"
+#include "vp4.h"
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <assert.h>
+#include <chrono>
+#include <endian.h>
+#include <fcntl.h>
+#include <functional>
+#include <getopt.h>
+#include <limits.h>
+#include <memory>
 #include <stdio.h>
 #include <string.h>
-#include <algorithm>
-#include <unordered_map>
 #include <string>
-#include <vector>
-#include <chrono>
 #include <unistd.h>
-#include <fcntl.h>
-#include <sys/mman.h>
-#include <arpa/inet.h>
-#include <endian.h>
+#include <unordered_map>
+#include <vector>
 #include <zstd.h>
 
-#include "vp4.h"
-
-#define P4NENC_BOUND(n) ((n+127)/128+(n+32)*sizeof(uint32_t))
-
 using namespace std;
 using namespace std::chrono;
 
 #define dprintf(...)
 //#define dprintf(...) fprintf(stderr, __VA_ARGS__);
-       
+
+#include "turbopfor.h"
+
+const char *dbpath = "/var/lib/mlocate/plocate.db";
+bool print_nul = false;
+
+class Serializer {
+public:
+       bool ready_to_print(int seq) { return next_seq == seq; }
+       void print_delayed(int seq, const vector<string> msg);
+       void release_current();
+
+private:
+       int next_seq = 0;
+       struct Element {
+               int seq;
+               vector<string> msg;
+
+               bool operator<(const Element &other) const
+               {
+                       return seq > other.seq;
+               }
+       };
+       priority_queue<Element> pending;
+};
+
+void Serializer::print_delayed(int seq, const vector<string> msg)
+{
+       pending.push(Element{ seq, move(msg) });
+}
+
+void Serializer::release_current()
+{
+       ++next_seq;
+
+       // See if any delayed prints can now be dealt with.
+       while (!pending.empty() && pending.top().seq == next_seq) {
+               for (const string &msg : pending.top().msg) {
+                       if (print_nul) {
+                               printf("%s%c", msg.c_str(), 0);
+                       } else {
+                               printf("%s\n", msg.c_str());
+                       }
+               }
+               pending.pop();
+               ++next_seq;
+       }
+}
+
 static inline uint32_t read_unigram(const string &s, size_t idx)
 {
        if (idx < s.size()) {
@@ -33,12 +85,12 @@ static inline uint32_t read_unigram(const string &s, size_t idx)
 
 static inline uint32_t read_trigram(const string &s, size_t start)
 {
-       return read_unigram(s, start) |
-               (read_unigram(s, start + 1) << 8) |
+       return read_unigram(s, start) | (read_unigram(s, start + 1) << 8) |
                (read_unigram(s, start + 2) << 16);
 }
 
-bool has_access(const char *filename, unordered_map<string, bool> *access_rx_cache)
+bool has_access(const char *filename,
+                unordered_map<string, bool> *access_rx_cache)
 {
        const char *end = strchr(filename + 1, '/');
        while (end != nullptr) {
@@ -57,118 +109,192 @@ bool has_access(const char *filename, unordered_map<string, bool> *access_rx_cac
                end = strchr(end + 1, '/');
        }
 
-#if 0
-       // Check for rx first in the cache; if that isn't true, check R_OK uncached.
-       // This is roughly the same thing as mlocate does.      
-       auto it = access_rx_cache->find(filename);
-       if (it != access_rx_cache->end() && it->second) {
-               return true;
-       }
-
-       return access(filename, R_OK) == 0;
-#endif
        return true;
 }
 
-struct Trigram {
-       uint32_t trgm;
-       uint32_t num_docids;
-       uint64_t offset;
-};
-
 class Corpus {
 public:
-       Corpus(int fd);
+       Corpus(int fd, IOUringEngine *engine);
        ~Corpus();
-       const Trigram *find_trigram(uint32_t trgm) const;
-       const unsigned char *get_compressed_posting_list(const Trigram *trigram) const;
-       string_view get_compressed_filename_block(uint32_t docid) const;
+       void find_trigram(uint32_t trgm, function<void(const Trigram *trgmptr, size_t len)> cb);
+       void get_compressed_filename_block(uint32_t docid, function<void(string)> cb) const;
+       size_t get_num_filename_blocks() const;
+       off_t offset_for_block(uint32_t docid) const
+       {
+               return hdr.filename_index_offset_bytes + docid * sizeof(uint64_t);
+       }
 
-private:
+public:
        const int fd;
-       off_t len;
-       const char *data;
-       const uint64_t *filename_offsets;
-       const Trigram *trgm_begin, *trgm_end;
+       IOUringEngine *const engine;
+
+       Header hdr;
 };
 
-Corpus::Corpus(int fd)
-       : fd(fd)
+Corpus::Corpus(int fd, IOUringEngine *engine)
+       : fd(fd), engine(engine)
 {
-       len = lseek(fd, 0, SEEK_END);
-       if (len == -1) {
-               perror("lseek");
+       // Enable to test cold-cache behavior (except for access()).
+       if (false) {
+               off_t len = lseek(fd, 0, SEEK_END);
+               if (len == -1) {
+                       perror("lseek");
+                       exit(1);
+               }
+               posix_fadvise(fd, 0, len, POSIX_FADV_DONTNEED);
+       }
+
+       complete_pread(fd, &hdr, sizeof(hdr), /*offset=*/0);
+       if (memcmp(hdr.magic, "\0plocate", 8) != 0) {
+               fprintf(stderr, "plocate.db is corrupt or an old version; please rebuild it.\n");
                exit(1);
        }
-       data = (char *)mmap(nullptr, len, PROT_READ, MAP_SHARED, fd, /*offset=*/0);
-       if (data == MAP_FAILED) {
-               perror("mmap");
+       if (hdr.version != 0) {
+               fprintf(stderr, "plocate.db has version %u, expected 0; please rebuild it.\n", hdr.version);
                exit(1);
        }
-
-       uint64_t num_trigrams = *(const uint64_t *)data;
-       uint64_t filename_index_offset = *(const uint64_t *)(data + sizeof(uint64_t));
-       filename_offsets = (const uint64_t *)(data + filename_index_offset);
-
-       trgm_begin = (Trigram *)(data + sizeof(uint64_t) * 2);
-       trgm_end = trgm_begin + num_trigrams;
 }
 
 Corpus::~Corpus()
 {
-       munmap((void *)data, len);
        close(fd);
 }
 
-const Trigram *Corpus::find_trigram(uint32_t trgm) const
+void Corpus::find_trigram(uint32_t trgm, function<void(const Trigram *trgmptr, size_t len)> cb)
 {
-       const Trigram *trgmptr = lower_bound(trgm_begin, trgm_end, trgm, [](const Trigram &trgm, uint32_t t) {
-               return trgm.trgm < t;
+       uint32_t bucket = hash_trigram(trgm, hdr.hashtable_size);
+       engine->submit_read(fd, sizeof(Trigram) * (hdr.extra_ht_slots + 2), hdr.hash_table_offset_bytes + sizeof(Trigram) * bucket, [this, trgm, bucket, cb{ move(cb) }](string s) {
+               const Trigram *trgmptr = reinterpret_cast<const Trigram *>(s.data());
+               for (unsigned i = 0; i < hdr.extra_ht_slots + 1; ++i) {
+                       if (trgmptr[i].trgm == trgm) {
+                               cb(trgmptr + i, trgmptr[i + 1].offset - trgmptr[i].offset);
+                               return;
+                       }
+               }
+
+               // Not found.
+               cb(nullptr, 0);
        });
-       if (trgmptr == trgm_end || trgmptr->trgm != trgm) {
-               return nullptr;
-       }
-       return trgmptr;
 }
 
-const unsigned char *Corpus::get_compressed_posting_list(const Trigram *trgmptr) const
+void Corpus::get_compressed_filename_block(uint32_t docid, function<void(string)> cb) const
 {
-       return reinterpret_cast<const unsigned char *>(data + trgmptr->offset);
+       // Read the file offset from this docid and the next one.
+       // This is always allowed, since we have a sentinel block at the end.
+       engine->submit_read(fd, sizeof(uint64_t) * 2, offset_for_block(docid), [this, cb{ move(cb) }](string s) {
+               const uint64_t *ptr = reinterpret_cast<const uint64_t *>(s.data());
+               off_t offset = ptr[0];
+               size_t len = ptr[1] - ptr[0];
+               engine->submit_read(fd, len, offset, cb);
+       });
 }
 
-string_view Corpus::get_compressed_filename_block(uint32_t docid) const
+size_t Corpus::get_num_filename_blocks() const
 {
-       const char *compressed = (const char *)(data + filename_offsets[docid]);
-       size_t compressed_size = filename_offsets[docid + 1] - filename_offsets[docid];  // Allowed we have a sentinel block at the end.
-       return {compressed, compressed_size};
+       return hdr.num_docids;
 }
 
-size_t scan_docid(const string &needle, uint32_t docid, const Corpus &corpus, unordered_map<string, bool> *access_rx_cache)
+size_t scan_file_block(const vector<string> &needles, string_view compressed,
+                       unordered_map<string, bool> *access_rx_cache, int seq,
+                       Serializer *serializer)
 {
-       string_view compressed = corpus.get_compressed_filename_block(docid);
        size_t matched = 0;
 
+       unsigned long long uncompressed_len = ZSTD_getFrameContentSize(compressed.data(), compressed.size());
+       if (uncompressed_len == ZSTD_CONTENTSIZE_UNKNOWN || uncompressed_len == ZSTD_CONTENTSIZE_ERROR) {
+               fprintf(stderr, "ZSTD_getFrameContentSize() failed\n");
+               exit(1);
+       }
+
        string block;
-       block.resize(ZSTD_getFrameContentSize(compressed.data(), compressed.size()) + 1);
+       block.resize(uncompressed_len + 1);
 
-       ZSTD_decompress(&block[0], block.size(), compressed.data(), compressed.size());
+       size_t err = ZSTD_decompress(&block[0], block.size(), compressed.data(),
+                                    compressed.size());
+       if (ZSTD_isError(err)) {
+               fprintf(stderr, "ZSTD_decompress(): %s\n", ZSTD_getErrorName(err));
+               exit(1);
+       }
        block[block.size() - 1] = '\0';
 
+       bool immediate_print = (serializer == nullptr || serializer->ready_to_print(seq));
+       vector<string> delayed;
+
        for (const char *filename = block.data();
-                       filename != block.data() + block.size();
-                       filename += strlen(filename) + 1) {
-               if (strstr(filename, needle.c_str()) == nullptr) {
-                       continue;
+            filename != block.data() + block.size();
+            filename += strlen(filename) + 1) {
+               bool found = true;
+               for (const string &needle : needles) {
+                       if (strstr(filename, needle.c_str()) == nullptr) {
+                               found = false;
+                               break;
+                       }
                }
-               if (has_access(filename, access_rx_cache)) {
+               if (found && has_access(filename, access_rx_cache)) {
                        ++matched;
-                       printf("%s\n", filename);
+                       if (immediate_print) {
+                               if (print_nul) {
+                                       printf("%s%c", filename, 0);
+                               } else {
+                                       printf("%s\n", filename);
+                               }
+                       } else {
+                               delayed.push_back(filename);
+                       }
                }
        }
+       if (serializer != nullptr) {
+               if (immediate_print) {
+                       serializer->release_current();
+               } else {
+                       serializer->print_delayed(seq, move(delayed));
+               }
+       }
+       return matched;
+}
+
+size_t scan_docids(const vector<string> &needles, const vector<uint32_t> &docids, const Corpus &corpus, IOUringEngine *engine)
+{
+       Serializer docids_in_order;
+       unordered_map<string, bool> access_rx_cache;
+       size_t matched = 0;
+       for (size_t i = 0; i < docids.size(); ++i) {
+               uint32_t docid = docids[i];
+               corpus.get_compressed_filename_block(docid, [i, &matched, &needles, &access_rx_cache, &docids_in_order](string compressed) {
+                       matched += scan_file_block(needles, compressed, &access_rx_cache, i, &docids_in_order);
+               });
+       }
+       engine->finish();
        return matched;
 }
 
-void do_search_file(const string &needle, const char *filename)
+// We do this sequentially, as it's faster than scattering
+// a lot of I/O through io_uring and hoping the kernel will
+// coalesce it plus readahead for us.
+void scan_all_docids(const vector<string> &needles, int fd, const Corpus &corpus, IOUringEngine *engine)
+{
+       unordered_map<string, bool> access_rx_cache;
+       uint32_t num_blocks = corpus.get_num_filename_blocks();
+       unique_ptr<uint64_t[]> offsets(new uint64_t[num_blocks + 1]);
+       complete_pread(fd, offsets.get(), (num_blocks + 1) * sizeof(uint64_t), corpus.offset_for_block(0));
+       string compressed;
+       for (uint32_t io_docid = 0; io_docid < num_blocks; io_docid += 32) {
+               uint32_t last_docid = std::min(io_docid + 32, num_blocks);
+               size_t io_len = offsets[last_docid] - offsets[io_docid];
+               if (compressed.size() < io_len) {
+                       compressed.resize(io_len);
+               }
+               complete_pread(fd, &compressed[0], io_len, offsets[io_docid]);
+
+               for (uint32_t docid = io_docid; docid < last_docid; ++docid) {
+                       size_t relative_offset = offsets[docid] - offsets[io_docid];
+                       size_t len = offsets[docid + 1] - offsets[docid];
+                       scan_file_block(needles, { &compressed[relative_offset], len }, &access_rx_cache, 0, nullptr);
+               }
+       }
+}
+
+void do_search_file(const vector<string> &needles, const char *filename)
 {
        int fd = open(filename, O_RDONLY);
        if (fd == -1) {
@@ -182,83 +308,173 @@ void do_search_file(const string &needle, const char *filename)
                exit(EXIT_FAILURE);
        }
 
-       //steady_clock::time_point start = steady_clock::now();
+       steady_clock::time_point start __attribute__((unused)) = steady_clock::now();
        if (access("/", R_OK | X_OK)) {
                // We can't find anything, no need to bother...
                return;
        }
 
-       Corpus corpus(fd);
+       IOUringEngine engine;
+       Corpus corpus(fd, &engine);
+       dprintf("Corpus init done after %.1f ms.\n", 1e3 * duration<float>(steady_clock::now() - start).count());
 
-       vector<const Trigram *> trigrams;
-       for (size_t i = 0; i < needle.size() - 2; ++i) {
-               uint32_t trgm = read_trigram(needle, i);
-               const Trigram *trgmptr = corpus.find_trigram(trgm);
-               if (trgmptr == nullptr) {
-                       dprintf("trigram %06x isn't found, we abort the search\n", trgm);
-                       return;
+       vector<pair<Trigram, size_t>> trigrams;
+       uint64_t shortest_so_far = numeric_limits<uint32_t>::max();
+       for (const string &needle : needles) {
+               if (needle.size() < 3)
+                       continue;
+               for (size_t i = 0; i < needle.size() - 2; ++i) {
+                       uint32_t trgm = read_trigram(needle, i);
+                       corpus.find_trigram(trgm, [trgm, &trigrams, &shortest_so_far](const Trigram *trgmptr, size_t len) {
+                               if (trgmptr == nullptr) {
+                                       dprintf("trigram '%c%c%c' isn't found, we abort the search\n",
+                                               trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff);
+                                       exit(0);
+                               }
+                               if (trgmptr->num_docids > shortest_so_far * 100) {
+                                       dprintf("not loading trigram '%c%c%c' with %u docids, it would be ignored later anyway\n",
+                                               trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff,
+                                               trgmptr->num_docids);
+                               } else {
+                                       trigrams.emplace_back(*trgmptr, len);
+                                       shortest_so_far = std::min<uint64_t>(shortest_so_far, trgmptr->num_docids);
+                               }
+                       });
                }
-               trigrams.push_back(trgmptr);
+       }
+       engine.finish();
+       dprintf("Hashtable lookups done after %.1f ms.\n", 1e3 * duration<float>(steady_clock::now() - start).count());
+
+       if (trigrams.empty()) {
+               // Too short for trigram matching. Apply brute force.
+               // (We could have searched through all trigrams that matched
+               // the pattern and done a union of them, but that's a lot of
+               // work for fairly unclear gain.)
+               scan_all_docids(needles, fd, corpus, &engine);
+               return;
        }
        sort(trigrams.begin(), trigrams.end());
        {
                auto last = unique(trigrams.begin(), trigrams.end());
                trigrams.erase(last, trigrams.end());
        }
-       sort(trigrams.begin(), trigrams.end(), [&](const Trigram *a, const Trigram *b) {
-               return a->num_docids < b->num_docids;
-       });
+       sort(trigrams.begin(), trigrams.end(),
+            [&](const pair<Trigram, size_t> &a, const pair<Trigram, size_t> &b) {
+                    return a.first.num_docids < b.first.num_docids;
+            });
 
        vector<uint32_t> in1, in2, out;
-       for (const Trigram *trgmptr : trigrams) {
-               //uint32_t trgm = trgmptr->trgm;
-               size_t num = trgmptr->num_docids;
-               const unsigned char *pldata = corpus.get_compressed_posting_list(trgmptr);
-               if (in1.empty()) {
-                       in1.resize(num + 128);
-                       p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in1[0]);
-                       in1.resize(num);
-                       dprintf("trigram '%c%c%c' decoded to %zu entries\n", trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
-               } else {
-                       if (num > in1.size() * 100) {
-                               dprintf("trigram '%c%c%c' has %zu entries, ignoring the rest (will weed out false positives later)\n",
-                                       trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
-                               break;
-                       }
-
-                       if (in2.size() < num + 128) {
-                               in2.resize(num + 128);
-                       }
-                       p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in2[0]);
+       bool done = false;
+       for (auto [trgmptr, len] : trigrams) {
+               if (!in1.empty() && trgmptr.num_docids > in1.size() * 100) {
+                       uint32_t trgm __attribute__((unused)) = trgmptr.trgm;
+                       dprintf("trigram '%c%c%c' (%zu bytes) has %u entries, ignoring the rest (will "
+                               "weed out false positives later)\n",
+                               trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff,
+                               len, trgmptr.num_docids);
+                       break;
+               }
 
-                       out.clear();
-                       set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num, back_inserter(out));
-                       swap(in1, out);
-                       dprintf("trigram '%c%c%c' decoded to %zu entries, %zu left\n", trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num, in1.size());
-                       if (in1.empty()) {
-                               dprintf("no matches (intersection list is empty)\n");
+               // Only stay a certain amount ahead, so that we don't spend I/O
+               // on reading the latter, large posting lists. We are unlikely
+               // to need them anyway, even if they should come in first.
+               if (engine.get_waiting_reads() >= 5) {
+                       engine.finish();
+                       if (done)
                                break;
-                       }
                }
+               engine.submit_read(fd, len, trgmptr.offset, [trgmptr, len, &done, &in1, &in2, &out](string s) {
+                       if (done)
+                               return;
+                       uint32_t trgm __attribute__((unused)) = trgmptr.trgm;
+                       size_t num = trgmptr.num_docids;
+                       unsigned char *pldata = reinterpret_cast<unsigned char *>(s.data());
+                       if (in1.empty()) {
+                               in1.resize(num + 128);
+                               decode_pfor_delta1<128>(pldata, num, /*interleaved=*/true, &in1[0]);
+                               in1.resize(num);
+                               dprintf("trigram '%c%c%c' (%zu bytes) decoded to %zu entries\n", trgm & 0xff,
+                                       (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, len, num);
+                       } else {
+                               if (in2.size() < num + 128) {
+                                       in2.resize(num + 128);
+                               }
+                               decode_pfor_delta1<128>(pldata, num, /*interleaved=*/true, &in2[0]);
+
+                               out.clear();
+                               set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num,
+                                                back_inserter(out));
+                               swap(in1, out);
+                               dprintf("trigram '%c%c%c' (%zu bytes) decoded to %zu entries, %zu left\n",
+                                       trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff,
+                                       len, num, in1.size());
+                               if (in1.empty()) {
+                                       dprintf("no matches (intersection list is empty)\n");
+                                       done = true;
+                               }
+                       }
+               });
        }
-       steady_clock::time_point end = steady_clock::now();
-
-       dprintf("Intersection took %.1f ms. Doing final verification and printing:\n",
-               1e3 * duration<float>(end - start).count());
+       engine.finish();
+       if (done) {
+               return;
+       }
+       dprintf("Intersection done after %.1f ms. Doing final verification and printing:\n",
+               1e3 * duration<float>(steady_clock::now() - start).count());
 
-       unordered_map<string, bool> access_rx_cache;
+       size_t matched __attribute__((unused)) = scan_docids(needles, in1, corpus, &engine);
+       dprintf("Done in %.1f ms, found %zu matches.\n",
+               1e3 * duration<float>(steady_clock::now() - start).count(), matched);
+}
 
-       int matched = 0;
-       for (uint32_t docid : in1) {
-               matched += scan_docid(needle, docid, corpus, &access_rx_cache);
-       }
-       end = steady_clock::now();
-       dprintf("Done in %.1f ms, found %d matches.\n",
-               1e3 * duration<float>(end - start).count(), matched);
+void usage()
+{
+       // The help text comes from mlocate.
+       printf("Usage: plocate [OPTION]... PATTERN...\n");
+       printf("\n");
+       printf("  -d, --database DBPATH  use DBPATH instead of default database (which is\n");
+       printf("                         %s)\n", dbpath);
+       printf("  -h, --help             print this help\n");
+       printf("  -0, --null             separate entries with NUL on output\n");
 }
 
 int main(int argc, char **argv)
 {
-       //do_search_file(argv[1], "all.trgm");
-       do_search_file(argv[1], "/var/lib/mlocate/plocate.db");
+       static const struct option long_options[] = {
+               { "help", no_argument, 0, 'h' },
+               { "database", required_argument, 0, 'd' },
+               { "null", no_argument, 0, '0' },
+               { 0, 0, 0, 0 }
+       };
+
+       for (;;) {
+               int option_index = 0;
+               int c = getopt_long(argc, argv, "d:h0", long_options, &option_index);
+               if (c == -1) {
+                       break;
+               }
+               switch (c) {
+               case 'd':
+                       dbpath = strdup(optarg);
+                       break;
+               case 'h':
+                       usage();
+                       exit(0);
+               case '0':
+                       print_nul = true;
+                       break;
+               default:
+                       exit(1);
+               }
+       }
+
+       vector<string> needles;
+       for (int i = optind; i < argc; ++i) {
+               needles.push_back(argv[i]);
+       }
+       if (needles.empty()) {
+               fprintf(stderr, "plocate: no pattern to search for specified\n");
+               exit(0);
+       }
+       do_search_file(needles, dbpath);
 }