1
- /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1
+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0;
33
33
34
34
struct OpData {
35
35
ConcatenationParams params;
36
+
37
+ #ifdef USE_TFLM_COMPRESSION
38
+
39
+ // scratch buffers for compressed tensors
40
+ int scratch_indices[kMaxInputNum ];
41
+
42
+ #endif // USE_TFLM_COMPRESSION
36
43
};
37
44
38
45
// Handles negative axis index, coerces to positive index value.
@@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
52
59
inline void GetAllInputTensorShapes (const TfLiteContext* context,
53
60
const TfLiteNode* node,
54
61
RuntimeShape all_shapes[kMaxInputNum ]) {
55
- TFLITE_DCHECK (context != nullptr );
56
- TFLITE_DCHECK (node != nullptr );
57
62
for (int i = 0 ; i < node->inputs ->size ; ++i) {
58
63
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput (context, node, i);
59
64
RuntimeShape shape = tflite::micro::GetTensorShape (t);
@@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
73
78
template <typename T>
74
79
inline void GetAllInputTensorData (const TfLiteContext* context,
75
80
const TfLiteNode* node,
76
- T* all_data[kMaxInputNum ]) {
77
- TFLITE_DCHECK (context != nullptr );
78
- TFLITE_DCHECK (node != nullptr );
81
+ const T* all_data[kMaxInputNum ]) {
82
+ #ifdef USE_TFLM_COMPRESSION
83
+ const OpData* data = static_cast <const OpData*>(node->user_data );
84
+ MicroContext* micro_context = GetMicroContext (context);
85
+ #endif // USE_TFLM_COMPRESSION
86
+
79
87
for (int i = 0 ; i < node->inputs ->size ; ++i) {
80
88
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput (context, node, i);
89
+ #ifdef USE_TFLM_COMPRESSION
90
+ const CompressionTensorData* comp_td =
91
+ micro_context->GetTensorCompressionData (node, i);
92
+ all_data[i] = tflite::micro::GetTensorData<T>(micro_context, t, comp_td,
93
+ data->scratch_indices [i]);
94
+ #else // USE_TFLM_COMPRESSION
81
95
all_data[i] = tflite::micro::GetTensorData<T>(t);
96
+ #endif // USE_TFLM_COMPRESSION
82
97
}
83
98
}
84
99
@@ -88,16 +103,17 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
88
103
RuntimeShape inputs_shape[kMaxInputNum ];
89
104
const RuntimeShape* inputs_shape_ptr[kMaxInputNum ];
90
105
const data_type* inputs_data[kMaxInputNum ];
106
+ TFLITE_DCHECK (context != nullptr );
107
+ TFLITE_DCHECK (node != nullptr );
108
+ TFLITE_DCHECK (node->user_data != nullptr );
109
+ const OpData* data = static_cast <const OpData*>(node->user_data );
91
110
GetAllInputTensorShapes (context, node, inputs_shape);
92
111
GetShapesPointers (inputs_shape, node->inputs ->size , inputs_shape_ptr);
93
112
GetAllInputTensorData (context, node, inputs_data);
94
113
95
114
TfLiteEvalTensor* output =
96
115
tflite::micro::GetEvalOutput (context, node, kOutputTensor );
97
116
98
- TFLITE_DCHECK (node->user_data != nullptr );
99
- const OpData* data = static_cast <const OpData*>(node->user_data );
100
-
101
117
reference_ops::Concatenation (data->params , inputs_shape_ptr, inputs_data,
102
118
tflite::micro::GetTensorShape (output),
103
119
tflite::micro::GetTensorData<data_type>(output));
@@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
126
142
TfLiteType output_type = output_tensor->type ;
127
143
128
144
micro_context->DeallocateTempTfLiteTensor (input_tensor);
129
- micro_context->DeallocateTempTfLiteTensor (output_tensor);
130
145
131
146
// Check activation and input type
132
147
TF_LITE_ENSURE_EQ (context, params->activation , kTfLiteActNone );
@@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
136
151
input_type == kTfLiteInt64 || input_type == kTfLiteBool );
137
152
138
153
// Output type must match input type
139
- TF_LITE_ENSURE_EQ (context, output_type, input_type);
154
+ TF_LITE_ENSURE_TYPES_EQ (context, output_type, input_type);
140
155
141
156
// This implementation does not support large number of input tensors
142
157
const int num_inputs = NumInputs (node);
143
158
TF_LITE_ENSURE (context, num_inputs <= kMaxInputNum );
144
159
145
- // Shapes with dimensions >4 are not yet supported with static allocation.
160
+ // Calculate OpData.
161
+ TFLITE_DCHECK (node->user_data != nullptr );
162
+ OpData* data = static_cast <OpData*>(node->user_data );
163
+
164
+ // Shapes with dimensions > kMaxSmallSize are not yet supported with static
165
+ // allocation.
146
166
for (int i = 0 ; i < num_inputs; ++i) {
147
167
TfLiteTensor* input = micro_context->AllocateTempInputTensor (node, i);
148
168
TF_LITE_ENSURE (context, input != nullptr );
169
+ TF_LITE_ENSURE_TYPES_EQ (context, input->type , input_type);
149
170
int num_dimensions = NumDimensions (input);
150
171
151
172
if (num_dimensions > RuntimeShape::kMaxSmallSize ) {
@@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
155
176
RuntimeShape::kMaxSmallSize , num_dimensions);
156
177
return kTfLiteError ;
157
178
}
179
+
180
+ if (input_type == kTfLiteInt8 ) {
181
+ // Make sure there is no re-scaling needed for Int8 quantized kernel. This
182
+ // is a restriction we introduced to Int8 kernels.
183
+ TF_LITE_ENSURE_EQ (context, static_cast <double >(input->params .scale ),
184
+ static_cast <double >(output_tensor->params .scale ));
185
+ TF_LITE_ENSURE_EQ (context, input->params .zero_point ,
186
+ output_tensor->params .zero_point );
187
+ } else if (input_type == kTfLiteInt16 ) {
188
+ // Make sure that all Int16 inputs have a null zero-point.
189
+ TF_LITE_ENSURE_EQ (context, input->params .zero_point , 0 );
190
+ }
191
+
192
+ #ifdef USE_TFLM_COMPRESSION
193
+
194
+ // Compression scratch buffers.
195
+ // These will only be allocated if the tensor is compressed.
196
+ data->scratch_indices [i] =
197
+ micro_context->AllocateDecompressionScratchBuffer (node, i);
198
+
199
+ #endif // USE_TFLM_COMPRESSION
200
+
158
201
micro_context->DeallocateTempTfLiteTensor (input);
159
202
}
160
203
161
- // Calculate OpData.
162
- TFLITE_DCHECK (node->user_data != nullptr );
163
- OpData* data = static_cast <OpData*>(node->user_data );
164
-
165
- TfLiteTensor* output =
166
- micro_context->AllocateTempOutputTensor (node, kOutputTensor );
167
- TF_LITE_ENSURE (context, output != nullptr );
204
+ if (input_type == kTfLiteInt16 ) {
205
+ TF_LITE_ENSURE_EQ (context, output_tensor->params .zero_point , 0 );
206
+ }
168
207
169
208
switch (output_type) { // Already know in/outtypes are same.
170
209
case kTfLiteBool :
171
210
case kTfLiteFloat32 :
211
+ case kTfLiteInt8 :
172
212
case kTfLiteInt16 :
173
213
case kTfLiteInt32 :
174
214
case kTfLiteInt64 : {
175
- data->params .axis = CalculatePositiveAxis (params->axis , output);
176
- data->params .inputs_count = node->inputs ->size ;
177
- break ;
178
- }
179
- case kTfLiteInt8 : {
180
- data->params .axis = CalculatePositiveAxis (params->axis , output);
215
+ data->params .axis = CalculatePositiveAxis (params->axis , output_tensor);
181
216
data->params .inputs_count = node->inputs ->size ;
182
-
183
- float * input_scales =
184
- reinterpret_cast <float *>(context->AllocatePersistentBuffer (
185
- context, node->inputs ->size * sizeof (float )));
186
-
187
- int32_t * input_zero_points =
188
- reinterpret_cast <int32_t *>(context->AllocatePersistentBuffer (
189
- context, node->inputs ->size * sizeof (int32_t )));
190
-
191
- // Allocate persistent scale and zeropoint buffers.
192
- // Store input scale and zero point values in OpParams:
193
- for (int i = 0 ; i < node->inputs ->size ; ++i) {
194
- TfLiteTensor* t = micro_context->AllocateTempInputTensor (node, i);
195
- TF_LITE_ENSURE (context, t != nullptr );
196
- input_scales[i] = t->params .scale ;
197
- input_zero_points[i] = t->params .zero_point ;
198
- micro_context->DeallocateTempTfLiteTensor (t);
199
- }
200
-
201
- data->params .input_scale = input_scales;
202
- data->params .input_zeropoint = input_zero_points;
203
- data->params .output_zeropoint = output->params .zero_point ;
204
- data->params .output_scale = output->params .scale ;
205
217
break ;
206
218
}
207
219
default :
208
- MicroPrintf (" Op Concatenation does not currently support Type '%s'." ,
220
+ MicroPrintf (" Op Concatenation does not currently support type '%s'." ,
209
221
TfLiteTypeGetName (output_type));
210
222
return kTfLiteError ;
211
223
}
212
224
213
- micro_context->DeallocateTempTfLiteTensor (output );
225
+ micro_context->DeallocateTempTfLiteTensor (output_tensor );
214
226
215
227
return kTfLiteOk ;
216
228
}
0 commit comments