Merge branch 'master' of /srv/git.sesse.net/www/c64tapwav
[c64tapwav] / cleaner.cpp
1 // "Cleans" tapes by finding the most likely path through a hidden Markov model (HMM),
2 // using the Viterbi algorithm. Usually works much worse than e.g. TAPclean;
3 // you have been warned :-)
4 //
5 // Takes in a cycles.plot (from decode) on stdin.
6
7 #include <stdio.h>
8 #include <string.h>
9 #include <math.h>
10 #include <assert.h>
11 #include <vector>
12 #include <map>
13 #include <algorithm>
14 #include <memory>
15
16 #include "tap.h"
17
18 #define STATE_BREADTH 75
19
20 // From SLC, reportedly from measuring the CIA chip.
21 #define SYNC_REFERENCE 378.0
22
23 // From http://c64tapes.org/dokuwiki/doku.php?id=loaders:rom_loader.
24 #define ROM_SHORT_PULSE_LENGTH (0x30 * 8)  // 384
25 #define ROM_MEDIUM_PULSE_LENGTH (0x42 * 8)  // 528
26 #define ROM_LONG_PULSE_LENGTH (0x56 * 8)  // 688
27
28 // From TAPclean sources.
29 #define NOVA_SHORT_PULSE_LENGTH 288
30 #define NOVA_LONG_PULSE_LENGTH 688
31
32 using namespace std;
33
34 enum State { STATE_ROM_SYNC, STATE_ROM_LOADER, STATE_NOVA, STATE_PAUSE };
35 enum ROMPulseType { ROM_PULSE_TYPE_SHORT = 0, ROM_PULSE_MEDIUM, ROM_PULSE_LONG };
36 enum NOVAPulseType { NOVA_PULSE_SHORT = 0, NOVA_PULSE_LONG };
37
38 static int total_alloc = 0;
39                 
40 static const double rom_lengths[] = { ROM_SHORT_PULSE_LENGTH, ROM_MEDIUM_PULSE_LENGTH, ROM_LONG_PULSE_LENGTH };
41 static const double nova_lengths[] = { NOVA_SHORT_PULSE_LENGTH, NOVA_LONG_PULSE_LENGTH };
42
43 struct Node {
44         Node() : prev_node(NULL), refcount(1) { ++total_alloc; }
45         void ref() const { ++refcount; }
46         void unref() const { if (--refcount == 0) { if (prev_node) { prev_node->unref(); } delete this; --total_alloc; } }
47         void set_prev_node(const Node *node) { if (prev_node) { prev_node->unref(); } node->ref(); prev_node = node; }
48         const Node *get_prev_node() const { return prev_node; }
49
50         // Viterbi information.
51         double cost;
52 private:
53         const Node* prev_node;
54 public:
55         int emit;
56         
57         // State information.
58         State state;
59
60         union {
61                 // For STATE_ROM_SYNC only.
62                 struct {
63                         int num_pulses_left;
64                 };
65
66                 // For STATE_ROM_LOADER.
67                 struct {
68                         ROMPulseType last_pulse_type;
69                         bool first_in_pair;  // If the _last_ pulse was first in the pair.
70                 };
71
72                 // For STATE_NOVA.
73                 struct {
74                         NOVAPulseType last_nova_pulse_type;
75                 };
76         };
77
78 private:
79         mutable int refcount;
80         ~Node() {}
81         Node(const Node &);
82 };
83
84 struct StateLessThanComparator {
85         bool operator() (const Node *a, const Node *b) const
86         {
87                 if (a->state != b->state)
88                         return (a->state < b->state);
89
90                 if (a->state == STATE_ROM_SYNC) {
91                         if (a->num_pulses_left != b->num_pulses_left)
92                                 return (a->num_pulses_left < b->num_pulses_left);
93                 }
94
95                 if (a->state == STATE_ROM_LOADER) {
96                         if (a->last_pulse_type != b->last_pulse_type)
97                                 return (a->last_pulse_type < b->last_pulse_type);
98                         if (a->first_in_pair != b->first_in_pair)
99                                 return (a->first_in_pair < b->first_in_pair);
100                 }
101                 
102                 if (a->state == STATE_NOVA) {
103                         if (a->last_nova_pulse_type != b->last_nova_pulse_type)
104                                 return (a->last_nova_pulse_type < b->last_nova_pulse_type);
105                 }
106
107                 return false;
108         }
109 };
110
111 struct StateEqualsComparator {
112         StateLessThanComparator lt;
113
114         bool operator() (const Node *a, const Node *b) const
115         {
116                 return !(lt(a, b) || lt(b, a));
117         }
118 };
119
120 struct CompareByCost {
121         bool operator() (const Node *a, const Node *b) const
122         {
123                 return a->cost < b->cost;
124         }
125 };
126
127 double penalty(double length, double reference, double ratio, double reference_ratio)
128 {
129         double ratio_penalty = ((ratio > reference_ratio) ? ratio / reference_ratio : reference_ratio / ratio) - 1.0;
130         double distance_penalty = fabs(length - reference);
131         return ratio_penalty * ratio_penalty + 1e-4 * distance_penalty;
132 }
133
134 double pulse_penalty(int num_pulses_left)
135 {
136         int num_pulses = 27136 - num_pulses_left;
137         return fabs(sqrt(27136.0) - sqrt(num_pulses));
138         // min(abs(prev->num_pulses_left - (27136 - 0x4f)), abs(prev->num_pulses_left));
139 }
140
141 void possibly_add_state(Node *node, vector<Node *> *next_states)
142 {
143         assert(node->cost >= node->get_prev_node()->cost);
144         next_states->push_back(node);
145 }
146                 
147 void extend_state(const Node *prev, double last_length, double length, double ratio, vector<Node *> *next_states)
148 {
149         // If this pulse is  long, it means we could transition into another type.
150         if (length > 700.0) {
151                 double cost = 0.0;
152                 if (prev->state == STATE_ROM_SYNC) {
153                         // We don't really want to jump directly from sync to Novaload sync...
154                         cost = pulse_penalty(prev->num_pulses_left);
155                 } else if (prev->state == STATE_ROM_LOADER) {
156                         // Jumping in the middle of a bit is bad, too
157                         cost = 50.0;
158                 } else if (prev->state == STATE_NOVA) {
159                         // is this too close to a long bit?
160                         double prev_ref = nova_lengths[prev->last_nova_pulse_type];
161                         double long_pulse_would_be = last_length * (NOVA_LONG_PULSE_LENGTH / prev_ref);
162                         double ratio_penalty = max(2.5 - length / long_pulse_would_be, 0.0);
163                         double distance_penalty = max(1000.0 - length, 0.0);
164                         cost = ratio_penalty * ratio_penalty + 1e-4 * distance_penalty;
165                 } else {
166                         // Make sure that marking things as pauses are not _that_ painless...
167                         cost = 1e-4 * max(1000.0 - length, 0.0);
168                 }
169
170                 Node *n = new Node;
171                 n->cost = prev->cost + cost; // + 10.0;
172                 n->set_prev_node(prev);
173                 n->emit = length;
174
175                 n->state = STATE_PAUSE;
176                 possibly_add_state(n, next_states);
177         }
178
179         // If in STATE_ROM_SYNC, it's possible that this was another sync pulse.
180         if (prev->state == STATE_ROM_SYNC || prev->state == STATE_PAUSE) {
181                 Node *n = new Node;
182                 if (prev->state == STATE_PAUSE) {
183                         n->cost = prev->cost + penalty(length, SYNC_REFERENCE, 1.0, 1.0);
184                 } else {
185                         n->cost = prev->cost + penalty(length, SYNC_REFERENCE, ratio, 1.0);
186                 }
187                 n->set_prev_node(prev);
188                 n->emit = SYNC_REFERENCE;
189
190                 n->state = STATE_ROM_SYNC;      
191                 if (prev->state == STATE_PAUSE) {
192                         n->num_pulses_left = 27136;
193                 } else {
194                         n->num_pulses_left = prev->num_pulses_left - 1;
195                 }
196                 possibly_add_state(n, next_states);
197         }
198
199         // If in STATE_ROM_SYNC, maybe we transitioned into ROM_LOADER.
200         // That always starts with a (L,M).
201         if (prev->state == STATE_ROM_SYNC) {
202                 Node *n = new Node;
203                 n->cost = prev->cost +
204                         penalty(length, ROM_LONG_PULSE_LENGTH, ratio, ROM_LONG_PULSE_LENGTH / SYNC_REFERENCE) +
205                         0.1 * pulse_penalty(prev->num_pulses_left);
206                 n->set_prev_node(prev);
207                 n->emit = ROM_LONG_PULSE_LENGTH;
208
209                 n->state = STATE_ROM_LOADER;
210                 n->last_pulse_type = ROM_PULSE_LONG;
211                 n->first_in_pair = true;
212                 possibly_add_state(n, next_states);
213         }
214         
215         // If in ROM_LOADER, we could have short, medium or long pulses.
216         if (prev->state == STATE_ROM_LOADER) {
217                 for (int pulse_type = ROM_PULSE_TYPE_SHORT; pulse_type <= ROM_PULSE_LONG; ++pulse_type) {
218
219                         // Filter illegal ROM loader pairs.
220                         if (prev->last_pulse_type == ROM_PULSE_LONG && pulse_type == ROM_PULSE_LONG) {
221                                 continue;
222                         }
223                         if (prev->first_in_pair) {
224                                 if (prev->last_pulse_type == ROM_PULSE_TYPE_SHORT && pulse_type != ROM_PULSE_MEDIUM) {
225                                         continue;
226                                 }
227                                 if (prev->last_pulse_type == ROM_PULSE_MEDIUM && pulse_type != ROM_PULSE_TYPE_SHORT) {
228                                         continue;
229                                 }
230                                 if (pulse_type == ROM_PULSE_LONG) {
231                                         continue;
232                                 }
233                         }
234
235                         Node *n = new Node;
236                         n->cost = prev->cost +
237                                 penalty(length, rom_lengths[pulse_type], ratio, rom_lengths[pulse_type] / rom_lengths[prev->last_pulse_type]);
238                         n->set_prev_node(prev);
239                         n->emit = rom_lengths[pulse_type];
240
241                         if (prev->first_in_pair && (prev->last_pulse_type == ROM_PULSE_LONG && pulse_type == ROM_PULSE_TYPE_SHORT)) {
242                                 // (L,S) = end-of-data-marker
243                                 n->state = STATE_ROM_SYNC;
244                                 n->num_pulses_left = 0x4f;  // http://c64tapes.org/dokuwiki/doku.php?id=loaders:rom_loader
245                         } else {
246                                 n->state = STATE_ROM_LOADER;
247                                 n->last_pulse_type = ROMPulseType(pulse_type);
248                                 n->first_in_pair = !prev->first_in_pair;
249                         }
250                         possibly_add_state(n, next_states);
251                 }               
252         }
253
254         // If in STATE_NOVA, we only have long and short pulses.
255         if (prev->state == STATE_NOVA || prev->state == STATE_PAUSE || prev->state == STATE_ROM_SYNC) {
256                 for (int pulse_type = NOVA_PULSE_SHORT; pulse_type <= NOVA_PULSE_LONG; ++pulse_type) {
257                         Node *n = new Node;
258                         if (prev->state == STATE_PAUSE) {
259                                 // going from PAUSE to NOVA is only realistic after a pretty long pause.
260                                 n->cost = prev->cost + fabs(2000.0f - last_length) + penalty(length, nova_lengths[pulse_type], 1.0, 1.0);
261                         } else if (prev->state == STATE_ROM_SYNC) {
262                                 n->cost = prev->cost +
263                                         penalty(length, NOVA_SHORT_PULSE_LENGTH, ratio, nova_lengths[pulse_type] / SYNC_REFERENCE) +
264                                         0.1 * pulse_penalty(prev->num_pulses_left);
265                         } else {
266                                 n->cost = prev->cost +
267                                         penalty(length, nova_lengths[pulse_type], ratio, nova_lengths[pulse_type] / nova_lengths[prev->last_nova_pulse_type]);
268                         }
269                         n->set_prev_node(prev);
270                         n->emit = nova_lengths[pulse_type];
271
272                         n->state = STATE_NOVA;
273                         n->last_nova_pulse_type = NOVAPulseType(pulse_type);
274                         possibly_add_state(n, next_states);
275                 }
276         }
277
278         // TODO: Other loader types
279 }
280
281 int main(int argc, char **argv)
282 {
283         Node *start = new Node;
284         start->cost = 0.0;
285         start->emit = -1;
286         start->state = STATE_PAUSE;
287
288         vector<Node *> states;
289         states.push_back(start);
290
291         double last_length = SYNC_REFERENCE;
292
293         int pulse_num = 0;
294         int max_total_alloc = 0;
295         vector<pair<double, double>> original_pulses;
296         for ( ;; ) {
297                 double time, length;
298                 if (scanf("%lf %lf", &time, &length) != 2) {
299                         break;
300                 }
301                 original_pulses.push_back(make_pair(time, length));
302
303                 ++pulse_num;
304
305                 max_total_alloc = max(total_alloc, max_total_alloc);
306                 if (total_alloc > 20000000) {
307                         printf("More than 20M states reached (out of RAM); aborting at pulse %d.\n", pulse_num);
308                         break;
309                 }
310
311                 //if (length > 2000) {
312                 //      break;
313                 //}
314
315                 double ratio = length / last_length;
316         
317                 vector<Node *> next_states;
318                 for (unsigned i = 0; i < states.size(); ++i) {
319                         extend_state(states[i], last_length, length, ratio, &next_states);
320
321                         // We no longer need this state; if it doesn't have any children,
322                         // it can go away.
323                         states[i]->unref();
324                 }
325                 states.clear();
326
327                 // Remove duplicates, tie-breaking by score.
328                 sort(next_states.begin(), next_states.end(), CompareByCost());
329                 stable_sort(next_states.begin(), next_states.end(), StateLessThanComparator());
330
331                 // unique and move in one step. Do not use std::unique(),
332                 // it has very wrong behavior for pointers!
333                 for (unsigned i = 0; i < next_states.size(); ++i) {
334                         if (i > 0 && StateEqualsComparator()(next_states[i], states.back())) {
335                                 next_states[i]->unref();
336                         } else {
337                                 states.push_back(next_states[i]);
338                         }
339                 }
340                 assert(!states.empty());
341                         
342                 // Prune unlikely next_states to save time and memory.
343                 if (states.size() >= STATE_BREADTH) {
344                         sort(states.begin(), states.end(), CompareByCost());  // nth element?
345                         for (unsigned i = STATE_BREADTH; i < states.size(); ++i) {
346                                 states[i]->unref();
347                         }
348                         states.resize(STATE_BREADTH);
349                 }
350                 last_length = length;
351
352                 if (pulse_num % 1000 == 0) {
353                         fprintf(stderr, "\nProcessing pulses... %d [%5.2f,%7.2f]  ", pulse_num, time, length);
354                         for (unsigned i = 0; i < 3 && i < states.size(); ++i) {
355                                 fprintf(stderr, " [%d:state=%d cost=%f]", i, states[i]->state, states[i]->cost);
356                         }
357                 }       
358         }
359
360         // Find the best final node.
361         const Node *best_node = NULL;
362         for (unsigned i = 0; i < states.size(); ++i) {
363                 if (states[i]->state == STATE_ROM_SYNC) {
364                         states[i]->cost += 0.1 * pulse_penalty(states[i]->num_pulses_left);
365                 }
366
367                 if (best_node == NULL || states[i]->cost < best_node->cost) {
368                         best_node = states[i];
369                 }
370         }
371
372         fprintf(stderr, "\rTotal cost is %f. Peak RAM usage %.2f MB, pluss malloc overhead.\n",
373                 best_node->cost, sizeof(Node) * max_total_alloc / 1048576.0);
374
375         // Backtrack.
376         vector<const Node *> cleaned;
377         while (best_node != NULL) {
378                 cleaned.push_back(best_node);
379                 best_node = best_node->get_prev_node();
380         }
381
382         reverse(cleaned.begin(), cleaned.end());
383
384         for (unsigned i = 0; i < cleaned.size(); ++i) {
385                 //fprintf(stderr, "%d: state %d emit %d cost %f\n", i, cleaned[i]->state, cleaned[i]->emit, cleaned[i]->cost);
386                 if (cleaned[i]->emit <= 0) {
387                         continue;
388                 }
389                 int state;
390                 if (cleaned[i]->state == STATE_ROM_SYNC) {
391                         state = 0;
392                 } else if (cleaned[i]->state == STATE_ROM_LOADER) {
393                         state = 1 + cleaned[i]->last_pulse_type;
394                 } else if (cleaned[i]->state == STATE_NOVA) {
395                         state = 4 + cleaned[i]->last_nova_pulse_type;
396                 } else if (cleaned[i]->state == STATE_PAUSE) {
397                         state = 6;
398                 } else {
399                         assert(false);
400                 }
401         //      printf("%d\n", cleaned[i]->emit);
402                 printf("%f %f %d\n", original_pulses[i - 1].first, original_pulses[i - 1].second, state);
403         }
404
405         // output TAP file      
406         FILE *fp = fopen("cleaned.tap", "wb");
407         std::vector<char> tap_data;
408         for (unsigned i = 0; i < cleaned.size(); ++i) {
409                 if (cleaned[i]->emit <= 0) {
410                         continue;
411                 }
412                 int len = lrintf(cleaned[i]->emit / TAP_RESOLUTION);
413                 if (len <= 255) {
414                         tap_data.push_back(len);
415                 } else {
416                         int overflow_len = lrintf(cleaned[i]->emit);
417                         tap_data.push_back(0);
418                         tap_data.push_back(overflow_len & 0xff);
419                         tap_data.push_back((overflow_len >> 8) & 0xff);
420                         tap_data.push_back(overflow_len >> 16);
421                 }
422         }
423
424         tap_header hdr;
425         memcpy(hdr.identifier, "C64-TAPE-RAW", 12);
426         hdr.version = 1;
427         hdr.reserved[0] = hdr.reserved[1] = hdr.reserved[2] = 0;
428         hdr.data_len = tap_data.size();
429
430         fwrite(&hdr, sizeof(hdr), 1, fp);
431         fwrite(tap_data.data(), tap_data.size(), 1, fp);
432         fclose(fp);
433 }