@@ -73,6 +73,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
73
73
ruby_whisper_params *rwp;
74
74
rwp = ALLOC (ruby_whisper_params);
75
75
rwp->params = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
76
+ rwp->new_segment_callback = Qnil;
76
77
return Data_Wrap_Struct (klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
77
78
}
78
79
@@ -205,6 +206,28 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
205
206
};
206
207
rwp->params .encoder_begin_callback_user_data = &is_aborted;
207
208
}
209
+ {
210
+ // This cannot be used later because it is not incremented when new_segment_callback is not given.
211
+ static int n_segments = 0 ;
212
+
213
+ rwp->params .new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
214
+ VALUE callback = *(VALUE *)user_data;
215
+ if (NIL_P (callback)){
216
+ return ;
217
+ }
218
+
219
+ for (int i = 0 ; i < n_new; i++) {
220
+ const int i_segment = n_segments + i;
221
+ const char * text = whisper_full_get_segment_text_from_state (state, i_segment);
222
+ // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it
223
+ const int64_t t0 = whisper_full_get_segment_t0_from_state (state, i_segment) * 10 ;
224
+ const int64_t t1 = whisper_full_get_segment_t1_from_state (state, i_segment) * 10 ;
225
+ rb_funcall (callback, rb_intern (" call" ), 4 , rb_str_new2 (text), INT2NUM (t0), INT2NUM (t1), INT2FIX (i_segment));
226
+ }
227
+ n_segments += n_new;
228
+ };
229
+ rwp->params .new_segment_callback_user_data = &rwp->new_segment_callback ;
230
+ }
208
231
209
232
if (whisper_full_parallel (rw->context , rwp->params , pcmf32.data (), pcmf32.size (), 1 ) != 0 ) {
210
233
fprintf (stderr, " failed to process audio\n " );
@@ -365,6 +388,12 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
365
388
rwp->params .n_max_text_ctx = NUM2INT (value);
366
389
return value;
367
390
}
391
+ static VALUE ruby_whisper_params_set_new_segment_callback (VALUE self, VALUE value) {
392
+ ruby_whisper_params *rwp;
393
+ Data_Get_Struct (self, ruby_whisper_params, rwp);
394
+ rwp->new_segment_callback = value;
395
+ return value;
396
+ }
368
397
369
398
void Init_whisper () {
370
399
mWhisper = rb_define_module (" Whisper" );
@@ -412,6 +441,8 @@ void Init_whisper() {
412
441
413
442
rb_define_method (cParams, " max_text_tokens" , ruby_whisper_params_get_max_text_tokens, 0 );
414
443
rb_define_method (cParams, " max_text_tokens=" , ruby_whisper_params_set_max_text_tokens, 1 );
444
+
445
+ rb_define_method (cParams, " new_segment_callback=" , ruby_whisper_params_set_new_segment_callback, 1 );
415
446
}
416
447
#ifdef __cplusplus
417
448
}
0 commit comments