]> git.sesse.net Git - nageru/blob - nageru/audio_clip.cpp
Make an automated delay estimate, by way of cross-correlation.
[nageru] / nageru / audio_clip.cpp
1 #include "audio_clip.h"
2
3 #include <assert.h>
4 #include <math.h>
5 #include <zita-resampler/resampler.h>
6
7 #include <strings.h>
8
9 extern "C" {
10 #include <libavcodec/avfft.h>
11 }
12
13 #include <complex>
14
15 using namespace std;
16 using namespace std::chrono;
17
18 void AudioClip::clear()
19 {
20         lock_guard<mutex> lock(mu);
21         vals.clear();
22 }
23
24 void AudioClip::add_audio(const float *samples, size_t num_samples, double sample_rate, std::chrono::steady_clock::time_point frame_time)
25 {
26         lock_guard<mutex> lock(mu);
27         if (!vals.empty() && sample_rate != this->sample_rate) {
28                 vals.clear();
29         }
30         if (vals.empty()) {
31                 first_sample = frame_time;
32         }
33         this->sample_rate = sample_rate;
34         vals.insert(vals.end(), samples, samples + num_samples);
35 }
36
37 double AudioClip::get_length_seconds() const
38 {
39         lock_guard<mutex> lock(mu);
40         if (vals.empty()) {
41                 return 0.0;
42         }
43
44         return double(vals.size()) / sample_rate;
45 }
46
47 double AudioClip::get_length_seconds_after_base(std::chrono::steady_clock::time_point base) const
48 {
49         lock_guard<mutex> lock(mu);
50         if (vals.empty() || base < first_sample) {
51                 // NOTE: base < first_sample can happen only during race conditions.
52                 return 0.0;
53         }
54
55         return double(vals.size()) / sample_rate - duration<double>(base - first_sample).count();
56 }
57
58 bool AudioClip::empty() const
59 {
60         lock_guard<mutex> lock(mu);
61         return vals.empty();
62 }
63
64 steady_clock::time_point AudioClip::get_first_sample() const
65 {
66         lock_guard<mutex> lock(mu);
67         assert(!vals.empty());
68         return first_sample;
69 }
70
71 namespace {
72
73 float avg(const vector<float> &vals)
74 {
75         float sum = 0.0f;
76         for (float val : vals) {
77                 sum += val;
78         }
79         return sum / vals.size();
80 }
81
82 void remove_dc(vector<float> *vals)
83 {
84         float dc = avg(*vals);
85         for (float &val : *vals) {
86                 val -= dc;
87         }
88 }
89
90 double find_sumsq(const vector<double> &sum_xx, size_t first, size_t last)
91 {
92         if (first == 0) {
93                 return sum_xx[last - 1];
94         } else {
95                 return sum_xx[last - 1] - sum_xx[first - 1];
96         }
97 }
98
99 vector<double> calc_sumsq_table(const vector<float> &vals)
100 {
101         vector<double> ret;
102         ret.resize(vals.size());
103
104         double sum = 0.0;
105         for (size_t i = 0; i < vals.size(); ++i) {
106                 sum += vals[i] * vals[i];
107                 ret[i] = sum;
108         }
109         return ret;
110 }
111
112 vector<float> resample(const vector<float> &vals, float src_freq, float dst_freq)
113 {
114         Resampler resampler;
115         resampler.setup(src_freq, dst_freq, /*num_channels=*/1, /*hlen=*/32, /*frel=*/1.0);
116
117         // Prefill the sampler so that we don't get any extra delay from it.
118         resampler.inp_data = nullptr;
119         resampler.inp_count = resampler.inpsize() / 2 - 1;
120         resampler.out_data = nullptr;
121         resampler.out_count = 1024;
122         resampler.process();
123
124         vector<float> out;
125         out.resize(lrint(vals.size() * dst_freq / src_freq));
126
127         // Do the actual resampling.
128         resampler.inp_data = const_cast<float *>(vals.data());
129         resampler.inp_count = vals.size();
130         resampler.out_data = &out[0];
131         resampler.out_count = out.size();
132         resampler.process();
133
134         return out;
135 }
136
137 unique_ptr<FFTSample[], decltype(free)*> pad(const vector<float> &x, size_t N)
138 {
139         assert(x.size() <= N);
140
141         // avfft requires AVX alignment.
142         void *ptr;
143         if (posix_memalign(&ptr, 32, N * sizeof(FFTSample)) != 0) {
144                 perror("posix_memalign");
145                 abort();
146         }
147         unique_ptr<FFTSample[], decltype(free)*> ret(reinterpret_cast<FFTSample *>(ptr), free);
148         for (size_t i = 0; i < x.size(); ++i) {
149                 ret[i] = x[i];
150         }
151         for (size_t i = x.size(); i < N; ++i) {
152                 ret[i] = 0.0f;
153         }
154         return ret;
155 }
156
157 unsigned round_up_to_pow2(unsigned x)
158 {
159         --x;
160         x |= x >> 1;
161         x |= x >> 2;
162         x |= x >> 4;
163         x |= x >> 8;
164         x |= x >> 16;
165         ++x;
166         return x;
167 }
168
169 // Calculate the unnormalized circular correlation between x and y, both zero-padded to length N.
170 unique_ptr<FFTSample[], decltype(free)*> pad_and_correlate(const vector<float> &x, const vector<float> &y, unsigned N)
171 {
172         assert((N & (N - 1)) == 0);
173
174         unsigned bits = ffs(N) - 1;
175         unique_ptr<FFTSample[], decltype(free)*> padded_x = pad(x, N);
176         unique_ptr<FFTSample[], decltype(free)*> padded_y = pad(y, N);
177
178         RDFTContext *fft = av_rdft_init(bits, DFT_R2C);
179         av_rdft_calc(fft, &padded_x[0]);
180         av_rdft_calc(fft, &padded_y[0]);
181         av_rdft_end(fft);
182
183         // Now that we have FFT(X) and FFT(Y), we can compute FFT(result) = conj(FFT(X)) * FFT(Y).
184         // We reuse the X array for the result.
185
186         // The two first elements are the real values of the lowest (DC) and highest bin.
187         // (These have zero imaginary value, so that's not stored anywhere.)
188         padded_x[0] *= padded_y[0];
189         padded_x[1] *= padded_y[1];
190
191         // The remaining elements are complex values for each bin.
192         for (size_t i = 1; i < N / 2; ++i) {
193                 complex<float> xc(padded_x[i * 2 + 0], padded_x[i * 2 + 1]);
194                 complex<float> yc(padded_y[i * 2 + 0], padded_y[i * 2 + 1]);
195                 complex<float> p = conj(xc) * yc;
196                 padded_x[i * 2 + 0] = p.real();
197                 padded_x[i * 2 + 1] = p.imag();
198         }
199
200         RDFTContext *ifft = av_rdft_init(bits, IDFT_C2R);
201         av_rdft_calc(ifft, &padded_x[0]);
202         av_rdft_end(ifft);
203
204         return padded_x;
205 }
206
207 } // namespace
208
209 AudioClip::BestCorrelation AudioClip::find_best_correlation(const AudioClip *reference) const
210 {
211         // We estimate the delay between the two clips by way of (normalized) cross-correlation;
212         // if they are perfectly in sync, the cross-correlation will have its maximum at τ=0,
213         // if they're 100 samples off, the maximum will be at τ=-100 or τ=+100
214         // (depending on which one is later), and so on. This gives us single-sample accuracy,
215         // which is good enough for our use, and it reasonably cheap and simple to compute.
216         //
217         // The definition of the cross-correlation (where all sums are over the entire range i) is
218         //
219         //  R_xy(τ) = sum(x[i] y[i + τ]) / sqrt(sum(x[i]²) sum(y[i]²))
220         //
221         // This assumes a zero mean; if not, the signals will need to be normalized first.
222         // (We do this below.) It also assumes real-valued signals.
223         //
224         // We want to evaluate this over all τ as long as there is reasonable overlap,
225         // truncating the signals to the part where there is overlap (so the sums of x[i]²
226         // and y[i]² will also depend on τ, even though it isn't apparent in the formula).
227         // We could have done this by brute force; it would take about 70 ms, or a bit less
228         // with AVX-optimized sums. However, we can optimize it significantly with some
229         // standard algorithmic trickery:
230         //
231         // First of all, for the squared sums, we can simply use sum tables; compute the
232         // cumulative sum of x[i]² and y[i]², and then we can compute sum(x[i]², i=a..b)
233         // for any a and b by just looking up in the table and doing a subtraction.
234         //
235         // Calculating the sum_xy[τ] = sum(x[i] y[i + τ]) part efficiently for all τ
236         // requires a little signal processing. We use the identity that
237         //
238         //    FFT(sum_xy) = conj(FFT(x)) * FFT(y)
239         //
240         // or equivalently
241         //
242         //    sum_xy = IFFT(conj(FFT(x)) * FFT(y))
243         //
244         // where conj() is complex conjugate and * is (complex) pointwise multiplication.
245         // Since FFT and IFFT can be calculated in O(n log n) time, and we already link to ffmpeg,
246         // which supplies a reasonably efficient implementation, this gives us a significant
247         // speedup (down to 3–4 ms or less) and a convenient way to calculate sum_xy.
248         // FFT gives us _circular_ convolution, so sum_xy[0] is for τ=0, sum_xy[1] is for τ=1,
249         // sum_xy[N - 1] is for τ=-1 and so on. This also means we'll need to take care to
250         // zero-pad our signals so that we don't get values corrupted by wraparound.
251         // (We also need to pad to a power of two, since ffmpeg's FFT doesn't support other sizes.)
252
253         // Resample to 16 kHz, so we have a consistent sample rate between the two.
254         // It also helps performance (and removes HF noise, although noise
255         // shouldn't be a big problem for the algorithm). Note that ffmpeg's
256         // FFT drops down to a less optimized implementation for N > 65536;
257         // with our choice of parameters, N = 32768 typically, so we should be fine.
258         //
259         // We don't do sub-sample precision, so this takes us to 0.06 ms, which
260         // is more than good enough.
261         constexpr float freq_hz = 16000.0;
262         vector<float> x = resample(vals, sample_rate, freq_hz);
263         vector<float> y = resample(reference->vals, sample_rate, freq_hz);
264
265         // There should normally be no DC, but let's be sure.
266         remove_dc(&x);
267         remove_dc(&y);
268
269         // Truncate the clips so that they start at the same point.
270         if (first_sample < reference->first_sample) {
271                 int trunc_samples = lrintf(duration<double>(reference->first_sample - first_sample).count() * freq_hz);
272                 assert(trunc_samples >= 0);
273                 if (size_t(trunc_samples) >= x.size()) {
274                         return { 0.0, 0.0f / 0.0f };
275                 }
276                 x.erase(x.begin(), x.begin() + trunc_samples);
277         } else if (reference->first_sample < first_sample) {
278                 int trunc_samples = lrintf(duration<double>(first_sample - reference->first_sample).count() * freq_hz);
279                 assert(trunc_samples >= 0);
280                 if (size_t(trunc_samples) >= y.size()) {
281                         return { 0.0, 0.0f / 0.0f };
282                 }
283                 y.erase(y.begin(), y.begin() + trunc_samples);
284         }
285
286         // Truncate the clips to one second.
287         if (x.size() > int(freq_hz)) x.resize(int(freq_hz));
288         if (y.size() > int(freq_hz)) y.resize(int(freq_hz));
289
290         // Find the cumulative sum of squares. Doing this once will allow us to
291         // find sum(x[i]², i=a..b) essentially for free in all the iterations below.
292         vector<double> cumsum_xx = calc_sumsq_table(x);
293         vector<double> cumsum_yy = calc_sumsq_table(y);
294
295         unsigned N = round_up_to_pow2(x.size() + y.size() - 1);
296         unique_ptr<FFTSample[], decltype(free)*> padded_sum_xy = pad_and_correlate(x, y, N);
297
298         // We only search for delays such that there's at least 100 ms overlap between them.
299         // Less than that, and the chance of a spurious match feels rather high
300         // (and 900 ms delay measured on a one-second window sounds pretty excessive!).
301         constexpr int min_overlap = freq_hz * 0.1;
302         BestCorrelation best_corr { 0.0, 0.0f / 0.0f };  // NaN.
303
304         // First check candidates where y (the reference) has to be moved later to fit
305         // (ie., x, ourselves, is delayed -- positive delay).
306         for (size_t delay_y = 1; delay_y < x.size() - min_overlap; ++delay_y) {
307                 size_t overlap = std::min(x.size() - delay_y, y.size());
308                 float sum_xy = padded_sum_xy[N - int(delay_y)] * (2.0 / N);
309                 float sum_xx = find_sumsq(cumsum_xx, delay_y, delay_y + overlap);
310                 float sum_yy = find_sumsq(cumsum_yy, 0, overlap);
311                 float corr = sum_xy / sqrt(sum_xx * sum_yy);
312
313                 if (isnan(best_corr.correlation) || fabs(corr) > fabs(best_corr.correlation)) {
314                         best_corr = BestCorrelation { int(delay_y) * 1000.0f / freq_hz, corr };
315                 }
316         }
317
318         // Then where x (ourselves) has to be moved later to fit (ie., negative delay).
319         for (size_t delay_x = 0; delay_x < y.size() - min_overlap; ++delay_x) {
320                 size_t overlap = std::min(x.size(), y.size() - delay_x);
321                 float sum_xy = padded_sum_xy[delay_x] * (2.0 / N);
322                 float sum_xx = find_sumsq(cumsum_xx, 0, overlap);
323                 float sum_yy = find_sumsq(cumsum_yy, delay_x, delay_x + overlap);
324                 float corr = sum_xy / sqrt(sum_xx * sum_yy);
325
326                 if (isnan(best_corr.correlation) || fabs(corr) > fabs(best_corr.correlation)) {
327                         best_corr = BestCorrelation { -int(delay_x) * 1000.0f / freq_hz, corr };
328                 }
329         }
330
331         return best_corr;
332 }
333
334 unique_ptr<pair<float, float>[]> AudioClip::get_min_max_peaks(unsigned width, steady_clock::time_point base) const
335 {
336         unique_ptr<pair<float, float>[]> min_max(new pair<float, float>[width]);
337         for (unsigned x = 0; x < width; ++x) {
338                 min_max[x].first = min_max[x].second = 0.0 / 0.0;  // NaN.
339         }
340
341         lock_guard<mutex> lock(mu);
342         double skip_samples = duration<double>(base - first_sample).count() * sample_rate;
343         for (size_t i = int(floor(skip_samples)); i < vals.size(); ++i) {
344                 // We display one second.
345                 int x = lrint((i - skip_samples) * (double(width) / sample_rate));
346                 if (x < 0 || x >= int(width)) continue;
347                 if (isnan(min_max[x].first)) {
348                         min_max[x].first = min_max[x].second = 0.0;
349                 }
350                 min_max[x].first = min(min_max[x].first, vals[i]);
351                 min_max[x].second = max(min_max[x].second, vals[i]);
352         }
353
354         return min_max;
355 }