+class PostingListBuilder {
+public:
+ inline void add_docid(uint32_t docid);
+ void finish();
+
+ string encoded;
+ size_t num_docids = 0;
+
+private:
+ void write_header(uint32_t docid);
+ void append_block();
+
+ vector<uint32_t> pending_deltas;
+
+ uint32_t last_block_end, last_docid = -1;
+};
+
+void PostingListBuilder::add_docid(uint32_t docid)
+{
+ // Deduplicate against the last inserted value, if any.
+ if (docid == last_docid) {
+ return;
+ }
+
+ if (num_docids == 0) {
+ // Very first docid.
+ write_header(docid);
+ ++num_docids;
+ last_block_end = last_docid = docid;
+ return;
+ }
+
+ pending_deltas.push_back(docid - last_docid - 1);
+ last_docid = docid;
+ if (pending_deltas.size() == 128) {
+ append_block();
+ pending_deltas.clear();
+ last_block_end = docid;
+ }
+ ++num_docids;
+}
+
+void PostingListBuilder::finish()
+{
+ if (pending_deltas.empty()) {
+ return;
+ }
+
+ assert(!encoded.empty()); // write_header() should already have run.
+
+ // No interleaving for partial blocks.
+ unsigned char buf[P4NENC_BOUND(128)];
+ unsigned char *end = encode_pfor_single_block<128>(pending_deltas.data(), pending_deltas.size(), /*interleaved=*/false, buf);
+ encoded.append(reinterpret_cast<char *>(buf), reinterpret_cast<char *>(end));
+}
+
+void PostingListBuilder::append_block()
+{
+ unsigned char buf[P4NENC_BOUND(128)];
+ assert(pending_deltas.size() == 128);
+ unsigned char *end = encode_pfor_single_block<128>(pending_deltas.data(), 128, /*interleaved=*/true, buf);
+ encoded.append(reinterpret_cast<char *>(buf), reinterpret_cast<char *>(end));
+}
+
+void PostingListBuilder::write_header(uint32_t docid)
+{
+ unsigned char buf[P4NENC_BOUND(1)];
+ unsigned char *end = write_baseval(docid, buf);
+ encoded.append(reinterpret_cast<char *>(buf), end - buf);
+}
+
+class DatabaseReceiver {
+public:
+ virtual ~DatabaseReceiver() = default;
+ virtual void add_file(string filename) = 0;
+ virtual void flush_block() = 0;
+};
+
+class DictionaryBuilder : public DatabaseReceiver {
+public:
+ DictionaryBuilder(size_t blocks_to_keep, size_t block_size)
+ : blocks_to_keep(blocks_to_keep), block_size(block_size) {}
+ void add_file(string filename) override;
+ void flush_block() override;
+ string train(size_t buf_size);
+
+private:
+ const size_t blocks_to_keep, block_size;
+ string current_block;
+ uint64_t block_num = 0;
+ size_t num_files_in_block = 0;
+
+ std::mt19937 reservoir_rand{ 1234 }; // Fixed seed for reproducibility.
+ bool keep_current_block = true;
+ int64_t slot_for_current_block = -1;
+
+ vector<string> sampled_blocks;
+ vector<size_t> lengths;
+};
+
+void DictionaryBuilder::add_file(string filename)
+{
+ if (keep_current_block) { // Only bother saving the filenames if we're actually keeping the block.
+ if (!current_block.empty()) {
+ current_block.push_back('\0');
+ }
+ current_block += filename;
+ }
+ if (++num_files_in_block == block_size) {
+ flush_block();
+ }
+}
+
+void DictionaryBuilder::flush_block()
+{
+ if (keep_current_block) {
+ if (slot_for_current_block == -1) {
+ lengths.push_back(current_block.size());
+ sampled_blocks.push_back(move(current_block));
+ } else {
+ lengths[slot_for_current_block] = current_block.size();
+ sampled_blocks[slot_for_current_block] = move(current_block);
+ }
+ }
+ current_block.clear();
+ num_files_in_block = 0;
+ ++block_num;
+
+ if (block_num < blocks_to_keep) {
+ keep_current_block = true;
+ slot_for_current_block = -1;
+ } else {
+ // Keep every block with equal probability (reservoir sampling).
+ uint64_t idx = uniform_int_distribution<uint64_t>(0, block_num)(reservoir_rand);
+ keep_current_block = (idx < blocks_to_keep);
+ slot_for_current_block = idx;
+ }
+}
+
+string DictionaryBuilder::train(size_t buf_size)
+{
+ string dictionary_buf;
+ sort(sampled_blocks.begin(), sampled_blocks.end()); // Seemingly important for decompression speed.
+ for (const string &block : sampled_blocks) {
+ dictionary_buf += block;
+ }
+
+ string buf;
+ buf.resize(buf_size);
+ size_t ret = ZDICT_trainFromBuffer(&buf[0], buf_size, dictionary_buf.data(), lengths.data(), lengths.size());
+ dprintf(stderr, "Sampled %zu bytes in %zu blocks, built a dictionary of size %zu\n", dictionary_buf.size(), lengths.size(), ret);
+ buf.resize(ret);
+
+ sampled_blocks.clear();
+ lengths.clear();
+
+ return buf;
+}
+
+class Corpus : public DatabaseReceiver {
+public:
+ Corpus(FILE *outfp, size_t block_size, ZSTD_CDict *cdict)
+ : invindex(new PostingListBuilder *[NUM_TRIGRAMS]), outfp(outfp), block_size(block_size), cdict(cdict)
+ {
+ fill(invindex.get(), invindex.get() + NUM_TRIGRAMS, nullptr);
+ }
+ ~Corpus() override
+ {
+ for (unsigned i = 0; i < NUM_TRIGRAMS; ++i) {
+ delete invindex[i];
+ }
+ }
+
+ void add_file(string filename) override;
+ void flush_block() override;
+
+ vector<uint64_t> filename_blocks;
+ size_t num_files = 0, num_files_in_block = 0, num_blocks = 0;
+ bool seen_trigram(uint32_t trgm)
+ {
+ return invindex[trgm] != nullptr;
+ }
+ PostingListBuilder &get_pl_builder(uint32_t trgm)
+ {
+ if (invindex[trgm] == nullptr) {
+ invindex[trgm] = new PostingListBuilder;
+ }
+ return *invindex[trgm];
+ }
+
+private:
+ unique_ptr<PostingListBuilder *[]> invindex;
+ FILE *outfp;
+ string current_block;
+ string tempbuf;
+ const size_t block_size;
+ ZSTD_CDict *cdict;
+};
+
+void Corpus::add_file(string filename)
+{
+ ++num_files;
+ if (!current_block.empty()) {
+ current_block.push_back('\0');
+ }
+ current_block += filename;
+ if (++num_files_in_block == block_size) {
+ flush_block();
+ }
+}
+
+void Corpus::flush_block()
+{
+ if (current_block.empty()) {
+ return;
+ }
+
+ uint32_t docid = num_blocks;
+
+ // Create trigrams.
+ const char *ptr = current_block.c_str();
+ while (ptr < current_block.c_str() + current_block.size()) {
+ string_view s(ptr);
+ if (s.size() >= 3) {
+ for (size_t j = 0; j < s.size() - 2; ++j) {
+ uint32_t trgm = read_trigram(s, j);
+ get_pl_builder(trgm).add_docid(docid);
+ }
+ }
+ ptr += s.size() + 1;
+ }
+
+ // Compress and add the filename block.
+ filename_blocks.push_back(ftell(outfp));
+ string compressed = zstd_compress(current_block, cdict, &tempbuf);
+ if (fwrite(compressed.data(), compressed.size(), 1, outfp) != 1) {
+ perror("fwrite()");
+ exit(1);
+ }
+
+ current_block.clear();
+ num_files_in_block = 0;
+ ++num_blocks;
+}
+
+string read_cstr(FILE *fp)
+{
+ string ret;
+ for (;;) {
+ int ch = getc(fp);
+ if (ch == -1) {
+ perror("getc");
+ exit(1);
+ }
+ if (ch == 0) {
+ return ret;
+ }
+ ret.push_back(ch);
+ }
+}
+
+void handle_directory(FILE *fp, DatabaseReceiver *receiver)