]> git.sesse.net Git - plocate/blob - plocate.cpp
Run include-what-you-use.
[plocate] / plocate.cpp
1 #include <fcntl.h>
2 #include <getopt.h>
3 #include <stdio.h>
4 #include <string.h>
5 #include <unistd.h>
6 #include <zstd.h>
7 #include <getopt.h>
8 #include <stdlib.h>
9 #include <algorithm>
10 #include <chrono>
11 #include <functional>
12 #include <memory>
13 #include <string>
14 #include <unordered_map>
15 #include <vector>
16 #include <stdint.h>
17 #include <iosfwd>
18 #include <iterator>
19 #include <limits>
20 #include <queue>
21 #include <string_view>
22 #include <utility>
23
24 #include "db.h"
25 #include "io_uring_engine.h"
26
27 using namespace std;
28 using namespace std::chrono;
29
30 #define dprintf(...)
31 //#define dprintf(...) fprintf(stderr, __VA_ARGS__);
32
33 #include "turbopfor.h"
34
35 const char *dbpath = "/var/lib/mlocate/plocate.db";
36 bool only_count = false;
37 bool print_nul = false;
38 int64_t limit_matches = numeric_limits<int64_t>::max();
39
40 class Serializer {
41 public:
42         bool ready_to_print(int seq) { return next_seq == seq; }
43         void print_delayed(int seq, const vector<string> msg);
44         void release_current();
45
46 private:
47         int next_seq = 0;
48         struct Element {
49                 int seq;
50                 vector<string> msg;
51
52                 bool operator<(const Element &other) const
53                 {
54                         return seq > other.seq;
55                 }
56         };
57         priority_queue<Element> pending;
58 };
59
60 void Serializer::print_delayed(int seq, const vector<string> msg)
61 {
62         pending.push(Element{ seq, move(msg) });
63 }
64
65 void Serializer::release_current()
66 {
67         ++next_seq;
68
69         // See if any delayed prints can now be dealt with.
70         while (!pending.empty() && pending.top().seq == next_seq) {
71                 if (limit_matches-- <= 0)
72                         return;
73                 for (const string &msg : pending.top().msg) {
74                         if (print_nul) {
75                                 printf("%s%c", msg.c_str(), 0);
76                         } else {
77                                 printf("%s\n", msg.c_str());
78                         }
79                 }
80                 pending.pop();
81                 ++next_seq;
82         }
83 }
84
85 static inline uint32_t read_unigram(const string &s, size_t idx)
86 {
87         if (idx < s.size()) {
88                 return (unsigned char)s[idx];
89         } else {
90                 return 0;
91         }
92 }
93
94 static inline uint32_t read_trigram(const string &s, size_t start)
95 {
96         return read_unigram(s, start) | (read_unigram(s, start + 1) << 8) |
97                 (read_unigram(s, start + 2) << 16);
98 }
99
100 bool has_access(const char *filename,
101                 unordered_map<string, bool> *access_rx_cache)
102 {
103         const char *end = strchr(filename + 1, '/');
104         while (end != nullptr) {
105                 string parent_path(filename, end);
106                 auto it = access_rx_cache->find(parent_path);
107                 bool ok;
108                 if (it == access_rx_cache->end()) {
109                         ok = access(parent_path.c_str(), R_OK | X_OK) == 0;
110                         access_rx_cache->emplace(move(parent_path), ok);
111                 } else {
112                         ok = it->second;
113                 }
114                 if (!ok) {
115                         return false;
116                 }
117                 end = strchr(end + 1, '/');
118         }
119
120         return true;
121 }
122
123 class Corpus {
124 public:
125         Corpus(int fd, IOUringEngine *engine);
126         ~Corpus();
127         void find_trigram(uint32_t trgm, function<void(const Trigram *trgmptr, size_t len)> cb);
128         void get_compressed_filename_block(uint32_t docid, function<void(string_view)> cb) const;
129         size_t get_num_filename_blocks() const;
130         off_t offset_for_block(uint32_t docid) const
131         {
132                 return hdr.filename_index_offset_bytes + docid * sizeof(uint64_t);
133         }
134
135 public:
136         const int fd;
137         IOUringEngine *const engine;
138
139         Header hdr;
140 };
141
142 Corpus::Corpus(int fd, IOUringEngine *engine)
143         : fd(fd), engine(engine)
144 {
145         // Enable to test cold-cache behavior (except for access()).
146         if (false) {
147                 off_t len = lseek(fd, 0, SEEK_END);
148                 if (len == -1) {
149                         perror("lseek");
150                         exit(1);
151                 }
152                 posix_fadvise(fd, 0, len, POSIX_FADV_DONTNEED);
153         }
154
155         complete_pread(fd, &hdr, sizeof(hdr), /*offset=*/0);
156         if (memcmp(hdr.magic, "\0plocate", 8) != 0) {
157                 fprintf(stderr, "plocate.db is corrupt or an old version; please rebuild it.\n");
158                 exit(1);
159         }
160         if (hdr.version != 0) {
161                 fprintf(stderr, "plocate.db has version %u, expected 0; please rebuild it.\n", hdr.version);
162                 exit(1);
163         }
164 }
165
166 Corpus::~Corpus()
167 {
168         close(fd);
169 }
170
171 void Corpus::find_trigram(uint32_t trgm, function<void(const Trigram *trgmptr, size_t len)> cb)
172 {
173         uint32_t bucket = hash_trigram(trgm, hdr.hashtable_size);
174         engine->submit_read(fd, sizeof(Trigram) * (hdr.extra_ht_slots + 2), hdr.hash_table_offset_bytes + sizeof(Trigram) * bucket, [this, trgm, cb{ move(cb) }](string_view s) {
175                 const Trigram *trgmptr = reinterpret_cast<const Trigram *>(s.data());
176                 for (unsigned i = 0; i < hdr.extra_ht_slots + 1; ++i) {
177                         if (trgmptr[i].trgm == trgm) {
178                                 cb(trgmptr + i, trgmptr[i + 1].offset - trgmptr[i].offset);
179                                 return;
180                         }
181                 }
182
183                 // Not found.
184                 cb(nullptr, 0);
185         });
186 }
187
188 void Corpus::get_compressed_filename_block(uint32_t docid, function<void(string_view)> cb) const
189 {
190         // Read the file offset from this docid and the next one.
191         // This is always allowed, since we have a sentinel block at the end.
192         engine->submit_read(fd, sizeof(uint64_t) * 2, offset_for_block(docid), [this, cb{ move(cb) }](string_view s) {
193                 const uint64_t *ptr = reinterpret_cast<const uint64_t *>(s.data());
194                 off_t offset = ptr[0];
195                 size_t len = ptr[1] - ptr[0];
196                 engine->submit_read(fd, len, offset, cb);
197         });
198 }
199
200 size_t Corpus::get_num_filename_blocks() const
201 {
202         return hdr.num_docids;
203 }
204
205 uint64_t scan_file_block(const vector<string> &needles, string_view compressed,
206                          unordered_map<string, bool> *access_rx_cache, int seq,
207                          Serializer *serializer)
208 {
209         uint64_t matched = 0;
210
211         unsigned long long uncompressed_len = ZSTD_getFrameContentSize(compressed.data(), compressed.size());
212         if (uncompressed_len == ZSTD_CONTENTSIZE_UNKNOWN || uncompressed_len == ZSTD_CONTENTSIZE_ERROR) {
213                 fprintf(stderr, "ZSTD_getFrameContentSize() failed\n");
214                 exit(1);
215         }
216
217         string block;
218         block.resize(uncompressed_len + 1);
219
220         size_t err = ZSTD_decompress(&block[0], block.size(), compressed.data(),
221                                      compressed.size());
222         if (ZSTD_isError(err)) {
223                 fprintf(stderr, "ZSTD_decompress(): %s\n", ZSTD_getErrorName(err));
224                 exit(1);
225         }
226         block[block.size() - 1] = '\0';
227
228         bool immediate_print = (serializer == nullptr || serializer->ready_to_print(seq));
229         vector<string> delayed;
230
231         for (const char *filename = block.data();
232              filename != block.data() + block.size();
233              filename += strlen(filename) + 1) {
234                 bool found = true;
235                 for (const string &needle : needles) {
236                         if (strstr(filename, needle.c_str()) == nullptr) {
237                                 found = false;
238                                 break;
239                         }
240                 }
241                 if (found && has_access(filename, access_rx_cache)) {
242                         if (limit_matches-- <= 0)
243                                 break;
244                         ++matched;
245                         if (only_count)
246                                 continue;
247                         if (immediate_print) {
248                                 if (print_nul) {
249                                         printf("%s%c", filename, 0);
250                                 } else {
251                                         printf("%s\n", filename);
252                                 }
253                         } else {
254                                 delayed.push_back(filename);
255                         }
256                 }
257         }
258         if (serializer != nullptr && !only_count) {
259                 if (immediate_print) {
260                         serializer->release_current();
261                 } else {
262                         serializer->print_delayed(seq, move(delayed));
263                 }
264         }
265         return matched;
266 }
267
268 size_t scan_docids(const vector<string> &needles, const vector<uint32_t> &docids, const Corpus &corpus, IOUringEngine *engine)
269 {
270         Serializer docids_in_order;
271         unordered_map<string, bool> access_rx_cache;
272         uint64_t matched = 0;
273         for (size_t i = 0; i < docids.size(); ++i) {
274                 uint32_t docid = docids[i];
275                 corpus.get_compressed_filename_block(docid, [i, &matched, &needles, &access_rx_cache, &docids_in_order](string_view compressed) {
276                         matched += scan_file_block(needles, compressed, &access_rx_cache, i, &docids_in_order);
277                 });
278         }
279         engine->finish();
280         return matched;
281 }
282
283 // We do this sequentially, as it's faster than scattering
284 // a lot of I/O through io_uring and hoping the kernel will
285 // coalesce it plus readahead for us.
286 uint64_t scan_all_docids(const vector<string> &needles, int fd, const Corpus &corpus, IOUringEngine *engine)
287 {
288         unordered_map<string, bool> access_rx_cache;
289         uint32_t num_blocks = corpus.get_num_filename_blocks();
290         unique_ptr<uint64_t[]> offsets(new uint64_t[num_blocks + 1]);
291         complete_pread(fd, offsets.get(), (num_blocks + 1) * sizeof(uint64_t), corpus.offset_for_block(0));
292         string compressed;
293         uint64_t matched = 0;
294         for (uint32_t io_docid = 0; io_docid < num_blocks; io_docid += 32) {
295                 uint32_t last_docid = std::min(io_docid + 32, num_blocks);
296                 size_t io_len = offsets[last_docid] - offsets[io_docid];
297                 if (compressed.size() < io_len) {
298                         compressed.resize(io_len);
299                 }
300                 complete_pread(fd, &compressed[0], io_len, offsets[io_docid]);
301
302                 for (uint32_t docid = io_docid; docid < last_docid; ++docid) {
303                         size_t relative_offset = offsets[docid] - offsets[io_docid];
304                         size_t len = offsets[docid + 1] - offsets[docid];
305                         matched += scan_file_block(needles, { &compressed[relative_offset], len }, &access_rx_cache, 0, nullptr);
306                         if (limit_matches <= 0)
307                                 return matched;
308                 }
309         }
310         return matched;
311 }
312
313 void do_search_file(const vector<string> &needles, const char *filename)
314 {
315         int fd = open(filename, O_RDONLY);
316         if (fd == -1) {
317                 perror(filename);
318                 exit(1);
319         }
320
321         // Drop privileges.
322         if (setgid(getgid()) != 0) {
323                 perror("setgid");
324                 exit(EXIT_FAILURE);
325         }
326
327         steady_clock::time_point start __attribute__((unused)) = steady_clock::now();
328         if (access("/", R_OK | X_OK)) {
329                 // We can't find anything, no need to bother...
330                 return;
331         }
332
333         IOUringEngine engine(/*slop_bytes=*/16);  // 16 slop bytes as described in turbopfor.h.
334         Corpus corpus(fd, &engine);
335         dprintf("Corpus init done after %.1f ms.\n", 1e3 * duration<float>(steady_clock::now() - start).count());
336
337         vector<pair<Trigram, size_t>> trigrams;
338         uint64_t shortest_so_far = numeric_limits<uint32_t>::max();
339         for (const string &needle : needles) {
340                 if (needle.size() < 3)
341                         continue;
342                 for (size_t i = 0; i < needle.size() - 2; ++i) {
343                         uint32_t trgm = read_trigram(needle, i);
344                         corpus.find_trigram(trgm, [trgm, &trigrams, &shortest_so_far](const Trigram *trgmptr, size_t len) {
345                                 if (trgmptr == nullptr) {
346                                         dprintf("trigram '%c%c%c' isn't found, we abort the search\n",
347                                                 trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff);
348                                         if (only_count) {
349                                                 printf("0\n");
350                                         }
351                                         exit(0);
352                                 }
353                                 if (trgmptr->num_docids > shortest_so_far * 100) {
354                                         dprintf("not loading trigram '%c%c%c' with %u docids, it would be ignored later anyway\n",
355                                                 trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff,
356                                                 trgmptr->num_docids);
357                                 } else {
358                                         trigrams.emplace_back(*trgmptr, len);
359                                         shortest_so_far = std::min<uint64_t>(shortest_so_far, trgmptr->num_docids);
360                                 }
361                         });
362                 }
363         }
364         engine.finish();
365         dprintf("Hashtable lookups done after %.1f ms.\n", 1e3 * duration<float>(steady_clock::now() - start).count());
366
367         if (trigrams.empty()) {
368                 // Too short for trigram matching. Apply brute force.
369                 // (We could have searched through all trigrams that matched
370                 // the pattern and done a union of them, but that's a lot of
371                 // work for fairly unclear gain.)
372                 uint64_t matched = scan_all_docids(needles, fd, corpus, &engine);
373                 printf("%zu\n", matched);
374                 return;
375         }
376         sort(trigrams.begin(), trigrams.end());
377         {
378                 auto last = unique(trigrams.begin(), trigrams.end());
379                 trigrams.erase(last, trigrams.end());
380         }
381         sort(trigrams.begin(), trigrams.end(),
382              [&](const pair<Trigram, size_t> &a, const pair<Trigram, size_t> &b) {
383                      return a.first.num_docids < b.first.num_docids;
384              });
385
386         vector<uint32_t> in1, in2, out;
387         bool done = false;
388         for (auto [trgmptr, len] : trigrams) {
389                 if (!in1.empty() && trgmptr.num_docids > in1.size() * 100) {
390                         uint32_t trgm __attribute__((unused)) = trgmptr.trgm;
391                         dprintf("trigram '%c%c%c' (%zu bytes) has %u entries, ignoring the rest (will "
392                                 "weed out false positives later)\n",
393                                 trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff,
394                                 len, trgmptr.num_docids);
395                         break;
396                 }
397
398                 // Only stay a certain amount ahead, so that we don't spend I/O
399                 // on reading the latter, large posting lists. We are unlikely
400                 // to need them anyway, even if they should come in first.
401                 if (engine.get_waiting_reads() >= 5) {
402                         engine.finish();
403                         if (done)
404                                 break;
405                 }
406                 engine.submit_read(fd, len, trgmptr.offset, [trgmptr{ trgmptr }, len{ len }, &done, &in1, &in2, &out](string_view s) {
407                         if (done)
408                                 return;
409                         uint32_t trgm __attribute__((unused)) = trgmptr.trgm;
410                         size_t num = trgmptr.num_docids;
411                         const unsigned char *pldata = reinterpret_cast<const unsigned char *>(s.data());
412                         if (in1.empty()) {
413                                 in1.resize(num + 128);
414                                 decode_pfor_delta1_128(pldata, num, /*interleaved=*/true, &in1[0]);
415                                 in1.resize(num);
416                                 dprintf("trigram '%c%c%c' (%zu bytes) decoded to %zu entries\n", trgm & 0xff,
417                                         (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, len, num);
418                         } else {
419                                 if (in2.size() < num + 128) {
420                                         in2.resize(num + 128);
421                                 }
422                                 decode_pfor_delta1_128(pldata, num, /*interleaved=*/true, &in2[0]);
423
424                                 out.clear();
425                                 set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num,
426                                                  back_inserter(out));
427                                 swap(in1, out);
428                                 dprintf("trigram '%c%c%c' (%zu bytes) decoded to %zu entries, %zu left\n",
429                                         trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff,
430                                         len, num, in1.size());
431                                 if (in1.empty()) {
432                                         dprintf("no matches (intersection list is empty)\n");
433                                         done = true;
434                                 }
435                         }
436                 });
437         }
438         engine.finish();
439         if (done) {
440                 return;
441         }
442         dprintf("Intersection done after %.1f ms. Doing final verification and printing:\n",
443                 1e3 * duration<float>(steady_clock::now() - start).count());
444
445         uint64_t matched = scan_docids(needles, in1, corpus, &engine);
446         dprintf("Done in %.1f ms, found %zu matches.\n",
447                 1e3 * duration<float>(steady_clock::now() - start).count(), matched);
448
449         if (only_count) {
450                 printf("%zu\n", matched);
451         }
452 }
453
454 void usage()
455 {
456         // The help text comes from mlocate.
457         printf("Usage: plocate [OPTION]... PATTERN...\n");
458         printf("\n");
459         printf("  -c, --count            only print number of found entries\n");
460         printf("  -d, --database DBPATH  use DBPATH instead of default database (which is\n");
461         printf("                         %s)\n", dbpath);
462         printf("  -h, --help             print this help\n");
463         printf("  -l, --limit, -n LIMIT  limit output (or counting) to LIMIT entries\n");
464         printf("  -0, --null             separate entries with NUL on output\n");
465 }
466
467 int main(int argc, char **argv)
468 {
469         static const struct option long_options[] = {
470                 { "help", no_argument, 0, 'h' },
471                 { "count", no_argument, 0, 'c' },
472                 { "database", required_argument, 0, 'd' },
473                 { "limit", required_argument, 0, 'l' },
474                 { nullptr, required_argument, 0, 'n' },
475                 { "null", no_argument, 0, '0' },
476                 { 0, 0, 0, 0 }
477         };
478
479         for (;;) {
480                 int option_index = 0;
481                 int c = getopt_long(argc, argv, "cd:hl:n:0", long_options, &option_index);
482                 if (c == -1) {
483                         break;
484                 }
485                 switch (c) {
486                 case 'c':
487                         only_count = true;
488                         break;
489                 case 'd':
490                         dbpath = strdup(optarg);
491                         break;
492                 case 'h':
493                         usage();
494                         exit(0);
495                 case 'l':
496                 case 'n':
497                         limit_matches = atoll(optarg);
498                         break;
499                 case '0':
500                         print_nul = true;
501                         break;
502                 default:
503                         exit(1);
504                 }
505         }
506
507         vector<string> needles;
508         for (int i = optind; i < argc; ++i) {
509                 needles.push_back(argv[i]);
510         }
511         if (needles.empty()) {
512                 fprintf(stderr, "plocate: no pattern to search for specified\n");
513                 exit(0);
514         }
515         do_search_file(needles, dbpath);
516 }