]> git.sesse.net Git - plocate/blob - plocate.cpp
Remove an unused macro.
[plocate] / plocate.cpp
1 #include "vp4.h"
2
3 #include <algorithm>
4 #include <arpa/inet.h>
5 #include <chrono>
6 #include <endian.h>
7 #include <fcntl.h>
8 #include <stdio.h>
9 #include <string.h>
10 #include <string>
11 #include <sys/mman.h>
12 #include <unistd.h>
13 #include <unordered_map>
14 #include <vector>
15 #include <zstd.h>
16
17 using namespace std;
18 using namespace std::chrono;
19
20 #define dprintf(...)
21 //#define dprintf(...) fprintf(stderr, __VA_ARGS__);
22
23 static inline uint32_t read_unigram(const string &s, size_t idx)
24 {
25         if (idx < s.size()) {
26                 return (unsigned char)s[idx];
27         } else {
28                 return 0;
29         }
30 }
31
32 static inline uint32_t read_trigram(const string &s, size_t start)
33 {
34         return read_unigram(s, start) | (read_unigram(s, start + 1) << 8) |
35                 (read_unigram(s, start + 2) << 16);
36 }
37
38 bool has_access(const char *filename,
39                 unordered_map<string, bool> *access_rx_cache)
40 {
41         const char *end = strchr(filename + 1, '/');
42         while (end != nullptr) {
43                 string parent_path(filename, end);
44                 auto it = access_rx_cache->find(parent_path);
45                 bool ok;
46                 if (it == access_rx_cache->end()) {
47                         ok = access(parent_path.c_str(), R_OK | X_OK) == 0;
48                         access_rx_cache->emplace(move(parent_path), ok);
49                 } else {
50                         ok = it->second;
51                 }
52                 if (!ok) {
53                         return false;
54                 }
55                 end = strchr(end + 1, '/');
56         }
57
58         return true;
59 }
60
61 struct Trigram {
62         uint32_t trgm;
63         uint32_t num_docids;
64         uint64_t offset;
65 };
66
67 class Corpus {
68 public:
69         Corpus(int fd);
70         ~Corpus();
71         const Trigram *find_trigram(uint32_t trgm) const;
72         const unsigned char *
73         get_compressed_posting_list(const Trigram *trigram) const;
74         string_view get_compressed_filename_block(uint32_t docid) const;
75         size_t get_num_filename_blocks() const;
76
77 private:
78         const int fd;
79         off_t len;
80         const char *data;
81         const uint64_t *filename_offsets;
82         const Trigram *trgm_begin, *trgm_end;
83 };
84
85 Corpus::Corpus(int fd)
86         : fd(fd)
87 {
88         len = lseek(fd, 0, SEEK_END);
89         if (len == -1) {
90                 perror("lseek");
91                 exit(1);
92         }
93         data = (char *)mmap(nullptr, len, PROT_READ, MAP_SHARED, fd, /*offset=*/0);
94         if (data == MAP_FAILED) {
95                 perror("mmap");
96                 exit(1);
97         }
98
99         uint64_t num_trigrams = *(const uint64_t *)data;
100         uint64_t filename_index_offset = *(const uint64_t *)(data + sizeof(uint64_t));
101         filename_offsets = (const uint64_t *)(data + filename_index_offset);
102
103         trgm_begin = (Trigram *)(data + sizeof(uint64_t) * 2);
104         trgm_end = trgm_begin + num_trigrams;
105 }
106
107 Corpus::~Corpus()
108 {
109         munmap((void *)data, len);
110         close(fd);
111 }
112
113 const Trigram *Corpus::find_trigram(uint32_t trgm) const
114 {
115         const Trigram *trgmptr = lower_bound(
116                 trgm_begin, trgm_end, trgm,
117                 [](const Trigram &trgm, uint32_t t) { return trgm.trgm < t; });
118         if (trgmptr == trgm_end || trgmptr->trgm != trgm) {
119                 return nullptr;
120         }
121         return trgmptr;
122 }
123
124 const unsigned char *
125 Corpus::get_compressed_posting_list(const Trigram *trgmptr) const
126 {
127         return reinterpret_cast<const unsigned char *>(data + trgmptr->offset);
128 }
129
130 string_view Corpus::get_compressed_filename_block(uint32_t docid) const
131 {
132         const char *compressed = (const char *)(data + filename_offsets[docid]);
133         size_t compressed_size =
134                 filename_offsets[docid + 1] -
135                 filename_offsets[docid];  // Allowed we have a sentinel block at the end.
136         return { compressed, compressed_size };
137 }
138
139 size_t Corpus::get_num_filename_blocks() const
140 {
141         // The beginning of the filename blocks is the end of the filename index blocks.
142         const uint64_t *filename_offsets_end = (const uint64_t *)(data + filename_offsets[0]);
143
144         // Subtract the sentinel block.
145         return filename_offsets_end - filename_offsets - 1;
146 }
147
148 size_t scan_docid(const string &needle, uint32_t docid, const Corpus &corpus,
149                   unordered_map<string, bool> *access_rx_cache)
150 {
151         string_view compressed = corpus.get_compressed_filename_block(docid);
152         size_t matched = 0;
153
154         string block;
155         block.resize(ZSTD_getFrameContentSize(compressed.data(), compressed.size()) +
156                      1);
157
158         ZSTD_decompress(&block[0], block.size(), compressed.data(),
159                         compressed.size());
160         block[block.size() - 1] = '\0';
161
162         for (const char *filename = block.data();
163              filename != block.data() + block.size();
164              filename += strlen(filename) + 1) {
165                 if (strstr(filename, needle.c_str()) == nullptr) {
166                         continue;
167                 }
168                 if (has_access(filename, access_rx_cache)) {
169                         ++matched;
170                         printf("%s\n", filename);
171                 }
172         }
173         return matched;
174 }
175
176 void do_search_file(const string &needle, const char *filename)
177 {
178         int fd = open(filename, O_RDONLY);
179         if (fd == -1) {
180                 perror(filename);
181                 exit(1);
182         }
183
184         // Drop privileges.
185         if (setgid(getgid()) != 0) {
186                 perror("setgid");
187                 exit(EXIT_FAILURE);
188         }
189
190         // steady_clock::time_point start = steady_clock::now();
191         if (access("/", R_OK | X_OK)) {
192                 // We can't find anything, no need to bother...
193                 return;
194         }
195
196         Corpus corpus(fd);
197
198         if (needle.size() < 3) {
199                 // Too short for trigram matching. Apply brute force.
200                 // (We could have searched through all trigrams that matched
201                 // the pattern and done a union of them, but that's a lot of
202                 // work for fairly unclear gain.)
203                 unordered_map<string, bool> access_rx_cache;
204                 uint32_t num_blocks = corpus.get_num_filename_blocks();
205                 for (uint32_t docid = 0; docid < num_blocks; ++docid) {
206                         scan_docid(needle, docid, corpus, &access_rx_cache);
207                 }
208                 return;
209         }
210
211         vector<const Trigram *> trigrams;
212         for (size_t i = 0; i < needle.size() - 2; ++i) {
213                 uint32_t trgm = read_trigram(needle, i);
214                 const Trigram *trgmptr = corpus.find_trigram(trgm);
215                 if (trgmptr == nullptr) {
216                         dprintf("trigram %06x isn't found, we abort the search\n", trgm);
217                         return;
218                 }
219                 trigrams.push_back(trgmptr);
220         }
221         sort(trigrams.begin(), trigrams.end());
222         {
223                 auto last = unique(trigrams.begin(), trigrams.end());
224                 trigrams.erase(last, trigrams.end());
225         }
226         sort(trigrams.begin(), trigrams.end(),
227              [&](const Trigram *a, const Trigram *b) {
228                      return a->num_docids < b->num_docids;
229              });
230
231         vector<uint32_t> in1, in2, out;
232         for (const Trigram *trgmptr : trigrams) {
233                 // uint32_t trgm = trgmptr->trgm;
234                 size_t num = trgmptr->num_docids;
235                 const unsigned char *pldata = corpus.get_compressed_posting_list(trgmptr);
236                 if (in1.empty()) {
237                         in1.resize(num + 128);
238                         p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in1[0]);
239                         in1.resize(num);
240                         dprintf("trigram '%c%c%c' decoded to %zu entries\n", trgm & 0xff,
241                                 (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
242                 } else {
243                         if (num > in1.size() * 100) {
244                                 dprintf("trigram '%c%c%c' has %zu entries, ignoring the rest (will "
245                                         "weed out false positives later)\n",
246                                         trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
247                                 break;
248                         }
249
250                         if (in2.size() < num + 128) {
251                                 in2.resize(num + 128);
252                         }
253                         p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in2[0]);
254
255                         out.clear();
256                         set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num,
257                                          back_inserter(out));
258                         swap(in1, out);
259                         dprintf("trigram '%c%c%c' decoded to %zu entries, %zu left\n",
260                                 trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num,
261                                 in1.size());
262                         if (in1.empty()) {
263                                 dprintf("no matches (intersection list is empty)\n");
264                                 break;
265                         }
266                 }
267         }
268         steady_clock::time_point end = steady_clock::now();
269
270         dprintf("Intersection took %.1f ms. Doing final verification and printing:\n",
271                 1e3 * duration<float>(end - start).count());
272
273         unordered_map<string, bool> access_rx_cache;
274
275         int matched = 0;
276         for (uint32_t docid : in1) {
277                 matched += scan_docid(needle, docid, corpus, &access_rx_cache);
278         }
279         end = steady_clock::now();
280         dprintf("Done in %.1f ms, found %d matches.\n",
281                 1e3 * duration<float>(end - start).count(), matched);
282 }
283
284 int main(int argc, char **argv)
285 {
286         do_search_file(argv[1], "/var/lib/mlocate/plocate.db");
287 }