b3b8aa3376f355b9fcad9e8cd3aa4fc9b1bf2b4b
[c64tapwav] / decode.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <math.h>
4 #include <assert.h>
5 #include <limits.h>
6 #include <getopt.h>
7 #include <vector>
8 #include <algorithm>
9
10 #include "audioreader.h"
11 #include "interpolate.h"
12 #include "level.h"
13 #include "tap.h"
14
15 #define BUFSIZE 4096
16 #define C64_FREQUENCY 985248
17 #define SYNC_PULSE_START 1000
18 #define SYNC_PULSE_END 20000
19 #define SYNC_PULSE_LENGTH 378.0
20 #define SYNC_TEST_TOLERANCE 1.10
21
22 // SPSA options
23 #define NUM_FILTER_COEFF 32
24 #define NUM_ITER 5000
25 #define A NUM_ITER/10  // approx
26 #define INITIAL_A 0.005 // A bit of trial and error...
27 #define INITIAL_C 0.02  // This too.
28 #define GAMMA 0.166
29 #define ALPHA 1.0
30
31 static float hysteresis_limit = 3000.0 / 32768.0;
32 static bool do_calibrate = true;
33 static bool output_cycles_plot = false;
34 static bool use_filter = false;
35 static bool do_crop = false;
36 static float crop_start = 0.0f, crop_end = HUGE_VAL;
37 static float filter_coeff[NUM_FILTER_COEFF] = { 1.0f };  // The rest is filled with 0.
38 static bool output_filtered = false;
39 static bool quiet = false;
40 static bool do_auto_level = false;
41 static bool output_leveled = false;
42 static std::vector<float> train_snap_points;
43 static bool do_train = false;
44
45 // The minimum estimated sound level (for do_auto_level) at any given point.
46 // If you decrease this, you'll be able to amplify really silent signals
47 // by more, but you'll also increase the level of silent (ie. noise-only) segments,
48 // possibly caused misdetected pulses in these segments.
49 static float min_level = 0.05f;
50
51 // between [x,x+1]
52 double find_zerocrossing(const std::vector<float> &pcm, int x)
53 {
54         if (pcm[x] == 0) {
55                 return x;
56         }
57         if (pcm[x + 1] == 0) {
58                 return x + 1;
59         }
60
61         assert(pcm[x + 1] < 0);
62         assert(pcm[x] > 0);
63
64         double upper = x;
65         double lower = x + 1;
66         while (lower - upper > 1e-3) {
67                 double mid = 0.5f * (upper + lower);
68                 if (lanczos_interpolate(pcm, mid) > 0) {
69                         upper = mid;
70                 } else {
71                         lower = mid;
72                 }
73         }
74
75         return 0.5f * (upper + lower);
76 }
77
78 struct pulse {
79         double time;  // in seconds from start
80         double len;   // in seconds
81 };
82
83 // Calibrate on the first ~25k pulses (skip a few, just to be sure).
84 double calibrate(const std::vector<pulse> &pulses) {
85         if (pulses.size() < SYNC_PULSE_END) {
86                 fprintf(stderr, "Too few pulses, not calibrating!\n");
87                 return 1.0;
88         }
89
90         int sync_pulse_end = -1;
91         double sync_pulse_stddev = -1.0;
92
93         // Compute the standard deviation (to check for uneven speeds).
94         // If it suddenly skyrockets, we assume that sync ended earlier
95         // than we thought (it should be 25000 cycles), and that we should
96         // calibrate on fewer cycles.
97         for (int try_end : { 2000, 4000, 5000, 7500, 10000, 15000, SYNC_PULSE_END }) {
98                 double sum2 = 0.0;
99                 for (int i = SYNC_PULSE_START; i < try_end; ++i) {
100                         double cycles = pulses[i].len * C64_FREQUENCY;
101                         sum2 += (cycles - SYNC_PULSE_LENGTH) * (cycles - SYNC_PULSE_LENGTH);
102                 }
103                 double stddev = sqrt(sum2 / (try_end - SYNC_PULSE_START - 1));
104                 if (sync_pulse_end != -1 && stddev > 5.0 && stddev / sync_pulse_stddev > 1.3) {
105                         fprintf(stderr, "Stopping at %d sync pulses because standard deviation would be too big (%.2f cycles); shorter-than-usual trailer?\n",
106                                 sync_pulse_end, stddev);
107                         break;
108                 }
109                 sync_pulse_end = try_end;
110                 sync_pulse_stddev = stddev;
111         }
112         if (!quiet) {
113                 fprintf(stderr, "Sync pulse length standard deviation: %.2f cycles\n",
114                         sync_pulse_stddev);
115         }
116
117         double sum = 0.0;
118         for (int i = SYNC_PULSE_START; i < sync_pulse_end; ++i) {
119                 sum += pulses[i].len;
120         }
121         double mean_length = C64_FREQUENCY * sum / (sync_pulse_end - SYNC_PULSE_START);
122         double calibration_factor = SYNC_PULSE_LENGTH / mean_length;
123         if (!quiet) {
124                 fprintf(stderr, "Calibrated sync pulse length: %.2f -> %.2f (change %+.2f%%)\n",
125                         mean_length, SYNC_PULSE_LENGTH, 100.0 * (calibration_factor - 1.0));
126         }
127
128         // Check for pulses outside +/- 10% (sign of misdetection).
129         for (int i = SYNC_PULSE_START; i < sync_pulse_end; ++i) {
130                 double cycles = pulses[i].len * calibration_factor * C64_FREQUENCY;
131                 if (cycles < SYNC_PULSE_LENGTH / SYNC_TEST_TOLERANCE || cycles > SYNC_PULSE_LENGTH * SYNC_TEST_TOLERANCE) {
132                         fprintf(stderr, "Sync cycle with downflank at %.6f was detected at %.0f cycles; misdetect?\n",
133                                 pulses[i].time, cycles);
134                 }
135         }
136
137         return calibration_factor;
138 }
139
140 void output_tap(const std::vector<pulse>& pulses, double calibration_factor)
141 {
142         std::vector<char> tap_data;
143         for (unsigned i = 0; i < pulses.size(); ++i) {
144                 double cycles = pulses[i].len * calibration_factor * C64_FREQUENCY;
145                 int len = lrintf(cycles / TAP_RESOLUTION);
146                 if (i > SYNC_PULSE_END && (cycles < 100 || cycles > 800)) {
147                         fprintf(stderr, "Cycle with downflank at %.6f was detected at %.0f cycles; misdetect?\n",
148                                         pulses[i].time, cycles);
149                 }
150                 if (len <= 255) {
151                         tap_data.push_back(len);
152                 } else {
153                         int overflow_len = lrintf(cycles);
154                         tap_data.push_back(0);
155                         tap_data.push_back(overflow_len & 0xff);
156                         tap_data.push_back((overflow_len >> 8) & 0xff);
157                         tap_data.push_back(overflow_len >> 16);
158                 }
159         }
160
161         tap_header hdr;
162         memcpy(hdr.identifier, "C64-TAPE-RAW", 12);
163         hdr.version = 1;
164         hdr.reserved[0] = hdr.reserved[1] = hdr.reserved[2] = 0;
165         hdr.data_len = tap_data.size();
166
167         fwrite(&hdr, sizeof(hdr), 1, stdout);
168         fwrite(tap_data.data(), tap_data.size(), 1, stdout);
169 }
170
171 static struct option long_options[] = {
172         {"auto-level",       0,                 0, 'a' },
173         {"output-leveled",   0,                 0, 'A' },
174         {"no-calibrate",     0,                 0, 's' },
175         {"plot-cycles",      0,                 0, 'p' },
176         {"hysteresis-limit", required_argument, 0, 'l' },
177         {"filter",           required_argument, 0, 'f' },
178         {"output-filtered",  0,                 0, 'F' },
179         {"crop",             required_argument, 0, 'c' },
180         {"quiet",            0,                 0, 'q' },
181         {"help",             0,                 0, 'h' },
182         {0,                  0,                 0, 0   }
183 };
184
185 void help()
186 {
187         fprintf(stderr, "decode [OPTIONS] AUDIO-FILE > TAP-FILE\n");
188         fprintf(stderr, "\n");
189         fprintf(stderr, "  -a, --auto-level             automatically adjust amplitude levels throughout the file\n");
190         fprintf(stderr, "  -A, --output-leveled         output leveled waveform to leveled.raw\n");
191         fprintf(stderr, "  -m, --min-level              minimum estimated sound level (0..32768) for --auto-level\n");
192         fprintf(stderr, "  -s, --no-calibrate           do not try to calibrate on sync pulse length\n");
193         fprintf(stderr, "  -p, --plot-cycles            output debugging info to cycles.plot\n");
194         fprintf(stderr, "  -l, --hysteresis-limit VAL   change amplitude threshold for ignoring pulses (0..32768)\n");
195         fprintf(stderr, "  -f, --filter C1:C2:C3:...    specify FIR filter (up to %d coefficients)\n", NUM_FILTER_COEFF);
196         fprintf(stderr, "  -F, --output-filtered        output filtered waveform to filtered.raw\n");
197         fprintf(stderr, "  -c, --crop START[:END]       use only the given part of the file\n");
198         fprintf(stderr, "  -t, --train LEN1:LEN2:...    train a filter for detecting any of the given number of cycles\n");
199         fprintf(stderr, "                               (implies --no-calibrate and --quiet unless overridden)\n");
200         fprintf(stderr, "  -q, --quiet                  suppress some informational messages\n");
201         fprintf(stderr, "  -h, --help                   display this help, then exit\n");
202         exit(1);
203 }
204
205 void parse_options(int argc, char **argv)
206 {
207         for ( ;; ) {
208                 int option_index = 0;
209                 int c = getopt_long(argc, argv, "aAm:spl:f:Fc:t:qh", long_options, &option_index);
210                 if (c == -1)
211                         break;
212
213                 switch (c) {
214                 case 'a':
215                         do_auto_level = true;
216                         break;
217
218                 case 'A':
219                         output_leveled = true;
220                         break;
221
222                 case 'm':
223                         min_level = atof(optarg) / 32768.0;
224                         break;
225
226                 case 's':
227                         do_calibrate = false;
228                         break;
229
230                 case 'p':
231                         output_cycles_plot = true;
232                         break;
233
234                 case 'l':
235                         hysteresis_limit = atof(optarg) / 32768.0;
236                         break;
237
238                 case 'f': {
239                         const char *coeffstr = strtok(optarg, ": ");
240                         int coeff_index = 0;
241                         while (coeff_index < NUM_FILTER_COEFF && coeffstr != NULL) {
242                                 filter_coeff[coeff_index++] = atof(coeffstr);
243                                 coeffstr = strtok(NULL, ": ");
244                         }
245                         use_filter = true;
246                         break;
247                 }
248
249                 case 'F':
250                         output_filtered = true;
251                         break;
252
253                 case 'c': {
254                         const char *cropstr = strtok(optarg, ":");
255                         crop_start = atof(cropstr);
256                         cropstr = strtok(NULL, ":");
257                         if (cropstr == NULL) {
258                                 crop_end = HUGE_VAL;
259                         } else {
260                                 crop_end = atof(cropstr);
261                         }
262                         do_crop = true;
263                         break;
264                 }
265
266                 case 't': {
267                         const char *cyclestr = strtok(optarg, ":");
268                         while (cyclestr != NULL) {
269                                 train_snap_points.push_back(atof(cyclestr));
270                                 cyclestr = strtok(NULL, ":");
271                         }
272                         do_train = true;
273
274                         // Set reasonable defaults (can be overridden later on the command line).
275                         do_calibrate = false;
276                         quiet = true;
277                         break;
278                 }
279
280                 case 'q':
281                         quiet = true;
282                         break;
283
284                 case 'h':
285                 default:
286                         help();
287                         exit(1);
288                 }
289         }
290 }
291
292 std::vector<float> crop(const std::vector<float>& pcm, float crop_start, float crop_end, int sample_rate)
293 {
294         size_t start_sample, end_sample;
295         if (crop_start >= 0.0f) {
296                 start_sample = std::min<size_t>(lrintf(crop_start * sample_rate), pcm.size());
297         }
298         if (crop_end >= 0.0f) {
299                 end_sample = std::min<size_t>(lrintf(crop_end * sample_rate), pcm.size());
300         }
301         return std::vector<float>(pcm.begin() + start_sample, pcm.begin() + end_sample);
302 }
303
304 // TODO: Support AVX here.
305 std::vector<float> do_filter(const std::vector<float>& pcm, const float* filter)
306 {
307         std::vector<float> filtered_pcm;
308         filtered_pcm.reserve(pcm.size());
309         for (unsigned i = NUM_FILTER_COEFF; i < pcm.size(); ++i) {
310                 float s = 0.0f;
311                 for (int j = 0; j < NUM_FILTER_COEFF; ++j) {
312                         s += filter[j] * pcm[i - j];
313                 }
314                 filtered_pcm.push_back(s);
315         }
316
317         if (output_filtered) {
318                 FILE *fp = fopen("filtered.raw", "wb");
319                 fwrite(filtered_pcm.data(), filtered_pcm.size() * sizeof(filtered_pcm[0]), 1, fp);
320                 fclose(fp);
321         }
322
323         return filtered_pcm;
324 }
325
326 std::vector<pulse> detect_pulses(const std::vector<float> &pcm, int sample_rate)
327 {
328         std::vector<pulse> pulses;
329
330         // Find the flanks.
331         int last_bit = -1;
332         double last_downflank = -1;
333         for (unsigned i = 0; i < pcm.size(); ++i) {
334                 int bit = (pcm[i] > 0) ? 1 : 0;
335                 if (bit == 0 && last_bit == 1) {
336                         // Check if we ever go up above <hysteresis_limit> before we dip down again.
337                         bool true_pulse = false;
338                         unsigned j;
339                         int min_level_after = 32767;
340                         for (j = i; j < pcm.size(); ++j) {
341                                 min_level_after = std::min<int>(min_level_after, pcm[j]);
342                                 if (pcm[j] > 0) break;
343                                 if (pcm[j] < -hysteresis_limit) {
344                                         true_pulse = true;
345                                         break;
346                                 }
347                         }
348
349                         if (!true_pulse) {
350 #if 0
351                                 fprintf(stderr, "Ignored down-flank at %.6f seconds due to hysteresis (%d < %d).\n",
352                                         double(i) / sample_rate, -min_level_after, hysteresis_limit);
353 #endif
354                                 i = j;
355                                 continue;
356                         } 
357
358                         // down-flank!
359                         double t = find_zerocrossing(pcm, i - 1) * (1.0 / sample_rate) + crop_start;
360                         if (last_downflank > 0) {
361                                 pulse p;
362                                 p.time = t;
363                                 p.len = t - last_downflank;
364                                 pulses.push_back(p);
365                         }
366                         last_downflank = t;
367                 }
368                 last_bit = bit;
369         }
370         return pulses;
371 }
372
373 void output_cycle_plot(const std::vector<pulse> &pulses, double calibration_factor)
374 {
375         FILE *fp = fopen("cycles.plot", "w");
376         for (unsigned i = 0; i < pulses.size(); ++i) {
377                 double cycles = pulses[i].len * calibration_factor * C64_FREQUENCY;
378                 fprintf(fp, "%f %f\n", pulses[i].time, cycles);
379         }
380         fclose(fp);
381 }
382
383 float eval_badness(const std::vector<pulse>& pulses, double calibration_factor)
384 {
385         double sum_badness = 0.0;
386         for (unsigned i = 0; i < pulses.size(); ++i) {
387                 double cycles = pulses[i].len * calibration_factor * C64_FREQUENCY;
388                 if (cycles > 2000.0) cycles = 2000.0;  // Don't make pauses arbitrarily bad.
389                 double badness = (cycles - train_snap_points[0]) * (cycles - train_snap_points[0]);
390                 for (unsigned j = 1; j < train_snap_points.size(); ++j) {
391                         badness = std::min(badness, (cycles - train_snap_points[j]) * (cycles - train_snap_points[j]));
392                 }
393                 sum_badness += badness;
394         }
395         return sqrt(sum_badness / (pulses.size() - 1));
396 }
397
398 void spsa_train(std::vector<float> &pcm, int sample_rate)
399 {
400         // Train!
401         float filter[NUM_FILTER_COEFF] = { 1.0f };  // The rest is filled with 0.
402
403         float start_c = INITIAL_C;
404         double best_badness = HUGE_VAL;
405
406         for (int n = 1; n < NUM_ITER; ++n) {
407                 float a = INITIAL_A * pow(n + A, -ALPHA);
408                 float c = start_c * pow(n, -GAMMA);
409
410                 // find a random perturbation
411                 float p[NUM_FILTER_COEFF];
412                 float filter1[NUM_FILTER_COEFF], filter2[NUM_FILTER_COEFF];
413                 for (int i = 0; i < NUM_FILTER_COEFF; ++i) {
414                         p[i] = (rand() % 2) ? 1.0 : -1.0;
415                         filter1[i] = std::max(std::min(filter[i] - c * p[i], 1.0f), -1.0f);
416                         filter2[i] = std::max(std::min(filter[i] + c * p[i], 1.0f), -1.0f);
417                 }
418
419                 std::vector<pulse> pulses1 = detect_pulses(do_filter(pcm, filter1), sample_rate);
420                 std::vector<pulse> pulses2 = detect_pulses(do_filter(pcm, filter2), sample_rate);
421                 float badness1 = eval_badness(pulses1, 1.0);
422                 float badness2 = eval_badness(pulses2, 1.0);
423
424                 // Find the gradient estimator
425                 float g[NUM_FILTER_COEFF];
426                 for (int i = 0; i < NUM_FILTER_COEFF; ++i) {
427                         g[i] = (badness2 - badness1) / (2.0 * c * p[i]);
428                         filter[i] -= a * g[i];
429                         filter[i] = std::max(std::min(filter[i], 1.0f), -1.0f);
430                 }
431                 if (badness2 < badness1) {
432                         std::swap(badness1, badness2);
433                         std::swap(filter1, filter2);
434                         std::swap(pulses1, pulses2);
435                 }
436                 if (badness1 < best_badness) {
437                         printf("\nNew best filter (badness=%f):", badness1);
438                         for (int i = 0; i < NUM_FILTER_COEFF; ++i) {
439                                 printf(" %.5f", filter1[i]);
440                         }
441                         best_badness = badness1;
442                         printf("\n");
443
444                         if (output_cycles_plot) {
445                                 output_cycle_plot(pulses1, 1.0);
446                         }
447                 }
448                 printf("%d ", n);
449                 fflush(stdout);
450         }
451 }
452
453 int main(int argc, char **argv)
454 {
455         parse_options(argc, argv);
456
457         make_lanczos_weight_table();
458         std::vector<float> pcm;
459         int sample_rate;
460         if (!read_audio_file(argv[optind], &pcm, &sample_rate)) {
461                 exit(1);
462         }
463
464         if (do_crop) {
465                 pcm = crop(pcm, crop_start, crop_end, sample_rate);
466         }
467
468         if (use_filter) {
469                 pcm = do_filter(pcm, filter_coeff);
470         }
471
472         if (do_auto_level) {
473                 pcm = level_samples(pcm, min_level, sample_rate);
474                 if (output_leveled) {
475                         FILE *fp = fopen("leveled.raw", "wb");
476                         fwrite(pcm.data(), pcm.size() * sizeof(pcm[0]), 1, fp);
477                         fclose(fp);
478                 }
479         }
480
481 #if 0
482         for (int i = 0; i < LEN; ++i) {
483                 in[i] += rand() % 10000;
484         }
485 #endif
486
487 #if 0
488         for (int i = 0; i < LEN; ++i) {
489                 printf("%d\n", in[i]);
490         }
491 #endif
492
493         if (do_train) {
494                 spsa_train(pcm, sample_rate);
495                 exit(0);
496         }
497
498         std::vector<pulse> pulses = detect_pulses(pcm, sample_rate);
499
500         double calibration_factor = 1.0;
501         if (do_calibrate) {
502                 calibration_factor = calibrate(pulses);
503         }
504
505         if (output_cycles_plot) {
506                 output_cycle_plot(pulses, calibration_factor);
507         }
508
509         output_tap(pulses, calibration_factor);
510 }