--- /dev/null
+#include <stdio.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <assert.h>
+#include <math.h>
+
+//#include "ryg_rans/rans64.h"
+#include "ryg_rans/rans_byte.h"
+
+#include <memory>
+
+#define WIDTH 1280
+#define HEIGHT 720
+#define NUM_SYMS 256
+#define ESCAPE_LIMIT (NUM_SYMS - 1)
+
+using namespace std;
+
+void fdct_int32(short *const In);
+void idct_int32(short *const In);
+
+unsigned char pix[WIDTH * HEIGHT];
+short coeff[WIDTH * HEIGHT];
+
+static const unsigned char std_luminance_quant_tbl[64] = {
+#if 0
+ 16, 11, 10, 16, 24, 40, 51, 61,
+ 12, 12, 14, 19, 26, 58, 60, 55,
+ 14, 13, 16, 24, 40, 57, 69, 56,
+ 14, 17, 22, 29, 51, 87, 80, 62,
+ 18, 22, 37, 56, 68, 109, 103, 77,
+ 24, 35, 55, 64, 81, 104, 113, 92,
+ 49, 64, 78, 87, 103, 121, 120, 101,
+ 72, 92, 95, 98, 112, 100, 103, 99
+#endif
+ 16, 16, 19, 22, 26, 27, 29, 34,
+ 16, 16, 22, 24, 27, 29, 34, 37,
+ 19, 22, 26, 27, 29, 34, 34, 38,
+ 22, 22, 26, 27, 29, 34, 37, 40,
+ 22, 26, 27, 29, 32, 35, 40, 48,
+ 26, 27, 29, 32, 35, 40, 48, 58,
+ 26, 27, 29, 34, 38, 46, 56, 69,
+ 27, 29, 35, 38, 46, 56, 69, 83
+};
+
+struct SymbolStats
+{
+ uint32_t freqs[NUM_SYMS];
+ uint32_t cum_freqs[NUM_SYMS + 1];
+
+ void clear();
+ void count_freqs(uint8_t const* in, size_t nbytes);
+ void calc_cum_freqs();
+ void normalize_freqs(uint32_t target_total);
+};
+
+void SymbolStats::clear()
+{
+ for (int i=0; i < NUM_SYMS; i++)
+ freqs[i] = 0;
+}
+
+void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
+{
+ clear();
+
+ for (size_t i=0; i < nbytes; i++)
+ freqs[in[i]]++;
+}
+
+void SymbolStats::calc_cum_freqs()
+{
+ cum_freqs[0] = 0;
+ for (int i=0; i < NUM_SYMS; i++)
+ cum_freqs[i+1] = cum_freqs[i] + freqs[i];
+}
+
+void SymbolStats::normalize_freqs(uint32_t target_total)
+{
+ assert(target_total >= NUM_SYMS);
+
+ calc_cum_freqs();
+ uint32_t cur_total = cum_freqs[NUM_SYMS];
+
+ if (cur_total == 0) return;
+
+ // resample distribution based on cumulative freqs
+ for (int i = 1; i <= NUM_SYMS; i++)
+ cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
+
+ // if we nuked any non-0 frequency symbol to 0, we need to steal
+ // the range to make the frequency nonzero from elsewhere.
+ //
+ // this is not at all optimal, i'm just doing the first thing that comes to mind.
+ for (int i=0; i < NUM_SYMS; i++) {
+ if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
+ // symbol i was set to zero freq
+
+ // find best symbol to steal frequency from (try to steal from low-freq ones)
+ uint32_t best_freq = ~0u;
+ int best_steal = -1;
+ for (int j=0; j < NUM_SYMS; j++) {
+ uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
+ if (freq > 1 && freq < best_freq) {
+ best_freq = freq;
+ best_steal = j;
+ }
+ }
+ assert(best_steal != -1);
+
+ // and steal from it!
+ if (best_steal < i) {
+ for (int j = best_steal + 1; j <= i; j++)
+ cum_freqs[j]--;
+ } else {
+ assert(best_steal > i);
+ for (int j = i + 1; j <= best_steal; j++)
+ cum_freqs[j]++;
+ }
+ }
+ }
+
+ // calculate updated freqs and make sure we didn't screw anything up
+ assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
+ for (int i=0; i < NUM_SYMS; i++) {
+ if (freqs[i] == 0)
+ assert(cum_freqs[i+1] == cum_freqs[i]);
+ else
+ assert(cum_freqs[i+1] > cum_freqs[i]);
+
+ // calc updated freq
+ freqs[i] = cum_freqs[i+1] - cum_freqs[i];
+ }
+}
+
+SymbolStats stats[64];
+
+int pick_stats_for(int y, int x)
+{
+ //return std::min<int>(hypot(x, y), 7);
+ return std::min<int>(x + y, 7);
+ //if (x + y >= 7) return 7;
+ //return x + y;
+// return y * 8 + x;
+#if 0
+ if (y == 0 && x == 0) {
+ return 0;
+ } else {
+ return 1;
+ }
+#endif
+}
+
+
+void write_varint(int x, FILE *fp)
+{
+ while (x >= 128) {
+ putc((x & 0x7f) | 0x80, fp);
+ x >>= 7;
+ }
+ putc(x, fp);
+}
+
+class RansEncoder {
+public:
+ static constexpr uint32_t prob_bits = 12;
+ static constexpr uint32_t prob_scale = 1 << prob_bits;
+
+ RansEncoder()
+ {
+ out_buf.reset(new uint8_t[out_max_size]);
+ sign_buf.reset(new uint8_t[max_num_sign]);
+ clear();
+ }
+
+ void init_prob(const SymbolStats &s)
+ {
+ for (int i = 0; i < NUM_SYMS; i++) {
+ //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
+ RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits);
+ }
+ }
+
+ void clear()
+ {
+ out_end = out_buf.get() + out_max_size;
+ sign_end = sign_buf.get() + max_num_sign;
+ ptr = out_end; // *end* of output buffer
+ sign_ptr = sign_end; // *end* of output buffer
+ RansEncInit(&rans);
+ free_sign_bits = 0;
+ }
+
+ uint32_t save_block(FILE *codedfp) // Returns number of bytes.
+ {
+ RansEncFlush(&rans, &ptr);
+ //printf("post-flush = %08x\n", rans);
+
+ uint32_t num_rans_bytes = out_end - ptr;
+ write_varint(num_rans_bytes, codedfp);
+ //fwrite(&num_rans_bytes, 1, 4, codedfp);
+ fwrite(ptr, 1, num_rans_bytes, codedfp);
+
+ //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
+
+ if (free_sign_bits > 0) {
+ *sign_ptr <<= free_sign_bits;
+ }
+
+#if 1
+ uint32_t num_sign_bytes = sign_end - sign_ptr;
+ write_varint((num_sign_bytes << 3) | free_sign_bits, codedfp);
+ fwrite(sign_ptr, 1, num_sign_bytes, codedfp);
+#endif
+
+ clear();
+
+ //printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
+ return num_rans_bytes + num_sign_bytes;
+ //return num_rans_bytes;
+ }
+
+ void encode_coeff(short signed_k)
+ {
+ //printf("encoding coeff %d\n", signed_k);
+ short k = abs(signed_k);
+ if (k >= ESCAPE_LIMIT) {
+ // Put the coefficient as a 1/(2^12) symbol _before_
+ // the 255 coefficient, since the decoder will read the
+ // 255 coefficient first.
+ RansEncPut(&rans, &ptr, k, 1, prob_bits);
+ k = ESCAPE_LIMIT;
+ }
+ if (k != 0) {
+#if 1
+ if (free_sign_bits == 0) {
+ --sign_ptr;
+ *sign_ptr = 0;
+ free_sign_bits = 8;
+ }
+ *sign_ptr <<= 1;
+ *sign_ptr |= (signed_k < 0);
+ --free_sign_bits;
+#else
+ RansEncPut(&rans, &ptr, (k < 0) ? prob_scale / 2 : 0, prob_scale / 2, prob_bits);
+#endif
+ }
+ RansEncPutSymbol(&rans, &ptr, &esyms[k]);
+ }
+
+private:
+ static constexpr size_t out_max_size = 32 << 20; // 32 MB.
+ static constexpr size_t max_num_sign = 1048576; // Way too big. And actually bytes.
+
+ unique_ptr<uint8_t[]> out_buf, sign_buf;
+ uint8_t *out_end, *sign_end;
+ uint8_t *ptr, *sign_ptr;
+ RansState rans;
+ size_t free_sign_bits;
+ RansEncSymbol esyms[NUM_SYMS];
+};
+
+int main(void)
+{
+ FILE *fp = fopen("pic.pgm", "rb");
+ fread(pix, 1, WIDTH * HEIGHT, fp);
+ fclose(fp);
+
+ double sum_sq_err = 0.0;
+
+ for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+ for (unsigned xb = 0; xb < WIDTH; xb += 8) {
+ // Read one block
+ short in[64];
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ in[y * 8 + x] = pix[(yb + y) * WIDTH + (xb + x)];
+ }
+ }
+
+ // FDCT it
+ fdct_int32(in);
+
+ //constexpr int extra_deadzone = 64;
+ constexpr int extra_deadzone = 4;
+
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ short *c = &in[y * 8 + x];
+ *c <<= 3;
+ *c = copysign(std::max(abs(*c) - extra_deadzone, 0), *c);
+ //*c /= std_luminance_quant_tbl[y * 8 + x];
+ *c = (int)(double(*c) / std_luminance_quant_tbl[y * 8 + x]);
+#if 0
+ if (x != 0 || y != 0) {
+ int ss = 1;
+ if (::abs(int(*c)) <= ss) {
+ *c = 0; // eeh
+ } else if (*c > 0) {
+ *c -= ss; // eeh
+ } else {
+ *c += ss; // eeh
+ }
+ }
+#endif
+ }
+ }
+
+ // Store it
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ coeff[(yb + y) * WIDTH + (xb + x)] = in[y * 8 + x];
+ }
+ }
+
+ // and back
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ in[y * 8 + x] *= std_luminance_quant_tbl[y * 8 + x];
+ if (in[y * 8 + x] > 0) {
+ in[y * 8 + x] += extra_deadzone;
+ } else if (in[y * 8 + x] < 0) {
+ in[y * 8 + x] -= extra_deadzone;
+ }
+ in[y * 8 + x] >>= 3;
+ }
+ }
+
+ idct_int32(in);
+
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ int k = in[y * 8 + x];
+ if (k < 0) k = 0;
+ if (k > 255) k = 255;
+ uint8_t *ptr = &pix[(yb + y) * WIDTH + (xb + x)];
+ sum_sq_err += (*ptr - k) * (*ptr - k);
+ *ptr = k;
+ }
+ }
+ }
+ }
+ double mse = sum_sq_err / double(WIDTH * HEIGHT);
+ double psnr_db = 20 * log10(255.0 / sqrt(mse));
+ printf("psnr = %.2f dB\n", psnr_db);
+
+ // DC coefficient pred from the right to left
+ for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+ for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
+ coeff[yb * WIDTH + xb] -= coeff[yb * WIDTH + (xb + 8)];
+ }
+ }
+
+ fp = fopen("reconstructed.pgm", "wb");
+ fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
+ fwrite(pix, 1, WIDTH * HEIGHT, fp);
+ fclose(fp);
+
+ // For each coefficient, make some tables.
+ size_t extra_bits = 0, sign_bits = 0;
+ for (unsigned i = 0; i < 64; ++i) {
+ stats[i].clear();
+ }
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ SymbolStats &s = stats[pick_stats_for(x, y)];
+
+ for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+ for (unsigned xb = 0; xb < WIDTH; xb += 8) {
+ short k = abs(coeff[(yb + y) * WIDTH + (xb + x)]);
+ if (k >= ESCAPE_LIMIT) {
+ //printf("coeff (%d,%d) had value %d\n", y, x, k);
+ k = ESCAPE_LIMIT;
+ extra_bits += 12; // escape this one
+ }
+ //if (y != 0 || x != 0) ++sign_bits;
+ if (k != 0) ++sign_bits;
+ ++s.freqs[k];
+ }
+ }
+ }
+ }
+ for (unsigned i = 0; i < 64; ++i) {
+#if 0
+ printf("coeff %i:", i);
+ for (unsigned j = 0; j <= ESCAPE_LIMIT; ++j) {
+ printf(" %d", stats[i].freqs[j]);
+ }
+ printf("\n");
+#endif
+ stats[i].normalize_freqs(RansEncoder::prob_scale);
+ }
+
+ FILE *codedfp = fopen("coded.dat", "wb");
+ if (codedfp == nullptr) {
+ perror("coded.dat");
+ exit(1);
+ }
+
+ // TODO: varint or something on the freqs
+ for (unsigned i = 0; i < 64; ++i) {
+ if (stats[i].cum_freqs[NUM_SYMS] == 0) {
+ continue;
+ }
+ printf("writing table %d\n", i);
+#if 0
+ for (unsigned j = 0; j <= NUM_SYMS; ++j) {
+ uint16_t freq = stats[i].cum_freqs[j];
+ fwrite(&freq, 1, sizeof(freq), codedfp);
+ printf("%d: %d\n", j, stats[i].freqs[j]);
+ }
+#else
+ // TODO: rather gamma-k or something
+ for (unsigned j = 0; j < NUM_SYMS; ++j) {
+ // write_varint(stats[i].freqs[j], codedfp);
+ }
+#endif
+ }
+
+ RansEncoder rans_encoder;
+
+ size_t tot_bytes = 0;
+ for (unsigned y = 0; y < 8; ++y) {
+ for (unsigned x = 0; x < 8; ++x) {
+ SymbolStats &s = stats[pick_stats_for(x, y)];
+
+ rans_encoder.init_prob(s);
+
+ // need to reverse later
+ rans_encoder.clear();
+ size_t num_bytes = 0;
+ for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+ for (unsigned xb = 0; xb < WIDTH; xb += 8) {
+ int k = coeff[(yb + y) * WIDTH + (xb + x)];
+ //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
+ rans_encoder.encode_coeff(k);
+ }
+ if (yb % 16 == 8) {
+ num_bytes += rans_encoder.save_block(codedfp);
+ }
+ }
+ if (HEIGHT % 16 != 0) {
+ num_bytes += rans_encoder.save_block(codedfp);
+ }
+ tot_bytes += num_bytes;
+ printf("coeff %d: %ld bytes\n", y * 8 + x, num_bytes);
+ }
+ }
+ printf("%ld bytes + %ld sign bits (%ld) + %ld escape bits (%ld) = %ld total bytes\n",
+ tot_bytes - sign_bits / 8 - extra_bits / 8,
+ sign_bits,
+ sign_bits / 8,
+ extra_bits,
+ extra_bits / 8,
+ tot_bytes);
+}