Add a Viterbi cleaner.
[c64tapwav] / cleaner.cpp
1 // "Cleans" tapes by finding the most likely path through a hidden Markov model (HMM),
2 // using the Viterbi algorithm. Usually works much worse than e.g. TAPclean;
3 // you have been warned :-)
4 //
5 // Takes in a cycles.plot (from decode) on stdin.
6
7 #include <stdio.h>
8 #include <string.h>
9 #include <math.h>
10 #include <assert.h>
11 #include <vector>
12 #include <map>
13 #include <algorithm>
14 #include <memory>
15
16 #include "tap.h"
17
18 #define STATE_BREADTH 75
19
20 // From SLC, reportedly from measuring the CIA chip.
21 #define SYNC_REFERENCE 378.0
22
23 // From http://c64tapes.org/dokuwiki/doku.php?id=loaders:rom_loader.
24 #define ROM_SHORT_PULSE_LENGTH (0x30 * 8)  // 384
25 #define ROM_MEDIUM_PULSE_LENGTH (0x42 * 8)  // 528
26 #define ROM_LONG_PULSE_LENGTH (0x56 * 8)  // 688
27
28 // From TAPclean sources.
29 #define NOVA_SHORT_PULSE_LENGTH 288
30 #define NOVA_LONG_PULSE_LENGTH 688
31
32 using namespace std;
33
34 enum State { STATE_ROM_SYNC, STATE_ROM_LOADER, STATE_NOVA };
35 enum ROMPulseType { ROM_PULSE_TYPE_SHORT = 0, ROM_PULSE_MEDIUM, ROM_PULSE_LONG };
36 enum NOVAPulseType { NOVA_PULSE_SHORT = 0, NOVA_PULSE_LONG, NOVA_PULSE_NONE };
37
38 static int total_alloc = 0;
39
40 struct Node {
41         Node() : prev_node(NULL), refcount(1) { ++total_alloc; }
42         void ref() const { ++refcount; }
43         void unref() const { if (--refcount == 0) { if (prev_node) { prev_node->unref(); } delete this; --total_alloc; } }
44         void set_prev_node(const Node *node) { if (prev_node) { prev_node->unref(); } node->ref(); prev_node = node; }
45         const Node *get_prev_node() const { return prev_node; }
46
47         // Viterbi information.
48         double cost;
49 private:
50         const Node* prev_node;
51 public:
52         int emit;
53         
54         // State information.
55         State state;
56
57         union {
58                 // For STATE_ROM_SYNC only.
59                 struct {
60                         int num_pulses_left;
61                 };
62
63                 // For STATE_ROM_LOADER.
64                 struct {
65                         ROMPulseType last_pulse_type;
66                         bool first_in_pair;  // If the _last_ pulse was first in the pair.
67                 };
68
69                 // For STATE_NOVA.
70                 struct {
71                         NOVAPulseType last_nova_pulse_type;
72                 };
73         };
74
75 private:
76         mutable int refcount;
77         ~Node() {}
78         Node(const Node &);
79 };
80
81 struct StateLessThanComparator {
82         bool operator() (const Node *a, const Node *b) const
83         {
84                 if (a->state != b->state)
85                         return (a->state < b->state);
86
87                 if (a->state == STATE_ROM_SYNC) {
88                         if (a->num_pulses_left != b->num_pulses_left)
89                                 return (a->num_pulses_left < b->num_pulses_left);
90                 }
91
92                 if (a->state == STATE_ROM_LOADER) {
93                         if (a->last_pulse_type != b->last_pulse_type)
94                                 return (a->last_pulse_type < b->last_pulse_type);
95                         if (a->first_in_pair != b->first_in_pair)
96                                 return (a->first_in_pair < b->first_in_pair);
97                 }
98                 
99                 if (a->state == STATE_NOVA) {
100                         if (a->last_nova_pulse_type != b->last_nova_pulse_type)
101                                 return (a->last_nova_pulse_type < b->last_nova_pulse_type);
102                 }
103
104                 return false;
105         }
106 };
107
108 struct StateEqualsComparator {
109         StateLessThanComparator lt;
110
111         bool operator() (const Node *a, const Node *b) const
112         {
113                 return !(lt(a, b) || lt(b, a));
114         }
115 };
116
117 struct CompareByCost {
118         bool operator() (const Node *a, const Node *b) const
119         {
120                 return a->cost < b->cost;
121         }
122 };
123
124 double penalty(double length, double reference, double ratio, double reference_ratio)
125 {
126         double ratio_penalty = ((ratio > reference_ratio) ? ratio / reference_ratio : reference_ratio / ratio) - 1.0;
127         double distance_penalty = fabs(length - reference);
128         return ratio_penalty * ratio_penalty + 1e-4 * distance_penalty;
129 }
130
131 void possibly_add_state(Node *node, vector<Node *> *next_states)
132 {
133         next_states->push_back(node);
134 }
135                 
136 void extend_state(const Node *prev, double last_length, double length, double ratio, vector<Node *> *next_states)
137 {
138         // If this pulse is really long, it means we could transition into another type.
139         if (length > 2000.0) {
140                 double cost = 0.0;
141                 if (prev->state == STATE_ROM_SYNC) {
142                         // We don't really want to jump directly from sync to Novaload sync...
143                         cost = min(fabs(0x4f - prev->num_pulses_left), fabs(27136 - prev->num_pulses_left));
144                 } else if (prev->state == STATE_ROM_LOADER) {
145                         // Jumping in the middle of a bit is bad, too
146                         cost = 50.0;
147                 }
148
149                 // OK, so it could be Nova
150                 {
151                         Node *n = new Node;
152                         n->cost = prev->cost + cost; // + 10.0;
153                         n->set_prev_node(prev);
154                         n->emit = length;
155
156                         n->state = STATE_NOVA;
157                         n->last_nova_pulse_type = NOVA_PULSE_NONE;
158                         possibly_add_state(n, next_states);
159                 }
160                 // or maybe the ROM loader
161                 {
162                         Node *n = new Node;
163                         n->cost = prev->cost + cost; // + 10.0;
164                         n->set_prev_node(prev);
165                         n->emit = length;
166
167                         n->state = STATE_ROM_SYNC;
168                         n->num_pulses_left = 27136;
169                         possibly_add_state(n, next_states);
170                 }
171                 return;  // hack
172         }
173
174         // If in STATE_ROM_SYNC, it's possible that this was another sync pulse.
175         if (prev->state == STATE_ROM_SYNC) {
176                 Node *n = new Node;
177                 n->cost = prev->cost + penalty(length, SYNC_REFERENCE, ratio, 1.0);
178                 n->set_prev_node(prev);
179                 n->emit = SYNC_REFERENCE;
180
181                 n->state = STATE_ROM_SYNC;      
182                 n->num_pulses_left = prev->num_pulses_left - 1;
183                 possibly_add_state(n, next_states);
184         }
185
186         // If in STATE_ROM_SYNC, maybe we transitioned into ROM_LOADER.
187         // That always starts with a (L,M).
188         if (prev->state == STATE_ROM_SYNC) {
189                 Node *n = new Node;
190                 n->cost = prev->cost +
191                         penalty(length, ROM_LONG_PULSE_LENGTH, ratio, ROM_LONG_PULSE_LENGTH / SYNC_REFERENCE) +
192                         0.1 * fabs(prev->num_pulses_left);
193                 n->set_prev_node(prev);
194                 n->emit = ROM_LONG_PULSE_LENGTH;
195
196                 n->state = STATE_ROM_LOADER;
197                 n->last_pulse_type = ROM_PULSE_LONG;
198                 n->first_in_pair = true;
199                 possibly_add_state(n, next_states);
200         }
201         
202         // If in STATE_ROM_SYNC, we could also seemingly transition into Nova data.
203         if (prev->state == STATE_ROM_SYNC) {
204                 Node *n = new Node;
205                 n->cost = prev->cost +
206                         penalty(length, NOVA_SHORT_PULSE_LENGTH, ratio, NOVA_SHORT_PULSE_LENGTH / SYNC_REFERENCE) +
207                         0.1 * fabs(prev->num_pulses_left);
208                 n->set_prev_node(prev);
209                 n->emit = NOVA_SHORT_PULSE_LENGTH;
210
211                 n->state = STATE_NOVA;
212                 n->last_nova_pulse_type = NOVA_PULSE_NONE;
213                 possibly_add_state(n, next_states);
214         }
215
216         // If in ROM_LOADER, we could have short, medium or long pulses.
217         if (prev->state == STATE_ROM_LOADER) {
218                 static const double lengths[] = { ROM_SHORT_PULSE_LENGTH, ROM_MEDIUM_PULSE_LENGTH, ROM_LONG_PULSE_LENGTH };
219                 for (int pulse_type = ROM_PULSE_TYPE_SHORT; pulse_type <= ROM_PULSE_LONG; ++pulse_type) {
220
221                         // Filter illegal ROM loader pairs.
222                         if (prev->last_pulse_type == ROM_PULSE_LONG && pulse_type == ROM_PULSE_LONG) {
223                                 continue;
224                         }
225                         if (prev->first_in_pair) {
226                                 if (prev->last_pulse_type == ROM_PULSE_TYPE_SHORT && pulse_type != ROM_PULSE_MEDIUM) {
227                                         continue;
228                                 }
229                                 if (prev->last_pulse_type == ROM_PULSE_MEDIUM && pulse_type != ROM_PULSE_TYPE_SHORT) {
230                                         continue;
231                                 }
232                                 if (pulse_type == ROM_PULSE_LONG) {
233                                         continue;
234                                 }
235                         }
236
237                         Node *n = new Node;
238                         n->cost = prev->cost +
239                                 penalty(length, lengths[pulse_type], ratio, lengths[pulse_type] / lengths[prev->last_pulse_type]);
240                         n->set_prev_node(prev);
241                         n->emit = lengths[pulse_type];
242
243                         if (prev->first_in_pair && (prev->last_pulse_type == ROM_PULSE_LONG && pulse_type == ROM_PULSE_TYPE_SHORT)) {
244                                 // (L,S) = end-of-data-marker
245                                 n->state = STATE_ROM_SYNC;
246                                 n->num_pulses_left = 0x4f;  // http://c64tapes.org/dokuwiki/doku.php?id=loaders:rom_loader
247                         } else {
248                                 n->state = STATE_ROM_LOADER;
249                                 n->last_pulse_type = ROMPulseType(pulse_type);
250                                 n->first_in_pair = !prev->first_in_pair;
251                         }
252                         possibly_add_state(n, next_states);
253                 }               
254         }
255
256         // If in STATE_NOVA, we only have long and short pulses.
257         if (prev->state == STATE_NOVA) {  // hack
258                 static const double lengths[] = { NOVA_SHORT_PULSE_LENGTH, NOVA_LONG_PULSE_LENGTH };
259                 
260                 for (int pulse_type = NOVA_PULSE_SHORT; pulse_type <= NOVA_PULSE_LONG; ++pulse_type) {
261                         
262
263                         Node *n = new Node;
264                         if (prev->last_nova_pulse_type == NOVA_PULSE_NONE) {
265                                 n->cost = prev->cost +
266                                         penalty(length, lengths[pulse_type], length / NOVA_SHORT_PULSE_LENGTH, lengths[pulse_type] / NOVA_SHORT_PULSE_LENGTH);
267                         } else {
268                                 n->cost = prev->cost +
269                                         penalty(length, lengths[pulse_type], ratio, lengths[pulse_type] / lengths[prev->last_nova_pulse_type]);
270                         }
271                         n->set_prev_node(prev);
272                         n->emit = lengths[pulse_type];
273
274                         n->state = STATE_NOVA;
275                         n->last_nova_pulse_type = NOVAPulseType(pulse_type);
276                         possibly_add_state(n, next_states);
277                 }
278         }
279
280         // TODO: Other loader types
281 }
282
283 int main(int argc, char **argv)
284 {
285         Node *start = new Node;
286         start->cost = 0.0;
287         start->emit = -1;
288         start->state = STATE_ROM_SYNC;
289         start->num_pulses_left = 27136;
290
291         vector<Node *> states;
292         states.push_back(start);
293
294         double last_length = SYNC_REFERENCE;
295
296         int pulse_num = 0;
297         int max_total_alloc = 0;
298         for ( ;; ) {
299                 double time, length;
300                 if (scanf("%lf %lf", &time, &length) != 2) {
301                         break;
302                 }
303
304                 ++pulse_num;
305
306                 if (pulse_num % 1000 == 0) {
307                         fprintf(stderr, "\rProcessing pulses... %d", pulse_num);
308                 }       
309
310                 max_total_alloc = max(total_alloc, max_total_alloc);
311                 if (total_alloc > 20000000) {
312                         printf("More than 20M states reached (out of RAM); aborting at pulse %d.\n", pulse_num);
313                         break;
314                 }
315
316                 //if (length > 2000) {
317                 //      break;
318                 //}
319
320                 double ratio = length / last_length;
321         
322                 vector<Node *> next_states;
323                 for (unsigned i = 0; i < states.size(); ++i) {
324                         extend_state(states[i], last_length, length, ratio, &next_states);
325
326                         // We no longer need this state; if it doesn't have any children,
327                         // it can go away.
328                         states[i]->unref();
329                 }
330                 states.clear();
331
332                 // Remove duplicates, tie-breaking by score.
333                 sort(next_states.begin(), next_states.end(), CompareByCost());
334                 stable_sort(next_states.begin(), next_states.end(), StateLessThanComparator());
335
336                 // unique and move in one step. Do not use std::unique(),
337                 // it has very wrong behavior for pointers!
338                 for (unsigned i = 0; i < next_states.size(); ++i) {
339                         if (i > 0 && StateEqualsComparator()(next_states[i], states.back())) {
340                                 next_states[i]->unref();
341                         } else {
342                                 states.push_back(next_states[i]);
343                         }
344                 }
345                 assert(!states.empty());
346                         
347                 // Prune unlikely next_states to save time and memory.
348                 if (states.size() >= STATE_BREADTH) {
349                         sort(states.begin(), states.end(), CompareByCost());
350                         for (unsigned i = STATE_BREADTH; i < states.size(); ++i) {
351                                 states[i]->unref();
352                         }
353                         states.resize(STATE_BREADTH);
354                 }
355                 last_length = length;
356         }
357
358         // Find the best final node.
359         const Node *best_node = NULL;
360         for (unsigned i = 0; i < states.size(); ++i) {
361                 if (states[i]->state == STATE_ROM_SYNC) {
362                         states[i]->cost += 0.1 * fabs(states[i]->num_pulses_left);
363                 }
364
365                 if (best_node == NULL || states[i]->cost < best_node->cost) {
366                         best_node = states[i];
367                 }
368         }
369
370         fprintf(stderr, "\rTotal cost is %f. Peak RAM usage %.2f MB, pluss malloc overhead.\n",
371                 best_node->cost, sizeof(Node) * max_total_alloc / 1048576.0);
372
373         // Backtrack.
374         vector<const Node *> cleaned;
375         while (best_node != NULL) {
376                 cleaned.push_back(best_node);
377                 best_node = best_node->get_prev_node();
378         }
379
380         reverse(cleaned.begin(), cleaned.end());
381
382         for (unsigned i = 0; i < cleaned.size(); ++i) {
383                 //fprintf(stderr, "%d: state %d emit %d cost %f\n", i, cleaned[i]->state, cleaned[i]->emit, cleaned[i]->cost);
384                 if (cleaned[i]->emit <= 0) {
385                         continue;
386                 }
387                 printf("%d\n", cleaned[i]->emit);
388         }
389
390         // output TAP file      
391         FILE *fp = fopen("cleaned.tap", "wb");
392         std::vector<char> tap_data;
393         for (unsigned i = 0; i < cleaned.size(); ++i) {
394                 if (cleaned[i]->emit <= 0) {
395                         continue;
396                 }
397                 int len = lrintf(cleaned[i]->emit / TAP_RESOLUTION);
398                 if (len <= 255) {
399                         tap_data.push_back(len);
400                 } else {
401                         int overflow_len = lrintf(cleaned[i]->emit);
402                         tap_data.push_back(0);
403                         tap_data.push_back(overflow_len & 0xff);
404                         tap_data.push_back((overflow_len >> 8) & 0xff);
405                         tap_data.push_back(overflow_len >> 16);
406                 }
407         }
408
409         tap_header hdr;
410         memcpy(hdr.identifier, "C64-TAPE-RAW", 12);
411         hdr.version = 1;
412         hdr.reserved[0] = hdr.reserved[1] = hdr.reserved[2] = 0;
413         hdr.data_len = tap_data.size();
414
415         fwrite(&hdr, sizeof(hdr), 1, fp);
416         fwrite(tap_data.data(), tap_data.size(), 1, fp);
417         fclose(fp);
418 }