@@ -2284,6 +2284,60 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
22842284 }
22852285}
22862286
2287+ static void log_mel_spectrogram_worker_thread (int ith, const std::vector<float > &hann, const float *samples,
2288+ int n_samples, int fft_size, int fft_step, int n_threads,
2289+ const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
2290+ std::vector<float > fft_in (fft_size, 0.0 );
2291+ std::vector<float > fft_out (2 * fft_size);
2292+ int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2 );
2293+
2294+ for (int i = ith; i < mel.n_len ; i += n_threads) {
2295+ const int offset = i * fft_step;
2296+
2297+ // apply Hanning window
2298+ for (int j = 0 ; j < fft_size; j++) {
2299+ if (offset + j < n_samples) {
2300+ fft_in[j] = hann[j] * samples[offset + j];
2301+ } else {
2302+ fft_in[j] = 0.0 ;
2303+ }
2304+ }
2305+
2306+ // FFT -> mag^2
2307+ fft (fft_in, fft_out);
2308+
2309+ for (int j = 0 ; j < fft_size; j++) {
2310+ fft_out[j] = (fft_out[2 * j + 0 ] * fft_out[2 * j + 0 ] + fft_out[2 * j + 1 ] * fft_out[2 * j + 1 ]);
2311+ }
2312+ for (int j = 1 ; j < fft_size / 2 ; j++) {
2313+ fft_out[j] += fft_out[fft_size - j];
2314+ }
2315+
2316+ if (speed_up) {
2317+ // scale down in the frequency domain results in a speed up in the time domain
2318+ for (int j = 0 ; j < n_fft; j++) {
2319+ fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1 ]);
2320+ }
2321+ }
2322+
2323+ // mel spectrogram
2324+ for (int j = 0 ; j < mel.n_mel ; j++) {
2325+ double sum = 0.0 ;
2326+
2327+ for (int k = 0 ; k < n_fft; k++) {
2328+ sum += fft_out[k] * filters.data [j * n_fft + k];
2329+ }
2330+ if (sum < 1e-10 ) {
2331+ sum = 1e-10 ;
2332+ }
2333+
2334+ sum = log10 (sum);
2335+
2336+ mel.data [j * mel.n_len + i] = sum;
2337+ }
2338+ }
2339+ }
2340+
22872341// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
22882342static bool log_mel_spectrogram (
22892343 whisper_state & wstate,
@@ -2310,81 +2364,22 @@ static bool log_mel_spectrogram(
23102364 mel.n_len = (n_samples)/fft_step;
23112365 mel.data .resize (mel.n_mel *mel.n_len );
23122366
2313- const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2 );
2314-
23152367 // printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
23162368 // printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
23172369
2318- std::vector<std::thread> workers (n_threads);
2319- for (int iw = 0 ; iw < n_threads; ++iw) {
2320- workers[iw] = std::thread ([&](int ith) {
2321- std::vector<float > fft_in;
2322- fft_in.resize (fft_size);
2323- for (int i = 0 ; i < fft_size; i++) {
2324- fft_in[i] = 0.0 ;
2325- }
2326-
2327- std::vector<float > fft_out;
2328- fft_out.resize (2 *fft_size);
2329-
2330- for (int i = ith; i < mel.n_len ; i += n_threads) {
2331- const int offset = i*fft_step;
2332-
2333- // apply Hanning window
2334- for (int j = 0 ; j < fft_size; j++) {
2335- if (offset + j < n_samples) {
2336- fft_in[j] = hann[j]*samples[offset + j];
2337- } else {
2338- fft_in[j] = 0.0 ;
2339- }
2340- }
2341-
2342- // FFT -> mag^2
2343- fft (fft_in, fft_out);
2344-
2345- for (int j = 0 ; j < fft_size; j++) {
2346- fft_out[j] = (fft_out[2 *j + 0 ]*fft_out[2 *j + 0 ] + fft_out[2 *j + 1 ]*fft_out[2 *j + 1 ]);
2347- }
2348- for (int j = 1 ; j < fft_size/2 ; j++) {
2349- // if (i == 0) {
2350- // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2351- // }
2352- fft_out[j] += fft_out[fft_size - j];
2353- }
2354- if (i == 0 ) {
2355- // for (int j = 0; j < fft_size; j++) {
2356- // printf("%d: %e\n", j, fft_out[j]);
2357- // }
2358- }
2359-
2360- if (speed_up) {
2361- // scale down in the frequency domain results in a speed up in the time domain
2362- for (int j = 0 ; j < n_fft; j++) {
2363- fft_out[j] = 0.5 *(fft_out[2 *j] + fft_out[2 *j + 1 ]);
2364- }
2365- }
2366-
2367- // mel spectrogram
2368- for (int j = 0 ; j < mel.n_mel ; j++) {
2369- double sum = 0.0 ;
2370-
2371- for (int k = 0 ; k < n_fft; k++) {
2372- sum += fft_out[k]*filters.data [j*n_fft + k];
2373- }
2374- if (sum < 1e-10 ) {
2375- sum = 1e-10 ;
2376- }
2377-
2378- sum = log10 (sum);
2379-
2380- mel.data [j*mel.n_len + i] = sum;
2381- }
2382- }
2383- }, iw);
2384- }
2370+ if (n_threads == 1 ) {
2371+ log_mel_spectrogram_worker_thread (0 , hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
2372+ } else {
2373+ std::vector<std::thread> workers (n_threads);
2374+ for (int iw = 0 ; iw < n_threads; ++iw) {
2375+ workers[iw] = std::thread (log_mel_spectrogram_worker_thread, iw, std::cref (hann), samples,
2376+ n_samples, fft_size, fft_step, n_threads,
2377+ std::cref (filters), speed_up, std::ref (mel));
2378+ }
23852379
2386- for (int iw = 0 ; iw < n_threads; ++iw) {
2387- workers[iw].join ();
2380+ for (int iw = 0 ; iw < n_threads; ++iw) {
2381+ workers[iw].join ();
2382+ }
23882383 }
23892384
23902385 // clamping and normalization
0 commit comments