]> git.sesse.net Git - plocate/blob - plocate.cpp
Abstract out some details of reading the corpus into a class.
[plocate] / plocate.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <algorithm>
4 #include <unordered_map>
5 #include <string>
6 #include <vector>
7 #include <chrono>
8 #include <unistd.h>
9 #include <fcntl.h>
10 #include <sys/mman.h>
11 #include <arpa/inet.h>
12 #include <endian.h>
13 #include <zstd.h>
14
15 #include "vp4.h"
16
17 #define P4NENC_BOUND(n) ((n+127)/128+(n+32)*sizeof(uint32_t))
18
19 using namespace std;
20 using namespace std::chrono;
21
22 #define dprintf(...)
23 //#define dprintf(...) fprintf(stderr, __VA_ARGS__);
24         
25 static inline uint32_t read_unigram(const string &s, size_t idx)
26 {
27         if (idx < s.size()) {
28                 return (unsigned char)s[idx];
29         } else {
30                 return 0;
31         }
32 }
33
34 static inline uint32_t read_trigram(const string &s, size_t start)
35 {
36         return read_unigram(s, start) |
37                 (read_unigram(s, start + 1) << 8) |
38                 (read_unigram(s, start + 2) << 16);
39 }
40
41 bool has_access(const char *filename, unordered_map<string, bool> *access_rx_cache)
42 {
43         const char *end = strchr(filename + 1, '/');
44         while (end != nullptr) {
45                 string parent_path(filename, end);
46                 auto it = access_rx_cache->find(parent_path);
47                 bool ok;
48                 if (it == access_rx_cache->end()) {
49                         ok = access(parent_path.c_str(), R_OK | X_OK) == 0;
50                         access_rx_cache->emplace(move(parent_path), ok);
51                 } else {
52                         ok = it->second;
53                 }
54                 if (!ok) {
55                         return false;
56                 }
57                 end = strchr(end + 1, '/');
58         }
59
60 #if 0
61         // Check for rx first in the cache; if that isn't true, check R_OK uncached.
62         // This is roughly the same thing as mlocate does.      
63         auto it = access_rx_cache->find(filename);
64         if (it != access_rx_cache->end() && it->second) {
65                 return true;
66         }
67
68         return access(filename, R_OK) == 0;
69 #endif
70         return true;
71 }
72
73 struct Trigram {
74         uint32_t trgm;
75         uint32_t num_docids;
76         uint64_t offset;
77 };
78
79 class Corpus {
80 public:
81         Corpus(int fd);
82         ~Corpus();
83         const Trigram *find_trigram(uint32_t trgm) const;
84         const unsigned char *get_compressed_posting_list(const Trigram *trigram) const;
85         string_view get_compressed_filename_block(uint32_t docid) const;
86
87 private:
88         const int fd;
89         off_t len;
90         const char *data;
91         const uint64_t *filename_offsets;
92         const Trigram *trgm_begin, *trgm_end;
93 };
94
95 Corpus::Corpus(int fd)
96         : fd(fd)
97 {
98         len = lseek(fd, 0, SEEK_END);
99         if (len == -1) {
100                 perror("lseek");
101                 exit(1);
102         }
103         data = (char *)mmap(nullptr, len, PROT_READ, MAP_SHARED, fd, /*offset=*/0);
104         if (data == MAP_FAILED) {
105                 perror("mmap");
106                 exit(1);
107         }
108
109         uint64_t num_trigrams = *(const uint64_t *)data;
110         uint64_t filename_index_offset = *(const uint64_t *)(data + sizeof(uint64_t));
111         filename_offsets = (const uint64_t *)(data + filename_index_offset);
112
113         trgm_begin = (Trigram *)(data + sizeof(uint64_t) * 2);
114         trgm_end = trgm_begin + num_trigrams;
115 }
116
117 Corpus::~Corpus()
118 {
119         munmap((void *)data, len);
120         close(fd);
121 }
122
123 const Trigram *Corpus::find_trigram(uint32_t trgm) const
124 {
125         const Trigram *trgmptr = lower_bound(trgm_begin, trgm_end, trgm, [](const Trigram &trgm, uint32_t t) {
126                 return trgm.trgm < t;
127         });
128         if (trgmptr == trgm_end || trgmptr->trgm != trgm) {
129                 return nullptr;
130         }
131         return trgmptr;
132 }
133
134 const unsigned char *Corpus::get_compressed_posting_list(const Trigram *trgmptr) const
135 {
136         return reinterpret_cast<const unsigned char *>(data + trgmptr->offset);
137 }
138
139 string_view Corpus::get_compressed_filename_block(uint32_t docid) const
140 {
141         const char *compressed = (const char *)(data + filename_offsets[docid]);
142         size_t compressed_size = filename_offsets[docid + 1] - filename_offsets[docid];  // Allowed we have a sentinel block at the end.
143         return {compressed, compressed_size};
144 }
145
146 size_t scan_docid(const string &needle, uint32_t docid, const Corpus &corpus, unordered_map<string, bool> *access_rx_cache)
147 {
148         string_view compressed = corpus.get_compressed_filename_block(docid);
149         size_t matched = 0;
150
151         string block;
152         block.resize(ZSTD_getFrameContentSize(compressed.data(), compressed.size()) + 1);
153
154         ZSTD_decompress(&block[0], block.size(), compressed.data(), compressed.size());
155         block[block.size() - 1] = '\0';
156
157         for (const char *filename = block.data();
158                         filename != block.data() + block.size();
159                         filename += strlen(filename) + 1) {
160                 if (strstr(filename, needle.c_str()) == nullptr) {
161                         continue;
162                 }
163                 if (has_access(filename, access_rx_cache)) {
164                         ++matched;
165                         printf("%s\n", filename);
166                 }
167         }
168         return matched;
169 }
170
171 void do_search_file(const string &needle, const char *filename)
172 {
173         int fd = open(filename, O_RDONLY);
174         if (fd == -1) {
175                 perror(filename);
176                 exit(1);
177         }
178
179         // Drop privileges.
180         if (setgid(getgid()) != 0) {
181                 perror("setgid");
182                 exit(EXIT_FAILURE);
183         }
184
185         //steady_clock::time_point start = steady_clock::now();
186         if (access("/", R_OK | X_OK)) {
187                 // We can't find anything, no need to bother...
188                 return;
189         }
190
191         Corpus corpus(fd);
192
193         vector<const Trigram *> trigrams;
194         for (size_t i = 0; i < needle.size() - 2; ++i) {
195                 uint32_t trgm = read_trigram(needle, i);
196                 const Trigram *trgmptr = corpus.find_trigram(trgm);
197                 if (trgmptr == nullptr) {
198                         dprintf("trigram %06x isn't found, we abort the search\n", trgm);
199                         return;
200                 }
201                 trigrams.push_back(trgmptr);
202         }
203         sort(trigrams.begin(), trigrams.end());
204         {
205                 auto last = unique(trigrams.begin(), trigrams.end());
206                 trigrams.erase(last, trigrams.end());
207         }
208         sort(trigrams.begin(), trigrams.end(), [&](const Trigram *a, const Trigram *b) {
209                 return a->num_docids < b->num_docids;
210         });
211
212         vector<uint32_t> in1, in2, out;
213         for (const Trigram *trgmptr : trigrams) {
214                 //uint32_t trgm = trgmptr->trgm;
215                 size_t num = trgmptr->num_docids;
216                 const unsigned char *pldata = corpus.get_compressed_posting_list(trgmptr);
217                 if (in1.empty()) {
218                         in1.resize(num + 128);
219                         p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in1[0]);
220                         in1.resize(num);
221                         dprintf("trigram '%c%c%c' decoded to %zu entries\n", trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
222                 } else {
223                         if (num > in1.size() * 100) {
224                                 dprintf("trigram '%c%c%c' has %zu entries, ignoring the rest (will weed out false positives later)\n",
225                                         trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num);
226                                 break;
227                         }
228
229                         if (in2.size() < num + 128) {
230                                 in2.resize(num + 128);
231                         }
232                         p4nd1dec128v32(const_cast<unsigned char *>(pldata), num, &in2[0]);
233
234                         out.clear();
235                         set_intersection(in1.begin(), in1.end(), in2.begin(), in2.begin() + num, back_inserter(out));
236                         swap(in1, out);
237                         dprintf("trigram '%c%c%c' decoded to %zu entries, %zu left\n", trgm & 0xff, (trgm >> 8) & 0xff, (trgm >> 16) & 0xff, num, in1.size());
238                         if (in1.empty()) {
239                                 dprintf("no matches (intersection list is empty)\n");
240                                 break;
241                         }
242                 }
243         }
244         steady_clock::time_point end = steady_clock::now();
245
246         dprintf("Intersection took %.1f ms. Doing final verification and printing:\n",
247                 1e3 * duration<float>(end - start).count());
248
249         unordered_map<string, bool> access_rx_cache;
250
251         int matched = 0;
252         for (uint32_t docid : in1) {
253                 matched += scan_docid(needle, docid, corpus, &access_rx_cache);
254         }
255         end = steady_clock::now();
256         dprintf("Done in %.1f ms, found %d matches.\n",
257                 1e3 * duration<float>(end - start).count(), matched);
258 }
259
260 int main(int argc, char **argv)
261 {
262         //do_search_file(argv[1], "all.trgm");
263         do_search_file(argv[1], "/var/lib/mlocate/plocate.db");
264 }