@@ -15,6 +15,8 @@ limitations under the License.
15
15
16
16
#include " tensorflow/lite/micro/micro_interpreter_graph.h"
17
17
18
+ #include < algorithm>
19
+
18
20
#include " flatbuffers/flatbuffers.h" // from @flatbuffers
19
21
#include " tensorflow/lite/c/common.h"
20
22
#include " tensorflow/lite/kernels/internal/compatibility.h"
@@ -42,6 +44,34 @@ const char* OpNameFromRegistration(const TFLMRegistration* registration) {
42
44
}
43
45
}
44
46
47
+ // Check tensor shapes to determine if there are dynamic tensors present.
48
+ // Returns the index of the first dynamic tensor found, otherwise returns -1.
49
+ int CheckDynamicTensors (const TfLiteIntArray* const tensor_indices,
50
+ const TfLiteEvalTensor* const eval_tensors) {
51
+ // some operators have no tensors, so node->inputs and/or node->outputs
52
+ // can be <nullptr>. This occurs in the MicroInterpreter unit tests.
53
+ if (tensor_indices == nullptr ) {
54
+ return -1 ;
55
+ }
56
+
57
+ for (int i = 0 ; i < tensor_indices->size ; i++) {
58
+ const int tensor_index = tensor_indices->data [i];
59
+ // Skip optional tensors
60
+ if (tensor_index < 0 ) {
61
+ continue ;
62
+ }
63
+ // Check shape for dims <= 0.
64
+ // This code handles legacy scalar tensors (dims->size == 0).
65
+ const TfLiteEvalTensor* const tp = eval_tensors + tensor_index;
66
+ if (!std::all_of (tp->dims ->data , tp->dims ->data + tp->dims ->size ,
67
+ [](int dim) { return dim > 0 ; })) {
68
+ return tensor_index;
69
+ }
70
+ }
71
+
72
+ return -1 ;
73
+ }
74
+
45
75
} // namespace
46
76
47
77
MicroInterpreterGraph::MicroInterpreterGraph (
@@ -117,7 +147,7 @@ TfLiteStatus MicroInterpreterGraph::PrepareSubgraphs() {
117
147
if (registration->prepare != nullptr ) {
118
148
TfLiteStatus prepare_status = registration->prepare (context_, node);
119
149
if (prepare_status != kTfLiteOk ) {
120
- MicroPrintf (" Node %s (number %df ) failed to prepare with status %d" ,
150
+ MicroPrintf (" Node %s (number %u ) failed to prepare with status %d" ,
121
151
OpNameFromRegistration (registration),
122
152
current_operator_index_, prepare_status);
123
153
return kTfLiteError ;
@@ -126,6 +156,18 @@ TfLiteStatus MicroInterpreterGraph::PrepareSubgraphs() {
126
156
GetMicroContext (context_)->ResetDecompressionMemoryAllocations ();
127
157
#endif // USE_TFLM_COMPRESSION
128
158
}
159
+
160
+ const int dynamic_tensor_index = CheckDynamicTensors (
161
+ node->outputs , subgraph_allocations_[subgraph_idx].tensors );
162
+ if (dynamic_tensor_index != -1 ) {
163
+ MicroPrintf (
164
+ " Op#%u (%s) of subgraph %u has dynamic tensor #%d\n "
165
+ " Dynamic tensors are not supported" ,
166
+ current_operator_index_, OpNameFromRegistration (registration),
167
+ current_subgraph_index_, dynamic_tensor_index);
168
+ return kTfLiteError ;
169
+ }
170
+
129
171
allocator_->FinishPrepareNodeAllocations (
130
172
/* node_id=*/ current_operator_index_);
131
173
}
@@ -205,6 +247,7 @@ TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
205
247
subgraph_idx, subgraphs_->size ());
206
248
return kTfLiteError ;
207
249
}
250
+ TfLiteStatus invoke_status = kTfLiteOk ;
208
251
uint32_t operators_size = NumSubgraphOperators (model_, subgraph_idx);
209
252
for (current_operator_index_ = 0 ; current_operator_index_ < operators_size;
210
253
++current_operator_index_) {
@@ -226,7 +269,7 @@ TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
226
269
#endif
227
270
228
271
TFLITE_DCHECK (registration->invoke );
229
- TfLiteStatus invoke_status = registration->invoke (context_, node);
272
+ invoke_status = registration->invoke (context_, node);
230
273
#ifdef USE_TFLM_COMPRESSION
231
274
GetMicroContext (context_)->ResetDecompressionMemoryAllocations ();
232
275
#endif // USE_TFLM_COMPRESSION
@@ -243,12 +286,15 @@ TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
243
286
OpNameFromRegistration (registration),
244
287
current_operator_index_, invoke_status);
245
288
}
246
- return invoke_status;
289
+ // make sure to restore subgraph and operator indices
290
+ break ;
247
291
}
248
292
}
293
+
249
294
current_subgraph_index_ = previous_subgraph_idx;
250
295
current_operator_index_ = previous_operator_idx;
251
- return kTfLiteOk ;
296
+
297
+ return invoke_status;
252
298
}
253
299
254
300
TfLiteStatus MicroInterpreterGraph::ResetVariableTensors () {
0 commit comments