]> git.sesse.net Git - c64tapwav/blobdiff - decode.cpp
Add a simple README for license reasons.
[c64tapwav] / decode.cpp
index f9e525eeea07b8ebaf351d87d9870be36d1c22cb..d2d6fc89a26aeecee613f6ddb446d35571d35365 100644 (file)
@@ -7,6 +7,9 @@
 #include <assert.h>
 #include <limits.h>
 #include <getopt.h>
+#ifdef __AVX__
+#include <immintrin.h>
+#endif
 #include <vector>
 #include <algorithm>
 
@@ -64,20 +67,27 @@ static float auto_level_freq = 200.0;
 static float min_level = 0.05f;
 
 // search for the value <limit> between [x,x+1]
+template<bool fast>
 double find_crossing(const std::vector<float> &pcm, int x, float limit)
 {
-       double upper = x;
-       double lower = x + 1;
-       while (lower - upper > 1e-3) {
-               double mid = 0.5f * (upper + lower);
-               if (lanczos_interpolate(pcm, mid) > limit) {
-                       upper = mid;
-               } else {
-                       lower = mid;
+       if (fast) {
+               // Do simple linear interpolation.
+               return x + (limit - pcm[x]) / (pcm[x + 1] - pcm[x]);
+       } else {
+               // Binary search for the zero crossing as given by Lanczos interpolation.
+               double upper = x;
+               double lower = x + 1;
+               while (lower - upper > 1e-3) {
+                       double mid = 0.5f * (upper + lower);
+                       if (lanczos_interpolate(pcm, mid) > limit) {
+                               upper = mid;
+                       } else {
+                               lower = mid;
+                       }
                }
-       }
 
-       return 0.5f * (upper + lower);
+               return 0.5f * (upper + lower);
+       }
 }
 
 struct pulse {
@@ -185,6 +195,7 @@ static struct option long_options[] = {
        {"rc-filter",        required_argument, 0, 'r' },
        {"output-filtered",  0,                 0, 'F' },
        {"crop",             required_argument, 0, 'c' },
+       {"train",            required_argument, 0, 't' },
        {"quiet",            0,                 0, 'q' },
        {"help",             0,                 0, 'h' },
        {0,                  0,                 0, 0   }
@@ -328,17 +339,30 @@ std::vector<float> crop(const std::vector<float>& pcm, float crop_start, float c
        return std::vector<float>(pcm.begin() + start_sample, pcm.begin() + end_sample);
 }
 
-// TODO: Support AVX here.
 std::vector<float> do_fir_filter(const std::vector<float>& pcm, const float* filter)
 {
        std::vector<float> filtered_pcm;
-       filtered_pcm.reserve(pcm.size());
-       for (unsigned i = NUM_FILTER_COEFF; i < pcm.size(); ++i) {
+       filtered_pcm.resize(pcm.size());
+       unsigned i = NUM_FILTER_COEFF;
+#ifdef __AVX__
+       unsigned avx_end = i + ((pcm.size() - i) & ~7);
+       for ( ; i < avx_end; i += 8) {
+               __m256 s = _mm256_setzero_ps();
+               for (int j = 0; j < NUM_FILTER_COEFF; ++j) {
+                       __m256 f = _mm256_set1_ps(filter[j]);
+                       s = _mm256_fmadd_ps(f, _mm256_load_ps(&pcm[i - j]), s);
+               }
+               _mm256_storeu_ps(&filtered_pcm[i], s);
+       }
+#endif
+       // Do what we couldn't do with AVX (which is everything for non-AVX machines)
+       // as scalar code.
+       for (; i < pcm.size(); ++i) {
                float s = 0.0f;
                for (int j = 0; j < NUM_FILTER_COEFF; ++j) {
                        s += filter[j] * pcm[i - j];
                }
-               filtered_pcm.push_back(s);
+               filtered_pcm[i] = s;
        }
 
        if (output_filtered) {
@@ -380,6 +404,7 @@ std::vector<float> do_rc_filter(const std::vector<float>& pcm, float freq, int s
        return filtered_pcm;
 }
 
+template<bool fast>
 std::vector<pulse> detect_pulses(const std::vector<float> &pcm, float hysteresis_upper_limit, float hysteresis_lower_limit, int sample_rate)
 {
        std::vector<pulse> pulses;
@@ -393,7 +418,7 @@ std::vector<pulse> detect_pulses(const std::vector<float> &pcm, float hysteresis
                } else if (pcm[i] < hysteresis_lower_limit) {
                        if (state == ABOVE) {
                                // down-flank!
-                               double t = find_crossing(pcm, i - 1, hysteresis_lower_limit) * (1.0 / sample_rate) + crop_start;
+                               double t = find_crossing<fast>(pcm, i - 1, hysteresis_lower_limit) * (1.0 / sample_rate) + crop_start;
                                if (last_downflank > 0) {
                                        pulse p;
                                        p.time = t;
@@ -510,8 +535,8 @@ void spsa_train(const std::vector<float> &pcm, int sample_rate)
                        vals2[i] = std::max(std::min(vals[i] + c * p[i], 1.0f), -1.0f);
                }
 
-               std::vector<pulse> pulses1 = detect_pulses(do_fir_filter(pcm, vals1 + 2), vals1[0], vals1[1], sample_rate);
-               std::vector<pulse> pulses2 = detect_pulses(do_fir_filter(pcm, vals2 + 2), vals2[0], vals2[1], sample_rate);
+               std::vector<pulse> pulses1 = detect_pulses<true>(do_fir_filter(pcm, vals1 + 2), vals1[0], vals1[1], sample_rate);
+               std::vector<pulse> pulses2 = detect_pulses<true>(do_fir_filter(pcm, vals2 + 2), vals2[0], vals2[1], sample_rate);
                float badness1 = eval_badness(pulses1, 1.0);
                float badness2 = eval_badness(pulses2, 1.0);
 
@@ -595,7 +620,7 @@ int main(int argc, char **argv)
                exit(0);
        }
 
-       std::vector<pulse> pulses = detect_pulses(pcm, hysteresis_upper_limit, hysteresis_lower_limit, sample_rate);
+       std::vector<pulse> pulses = detect_pulses<false>(pcm, hysteresis_upper_limit, hysteresis_lower_limit, sample_rate);
 
        double calibration_factor = 1.0;
        if (do_calibrate) {