]> git.sesse.net Git - narabu/blobdiff - ryg_rans/main.cpp
Add support for optimal renormalization.
[narabu] / ryg_rans / main.cpp
index a9b9a6eaceca81c6866446b0f8422bd079793f02..becf81b394a4c998da4b21ebd69269f5b56c9963 100644 (file)
@@ -7,6 +7,7 @@
 #include <assert.h>
 
 #include "rans_byte.h"
+#include "renormalize.h"
 
 // This is just the sample program. All the meat is in rans_byte.h.
 
@@ -77,43 +78,8 @@ void SymbolStats::normalize_freqs(uint32_t target_total)
     assert(target_total >= 256);
     
     calc_cum_freqs();
-    uint32_t cur_total = cum_freqs[256];
-    
-    // resample distribution based on cumulative freqs
-    for (int i = 1; i <= 256; i++)
-        cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
-
-    // if we nuked any non-0 frequency symbol to 0, we need to steal
-    // the range to make the frequency nonzero from elsewhere.
-    //
-    // this is not at all optimal, i'm just doing the first thing that comes to mind.
-    for (int i=0; i < 256; i++) {
-        if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
-            // symbol i was set to zero freq
-
-            // find best symbol to steal frequency from (try to steal from low-freq ones)
-            uint32_t best_freq = ~0u;
-            int best_steal = -1;
-            for (int j=0; j < 256; j++) {
-                uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
-                if (freq > 1 && freq < best_freq) {
-                    best_freq = freq;
-                    best_steal = j;
-                }
-            }
-            assert(best_steal != -1);
-
-            // and steal from it!
-            if (best_steal < i) {
-                for (int j = best_steal + 1; j <= i; j++)
-                    cum_freqs[j]--;
-            } else {
-                assert(best_steal > i);
-                for (int j = i + 1; j <= best_steal; j++)
-                    cum_freqs[j]++;
-            }
-        }
-    }
+
+    OptimalRenormalize(cum_freqs, 256, target_total);
 
     // calculate updated freqs and make sure we didn't screw anything up
     assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);