-struct SymbolStats
-{
- uint32_t freqs[NUM_SYMS];
- uint32_t cum_freqs[NUM_SYMS + 1];
-
- void clear();
- void calc_cum_freqs();
- void normalize_freqs(uint32_t target_total);
-};
-
-void SymbolStats::clear()
-{
- for (int i=0; i < NUM_SYMS; i++)
- freqs[i] = 0;
-}
-
-void SymbolStats::calc_cum_freqs()
-{
- cum_freqs[0] = 0;
- for (int i=0; i < NUM_SYMS; i++)
- cum_freqs[i+1] = cum_freqs[i] + freqs[i];
-}
-
-void SymbolStats::normalize_freqs(uint32_t target_total)
-{
- uint64_t real_freq[NUM_SYMS + 1]; // hack
-
- assert(target_total >= NUM_SYMS);
-
- calc_cum_freqs();
- uint32_t cur_total = cum_freqs[NUM_SYMS];
-
- if (cur_total == 0) return;
-
- double ideal_cost = 0.0;
- for (int i = 1; i <= NUM_SYMS; i++)
- {
- real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
- if (real_freq[i] > 0)
- ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
- }
-
- OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
-
- // calculate updated freqs and make sure we didn't screw anything up
- assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
- for (int i=0; i < NUM_SYMS; i++) {
- if (freqs[i] == 0)
- assert(cum_freqs[i+1] == cum_freqs[i]);
- else
- assert(cum_freqs[i+1] > cum_freqs[i]);
-
- // calc updated freq
- freqs[i] = cum_freqs[i+1] - cum_freqs[i];
- }
-
- double calc_cost = 0.0;
- for (int i = 1; i <= NUM_SYMS; i++)
- {
- uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
- if (real_freq[i] > 0)
- calc_cost -= real_freq[i] * log2(freq / double(target_total));
- }
-
- static double total_loss = 0.0;
- total_loss += calc_cost - ideal_cost;
- static double total_loss_with_dp = 0.0;
- double optimal_cost = 0.0;
- //total_loss_with_dp += optimal_cost - ideal_cost;
- printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
- ideal_cost, optimal_cost,
- calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
-}
-
-SymbolStats stats[128];
-
-const int luma_mapping[64] = {
- 0, 0, 1, 1, 2, 2, 3, 3,
- 0, 0, 1, 2, 2, 2, 3, 3,
- 1, 1, 2, 2, 2, 3, 3, 3,
- 1, 1, 2, 2, 2, 3, 3, 3,
- 1, 2, 2, 2, 2, 3, 3, 3,
- 2, 2, 2, 2, 3, 3, 3, 3,
- 2, 2, 3, 3, 3, 3, 3, 3,
- 3, 3, 3, 3, 3, 3, 3, 3,
-};
-
-int pick_stats_for(int x, int y)
-{
- return luma_mapping[y * 8 + x];
-}
-
-class RansEncoder {
-public:
- RansEncoder()
- {
- out_buf.reset(new uint8_t[out_max_size]);
- clear();
- }
-
- void init_prob(SymbolStats &s)
- {
- for (int i = 0; i < NUM_SYMS; i++) {
- //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
- RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
- }
- sign_bias = s.cum_freqs[NUM_SYMS];
- }
-
- void clear()
- {
- out_end = out_buf.get() + out_max_size;
- ptr = out_end; // *end* of output buffer
- RansEncInit(&rans);
- }
-
- uint32_t save_block(FILE *codedfp) // Returns number of bytes.
- {
- RansEncFlush(&rans, &ptr);
- //printf("post-flush = %08x\n", rans);
-
- uint32_t num_rans_bytes = out_end - ptr;
- if (num_rans_bytes == last_block.size() &&
- memcmp(last_block.data(), ptr, last_block.size()) == 0) {
- write_varint(0, codedfp);
- clear();
- return 1;
- } else {
- last_block = string((const char *)ptr, num_rans_bytes);
- }
-
- write_varint(num_rans_bytes, codedfp);
- //fwrite(&num_rans_bytes, 1, 4, codedfp);
- fwrite(ptr, 1, num_rans_bytes, codedfp);
-
- //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
-
-
- clear();
-
- //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
- return num_rans_bytes;
- //return num_rans_bytes;
- }
-
- void encode_coeff(short signed_k)
- {
- //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
- unsigned short k = abs(signed_k);
- if (k >= ESCAPE_LIMIT) {
- // Put the coefficient as a 1/(2^12) symbol _before_
- // the 255 coefficient, since the decoder will read the
- // 255 coefficient first.
- RansEncPut(&rans, &ptr, k, 1, prob_bits);
- k = ESCAPE_LIMIT;
- }
- RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
- if (signed_k < 0) {
- rans += sign_bias;
- }
- }
-
-private:
- static constexpr size_t out_max_size = 32 << 20; // 32 MB.
- static constexpr size_t max_num_sign = 1048576; // Way too big. And actually bytes.
-
- unique_ptr<uint8_t[]> out_buf;
- uint8_t *out_end;
- uint8_t *ptr;
- RansState rans;
- RansEncSymbol esyms[NUM_SYMS];
- uint32_t sign_bias;
-
- std::string last_block;
-};
-