Skip to content

Commit f11527a

Browse files
committed
Attempt at progressively feeding the Session to bypass type checking in
a sane way.
1 parent d2f1ebe commit f11527a

File tree

2 files changed

+93
-63
lines changed

2 files changed

+93
-63
lines changed

onnxruntime/src/error.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,16 @@ pub enum OrtError {
120120
#[derive(Error, Debug)]
121121
pub enum NonMatchingDimensionsError {
122122
/// Number of inputs from model does not match number of inputs from inference call
123-
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
123+
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?}")]
124124
InputsCount {
125125
/// Number of input dimensions used by inference call
126126
inference_input_count: usize,
127127
/// Number of input dimensions defined in model
128128
model_input_count: usize,
129-
/// Input dimensions used by inference call
130-
inference_input: Vec<Vec<usize>>,
131-
/// Input dimensions defined in model
132-
model_input: Vec<Vec<Option<u32>>>,
129+
// Input dimensions used by inference call
130+
// inference_input: Vec<Vec<usize>>,
131+
// Input dimensions defined in model
132+
// model_input: Vec<Vec<Option<u32>>>,
133133
},
134134
}
135135

onnxruntime/src/session.rs

Lines changed: 88 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,14 @@ impl<'a> SessionBuilder<'a> {
215215
let outputs = (0..num_output_nodes)
216216
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
217217
.collect::<Result<Vec<Output>>>()?;
218+
let input_ort_values = Vec::with_capacity(num_output_nodes as usize);
218219

219220
Ok(Session {
220221
env: self.env,
221222
session_ptr,
222223
allocator_ptr,
223224
memory_info,
225+
input_ort_values,
224226
inputs,
225227
outputs,
226228
})
@@ -271,12 +273,14 @@ impl<'a> SessionBuilder<'a> {
271273
let outputs = (0..num_output_nodes)
272274
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
273275
.collect::<Result<Vec<Output>>>()?;
276+
let input_ort_values = Vec::with_capacity(num_output_nodes as usize);
274277

275278
Ok(Session {
276279
env: self.env,
277280
session_ptr,
278281
allocator_ptr,
279282
memory_info,
283+
input_ort_values,
280284
inputs,
281285
outputs,
282286
})
@@ -290,6 +294,7 @@ pub struct Session<'a> {
290294
session_ptr: *mut sys::OrtSession,
291295
allocator_ptr: *mut sys::OrtAllocator,
292296
memory_info: MemoryInfo,
297+
input_ort_values: Vec<*const sys::OrtValue>,
293298
/// Information about the ONNX's inputs as stored in loaded file
294299
pub inputs: Vec<Input>,
295300
/// Information about the ONNX's outputs as stored in loaded file
@@ -357,6 +362,26 @@ impl<'a> Drop for Session<'a> {
357362
}
358363

359364
impl<'a> Session<'a> {
365+
/// Somedoc
366+
pub fn feed<'s, 't, 'm, TIn, D>(&'s mut self, input_array: Array<TIn, D>) -> Result<()>
367+
where
368+
TIn: TypeToTensorElementDataType + Debug + Clone,
369+
D: ndarray::Dimension,
370+
'm: 't, // 'm outlives 't (memory info outlives tensor)
371+
's: 'm, // 's outlives 'm (session outlives memory info)
372+
{
373+
self.validate_input_shapes(&input_array)?;
374+
// The C API expects pointers for the arrays (pointers to C-arrays)
375+
let input_ort_tensor: OrtTensor<TIn, D> =
376+
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)?;
377+
378+
let input_ort_value: *const sys::OrtValue = input_ort_tensor.c_ptr as *const sys::OrtValue;
379+
std::mem::forget(input_ort_tensor);
380+
self.input_ort_values.push(input_ort_value);
381+
382+
Ok(())
383+
}
384+
360385
/// Run the input data through the ONNX graph, performing inference.
361386
///
362387
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
@@ -371,9 +396,46 @@ impl<'a> Session<'a> {
371396
'm: 't, // 'm outlives 't (memory info outlives tensor)
372397
's: 'm, // 's outlives 'm (session outlives memory info)
373398
{
374-
self.validate_input_shapes(&input_arrays)?;
375-
399+
input_arrays
400+
.into_iter()
401+
.for_each(|input_array| self.feed(input_array).unwrap());
402+
self.inner_run()
403+
}
404+
/// Run the input data through the ONNX graph, performing inference.
405+
///
406+
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
407+
/// used for the input data here.
408+
pub fn inner_run<'s, 't, 'm>(
409+
&'s mut self,
410+
// input_arrays: Vec<Array<TIn, D>>,
411+
) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>>
412+
where
413+
'm: 't, // 'm outlives 't (memory info outlives tensor)
414+
's: 'm, // 's outlives 'm (session outlives memory info)
415+
{
376416
// Build arguments to Run()
417+
if self.input_ort_values.len() != self.inputs.len() {
418+
error!(
419+
"Non-matching number of inputs: {} (inference) vs {} (model)",
420+
self.input_ort_values.len(),
421+
self.inputs.len()
422+
);
423+
return Err(OrtError::NonMatchingDimensions(
424+
NonMatchingDimensionsError::InputsCount {
425+
inference_input_count: 0,
426+
model_input_count: 0,
427+
// inference_input: input_arrays
428+
// .iter()
429+
// .map(|input_array| input_array.shape().to_vec())
430+
// .collect(),
431+
// model_input: self
432+
// .inputs
433+
// .iter()
434+
// .map(|input| input.dimensions.clone())
435+
// .collect(),
436+
},
437+
));
438+
}
377439

378440
let input_names: Vec<String> = self.inputs.iter().map(|input| input.name.clone()).collect();
379441
let input_names_cstring: Vec<CString> = input_names
@@ -403,33 +465,22 @@ impl<'a> Session<'a> {
403465
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> =
404466
vec![std::ptr::null_mut(); self.outputs.len()];
405467

406-
// The C API expects pointers for the arrays (pointers to C-arrays)
407-
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
408-
.into_iter()
409-
.map(|input_array| {
410-
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
411-
})
412-
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
413-
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
414-
.iter()
415-
.map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue)
416-
.collect();
417-
418468
let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
419469

420470
let status = unsafe {
421471
g_ort().Run.unwrap()(
422472
self.session_ptr,
423473
run_options_ptr,
424474
input_names_ptr.as_ptr(),
425-
input_ort_values.as_ptr(),
426-
input_ort_values.len() as u64, // C API expects a u64, not isize
475+
self.input_ort_values.as_ptr(),
476+
self.input_ort_values.len() as u64, // C API expects a u64, not isize
427477
output_names_ptr.as_ptr(),
428478
output_names_ptr.len() as u64, // C API expects a u64, not isize
429479
output_tensor_ptrs.as_mut_ptr(),
430480
)
431481
};
432482
status_to_result(status).map_err(OrtError::Run)?;
483+
self.input_ort_values.iter().for_each(std::mem::drop);
433484

434485
let memory_info_ref = &self.memory_info;
435486
let outputs: Result<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> =
@@ -494,7 +545,7 @@ impl<'a> Session<'a> {
494545
// Tensor::from_array(self, array)
495546
// }
496547

497-
fn validate_input_shapes<TIn, D>(&mut self, input_arrays: &[Array<TIn, D>]) -> Result<()>
548+
fn validate_input_shapes<TIn, D>(&mut self, input_array: &Array<TIn, D>) -> Result<()>
498549
where
499550
TIn: TypeToTensorElementDataType + Debug + Clone,
500551
D: ndarray::Dimension,
@@ -504,62 +555,41 @@ impl<'a> Session<'a> {
504555
// Make sure all dimensions match (except dynamic ones)
505556

506557
// Verify length of inputs
507-
if input_arrays.len() != self.inputs.len() {
508-
error!(
509-
"Non-matching number of inputs: {} (inference) vs {} (model)",
510-
input_arrays.len(),
511-
self.inputs.len()
512-
);
513-
return Err(OrtError::NonMatchingDimensions(
514-
NonMatchingDimensionsError::InputsCount {
515-
inference_input_count: 0,
516-
model_input_count: 0,
517-
inference_input: input_arrays
518-
.iter()
519-
.map(|input_array| input_array.shape().to_vec())
520-
.collect(),
521-
model_input: self
522-
.inputs
523-
.iter()
524-
.map(|input| input.dimensions.clone())
525-
.collect(),
526-
},
527-
));
528-
}
529558

530559
// Verify length of each individual inputs
531-
let inputs_different_length = input_arrays
532-
.iter()
533-
.zip(self.inputs.iter())
534-
.any(|(l, r)| l.shape().len() != r.dimensions.len());
535-
if inputs_different_length {
560+
let current_input = self.input_ort_values.len();
561+
if current_input > self.inputs.len() {
536562
error!(
537-
"Different input lengths: {:?} vs {:?}",
538-
self.inputs, input_arrays
563+
"Attempting to feed too many inputs, expecting {:?} inputs",
564+
self.inputs.len()
539565
);
540566
panic!(
541-
"Different input lengths: {:?} vs {:?}",
542-
self.inputs, input_arrays
567+
"Attempting to feed too many inputs, expecting {:?} inputs",
568+
self.inputs.len()
543569
);
544570
}
571+
let input = &self.inputs[current_input];
572+
if input_array.shape().len() != input.dimensions().count() {
573+
error!("Different input lengths: {:?} vs {:?}", input, input_array);
574+
panic!("Different input lengths: {:?} vs {:?}", input, input_array);
575+
}
545576

546-
// Verify shape of each individual inputs
547-
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
548-
let l_shape = l.shape();
549-
let r_shape = r.dimensions.as_slice();
550-
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
551-
Some(r3) => *r3 as usize != *l2,
552-
None => false, // None means dynamic size; in that case shape always match
553-
})
577+
let l = input_array;
578+
let r = input;
579+
let l_shape = l.shape();
580+
let r_shape = r.dimensions.as_slice();
581+
let inputs_different_shape = l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
582+
Some(r3) => *r3 as usize != *l2,
583+
None => false, // None means dynamic size; in that case shape always match
554584
});
555585
if inputs_different_shape {
556586
error!(
557587
"Different input lengths: {:?} vs {:?}",
558-
self.inputs, input_arrays
588+
self.inputs, input_array
559589
);
560590
panic!(
561591
"Different input lengths: {:?} vs {:?}",
562-
self.inputs, input_arrays
592+
self.inputs, input_array
563593
);
564594
}
565595

0 commit comments

Comments
 (0)