]> git.sesse.net Git - narabu/blob - ryg_rans/renormalize.cpp
f748de9e6ca395ffb75c03e1e4410c6b90ed8417
[narabu] / ryg_rans / renormalize.cpp
1 // Copyright (c) 2017, Steinar H. Gunderson
2 // All rights reserved.
3 // 
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted.
6 // 
7 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
8 // “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
9 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
10 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
11 // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
12 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
13 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
14 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
15 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
16 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
17 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
18
19 #include "renormalize.h"
20
21 #include <assert.h>
22 #include <math.h>
23
24 #include <unordered_map>
25 #include <map>
26 #include <memory>
27 #include <utility>
28
29 using std::equal_to;
30 using std::hash;
31 using std::max;
32 using std::min;
33 using std::make_pair;
34 using std::pair;
35 using std::unique_ptr;
36 using std::unordered_map;
37
38 namespace {
39
40 struct OptimalChoice {
41         double cost;  // In bits.
42         uint32_t chosen_freq;
43 };
44 struct CacheKey {
45         int num_syms;
46         int available_slots;
47
48         bool operator== (const CacheKey &other) const
49         {
50                 return num_syms == other.num_syms && available_slots == other.available_slots;
51         }
52 };
53 struct HashCacheKey {
54         size_t operator() (const CacheKey &key) const
55         {
56                 return hash<int64_t>()((uint64_t(key.available_slots) << 32) | key.num_syms);
57         }
58 };
59 using CacheMap = unordered_map<CacheKey, OptimalChoice, HashCacheKey>;
60
61 // Find, recursively, the optimal cost of encoding the symbols [0, num_syms),
62 // assuming an optimal distribution of those symbols to "available_slots".
63 // The cache is used for memoization, and also to remember the best choice.
64 // No frequency can be zero.
65 //
66 // Returns HUGE_VAL if there's no legal mapping.
67 double FindOptimalCost(uint32_t *cum_freqs, int num_syms, int available_slots, const double *log2cache, CacheMap *cache)
68 {
69         static int k = 0;
70         if (num_syms == 0) {
71                 // Encoding zero symbols needs zero bits.
72                 return 0.0;
73         }
74         if (num_syms > available_slots) {
75                 // Every (non-zero-frequency) symbol needs at least one slot.
76                 return HUGE_VAL;
77         }
78         if (num_syms == 1) {
79                 return cum_freqs[1] * log2cache[available_slots];
80         }
81
82         CacheKey cache_key{num_syms, available_slots};
83         auto insert_result = cache->insert(make_pair(cache_key, OptimalChoice()));
84         if (!insert_result.second) {
85                 // There was already an item in the cache, so return it.
86                 return insert_result.first->second.cost;
87         }
88
89         // Minimize the number of total bits spent as a function of how many slots
90         // we assign to this symbol.
91         //
92         // The cost function is convex (at least in practice; I suppose also in
93         // theory because it's the sum of an increasing and a decreasing function?).
94         // Find a reasonable guess and see in what direction the function is decreasing,
95         // then iterate until we either hit the end or we start increasing again.
96         //
97         // Since the function is a sum of log() terms, it is differentiable, and we
98         // could in theory use this; however, it doesn't seem to be worth the complexity.
99         uint32_t freq = cum_freqs[num_syms] - cum_freqs[num_syms - 1];
100         assert(freq > 0);
101         double guess = lrint(available_slots * double(freq) / cum_freqs[num_syms]);
102
103         int x1 = max<int>(floor(guess), 1);
104         int x2 = x1 + 1;
105
106         double cost1 = freq * log2cache[x1] + FindOptimalCost(cum_freqs, num_syms - 1, available_slots - x1, log2cache, cache);
107         double cost2 = freq * log2cache[x2] + FindOptimalCost(cum_freqs, num_syms - 1, available_slots - x2, log2cache, cache);
108
109         int x;
110         int direction;  // -1 or +1.
111         double best_cost;
112         if (isinf(cost1) && isinf(cost2)) {
113                 // The cost isn't infinite due to the first term, so we need to go downwards
114                 // to give the second term more room to breathe.
115                 x = x1;
116                 best_cost = cost1;
117                 direction = -1;
118         } else if (cost1 < cost2) {
119                 x = x1;
120                 best_cost = cost1;
121                 direction = -1;
122         } else {
123                 x = x2;
124                 best_cost = cost2;
125                 direction = 1;
126         }
127         int best_choice = x;
128
129         for ( ;; ) {
130                 x += direction;
131                 if (x == 0 || x > available_slots) {
132                         // We hit the end; we can't assign zero slots to this symbol,
133                         // and we can't assign more slots than we have. This extreme
134                         // is the best choice.
135                         break;
136                 }
137                 double cost = freq * log2cache[x] + FindOptimalCost(cum_freqs, num_syms - 1, available_slots - x, log2cache, cache);
138                 if (cost > best_cost) {
139                         // The cost started increasing again, so we've found the optimal choice.
140                         break;
141                 }
142                 best_choice = x;
143                 best_cost = cost;
144         }
145         insert_result.first->second.cost = best_cost;
146         insert_result.first->second.chosen_freq = best_choice;
147         return best_cost;
148 }
149
150 }  // namespace
151
152 void OptimalRenormalize(uint32_t *cum_freqs, uint32_t num_syms, uint32_t target_total)
153 {
154         // First remove all symbols that have a zero frequency; they tend to
155         // complicate the analysis. We'll put them back afterwards.
156         unique_ptr<uint32_t[]> remapped_cum_freqs(new uint32_t[num_syms + 1]);
157         unique_ptr<uint32_t[]> mapping(new uint32_t[num_syms + 1]);
158
159         uint32_t new_num_syms = 0;
160         remapped_cum_freqs[0] = 0;
161         for (uint32_t i = 0; i < num_syms; ++i) {
162                 if (cum_freqs[i + 1] == cum_freqs[i]) {
163                         continue;
164                 }
165                 mapping[new_num_syms] = i;
166                 remapped_cum_freqs[new_num_syms + 1] = cum_freqs[i + 1];
167                 new_num_syms++;
168         }
169
170         // Calculate the cost of encoding a symbol with frequency f/target_total.
171         // We call log2() quite a lot, so it's best to cache it once at the start.
172         unique_ptr<double[]> log2cache(new double[target_total + 1]);
173         for (uint32_t i = 0; i <= target_total; ++i) {
174                 log2cache[i] = -log2(i * (1.0 / target_total));
175         }
176
177         CacheMap cache;
178         FindOptimalCost(remapped_cum_freqs.get(), new_num_syms, target_total, log2cache.get(), &cache);
179
180         for (uint32_t i = 0; i <= num_syms; ++i) {
181                 cum_freqs[i] = 0;
182         }
183
184         // Reconstruct the optimal choices from the cache. Note that during this,
185         // cum_freq contains frequencies, _not_ cumulative frequencies.
186         int available_slots = target_total;
187         for (int symbol_idx = new_num_syms; symbol_idx --> 0; ) {  // :-)
188                 uint32_t freq;
189                 if (symbol_idx == 0) {
190                         // Last symbol isn't in the cache, but it's obvious what the answer is.
191                         freq = available_slots;
192                 } else {
193                         CacheKey cache_key{symbol_idx + 1, available_slots};
194                         assert(cache.count(cache_key));
195                         freq = cache[cache_key].chosen_freq;
196                 }
197                 cum_freqs[mapping[symbol_idx]] = freq;
198                 assert(available_slots >= freq);
199                 available_slots -= freq;
200         }
201
202         // Convert the frequencies back to cumulative frequencies.
203         uint32_t total = 0;
204         for (uint32_t i = 0; i <= num_syms; ++i) {
205                 uint32_t freq = cum_freqs[i];
206                 cum_freqs[i] = total;
207                 total += freq;
208         }
209 }