@@ -215,12 +215,14 @@ impl<'a> SessionBuilder<'a> {
215
215
let outputs = ( 0 ..num_output_nodes)
216
216
. map ( |i| dangerous:: extract_output ( session_ptr, allocator_ptr, i) )
217
217
. collect :: < Result < Vec < Output > > > ( ) ?;
218
+ let input_ort_values = Vec :: with_capacity ( num_output_nodes as usize ) ;
218
219
219
220
Ok ( Session {
220
221
env : self . env ,
221
222
session_ptr,
222
223
allocator_ptr,
223
224
memory_info,
225
+ input_ort_values,
224
226
inputs,
225
227
outputs,
226
228
} )
@@ -271,12 +273,14 @@ impl<'a> SessionBuilder<'a> {
271
273
let outputs = ( 0 ..num_output_nodes)
272
274
. map ( |i| dangerous:: extract_output ( session_ptr, allocator_ptr, i) )
273
275
. collect :: < Result < Vec < Output > > > ( ) ?;
276
+ let input_ort_values = Vec :: with_capacity ( num_output_nodes as usize ) ;
274
277
275
278
Ok ( Session {
276
279
env : self . env ,
277
280
session_ptr,
278
281
allocator_ptr,
279
282
memory_info,
283
+ input_ort_values,
280
284
inputs,
281
285
outputs,
282
286
} )
@@ -290,6 +294,7 @@ pub struct Session<'a> {
290
294
session_ptr : * mut sys:: OrtSession ,
291
295
allocator_ptr : * mut sys:: OrtAllocator ,
292
296
memory_info : MemoryInfo ,
297
+ input_ort_values : Vec < * const sys:: OrtValue > ,
293
298
/// Information about the ONNX's inputs as stored in loaded file
294
299
pub inputs : Vec < Input > ,
295
300
/// Information about the ONNX's outputs as stored in loaded file
@@ -357,6 +362,26 @@ impl<'a> Drop for Session<'a> {
357
362
}
358
363
359
364
impl < ' a > Session < ' a > {
365
+ /// Somedoc
366
+ pub fn feed < ' s , ' t , ' m , TIn , D > ( & ' s mut self , input_array : Array < TIn , D > ) -> Result < ( ) >
367
+ where
368
+ TIn : TypeToTensorElementDataType + Debug + Clone ,
369
+ D : ndarray:: Dimension ,
370
+ ' m : ' t , // 'm outlives 't (memory info outlives tensor)
371
+ ' s : ' m , // 's outlives 'm (session outlives memory info)
372
+ {
373
+ self . validate_input_shapes ( & input_array) ?;
374
+ // The C API expects pointers for the arrays (pointers to C-arrays)
375
+ let input_ort_tensor: OrtTensor < TIn , D > =
376
+ OrtTensor :: from_array ( & self . memory_info , self . allocator_ptr , input_array) ?;
377
+
378
+ let input_ort_value: * const sys:: OrtValue = input_ort_tensor. c_ptr as * const sys:: OrtValue ;
379
+ std:: mem:: forget ( input_ort_tensor) ;
380
+ self . input_ort_values . push ( input_ort_value) ;
381
+
382
+ Ok ( ( ) )
383
+ }
384
+
360
385
/// Run the input data through the ONNX graph, performing inference.
361
386
///
362
387
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
@@ -371,9 +396,46 @@ impl<'a> Session<'a> {
371
396
' m : ' t , // 'm outlives 't (memory info outlives tensor)
372
397
' s : ' m , // 's outlives 'm (session outlives memory info)
373
398
{
374
- self . validate_input_shapes ( & input_arrays) ?;
375
-
399
+ input_arrays
400
+ . into_iter ( )
401
+ . for_each ( |input_array| self . feed ( input_array) . unwrap ( ) ) ;
402
+ self . inner_run ( )
403
+ }
404
+ /// Run the input data through the ONNX graph, performing inference.
405
+ ///
406
+ /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
407
+ /// used for the input data here.
408
+ pub fn inner_run < ' s , ' t , ' m > (
409
+ & ' s mut self ,
410
+ // input_arrays: Vec<Array<TIn, D>>,
411
+ ) -> Result < Vec < DynOrtTensor < ' m , ndarray:: IxDyn > > >
412
+ where
413
+ ' m : ' t , // 'm outlives 't (memory info outlives tensor)
414
+ ' s : ' m , // 's outlives 'm (session outlives memory info)
415
+ {
376
416
// Build arguments to Run()
417
+ if self . input_ort_values . len ( ) != self . inputs . len ( ) {
418
+ error ! (
419
+ "Non-matching number of inputs: {} (inference) vs {} (model)" ,
420
+ self . input_ort_values. len( ) ,
421
+ self . inputs. len( )
422
+ ) ;
423
+ return Err ( OrtError :: NonMatchingDimensions (
424
+ NonMatchingDimensionsError :: InputsCount {
425
+ inference_input_count : 0 ,
426
+ model_input_count : 0 ,
427
+ // inference_input: input_arrays
428
+ // .iter()
429
+ // .map(|input_array| input_array.shape().to_vec())
430
+ // .collect(),
431
+ // model_input: self
432
+ // .inputs
433
+ // .iter()
434
+ // .map(|input| input.dimensions.clone())
435
+ // .collect(),
436
+ } ,
437
+ ) ) ;
438
+ }
377
439
378
440
let input_names: Vec < String > = self . inputs . iter ( ) . map ( |input| input. name . clone ( ) ) . collect ( ) ;
379
441
let input_names_cstring: Vec < CString > = input_names
@@ -403,33 +465,22 @@ impl<'a> Session<'a> {
403
465
let mut output_tensor_ptrs: Vec < * mut sys:: OrtValue > =
404
466
vec ! [ std:: ptr:: null_mut( ) ; self . outputs. len( ) ] ;
405
467
406
- // The C API expects pointers for the arrays (pointers to C-arrays)
407
- let input_ort_tensors: Vec < OrtTensor < TIn , D > > = input_arrays
408
- . into_iter ( )
409
- . map ( |input_array| {
410
- OrtTensor :: from_array ( & self . memory_info , self . allocator_ptr , input_array)
411
- } )
412
- . collect :: < Result < Vec < OrtTensor < TIn , D > > > > ( ) ?;
413
- let input_ort_values: Vec < * const sys:: OrtValue > = input_ort_tensors
414
- . iter ( )
415
- . map ( |input_array_ort| input_array_ort. c_ptr as * const sys:: OrtValue )
416
- . collect ( ) ;
417
-
418
468
let run_options_ptr: * const sys:: OrtRunOptions = std:: ptr:: null ( ) ;
419
469
420
470
let status = unsafe {
421
471
g_ort ( ) . Run . unwrap ( ) (
422
472
self . session_ptr ,
423
473
run_options_ptr,
424
474
input_names_ptr. as_ptr ( ) ,
425
- input_ort_values. as_ptr ( ) ,
426
- input_ort_values. len ( ) as u64 , // C API expects a u64, not isize
475
+ self . input_ort_values . as_ptr ( ) ,
476
+ self . input_ort_values . len ( ) as u64 , // C API expects a u64, not isize
427
477
output_names_ptr. as_ptr ( ) ,
428
478
output_names_ptr. len ( ) as u64 , // C API expects a u64, not isize
429
479
output_tensor_ptrs. as_mut_ptr ( ) ,
430
480
)
431
481
} ;
432
482
status_to_result ( status) . map_err ( OrtError :: Run ) ?;
483
+ self . input_ort_values . iter ( ) . for_each ( std:: mem:: drop) ;
433
484
434
485
let memory_info_ref = & self . memory_info ;
435
486
let outputs: Result < Vec < DynOrtTensor < ndarray:: Dim < ndarray:: IxDynImpl > > > > =
@@ -494,7 +545,7 @@ impl<'a> Session<'a> {
494
545
// Tensor::from_array(self, array)
495
546
// }
496
547
497
- fn validate_input_shapes < TIn , D > ( & mut self , input_arrays : & [ Array < TIn , D > ] ) -> Result < ( ) >
548
+ fn validate_input_shapes < TIn , D > ( & mut self , input_array : & Array < TIn , D > ) -> Result < ( ) >
498
549
where
499
550
TIn : TypeToTensorElementDataType + Debug + Clone ,
500
551
D : ndarray:: Dimension ,
@@ -504,62 +555,41 @@ impl<'a> Session<'a> {
504
555
// Make sure all dimensions match (except dynamic ones)
505
556
506
557
// Verify length of inputs
507
- if input_arrays. len ( ) != self . inputs . len ( ) {
508
- error ! (
509
- "Non-matching number of inputs: {} (inference) vs {} (model)" ,
510
- input_arrays. len( ) ,
511
- self . inputs. len( )
512
- ) ;
513
- return Err ( OrtError :: NonMatchingDimensions (
514
- NonMatchingDimensionsError :: InputsCount {
515
- inference_input_count : 0 ,
516
- model_input_count : 0 ,
517
- inference_input : input_arrays
518
- . iter ( )
519
- . map ( |input_array| input_array. shape ( ) . to_vec ( ) )
520
- . collect ( ) ,
521
- model_input : self
522
- . inputs
523
- . iter ( )
524
- . map ( |input| input. dimensions . clone ( ) )
525
- . collect ( ) ,
526
- } ,
527
- ) ) ;
528
- }
529
558
530
559
// Verify length of each individual inputs
531
- let inputs_different_length = input_arrays
532
- . iter ( )
533
- . zip ( self . inputs . iter ( ) )
534
- . any ( |( l, r) | l. shape ( ) . len ( ) != r. dimensions . len ( ) ) ;
535
- if inputs_different_length {
560
+ let current_input = self . input_ort_values . len ( ) ;
561
+ if current_input > self . inputs . len ( ) {
536
562
error ! (
537
- "Different input lengths: {:?} vs {:?}" ,
538
- self . inputs, input_arrays
563
+ "Attempting to feed too many inputs, expecting {:?} inputs " ,
564
+ self . inputs. len ( )
539
565
) ;
540
566
panic ! (
541
- "Different input lengths: {:?} vs {:?}" ,
542
- self . inputs, input_arrays
567
+ "Attempting to feed too many inputs, expecting {:?} inputs " ,
568
+ self . inputs. len ( )
543
569
) ;
544
570
}
571
+ let input = & self . inputs [ current_input] ;
572
+ if input_array. shape ( ) . len ( ) != input. dimensions ( ) . count ( ) {
573
+ error ! ( "Different input lengths: {:?} vs {:?}" , input, input_array) ;
574
+ panic ! ( "Different input lengths: {:?} vs {:?}" , input, input_array) ;
575
+ }
545
576
546
- // Verify shape of each individual inputs
547
- let inputs_different_shape = input_arrays. iter ( ) . zip ( self . inputs . iter ( ) ) . any ( |( l, r) | {
548
- let l_shape = l. shape ( ) ;
549
- let r_shape = r. dimensions . as_slice ( ) ;
550
- l_shape. iter ( ) . zip ( r_shape. iter ( ) ) . any ( |( l2, r2) | match r2 {
551
- Some ( r3) => * r3 as usize != * l2,
552
- None => false , // None means dynamic size; in that case shape always match
553
- } )
577
+ let l = input_array;
578
+ let r = input;
579
+ let l_shape = l. shape ( ) ;
580
+ let r_shape = r. dimensions . as_slice ( ) ;
581
+ let inputs_different_shape = l_shape. iter ( ) . zip ( r_shape. iter ( ) ) . any ( |( l2, r2) | match r2 {
582
+ Some ( r3) => * r3 as usize != * l2,
583
+ None => false , // None means dynamic size; in that case shape always match
554
584
} ) ;
555
585
if inputs_different_shape {
556
586
error ! (
557
587
"Different input lengths: {:?} vs {:?}" ,
558
- self . inputs, input_arrays
588
+ self . inputs, input_array
559
589
) ;
560
590
panic ! (
561
591
"Different input lengths: {:?} vs {:?}" ,
562
- self . inputs, input_arrays
592
+ self . inputs, input_array
563
593
) ;
564
594
}
565
595
0 commit comments