]> git.sesse.net Git - plocate/blob - plocate.cpp
Generalize the sort+unique+erase pattern into unique_sort().
[plocate] / plocate.cpp
1 #include "db.h"
2 #include "io_uring_engine.h"
3 #include "unique_sort.h"
4
5 #include <algorithm>
6 #include <chrono>
7 #include <fcntl.h>
8 #include <functional>
9 #include <getopt.h>
10 #include <iosfwd>
11 #include <iterator>
12 #include <limits>
13 #include <memory>
14 #include <queue>
15 #include <stdint.h>
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <string.h>
19 #include <string>
20 #include <string_view>
21 #include <unistd.h>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include <zstd.h>
26
27 using namespace std;
28 using namespace std::chrono;
29
30 #define dprintf(...)
31 //#define dprintf(...) fprintf(stderr, __VA_ARGS__);
32
33 #include "turbopfor.h"
34
35 const char *dbpath = "/var/lib/mlocate/plocate.db";
36 bool only_count = false;
37 bool print_nul = false;
38 int64_t limit_matches = numeric_limits<int64_t>::max();
39
40 class Serializer {
41 public:
42         bool ready_to_print(int seq) { return next_seq == seq; }
43         void print_delayed(int seq, const vector<string> msg);
44         void release_current();
45
46 private:
47         int next_seq = 0;
48         struct Element {
49                 int seq;
50                 vector<string> msg;
51
52                 bool operator<(const Element &other) const
53                 {
54                         return seq > other.seq;
55                 }
56         };
57         priority_queue<Element> pending;
58 };
59
60 void Serializer::print_delayed(int seq, const vector<string> msg)
61 {
62         pending.push(Element{ seq, move(msg) });
63 }
64
65 void Serializer::release_current()
66 {
67         ++next_seq;
68
69         // See if any delayed prints can now be dealt with.
70         while (!pending.empty() && pending.top().seq == next_seq) {
71                 if (limit_matches-- <= 0)
72                         return;
73                 for (const string &msg : pending.top().msg) {
74                         if (print_nul) {
75                                 printf("%s%c", msg.c_str(), 0);
76                         } else {
77                                 printf("%s\n", msg.c_str());
78                         }
79                 }
80                 pending.pop();
81                 ++next_seq;
82         }
83 }
84
85 static inline uint32_t read_unigram(const string &s, size_t idx)
86 {
87         if (idx < s.size()) {
88                 return (unsigned char)s[idx];
89         } else {
90                 return 0;
91         }
92 }
93
94 static inline uint32_t read_trigram(const string &s, size_t start)
95 {
96         return read_unigram(s, start) | (read_unigram(s, start + 1) << 8) |
97                 (read_unigram(s, start + 2) << 16);
98 }
99
100 bool has_access(const char *filename,
101                 unordered_map<string, bool> *access_rx_cache)
102 {
103         const char *end = strchr(filename + 1, '/');
104         while (end != nullptr) {
105                 string parent_path(filename, end);
106                 auto it = access_rx_cache->find(parent_path);
107                 bool ok;
108                 if (it == access_rx_cache->end()) {
109                         ok = access(parent_path.c_str(), R_OK | X_OK) == 0;
110                         access_rx_cache->emplace(move(parent_path), ok);
111                 } else {
112                         ok = it->second;
113                 }
114                 if (!ok) {
115                         return false;
116                 }
117                 end = strchr(end + 1, '/');
118         }
119
120         return true;
121 }
122
123 class Corpus {
124 public:
125         Corpus(int fd, IOUringEngine *engine);
126         ~Corpus();
127         void find_trigram(uint32_t trgm, function<void(const Trigram *trgmptr, size_t len)> cb);
128         void get_compressed_filename_block(uint32_t docid, function<void(string_view)> cb) const;
129         size_t get_num_filename_blocks() const;
130         off_t offset_for_block(uint32_t docid) const
131         {
132                 return hdr.filename_index_offset_bytes + docid * sizeof(uint64_t);
133         }
134
135 public:
136         const int fd;
137         IOUringEngine *const engine;
138
139         Header hdr;
140 };
141
142 Corpus::Corpus(int fd, IOUringEngine *engine)
143         : fd(fd), engine(engine)
144 {
145         // Enable to test cold-cache behavior (except for access()).
146         if (false) {
147                 off_t len = lseek(fd, 0, SEEK_END);
148                 if (len == -1) {
149                         perror("lseek");
150                         exit(1);
151                 }
152                 posix_fadvise(fd, 0, len, POSIX_FADV_DONTNEED);
153         }
154
155         complete_pread(fd, &hdr, sizeof(hdr), /*offset=*/0);
156         if (memcmp(hdr.magic, "\0plocate", 8) != 0) {
157                 fprintf(stderr, "plocate.db is corrupt or an old version; please rebuild it.\n");
158                 exit(1);
159         }
160         if (hdr.version != 0) {
161                 fprintf(stderr, "plocate.db has version %u, expected 0; please rebuild it.\n", hdr.version);
162                 exit(1);
163         }
164 }
165
166 Corpus::~Corpus()
167 {
168         close(fd);
169 }
170
171 void Corpus::find_trigram(uint32_t trgm, function<void(const Trigram *trgmptr, size_t len)> cb)
172 {
173         uint32_t bucket = hash_trigram(trgm, hdr.hashtable_size);
174         engine->submit_read(fd, sizeof(Trigram) * (hdr.extra_ht_slots + 2), hdr.hash_table_offset_bytes + sizeof(Trigram) * bucket, [this, trgm, cb{ move(cb) }](string_view s) {
175                 const Trigram *trgmptr = reinterpret_cast<const Trigram *>(s.data());
176                 for (unsigned i = 0; i < hdr.extra_ht_slots + 1; ++i) {
177                         if (trgmptr[i].trgm == trgm) {
178                                 cb(trgmptr + i, trgmptr[i + 1].offset - trgmptr[i].offset);
179                                 return;
180                         }
181                 }
182
183                 // Not found.
184                 cb(nullptr, 0);
185         });
186 }
187
188 void Corpus::get_compressed_filename_block(uint32_t docid, function<void(string_view)> cb) const
189 {
190         // Read the file offset from this docid and the next one.
191         // This is always allowed, since we have a sentinel block at the end.
192         engine->submit_read(fd, sizeof(uint64_t) * 2, offset_for_block(docid), [this, cb{ move(cb) }](string_view s) {
193                 const uint64_t *ptr = reinterpret_cast<const uint64_t *>(s.data());
194                 off_t offset = ptr[0];
195                 size_t len = ptr[1] - ptr[0];
196                 engine->submit_read(fd, len, offset, cb);
197         });
198 }
199
200 size_t Corpus::get_num_filename_blocks() const
201 {
202         return hdr.num_docids;
203 }
204
205 uint64_t scan_file_block(const vector<string> &needles, string_view compressed,
206                          unordered_map<string, bool> *access_rx_cache, int seq,
207                          Serializer *serializer)
208 {
209         uint64_t matched = 0;
210
211         unsigned long long uncompressed_len = ZSTD_getFrameContentSize(compressed.data(), compressed.size());
212         if (uncompressed_len == ZSTD_CONTENTSIZE_UNKNOWN || uncompressed_len == ZSTD_CONTENTSIZE_ERROR) {
213                 fprintf(stderr, "ZSTD_getFrameContentSize() failed\n");
214                 exit(1);
215         }
216
217         string block;
218         block.resize(uncompressed_len + 1);
219
220         size_t err = ZSTD_decompress(&block[0], block.size(), compressed.data(),
221                                      compressed.size());
222         if (ZSTD_isError(err)) {
223                 fprintf(stderr, "ZSTD_decompress(): %s\n", ZSTD_getErrorName(err));
224                 exit(1);
225         }
226         block[block.size() - 1] = '\0';
227
228         bool immediate_print = (serializer == nullptr || serializer->ready_to_print(seq));
229         vector<string> delayed;
230
231         for (const char *filename = block.data();
232              filename != block.data() + block.size();
233              filename += strlen(filename) + 1) {
234                 bool found = true;
235                 for (const string &needle : needles) {
236                         if (strstr(filename, needle.c_str()) == nullptr) {
237                                 found = false;
238                                 break;
239                         }
240                 }
241                 if (found && has_access(filename, access_rx_cache)) {
242                         if (limit_matches-- <= 0)
243                                 break;
244                         ++matched;
245                         if (only_count)
246                                 continue;
247                         if (immediate_print) {
248                                 if (print_nul) {
249                                         printf("%s%c", filename, 0);
250                                 } else {
251                                         printf("%s\n", filename);
252                                 }
253                         } else {
254                                 delayed.push_back(filename);
255                         }
256                 }
257         }
258         if (serializer != nullptr && !only_count) {
259                 if (immediate_print) {
260                         serializer->release_current();
261                 } else {
262                         serializer->print_delayed(seq, move(delayed));
263                 }
264         }
265         return matched;
266 }
267
268 size_t scan_docids(const vector<string> &needles, const vector<uint32_t> &docids, const Corpus &corpus, IOUringEngine *engine)
269 {
270         Serializer docids_in_order;
271         unordered_map<string, bool> access_rx_cache;
272         uint64_t matched = 0;
273         for (size_t i = 0; i < docids.size(); ++i) {
274                 uint32_t docid = docids[i];
275                 corpus.get_compressed_filename_block(docid, [i, &matched, &needles, &access_rx_cache, &docids_in_order](string_view compressed) {
276                         matched += scan_file_block(needles, compressed, &access_rx_cache, i, &docids_in_order);
277                 });
278         }
279         engine->finish();
280         return matched;
281 }
282
283 // We do this sequentially, as it's faster than scattering
284 // a lot of I/O through io_uring and hoping the kernel will
285 // coalesce it plus readahead for us.
286 uint64_t scan_all_docids(const vector<string> &needles, int fd, const Corpus &corpus, IOUringEngine *engine)
287 {
288         unordered_map<string, bool> access_rx_cache;
289         uint32_t num_blocks = corpus.get_num_filename_blocks();
290         unique_ptr<uint64_t[]> offsets(new uint64_t[num_blocks + 1]);
291         complete_pread(fd, offsets.get(), (num_blocks + 1) * sizeof(uint64_t), corpus.offset_for_block(0));
292         string compressed;
293         uint64_t matched = 0;
294         for (uint32_t io_docid = 0; io_docid < num_blocks; io_docid += 32) {
295                 uint32_t last_docid = std::min(io_docid + 32, num_blocks);
296                 size_t io_len = offsets[last_docid] - offsets[io_docid];
297                 if (compressed.size() < io_len) {
298                         compressed.resize(io_len);
299                 }
300                 complete_pread(fd, &compressed[0], io_len, offsets[io_docid]);
301
302                 for (uint32_t docid = io_docid; docid < last_docid; ++docid) {
303                         size_t relative_offset = offsets[docid] - offsets[io_docid];
304                         size_t len = offsets[docid + 1] - offsets[docid];
305                         matched += scan_file_block(needles, { &compressed[relative_offset], len }, &access_rx_cache, 0, nullptr);
306                         if (limit_matches <= 0)
307                                 return matched;
308                 }
309         }
310         return matched;
311 }
312
313 // For debugging.
314 string print_trigram(uint32_t trgm)
315 {
316         char ch[3] = {
317                 char(trgm & 0xff), char((trgm >> 8) & 0xff), char((trgm >> 16) & 0xff)
318         };
319
320         string str = "'";
321         for (unsigned i = 0; i < 3;) {
322                 if (ch[i] == '\\') {
323                         str.push_back('\\');
324                         str.push_back(ch[i]);
325                         ++i;
326                 } else if (int(ch[i]) >= 32 && int(ch[i]) <= 127) {  // Holds no matter whether char is signed or unsigned.
327                         str.push_back(ch[i]);
328                         ++i;
329                 } else {
330                         // See if we have an entire UTF-8 codepoint, and that it's reasonably printable.
331                         mbtowc(nullptr, 0, 0);
332                         wchar_t pwc;
333                         int ret = mbtowc(&pwc, ch + i, 3 - i);
334                         if (ret >= 1 && pwc >= 32) {
335                                 str.append(ch + i, ret);
336                                 i += ret;
337                         } else {
338                                 char buf[16];
339                                 snprintf(buf, sizeof(buf), "\\x{%02x}", (unsigned char)ch[i]);
340                                 str += buf;
341                                 ++i;
342                         }
343                 }
344         }
345         str += "'";
346         return str;
347 }
348
349 void do_search_file(const vector<string> &needles, const char *filename)
350 {
351         int fd = open(filename, O_RDONLY);
352         if (fd == -1) {
353                 perror(filename);
354                 exit(1);
355         }
356
357         // Drop privileges.
358         if (setgid(getgid()) != 0) {
359                 perror("setgid");
360                 exit(EXIT_FAILURE);
361         }
362
363         steady_clock::time_point start __attribute__((unused)) = steady_clock::now();
364         if (access("/", R_OK | X_OK)) {
365                 // We can't find anything, no need to bother...
366                 return;
367         }
368
369         IOUringEngine engine(/*slop_bytes=*/16);  // 16 slop bytes as described in turbopfor.h.
370         Corpus corpus(fd, &engine);
371         dprintf("Corpus init done after %.1f ms.\n", 1e3 * duration<float>(steady_clock::now() - start).count());
372
373         vector<pair<Trigram, size_t>> trigrams;
374         for (const string &needle : needles) {
375                 if (needle.size() < 3)
376                         continue;
377                 for (size_t i = 0; i < needle.size() - 2; ++i) {
378                         uint32_t trgm = read_trigram(needle, i);
379                         corpus.find_trigram(trgm, [trgm, &trigrams](const Trigram *trgmptr, size_t len) {
380                                 if (trgmptr == nullptr) {
381                                         dprintf("trigram %s isn't found, we abort the search\n", print_trigram(trgm).c_str());
382                                         if (only_count) {
383                                                 printf("0\n");
384                                         }
385                                         exit(0);
386                                 }
387                                 trigrams.emplace_back(*trgmptr, len);
388                         });
389                 }
390         }
391         engine.finish();
392         dprintf("Hashtable lookups done after %.1f ms.\n", 1e3 * duration<float>(steady_clock::now() - start).count());
393
394         if (trigrams.empty()) {
395                 // Too short for trigram matching. Apply brute force.
396                 // (We could have searched through all trigrams that matched
397                 // the pattern and done a union of them, but that's a lot of
398                 // work for fairly unclear gain.)
399                 uint64_t matched = scan_all_docids(needles, fd, corpus, &engine);
400                 if (only_count) {
401                         printf("%zu\n", matched);
402                 }
403                 return;
404         }
405         unique_sort(&trigrams);
406         sort(trigrams.begin(), trigrams.end(),
407              [&](const pair<Trigram, size_t> &a, const pair<Trigram, size_t> &b) {
408                      return a.first.num_docids < b.first.num_docids;
409              });
410
411         vector<uint32_t> in1, in2, out;
412         bool done = false;
413         for (auto [trgmptr, len] : trigrams) {
414                 if (!in1.empty() && trgmptr.num_docids > in1.size() * 100) {
415                         uint32_t trgm __attribute__((unused)) = trgmptr.trgm;
416                         dprintf("trigram %s (%zu bytes) has %u entries, ignoring the rest (will "
417                                 "weed out false positives later)\n",
418                                 print_trigram(trgm).c_str(), len, trgmptr.num_docids);
419                         break;
420                 }
421
422                 // Only stay a certain amount ahead, so that we don't spend I/O
423                 // on reading the latter, large posting lists. We are unlikely
424                 // to need them anyway, even if they should come in first.
425                 if (engine.get_waiting_reads() >= 5) {
426                         engine.finish();
427                         if (done)
428                                 break;
429                 }
430                 engine.submit_read(fd, len, trgmptr.offset, [trgmptr{ trgmptr }, len{ len }, &done, &in1, &in2, &out](string_view s) {
431                         if (done)
432                                 return;
433                         uint32_t trgm __attribute__((unused)) = trgmptr.trgm;
434                         size_t num = trgmptr.num_docids;
435                         const unsigned char *pldata = reinterpret_cast<const unsigned char *>(s.data());
436                         if (in1.empty()) {
437                                 in1.resize(num + 128);
438                                 decode_pfor_delta1_128(pldata, num, /*interleaved=*/true, &in1[0]);
439                                 in1.resize(num);
440                                 dprintf("trigram %s (%zu bytes) decoded to %zu entries\n",
441                                         print_trigram(trgm).c_str(), len, num);
442                         } else {
443                                 if (in2.size() < num + 128) {
444                                         in2.resize(num + 128);
445                                 }
446                                 decode_pfor_delta1_128(pldata, num, /*interleaved=*/true, &in2[0]);
447
448                                 out.clear();
449                                 set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num,
450                                                  back_inserter(out));
451                                 swap(in1, out);
452                                 dprintf("trigram %s (%zu bytes) decoded to %zu entries, %zu left\n",
453                                         print_trigram(trgm).c_str(), len, num, in1.size());
454                                 if (in1.empty()) {
455                                         dprintf("no matches (intersection list is empty)\n");
456                                         done = true;
457                                 }
458                         }
459                 });
460         }
461         engine.finish();
462         if (done) {
463                 return;
464         }
465         dprintf("Intersection done after %.1f ms. Doing final verification and printing:\n",
466                 1e3 * duration<float>(steady_clock::now() - start).count());
467
468         uint64_t matched = scan_docids(needles, in1, corpus, &engine);
469         dprintf("Done in %.1f ms, found %zu matches.\n",
470                 1e3 * duration<float>(steady_clock::now() - start).count(), matched);
471
472         if (only_count) {
473                 printf("%zu\n", matched);
474         }
475 }
476
477 void usage()
478 {
479         // The help text comes from mlocate.
480         printf("Usage: plocate [OPTION]... PATTERN...\n");
481         printf("\n");
482         printf("  -c, --count            only print number of found entries\n");
483         printf("  -d, --database DBPATH  use DBPATH instead of default database (which is\n");
484         printf("                         %s)\n", dbpath);
485         printf("  -h, --help             print this help\n");
486         printf("  -l, --limit, -n LIMIT  limit output (or counting) to LIMIT entries\n");
487         printf("  -0, --null             separate entries with NUL on output\n");
488 }
489
490 int main(int argc, char **argv)
491 {
492         static const struct option long_options[] = {
493                 { "help", no_argument, 0, 'h' },
494                 { "count", no_argument, 0, 'c' },
495                 { "database", required_argument, 0, 'd' },
496                 { "limit", required_argument, 0, 'l' },
497                 { nullptr, required_argument, 0, 'n' },
498                 { "null", no_argument, 0, '0' },
499                 { 0, 0, 0, 0 }
500         };
501
502         setlocale(LC_ALL, "");
503         for (;;) {
504                 int option_index = 0;
505                 int c = getopt_long(argc, argv, "cd:hl:n:0", long_options, &option_index);
506                 if (c == -1) {
507                         break;
508                 }
509                 switch (c) {
510                 case 'c':
511                         only_count = true;
512                         break;
513                 case 'd':
514                         dbpath = strdup(optarg);
515                         break;
516                 case 'h':
517                         usage();
518                         exit(0);
519                 case 'l':
520                 case 'n':
521                         limit_matches = atoll(optarg);
522                         break;
523                 case '0':
524                         print_nul = true;
525                         break;
526                 default:
527                         exit(1);
528                 }
529         }
530
531         vector<string> needles;
532         for (int i = optind; i < argc; ++i) {
533                 needles.push_back(argv[i]);
534         }
535         if (needles.empty()) {
536                 fprintf(stderr, "plocate: no pattern to search for specified\n");
537                 exit(0);
538         }
539         do_search_file(needles, dbpath);
540 }