]> git.sesse.net Git - c64tapwav/blobdiff - decode.cpp
Merge branch 'master' of /srv/git.sesse.net/www/c64tapwav
[c64tapwav] / decode.cpp
index 3662566737d4d4fc415a525d2071f9185ce0e8b0..d2d6fc89a26aeecee613f6ddb446d35571d35365 100644 (file)
@@ -67,20 +67,27 @@ static float auto_level_freq = 200.0;
 static float min_level = 0.05f;
 
 // search for the value <limit> between [x,x+1]
+template<bool fast>
 double find_crossing(const std::vector<float> &pcm, int x, float limit)
 {
-       double upper = x;
-       double lower = x + 1;
-       while (lower - upper > 1e-3) {
-               double mid = 0.5f * (upper + lower);
-               if (lanczos_interpolate(pcm, mid) > limit) {
-                       upper = mid;
-               } else {
-                       lower = mid;
+       if (fast) {
+               // Do simple linear interpolation.
+               return x + (limit - pcm[x]) / (pcm[x + 1] - pcm[x]);
+       } else {
+               // Binary search for the zero crossing as given by Lanczos interpolation.
+               double upper = x;
+               double lower = x + 1;
+               while (lower - upper > 1e-3) {
+                       double mid = 0.5f * (upper + lower);
+                       if (lanczos_interpolate(pcm, mid) > limit) {
+                               upper = mid;
+                       } else {
+                               lower = mid;
+                       }
                }
-       }
 
-       return 0.5f * (upper + lower);
+               return 0.5f * (upper + lower);
+       }
 }
 
 struct pulse {
@@ -188,6 +195,7 @@ static struct option long_options[] = {
        {"rc-filter",        required_argument, 0, 'r' },
        {"output-filtered",  0,                 0, 'F' },
        {"crop",             required_argument, 0, 'c' },
+       {"train",            required_argument, 0, 't' },
        {"quiet",            0,                 0, 'q' },
        {"help",             0,                 0, 'h' },
        {0,                  0,                 0, 0   }
@@ -396,6 +404,7 @@ std::vector<float> do_rc_filter(const std::vector<float>& pcm, float freq, int s
        return filtered_pcm;
 }
 
+template<bool fast>
 std::vector<pulse> detect_pulses(const std::vector<float> &pcm, float hysteresis_upper_limit, float hysteresis_lower_limit, int sample_rate)
 {
        std::vector<pulse> pulses;
@@ -409,7 +418,7 @@ std::vector<pulse> detect_pulses(const std::vector<float> &pcm, float hysteresis
                } else if (pcm[i] < hysteresis_lower_limit) {
                        if (state == ABOVE) {
                                // down-flank!
-                               double t = find_crossing(pcm, i - 1, hysteresis_lower_limit) * (1.0 / sample_rate) + crop_start;
+                               double t = find_crossing<fast>(pcm, i - 1, hysteresis_lower_limit) * (1.0 / sample_rate) + crop_start;
                                if (last_downflank > 0) {
                                        pulse p;
                                        p.time = t;
@@ -526,8 +535,8 @@ void spsa_train(const std::vector<float> &pcm, int sample_rate)
                        vals2[i] = std::max(std::min(vals[i] + c * p[i], 1.0f), -1.0f);
                }
 
-               std::vector<pulse> pulses1 = detect_pulses(do_fir_filter(pcm, vals1 + 2), vals1[0], vals1[1], sample_rate);
-               std::vector<pulse> pulses2 = detect_pulses(do_fir_filter(pcm, vals2 + 2), vals2[0], vals2[1], sample_rate);
+               std::vector<pulse> pulses1 = detect_pulses<true>(do_fir_filter(pcm, vals1 + 2), vals1[0], vals1[1], sample_rate);
+               std::vector<pulse> pulses2 = detect_pulses<true>(do_fir_filter(pcm, vals2 + 2), vals2[0], vals2[1], sample_rate);
                float badness1 = eval_badness(pulses1, 1.0);
                float badness2 = eval_badness(pulses2, 1.0);
 
@@ -611,7 +620,7 @@ int main(int argc, char **argv)
                exit(0);
        }
 
-       std::vector<pulse> pulses = detect_pulses(pcm, hysteresis_upper_limit, hysteresis_lower_limit, sample_rate);
+       std::vector<pulse> pulses = detect_pulses<false>(pcm, hysteresis_upper_limit, hysteresis_lower_limit, sample_rate);
 
        double calibration_factor = 1.0;
        if (do_calibrate) {