X-Git-Url: https://git.sesse.net/?p=c64tapwav;a=blobdiff_plain;f=decode.cpp;h=d2d6fc89a26aeecee613f6ddb446d35571d35365;hp=a90ae8e6944167dc6051fb69c1f7e3f3450d8f2e;hb=HEAD;hpb=31b7a4c1f1b1a90c0598f67b9f859fc66c53debb diff --git a/decode.cpp b/decode.cpp index a90ae8e..d2d6fc8 100644 --- a/decode.cpp +++ b/decode.cpp @@ -7,6 +7,9 @@ #include #include #include +#ifdef __AVX__ +#include +#endif #include #include @@ -64,20 +67,27 @@ static float auto_level_freq = 200.0; static float min_level = 0.05f; // search for the value between [x,x+1] +template double find_crossing(const std::vector &pcm, int x, float limit) { - double upper = x; - double lower = x + 1; - while (lower - upper > 1e-3) { - double mid = 0.5f * (upper + lower); - if (lanczos_interpolate(pcm, mid) > limit) { - upper = mid; - } else { - lower = mid; + if (fast) { + // Do simple linear interpolation. + return x + (limit - pcm[x]) / (pcm[x + 1] - pcm[x]); + } else { + // Binary search for the zero crossing as given by Lanczos interpolation. + double upper = x; + double lower = x + 1; + while (lower - upper > 1e-3) { + double mid = 0.5f * (upper + lower); + if (lanczos_interpolate(pcm, mid) > limit) { + upper = mid; + } else { + lower = mid; + } } - } - return 0.5f * (upper + lower); + return 0.5f * (upper + lower); + } } struct pulse { @@ -185,6 +195,7 @@ static struct option long_options[] = { {"rc-filter", required_argument, 0, 'r' }, {"output-filtered", 0, 0, 'F' }, {"crop", required_argument, 0, 'c' }, + {"train", required_argument, 0, 't' }, {"quiet", 0, 0, 'q' }, {"help", 0, 0, 'h' }, {0, 0, 0, 0 } @@ -328,17 +339,30 @@ std::vector crop(const std::vector& pcm, float crop_start, float c return std::vector(pcm.begin() + start_sample, pcm.begin() + end_sample); } -// TODO: Support AVX here. std::vector do_fir_filter(const std::vector& pcm, const float* filter) { std::vector filtered_pcm; - filtered_pcm.reserve(pcm.size()); - for (unsigned i = NUM_FILTER_COEFF; i < pcm.size(); ++i) { + filtered_pcm.resize(pcm.size()); + unsigned i = NUM_FILTER_COEFF; +#ifdef __AVX__ + unsigned avx_end = i + ((pcm.size() - i) & ~7); + for ( ; i < avx_end; i += 8) { + __m256 s = _mm256_setzero_ps(); + for (int j = 0; j < NUM_FILTER_COEFF; ++j) { + __m256 f = _mm256_set1_ps(filter[j]); + s = _mm256_fmadd_ps(f, _mm256_load_ps(&pcm[i - j]), s); + } + _mm256_storeu_ps(&filtered_pcm[i], s); + } +#endif + // Do what we couldn't do with AVX (which is everything for non-AVX machines) + // as scalar code. + for (; i < pcm.size(); ++i) { float s = 0.0f; for (int j = 0; j < NUM_FILTER_COEFF; ++j) { s += filter[j] * pcm[i - j]; } - filtered_pcm.push_back(s); + filtered_pcm[i] = s; } if (output_filtered) { @@ -380,6 +404,7 @@ std::vector do_rc_filter(const std::vector& pcm, float freq, int s return filtered_pcm; } +template std::vector detect_pulses(const std::vector &pcm, float hysteresis_upper_limit, float hysteresis_lower_limit, int sample_rate) { std::vector pulses; @@ -393,7 +418,7 @@ std::vector detect_pulses(const std::vector &pcm, float hysteresis } else if (pcm[i] < hysteresis_lower_limit) { if (state == ABOVE) { // down-flank! - double t = find_crossing(pcm, i - 1, hysteresis_lower_limit) * (1.0 / sample_rate) + crop_start; + double t = find_crossing(pcm, i - 1, hysteresis_lower_limit) * (1.0 / sample_rate) + crop_start; if (last_downflank > 0) { pulse p; p.time = t; @@ -510,8 +535,8 @@ void spsa_train(const std::vector &pcm, int sample_rate) vals2[i] = std::max(std::min(vals[i] + c * p[i], 1.0f), -1.0f); } - std::vector pulses1 = detect_pulses(do_fir_filter(pcm, vals1 + 2), vals1[0], vals1[1], sample_rate); - std::vector pulses2 = detect_pulses(do_fir_filter(pcm, vals2 + 2), vals2[0], vals2[1], sample_rate); + std::vector pulses1 = detect_pulses(do_fir_filter(pcm, vals1 + 2), vals1[0], vals1[1], sample_rate); + std::vector pulses2 = detect_pulses(do_fir_filter(pcm, vals2 + 2), vals2[0], vals2[1], sample_rate); float badness1 = eval_badness(pulses1, 1.0); float badness2 = eval_badness(pulses2, 1.0); @@ -528,11 +553,11 @@ void spsa_train(const std::vector &pcm, int sample_rate) std::swap(pulses1, pulses2); } if (badness1 < best_badness) { - printf("\nNew best filter (badness=%f):", badness1); + fprintf(stderr, "\nNew best filter (badness=%f):", badness1); for (int i = 0; i < NUM_FILTER_COEFF; ++i) { - printf(" %.5f", vals1[i + 2]); + fprintf(stderr, " %.5f", vals1[i + 2]); } - printf(", hysteresis limits = %f %f\n", vals1[0], vals1[1]); + fprintf(stderr, ", hysteresis limits = %f %f\n", vals1[0], vals1[1]); best_badness = badness1; find_kmeans(pulses1, 1.0, train_snap_points); @@ -541,8 +566,8 @@ void spsa_train(const std::vector &pcm, int sample_rate) output_cycle_plot(pulses1, 1.0); } } - printf("%d ", n); - fflush(stdout); + fprintf(stderr, "%d ", n); + fflush(stderr); } } @@ -595,7 +620,7 @@ int main(int argc, char **argv) exit(0); } - std::vector pulses = detect_pulses(pcm, hysteresis_upper_limit, hysteresis_lower_limit, sample_rate); + std::vector pulses = detect_pulses(pcm, hysteresis_upper_limit, hysteresis_lower_limit, sample_rate); double calibration_factor = 1.0; if (do_calibrate) {