]> git.sesse.net Git - plocate/blob - parse_trigrams.cpp
Proof-of-concept of using ICU for strength-zero searches.
[plocate] / parse_trigrams.cpp
1 #include "parse_trigrams.h"
2
3 #include "dprintf.h"
4 #include "unique_sort.h"
5
6 #include <assert.h>
7 #include <memory>
8 #include <string.h>
9 #include <wctype.h>
10
11 #include <unicode/coll.h>
12
13 using namespace std;
14
15 string print_td(const TrigramDisjunction &td)
16 {
17         if (td.read_trigrams.size() == 0) {
18                 // Before we've done hash lookups (or none matched), so print all alternatives.
19                 if (td.trigram_alternatives.size() == 1) {
20                         return print_trigram(td.trigram_alternatives[0]);
21                 } else {
22                         string ret;
23                         ret = "(";
24                         bool first = true;
25                         for (uint32_t trgm : td.trigram_alternatives) {
26                                 if (!first)
27                                         ret += " OR ";
28                                 ret += print_trigram(trgm);
29                                 first = false;
30                         }
31                         return ret + ")";
32                 }
33         } else {
34                 // Print only those that we actually have in the index.
35                 if (td.read_trigrams.size() == 1) {
36                         return print_trigram(td.read_trigrams[0].first.trgm);
37                 } else {
38                         string ret;
39                         ret = "(";
40                         bool first = true;
41                         for (auto &[trgmptr, len] : td.read_trigrams) {
42                                 if (!first)
43                                         ret += " OR ";
44                                 ret += print_trigram(trgmptr.trgm);
45                                 first = false;
46                         }
47                         return ret + ")";
48                 }
49         }
50 }
51
52 string print_trigram(uint32_t trgm)
53 {
54         char ch[3] = {
55                 char(trgm & 0xff), char((trgm >> 8) & 0xff), char((trgm >> 16) & 0xff)
56         };
57
58         string str = "'";
59         for (unsigned i = 0; i < 3;) {
60                 if (ch[i] == '\\') {
61                         str.push_back('\\');
62                         str.push_back(ch[i]);
63                         ++i;
64                 } else if (int(ch[i]) >= 32 && int(ch[i]) <= 127) {  // Holds no matter whether char is signed or unsigned.
65                         str.push_back(ch[i]);
66                         ++i;
67                 } else {
68                         // See if we have an entire UTF-8 codepoint, and that it's reasonably printable.
69                         mbtowc(nullptr, 0, 0);
70                         wchar_t pwc;
71                         int ret = mbtowc(&pwc, ch + i, 3 - i);
72                         if (ret >= 1 && pwc >= 32) {
73                                 str.append(ch + i, ret);
74                                 i += ret;
75                         } else {
76                                 char buf[16];
77                                 snprintf(buf, sizeof(buf), "\\x{%02x}", (unsigned char)ch[i]);
78                                 str += buf;
79                                 ++i;
80                         }
81                 }
82         }
83         str += "'";
84         return str;
85 }
86
87 pair<uint32_t, size_t> read_unigram(const string &s, size_t start)
88 {
89         if (start >= s.size()) {
90                 return { PREMATURE_END_UNIGRAM, 0 };
91         }
92         if (s[start] == '\\') {
93                 // Escaped character.
94                 if (start + 1 >= s.size()) {
95                         return { PREMATURE_END_UNIGRAM, 1 };
96                 } else {
97                         return { (unsigned char)s[start + 1], 2 };
98                 }
99         }
100         if (s[start] == '*' || s[start] == '?') {
101                 // Wildcard.
102                 return { WILDCARD_UNIGRAM, 1 };
103         }
104         if (s[start] == '[') {
105                 // Character class; search to find the end.
106                 size_t len = 1;
107                 if (start + len >= s.size()) {
108                         return { PREMATURE_END_UNIGRAM, len };
109                 }
110                 if (s[start + len] == '!') {
111                         ++len;
112                 }
113                 if (start + len >= s.size()) {
114                         return { PREMATURE_END_UNIGRAM, len };
115                 }
116                 if (s[start + len] == ']') {
117                         ++len;
118                 }
119                 for (;;) {
120                         if (start + len >= s.size()) {
121                                 return { PREMATURE_END_UNIGRAM, len };
122                         }
123                         if (s[start + len] == ']') {
124                                 return { WILDCARD_UNIGRAM, len + 1 };
125                         }
126                         ++len;
127                 }
128         }
129
130         // Regular letter.
131         return { (unsigned char)s[start], 1 };
132 }
133
134 uint32_t read_trigram(const string &s, size_t start)
135 {
136         pair<uint32_t, size_t> u1 = read_unigram(s, start);
137         if (u1.first == WILDCARD_UNIGRAM || u1.first == PREMATURE_END_UNIGRAM) {
138                 return u1.first;
139         }
140         pair<uint32_t, size_t> u2 = read_unigram(s, start + u1.second);
141         if (u2.first == WILDCARD_UNIGRAM || u2.first == PREMATURE_END_UNIGRAM) {
142                 return u2.first;
143         }
144         pair<uint32_t, size_t> u3 = read_unigram(s, start + u1.second + u2.second);
145         if (u3.first == WILDCARD_UNIGRAM || u3.first == PREMATURE_END_UNIGRAM) {
146                 return u3.first;
147         }
148         return u1.first | (u2.first << 8) | (u3.first << 16);
149 }
150
151 struct TrigramState {
152         string buffered;
153         unsigned next_codepoint;
154
155         bool operator<(const TrigramState &other) const
156         {
157                 if (next_codepoint != other.next_codepoint)
158                         return next_codepoint < other.next_codepoint;
159                 return buffered < other.buffered;
160         }
161         bool operator==(const TrigramState &other) const
162         {
163                 return next_codepoint == other.next_codepoint &&
164                         buffered == other.buffered;
165         }
166 };
167
168 void parse_trigrams_ignore_case(const string &needle, vector<TrigramDisjunction> *trigram_groups)
169 {
170         vector<vector<string>> alternatives_for_cp;
171
172         // Parse the needle into Unicode code points, and do inverse case folding
173         // on each to find legal alternatives. This is far from perfect (e.g. ß
174         // will not become ss), but it's generally the best we can do without
175         // involving ICU or the likes.
176         mbtowc(nullptr, 0, 0);
177         const char *ptr = needle.c_str();
178         unique_ptr<char[]> buf(new char[MB_CUR_MAX]);
179         while (*ptr != '\0') {
180                 wchar_t ch;
181                 int ret = mbtowc(&ch, ptr, strlen(ptr));
182                 if (ret == -1) {
183                         perror(ptr);
184                         exit(1);
185                 }
186
187                 vector<string> alt;
188                 alt.push_back(string(ptr, ret));
189                 ptr += ret;
190                 if (towlower(ch) != wint_t(ch)) {
191                         ret = wctomb(buf.get(), towlower(ch));
192                         alt.push_back(string(buf.get(), ret));
193                 }
194                 if (towupper(ch) != wint_t(ch) && towupper(ch) != towlower(ch)) {
195                         ret = wctomb(buf.get(), towupper(ch));
196                         alt.push_back(string(buf.get(), ret));
197                 }
198                 alternatives_for_cp.push_back(move(alt));
199         }
200
201         // Now generate all possible byte strings from those code points in order;
202         // e.g., from abc, we'd create a and A, then extend those to ab aB Ab AB,
203         // then abc abC aBc aBC and so on. Since we don't want to have 2^n
204         // (or even 3^n) strings, we only extend them far enough to cover at
205         // least three bytes; this will give us a set of candidate trigrams
206         // (the filename must have at least one of those), and then we can
207         // chop off the first byte, deduplicate states and continue extending
208         // and generating trigram sets.
209         //
210         // There are a few special cases, notably the dotted i (İ), where the
211         // UTF-8 versions of upper and lower case have different number of bytes.
212         // If this happens, we can have combinatorial explosion and get many more
213         // than the normal 8 states. We detect this and simply bomb out; it will
214         // never really happen in real strings, and stopping trigram generation
215         // really only means our pruning of candidates will be less effective.
216         vector<TrigramState> states;
217         states.push_back(TrigramState{ "", 0 });
218
219         for (;;) {
220                 // Extend every state so that it has buffered at least three bytes.
221                 // If this isn't possible, we are done with the string (can generate
222                 // no more trigrams).
223                 bool need_another_pass;
224                 do {
225                         need_another_pass = false;
226                         vector<TrigramState> new_states;
227                         for (const TrigramState &state : states) {
228                                 if (read_trigram(state.buffered, 0) != PREMATURE_END_UNIGRAM) {
229                                         // No need to extend this further.
230                                         new_states.push_back(state);
231                                         continue;
232                                 }
233                                 if (state.next_codepoint == alternatives_for_cp.size()) {
234                                         // We can't form a complete trigram from this alternative,
235                                         // so we're done.
236                                         return;
237                                 }
238                                 for (const string &rune : alternatives_for_cp[state.next_codepoint]) {
239                                         TrigramState new_state{ state.buffered + rune, state.next_codepoint + 1 };
240                                         if (read_trigram(state.buffered, 0) == PREMATURE_END_UNIGRAM) {
241                                                 need_another_pass = true;
242                                         }
243                                         new_states.push_back(move(new_state));
244                                 }
245                         }
246                         states = move(new_states);
247                 } while (need_another_pass);
248
249                 // OK, so now we have a bunch of states, and all of them are at least
250                 // three bytes long. This means we have a complete set of trigrams,
251                 // and the destination filename must contain at least one of them.
252                 // Output those trigrams, cut out the first byte and then deduplicate
253                 // the states before we continue.
254                 bool any_wildcard = false;
255                 vector<uint32_t> trigram_alternatives;
256                 for (TrigramState &state : states) {
257                         trigram_alternatives.push_back(read_trigram(state.buffered, 0));
258                         state.buffered.erase(0, read_unigram(state.buffered, 0).second);
259                         assert(trigram_alternatives.back() != PREMATURE_END_UNIGRAM);
260                         if (trigram_alternatives.back() == WILDCARD_UNIGRAM) {
261                                 // If any of the candidates are wildcards, we need to drop the entire OR group.
262                                 // (Most likely, all of them would be anyway.) We need to keep stripping out
263                                 // the first unigram from each state.
264                                 any_wildcard = true;
265                         }
266                 }
267                 unique_sort(&trigram_alternatives);  // Could have duplicates, although it's rare.
268                 unique_sort(&states);
269
270                 if (!any_wildcard) {
271                         TrigramDisjunction new_pt;
272                         new_pt.remaining_trigrams_to_read = trigram_alternatives.size();
273                         new_pt.trigram_alternatives = move(trigram_alternatives);
274                         new_pt.max_num_docids = 0;
275                         trigram_groups->push_back(move(new_pt));
276                 }
277
278                 if (states.size() > 100) {
279                         // A completely crazy pattern with lots of those special characters.
280                         // We just give up; this isn't a realistic scenario anyway.
281                         // We already have lots of trigrams that should reduce the amount of
282                         // candidates.
283                         return;
284                 }
285         }
286 }
287
288 void parse_trigrams(const string &needle, bool ignore_case, vector<TrigramDisjunction> *trigram_groups)
289 {
290         // ICU...
291         string needle2;
292         for (char ch : needle) {
293                 if (ch != '*') needle2.push_back(ch);
294         }
295
296         dprintf("posix locale = %s\n", setlocale(LC_CTYPE, NULL));
297         icu::Locale locale = icu::Locale::createCanonical(setlocale(LC_CTYPE, NULL));
298         dprintf("icu locale = %s\n", locale.getName());
299         UErrorCode status = U_ZERO_ERROR;
300         icu::Collator *coll = icu::Collator::createInstance(locale, status);
301         // FIXME check for failure
302         uint8_t needlebuf[1024];  // FIXME
303         coll->setStrength(icu::Collator::PRIMARY);
304         int len = coll->getSortKey(icu::UnicodeString::fromUTF8(needle2), needlebuf, sizeof(needlebuf));
305         dprintf("needlelen = %d (from ascii %zu, needle '%s')\n", len, needle2.size(), needle2.c_str());
306         for (size_t i = 0; i < len; ++i) {
307                 dprintf(" %02x", needlebuf[i]);
308         }
309         dprintf("\n");
310         for (size_t i = 0; i < len - 3; ++i) {
311                 uint32_t trgm = needlebuf[i] | (needlebuf[i + 1] << 8) | (needlebuf[i + 2] << 16);
312                 dprintf("trgm = %06x\n", trgm);
313
314                 TrigramDisjunction new_pt;
315                 new_pt.remaining_trigrams_to_read = 1;
316                 new_pt.trigram_alternatives.push_back(trgm);
317                 new_pt.max_num_docids = 0;
318                 trigram_groups->push_back(move(new_pt));
319         }
320         return;
321
322         if (ignore_case) {
323                 parse_trigrams_ignore_case(needle, trigram_groups);
324                 return;
325         }
326
327         // The case-sensitive case is straightforward.
328         for (size_t i = 0; i < needle.size(); i += read_unigram(needle, i).second) {
329                 uint32_t trgm = read_trigram(needle, i);
330                 if (trgm == WILDCARD_UNIGRAM || trgm == PREMATURE_END_UNIGRAM) {
331                         // Invalid trigram, so skip.
332                         continue;
333                 }
334
335                 TrigramDisjunction new_pt;
336                 new_pt.remaining_trigrams_to_read = 1;
337                 new_pt.trigram_alternatives.push_back(trgm);
338                 new_pt.max_num_docids = 0;
339                 trigram_groups->push_back(move(new_pt));
340         }
341 }