+ unsigned char *end = write_baseval(docid, buf);
+ encoded.append(reinterpret_cast<char *>(buf), end - buf);
+}
+
+class DatabaseReceiver {
+public:
+ virtual ~DatabaseReceiver() = default;
+ virtual void add_file(string filename) = 0;
+ virtual void flush_block() = 0;
+};
+
+class DictionaryBuilder : public DatabaseReceiver {
+public:
+ DictionaryBuilder(size_t blocks_to_keep, size_t block_size)
+ : blocks_to_keep(blocks_to_keep), block_size(block_size) {}
+ void add_file(string filename) override;
+ void flush_block() override;
+ string train(size_t buf_size);
+
+private:
+ const size_t blocks_to_keep, block_size;
+ string current_block;
+ uint64_t block_num = 0;
+ size_t num_files_in_block = 0;
+
+ std::mt19937 reservoir_rand{ 1234 }; // Fixed seed for reproducibility.
+ bool keep_current_block = true;
+ int64_t slot_for_current_block = -1;
+
+ vector<string> sampled_blocks;
+ vector<size_t> lengths;
+};
+
+void DictionaryBuilder::add_file(string filename)
+{
+ if (keep_current_block) { // Only bother saving the filenames if we're actually keeping the block.
+ if (!current_block.empty()) {
+ current_block.push_back('\0');
+ }
+ current_block += filename;
+ }
+ if (++num_files_in_block == block_size) {
+ flush_block();
+ }
+}
+
+void DictionaryBuilder::flush_block()
+{
+ if (keep_current_block) {
+ if (slot_for_current_block == -1) {
+ lengths.push_back(current_block.size());
+ sampled_blocks.push_back(move(current_block));
+ } else {
+ lengths[slot_for_current_block] = current_block.size();
+ sampled_blocks[slot_for_current_block] = move(current_block);
+ }
+ }
+ current_block.clear();
+ num_files_in_block = 0;
+ ++block_num;
+
+ if (block_num < blocks_to_keep) {
+ keep_current_block = true;
+ slot_for_current_block = -1;
+ } else {
+ // Keep every block with equal probability (reservoir sampling).
+ uint64_t idx = uniform_int_distribution<uint64_t>(0, block_num)(reservoir_rand);
+ keep_current_block = (idx < blocks_to_keep);
+ slot_for_current_block = idx;
+ }
+}
+
+string DictionaryBuilder::train(size_t buf_size)
+{
+ string dictionary_buf;
+ sort(sampled_blocks.begin(), sampled_blocks.end()); // Seemingly important for decompression speed.
+ for (const string &block : sampled_blocks) {
+ dictionary_buf += block;
+ }
+
+ string buf;
+ buf.resize(buf_size);
+ size_t ret = ZDICT_trainFromBuffer(&buf[0], buf_size, dictionary_buf.data(), lengths.data(), lengths.size());
+ if (ret == size_t(-1)) {
+ return "";
+ }
+ dprintf("Sampled %zu bytes in %zu blocks, built a dictionary of size %zu\n", dictionary_buf.size(), lengths.size(), ret);
+ buf.resize(ret);
+
+ sampled_blocks.clear();
+ lengths.clear();
+
+ return buf;