]> git.sesse.net Git - plocate/blob - plocate.cpp
Support patterns shorter than 3 bytes.
[plocate] / plocate.cpp
1 #include "vp4.h"
2
3 #include <algorithm>
4 #include <arpa/inet.h>
5 #include <chrono>
6 #include <endian.h>
7 #include <fcntl.h>
8 #include <stdio.h>
9 #include <string.h>
10 #include <string>
11 #include <sys/mman.h>
12 #include <unistd.h>
13 #include <unordered_map>
14 #include <vector>
15 #include <zstd.h>
16
17 #define P4NENC_BOUND(n) ((n + 127) / 128 + (n + 32) * sizeof(uint32_t))
18
19 using namespace std;
20 using namespace std::chrono;
21
22 #define dprintf(...)
23 //#define dprintf(...) fprintf(stderr, __VA_ARGS__);
24
25 static inline uint32_t read_unigram(const string &s, size_t idx)
26 {
27         if (idx < s.size()) {
28                 return (unsigned char)s[idx];
29         } else {
30                 return 0;
31         }
32 }
33
34 static inline uint32_t read_trigram(const string &s, size_t start)
35 {
36         return read_unigram(s, start) | (read_unigram(s, start + 1) << 8) |
37                 (read_unigram(s, start + 2) << 16);
38 }
39
40 bool has_access(const char *filename,
41                 unordered_map<string, bool> *access_rx_cache)
42 {
43         const char *end = strchr(filename + 1, '/');
44         while (end != nullptr) {
45                 string parent_path(filename, end);
46                 auto it = access_rx_cache->find(parent_path);
47                 bool ok;
48                 if (it == access_rx_cache->end()) {
49                         ok = access(parent_path.c_str(), R_OK | X_OK) == 0;
50                         access_rx_cache->emplace(move(parent_path), ok);
51                 } else {
52                         ok = it->second;
53                 }
54                 if (!ok) {
55                         return false;
56                 }
57                 end = strchr(end + 1, '/');
58         }
59
60         return true;
61 }
62
63 struct Trigram {
64         uint32_t trgm;
65         uint32_t num_docids;
66         uint64_t offset;
67 };
68
69 class Corpus {
70 public:
71         Corpus(int fd);
72         ~Corpus();
73         const Trigram *find_trigram(uint32_t trgm) const;
74         const unsigned char *
75         get_compressed_posting_list(const Trigram *trigram) const;
76         string_view get_compressed_filename_block(uint32_t docid) const;
77         size_t get_num_filename_blocks() const;
78
79 private:
80         const int fd;
81         off_t len;
82         const char *data;
83         const uint64_t *filename_offsets;
84         const Trigram *trgm_begin, *trgm_end;
85 };
86
87 Corpus::Corpus(int fd)
88         : fd(fd)
89 {
90         len = lseek(fd, 0, SEEK_END);
91         if (len == -1) {
92                 perror("lseek");
93                 exit(1);
94         }
95         data = (char *)mmap(nullptr, len, PROT_READ, MAP_SHARED, fd, /*offset=*/0);
96         if (data == MAP_FAILED) {
97                 perror("mmap");
98                 exit(1);
99         }
100
101         uint64_t num_trigrams = *(const uint64_t *)data;
102         uint64_t filename_index_offset = *(const uint64_t *)(data + sizeof(uint64_t));
103         filename_offsets = (const uint64_t *)(data + filename_index_offset);
104
105         trgm_begin = (Trigram *)(data + sizeof(uint64_t) * 2);
106         trgm_end = trgm_begin + num_trigrams;
107 }
108
109 Corpus::~Corpus()
110 {
111         munmap((void *)data, len);
112         close(fd);
113 }
114
115 const Trigram *Corpus::find_trigram(uint32_t trgm) const
116 {
117         const Trigram *trgmptr = lower_bound(
118                 trgm_begin, trgm_end, trgm,
119                 [](const Trigram &trgm, uint32_t t) { return trgm.trgm < t; });
120         if (trgmptr == trgm_end || trgmptr->trgm != trgm) {
121                 return nullptr;
122         }
123         return trgmptr;
124 }
125
126 const unsigned char *
127 Corpus::get_compressed_posting_list(const Trigram *trgmptr) const
128 {
129         return reinterpret_cast<const unsigned char *>(data + trgmptr->offset);
130 }
131
132 string_view Corpus::get_compressed_filename_block(uint32_t docid) const
133 {
134         const char *compressed = (const char *)(data + filename_offsets[docid]);
135         size_t compressed_size =
136                 filename_offsets[docid + 1] -
137                 filename_offsets[docid];  // Allowed we have a sentinel block at the end.
138         return { compressed, compressed_size };
139 }
140
141 size_t Corpus::get_num_filename_blocks() const
142 {
143         // The beginning of the filename blocks is the end of the filename index blocks.
144         const uint64_t *filename_offsets_end = (const uint64_t *)(data + filename_offsets[0]);
145
146         // Subtract the sentinel block.
147         return filename_offsets_end - filename_offsets - 1;
148 }
149
150 size_t scan_docid(const string &needle, uint32_t docid, const Corpus &corpus,
151                   unordered_map<string, bool> *access_rx_cache)
152 {
153         string_view compressed = corpus.get_compressed_filename_block(docid);
154         size_t matched = 0;
155
156         string block;
157         block.resize(ZSTD_getFrameContentSize(compressed.data(), compressed.size()) +
158                      1);
159
160         ZSTD_decompress(&block[0], block.size(), compressed.data(),
161                         compressed.size());
162         block[block.size() - 1] = '\0';
163
164         for (const char *filename = block.data();
165              filename != block.data() + block.size();
166              filename += strlen(filename) + 1) {
167                 if (strstr(filename, needle.c_str()) == nullptr) {
168                         continue;
169                 }
170                 if (has_access(filename, access_rx_cache)) {
171                         ++matched;
172                         printf("%s\n", filename);
173                 }
174         }
175         return matched;
176 }
177
178 void do_search_file(const string &needle, const char *filename)
179 {
180         int fd = open(filename, O_RDONLY);
181         if (fd == -1) {
182                 perror(filename);
183                 exit(1);
184         }
185
186         // Drop privileges.
187         if (setgid(getgid()) != 0) {
188                 perror("setgid");
189                 exit(EXIT_FAILURE);
190         }
191
192         // steady_clock::time_point start = steady_clock::now();
193         if (access("/", R_OK | X_OK)) {
194                 // We can't find anything, no need to bother...
195                 return;
196         }
197
198         Corpus corpus(fd);
199
200         if (needle.size() < 3) {
201                 // Too short for trigram matching. Apply brute force.
202                 // (We could have searched through all trigrams that matched
203                 // the pattern and done a union of them, but that's a lot of
204                 // work for fairly unclear gain.)
205                 unordered_map<string, bool> access_rx_cache;
206                 uint32_t num_blocks = corpus.get_num_filename_blocks();
207                 for (uint32_t docid = 0; docid < num_blocks; ++docid) {
208                         scan_docid(needle, docid, corpus, &access_rx_cache);
209                 }
210                 return;
211         }
212
213         vector<const Trigram *> trigrams;
214         for (size_t i = 0; i < needle.size() - 2; ++i) {
215                 uint32_t trgm = read_trigram(needle, i);
216                 const Trigram *trgmptr = corpus.find_trigram(trgm);
217                 if (trgmptr == nullptr) {
218                         dprintf("trigram %06x isn't found, we abort the search\n", trgm);
219                         return;
220                 }
221                 trigrams.push_back(trgmptr);
222         }
223         sort(trigrams.begin(), trigrams.end());
224         {
225                 auto last = unique(trigrams.begin(), trigrams.end());
226                 trigrams.erase(last, trigrams.end());
227         }
228         sort(trigrams.begin(), trigrams.end(),
229              [&](const Trigram *a, const Trigram *b) {
230                      return a->num_docids < b->num_docids;
231              });
232
233         vector<uint32_t> in1, in2, out;
234         for (const Trigram *trgmptr : trigrams) {
235                 // uint32_t trgm = trgmptr->trgm;
236                 size_t num = trgmptr->num_docids;
237                 const unsigned char *pldata = corpus.get_compressed_posting_list(trgmptr);
238                 if (in1.empty()) {
239                         in1.resize(num + 128);
240                         p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in1[0]);
241                         in1.resize(num);
242                         dprintf("trigram '%c%c%c' decoded to %zu entries\n", trgm & 0xff,
243                                 (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
244                 } else {
245                         if (num > in1.size() * 100) {
246                                 dprintf("trigram '%c%c%c' has %zu entries, ignoring the rest (will "
247                                         "weed out false positives later)\n",
248                                         trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
249                                 break;
250                         }
251
252                         if (in2.size() < num + 128) {
253                                 in2.resize(num + 128);
254                         }
255                         p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in2[0]);
256
257                         out.clear();
258                         set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num,
259                                          back_inserter(out));
260                         swap(in1, out);
261                         dprintf("trigram '%c%c%c' decoded to %zu entries, %zu left\n",
262                                 trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num,
263                                 in1.size());
264                         if (in1.empty()) {
265                                 dprintf("no matches (intersection list is empty)\n");
266                                 break;
267                         }
268                 }
269         }
270         steady_clock::time_point end = steady_clock::now();
271
272         dprintf("Intersection took %.1f ms. Doing final verification and printing:\n",
273                 1e3 * duration<float>(end - start).count());
274
275         unordered_map<string, bool> access_rx_cache;
276
277         int matched = 0;
278         for (uint32_t docid : in1) {
279                 matched += scan_docid(needle, docid, corpus, &access_rx_cache);
280         }
281         end = steady_clock::now();
282         dprintf("Done in %.1f ms, found %d matches.\n",
283                 1e3 * duration<float>(end - start).count(), matched);
284 }
285
286 int main(int argc, char **argv)
287 {
288         do_search_file(argv[1], "/var/lib/mlocate/plocate.db");
289 }