]> git.sesse.net Git - plocate/blob - parse_trigrams.cpp
Support case-insensitive searches.
[plocate] / parse_trigrams.cpp
1 #include "parse_trigrams.h"
2
3 #include "unique_sort.h"
4
5 #include <string.h>
6 #include <wctype.h>
7
8 using namespace std;
9
10 string print_td(const TrigramDisjunction &td)
11 {
12         if (td.read_trigrams.size() == 0) {
13                 // Before we've done hash lookups (or none matched), so print all alternatives.
14                 if (td.trigram_alternatives.size() == 1) {
15                         return print_trigram(td.trigram_alternatives[0]);
16                 } else {
17                         string ret;
18                         ret = "(";
19                         bool first = true;
20                         for (uint32_t trgm : td.trigram_alternatives) {
21                                 if (!first)
22                                         ret += " OR ";
23                                 ret += print_trigram(trgm);
24                                 first = false;
25                         }
26                         return ret + ")";
27                 }
28         } else {
29                 // Print only those that we actually have in the index.
30                 if (td.read_trigrams.size() == 1) {
31                         return print_trigram(td.read_trigrams[0].first.trgm);
32                 } else {
33                         string ret;
34                         ret = "(";
35                         bool first = true;
36                         for (auto &[trgmptr, len] : td.read_trigrams) {
37                                 if (!first)
38                                         ret += " OR ";
39                                 ret += print_trigram(trgmptr.trgm);
40                                 first = false;
41                         }
42                         return ret + ")";
43                 }
44         }
45 }
46
47 string print_trigram(uint32_t trgm)
48 {
49         char ch[3] = {
50                 char(trgm & 0xff), char((trgm >> 8) & 0xff), char((trgm >> 16) & 0xff)
51         };
52
53         string str = "'";
54         for (unsigned i = 0; i < 3;) {
55                 if (ch[i] == '\\') {
56                         str.push_back('\\');
57                         str.push_back(ch[i]);
58                         ++i;
59                 } else if (int(ch[i]) >= 32 && int(ch[i]) <= 127) {  // Holds no matter whether char is signed or unsigned.
60                         str.push_back(ch[i]);
61                         ++i;
62                 } else {
63                         // See if we have an entire UTF-8 codepoint, and that it's reasonably printable.
64                         mbtowc(nullptr, 0, 0);
65                         wchar_t pwc;
66                         int ret = mbtowc(&pwc, ch + i, 3 - i);
67                         if (ret >= 1 && pwc >= 32) {
68                                 str.append(ch + i, ret);
69                                 i += ret;
70                         } else {
71                                 char buf[16];
72                                 snprintf(buf, sizeof(buf), "\\x{%02x}", (unsigned char)ch[i]);
73                                 str += buf;
74                                 ++i;
75                         }
76                 }
77         }
78         str += "'";
79         return str;
80 }
81
82 uint32_t read_unigram(const string &s, size_t idx)
83 {
84         if (idx < s.size()) {
85                 return (unsigned char)s[idx];
86         } else {
87                 return 0;
88         }
89 }
90
91 uint32_t read_trigram(const string &s, size_t start)
92 {
93         return read_unigram(s, start) | (read_unigram(s, start + 1) << 8) |
94                 (read_unigram(s, start + 2) << 16);
95 }
96
97 struct TrigramState {
98         string buffered;
99         unsigned next_codepoint;
100
101         bool operator<(const TrigramState &other) const
102         {
103                 if (next_codepoint != other.next_codepoint)
104                         return next_codepoint < other.next_codepoint;
105                 return buffered < other.buffered;
106         }
107         bool operator==(const TrigramState &other) const
108         {
109                 return next_codepoint == other.next_codepoint &&
110                         buffered == other.buffered;
111         }
112 };
113
114 void parse_trigrams_ignore_case(const string &needle, vector<TrigramDisjunction> *trigram_groups)
115 {
116         vector<vector<string>> alternatives_for_cp;
117
118         // Parse the needle into Unicode code points, and do inverse case folding
119         // on each to find legal alternatives. This is far from perfect (e.g. ß
120         // will not become ss), but it's generally the best we can do without
121         // involving ICU or the likes.
122         mbtowc(nullptr, 0, 0);
123         const char *ptr = needle.c_str();
124         while (*ptr != '\0') {
125                 wchar_t ch;
126                 int ret = mbtowc(&ch, ptr, strlen(ptr));
127                 if (ret == -1) {
128                         perror(ptr);
129                         exit(1);
130                 }
131
132                 char buf[MB_CUR_MAX];
133                 vector<string> alt;
134                 alt.push_back(string(ptr, ret));
135                 ptr += ret;
136                 if (towlower(ch) != wint_t(ch)) {
137                         ret = wctomb(buf, towlower(ch));
138                         alt.push_back(string(buf, ret));
139                 }
140                 if (towupper(ch) != wint_t(ch) && towupper(ch) != towlower(ch)) {
141                         ret = wctomb(buf, towupper(ch));
142                         alt.push_back(string(buf, ret));
143                 }
144                 alternatives_for_cp.push_back(move(alt));
145         }
146
147         // Now generate all possible byte strings from those code points in order;
148         // e.g., from abc, we'd create a and A, then extend those to ab aB Ab AB,
149         // then abc abC aBc aBC and so on. Since we don't want to have 2^n
150         // (or even 3^n) strings, we only extend them far enough to cover at
151         // least three bytes; this will give us a set of candidate trigrams
152         // (the filename must have at least one of those), and then we can
153         // chop off the first byte, deduplicate states and continue extending
154         // and generating trigram sets.
155         //
156         // There are a few special cases, notably the dotted i (İ), where the
157         // UTF-8 versions of upper and lower case have different number of bytes.
158         // If this happens, we can have combinatorial explosion and get many more
159         // than the normal 8 states. We detect this and simply bomb out; it will
160         // never really happen in real strings, and stopping trigram generation
161         // really only means our pruning of candidates will be less effective.
162         vector<TrigramState> states;
163         states.push_back(TrigramState{ "", 0 });
164
165         for (;;) {
166                 // Extend every state so that it has buffered at least three bytes.
167                 // If this isn't possible, we are done with the string (can generate
168                 // no more trigrams).
169                 bool need_another_pass;
170                 do {
171                         need_another_pass = false;
172                         vector<TrigramState> new_states;
173                         for (const TrigramState &state : states) {
174                                 if (state.buffered.size() >= 3) {
175                                         // No need to extend this further.
176                                         new_states.push_back(state);
177                                         continue;
178                                 }
179                                 if (state.next_codepoint == alternatives_for_cp.size()) {
180                                         // We can't form a complete trigram from this alternative,
181                                         // so we're done.
182                                         return;
183                                 }
184                                 for (const string &rune : alternatives_for_cp[state.next_codepoint]) {
185                                         TrigramState new_state{ state.buffered + rune, state.next_codepoint + 1 };
186                                         if (new_state.buffered.size() < 3) {
187                                                 need_another_pass = true;
188                                         }
189                                         new_states.push_back(move(new_state));
190                                 }
191                         }
192                         states = move(new_states);
193                 } while (need_another_pass);
194
195                 // OK, so now we have a bunch of states, and all of them are at least
196                 // three bytes long. This means we have a complete set of trigrams,
197                 // and the destination filename must contain at least one of them.
198                 // Output those trigrams, cut out the first byte and then deduplicate
199                 // the states before we continue.
200                 vector<uint32_t> trigram_alternatives;
201                 for (TrigramState &state : states) {
202                         trigram_alternatives.push_back(read_trigram(state.buffered, 0));
203                         state.buffered.erase(0, 1);
204                 }
205                 unique_sort(&trigram_alternatives);  // Could have duplicates, although it's rare.
206                 unique_sort(&states);
207
208                 TrigramDisjunction new_pt;
209                 new_pt.remaining_trigrams_to_read = trigram_alternatives.size();
210                 new_pt.trigram_alternatives = move(trigram_alternatives);
211                 new_pt.max_num_docids = 0;
212                 trigram_groups->push_back(move(new_pt));
213
214                 if (states.size() > 100) {
215                         // A completely crazy pattern with lots of those special characters.
216                         // We just give up; this isn't a realistic scenario anyway.
217                         // We already have lots of trigrams that should reduce the amount of
218                         // candidates.
219                         return;
220                 }
221         }
222 }
223
224 void parse_trigrams(const string &needle, bool ignore_case, vector<TrigramDisjunction> *trigram_groups)
225 {
226         if (ignore_case) {
227                 parse_trigrams_ignore_case(needle, trigram_groups);
228                 return;
229         }
230
231         // The case-sensitive case is straightforward.
232         if (needle.size() >= 3) {
233                 for (size_t i = 0; i < needle.size() - 2; ++i) {
234                         uint32_t trgm = read_trigram(needle, i);
235                         TrigramDisjunction new_pt;
236                         new_pt.remaining_trigrams_to_read = 1;
237                         new_pt.trigram_alternatives.push_back(trgm);
238                         new_pt.max_num_docids = 0;
239                         trigram_groups->push_back(move(new_pt));
240                 }
241         }
242 }