Skip to content

Commit ba96712

Browse files
committed
Merge branch 'ci/sync_gh_tflite-micro' into 'master'
Sync esp-tflite-micro from github - 870924 See merge request app-frameworks/esp-tflite-micro!162
2 parents 14079aa + 1658fbb commit ba96712

14 files changed

+546
-134
lines changed

tensorflow/lite/micro/kernels/assign_variable.cc

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "tensorflow/lite/micro/micro_graph.h"
2727
#include "tensorflow/lite/micro/micro_log.h"
2828
#include "tensorflow/lite/micro/micro_resource_variable.h"
29+
#include "tensorflow/lite/micro/micro_utils.h"
2930
#include "tensorflow/lite/schema/schema_generated.h"
3031

3132
namespace tflite {
@@ -35,6 +36,20 @@ namespace {
3536
constexpr int kInputVariableId = 0;
3637
constexpr int kInputValue = 1;
3738

39+
#ifdef USE_TFLM_COMPRESSION
40+
41+
struct OpData {
42+
// scratch buffer for compressed input tensor
43+
int scratch_index;
44+
};
45+
46+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
47+
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
48+
return context->AllocatePersistentBuffer(context, sizeof(OpData));
49+
}
50+
51+
#endif // USE_TFLM_COMPRESSION
52+
3853
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
3954
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
4055
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
@@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
7085
context, input_value));
7186
}
7287

88+
#ifdef USE_TFLM_COMPRESSION
89+
90+
TFLITE_DCHECK(node->user_data != nullptr);
91+
OpData* data = static_cast<OpData*>(node->user_data);
92+
// Compression scratch buffers.
93+
// These will only be allocated if the tensor is compressed.
94+
data->scratch_index =
95+
micro_context->AllocateDecompressionScratchBuffer(node, kInputValue);
96+
97+
#endif // USE_TFLM_COMPRESSION
98+
7399
micro_context->DeallocateTempTfLiteTensor(input_value);
74100
return kTfLiteOk;
75101
}
@@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
93119
"ResourceVariables and pass it to the interpreter.");
94120
return kTfLiteError;
95121
}
122+
123+
#ifdef USE_TFLM_COMPRESSION
124+
OpData* data = static_cast<OpData*>(node->user_data);
125+
const CompressionTensorData* comp_td =
126+
micro_context->GetTensorCompressionData(node, kInputValue);
127+
const void* buffer = tflite::micro::GetTensorData<void>(
128+
micro_context, input_value, comp_td, data->scratch_index);
129+
#else // USE_TFLM_COMPRESSION
130+
const void* buffer = tflite::micro::GetTensorData<void>(input_value);
131+
#endif // USE_TFLM_COMPRESSION
132+
96133
TF_LITE_ENSURE_OK(context,
97-
resources->Assign(input_id->data.i32[0], input_value));
134+
resources->Assign(input_id->data.i32[0],
135+
EvalTensorBytes(input_value), buffer));
98136
return kTfLiteOk;
99137
}
100138

101139
} // namespace.
102140

141+
#ifdef USE_TFLM_COMPRESSION
142+
143+
TFLMRegistration Register_ASSIGN_VARIABLE() {
144+
return tflite::micro::RegisterOp(Init, Prepare, Eval);
145+
146+
#else // USE_TFLM_COMPRESSION
147+
103148
TFLMRegistration Register_ASSIGN_VARIABLE() {
104149
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
150+
151+
#endif // USE_TFLM_COMPRESSION
105152
}
106153

107154
} // namespace tflite

tensorflow/lite/micro/kernels/concatenation.cc

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0;
3333

3434
struct OpData {
3535
ConcatenationParams params;
36+
37+
#ifdef USE_TFLM_COMPRESSION
38+
39+
// scratch buffers for compressed tensors
40+
int scratch_indices[kMaxInputNum];
41+
42+
#endif // USE_TFLM_COMPRESSION
3643
};
3744

3845
// Handles negative axis index, coerces to positive index value.
@@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
5259
inline void GetAllInputTensorShapes(const TfLiteContext* context,
5360
const TfLiteNode* node,
5461
RuntimeShape all_shapes[kMaxInputNum]) {
55-
TFLITE_DCHECK(context != nullptr);
56-
TFLITE_DCHECK(node != nullptr);
5762
for (int i = 0; i < node->inputs->size; ++i) {
5863
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
5964
RuntimeShape shape = tflite::micro::GetTensorShape(t);
@@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
7378
template <typename T>
7479
inline void GetAllInputTensorData(const TfLiteContext* context,
7580
const TfLiteNode* node,
76-
T* all_data[kMaxInputNum]) {
77-
TFLITE_DCHECK(context != nullptr);
78-
TFLITE_DCHECK(node != nullptr);
81+
const T* all_data[kMaxInputNum]) {
82+
#ifdef USE_TFLM_COMPRESSION
83+
const OpData* data = static_cast<const OpData*>(node->user_data);
84+
MicroContext* micro_context = GetMicroContext(context);
85+
#endif // USE_TFLM_COMPRESSION
86+
7987
for (int i = 0; i < node->inputs->size; ++i) {
8088
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
89+
#ifdef USE_TFLM_COMPRESSION
90+
const CompressionTensorData* comp_td =
91+
micro_context->GetTensorCompressionData(node, i);
92+
all_data[i] = tflite::micro::GetTensorData<T>(micro_context, t, comp_td,
93+
data->scratch_indices[i]);
94+
#else // USE_TFLM_COMPRESSION
8195
all_data[i] = tflite::micro::GetTensorData<T>(t);
96+
#endif // USE_TFLM_COMPRESSION
8297
}
8398
}
8499

@@ -88,16 +103,17 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
88103
RuntimeShape inputs_shape[kMaxInputNum];
89104
const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
90105
const data_type* inputs_data[kMaxInputNum];
106+
TFLITE_DCHECK(context != nullptr);
107+
TFLITE_DCHECK(node != nullptr);
108+
TFLITE_DCHECK(node->user_data != nullptr);
109+
const OpData* data = static_cast<const OpData*>(node->user_data);
91110
GetAllInputTensorShapes(context, node, inputs_shape);
92111
GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
93112
GetAllInputTensorData(context, node, inputs_data);
94113

95114
TfLiteEvalTensor* output =
96115
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
97116

98-
TFLITE_DCHECK(node->user_data != nullptr);
99-
const OpData* data = static_cast<const OpData*>(node->user_data);
100-
101117
reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
102118
tflite::micro::GetTensorShape(output),
103119
tflite::micro::GetTensorData<data_type>(output));
@@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
126142
TfLiteType output_type = output_tensor->type;
127143

128144
micro_context->DeallocateTempTfLiteTensor(input_tensor);
129-
micro_context->DeallocateTempTfLiteTensor(output_tensor);
130145

131146
// Check activation and input type
132147
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
@@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
136151
input_type == kTfLiteInt64 || input_type == kTfLiteBool);
137152

138153
// Output type must match input type
139-
TF_LITE_ENSURE_EQ(context, output_type, input_type);
154+
TF_LITE_ENSURE_TYPES_EQ(context, output_type, input_type);
140155

141156
// This implementation does not support large number of input tensors
142157
const int num_inputs = NumInputs(node);
143158
TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);
144159

145-
// Shapes with dimensions >4 are not yet supported with static allocation.
160+
// Calculate OpData.
161+
TFLITE_DCHECK(node->user_data != nullptr);
162+
OpData* data = static_cast<OpData*>(node->user_data);
163+
164+
// Shapes with dimensions > kMaxSmallSize are not yet supported with static
165+
// allocation.
146166
for (int i = 0; i < num_inputs; ++i) {
147167
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
148168
TF_LITE_ENSURE(context, input != nullptr);
169+
TF_LITE_ENSURE_TYPES_EQ(context, input->type, input_type);
149170
int num_dimensions = NumDimensions(input);
150171

151172
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
@@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
155176
RuntimeShape::kMaxSmallSize, num_dimensions);
156177
return kTfLiteError;
157178
}
179+
180+
if (input_type == kTfLiteInt8) {
181+
// Make sure there is no re-scaling needed for Int8 quantized kernel. This
182+
// is a restriction we introduced to Int8 kernels.
183+
TF_LITE_ENSURE_EQ(context, static_cast<double>(input->params.scale),
184+
static_cast<double>(output_tensor->params.scale));
185+
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
186+
output_tensor->params.zero_point);
187+
} else if (input_type == kTfLiteInt16) {
188+
// Make sure that all Int16 inputs have a null zero-point.
189+
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
190+
}
191+
192+
#ifdef USE_TFLM_COMPRESSION
193+
194+
// Compression scratch buffers.
195+
// These will only be allocated if the tensor is compressed.
196+
data->scratch_indices[i] =
197+
micro_context->AllocateDecompressionScratchBuffer(node, i);
198+
199+
#endif // USE_TFLM_COMPRESSION
200+
158201
micro_context->DeallocateTempTfLiteTensor(input);
159202
}
160203

161-
// Calculate OpData.
162-
TFLITE_DCHECK(node->user_data != nullptr);
163-
OpData* data = static_cast<OpData*>(node->user_data);
164-
165-
TfLiteTensor* output =
166-
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
167-
TF_LITE_ENSURE(context, output != nullptr);
204+
if (input_type == kTfLiteInt16) {
205+
TF_LITE_ENSURE_EQ(context, output_tensor->params.zero_point, 0);
206+
}
168207

169208
switch (output_type) { // Already know in/outtypes are same.
170209
case kTfLiteBool:
171210
case kTfLiteFloat32:
211+
case kTfLiteInt8:
172212
case kTfLiteInt16:
173213
case kTfLiteInt32:
174214
case kTfLiteInt64: {
175-
data->params.axis = CalculatePositiveAxis(params->axis, output);
176-
data->params.inputs_count = node->inputs->size;
177-
break;
178-
}
179-
case kTfLiteInt8: {
180-
data->params.axis = CalculatePositiveAxis(params->axis, output);
215+
data->params.axis = CalculatePositiveAxis(params->axis, output_tensor);
181216
data->params.inputs_count = node->inputs->size;
182-
183-
float* input_scales =
184-
reinterpret_cast<float*>(context->AllocatePersistentBuffer(
185-
context, node->inputs->size * sizeof(float)));
186-
187-
int32_t* input_zero_points =
188-
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
189-
context, node->inputs->size * sizeof(int32_t)));
190-
191-
// Allocate persistent scale and zeropoint buffers.
192-
// Store input scale and zero point values in OpParams:
193-
for (int i = 0; i < node->inputs->size; ++i) {
194-
TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i);
195-
TF_LITE_ENSURE(context, t != nullptr);
196-
input_scales[i] = t->params.scale;
197-
input_zero_points[i] = t->params.zero_point;
198-
micro_context->DeallocateTempTfLiteTensor(t);
199-
}
200-
201-
data->params.input_scale = input_scales;
202-
data->params.input_zeropoint = input_zero_points;
203-
data->params.output_zeropoint = output->params.zero_point;
204-
data->params.output_scale = output->params.scale;
205217
break;
206218
}
207219
default:
208-
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
220+
MicroPrintf("Op Concatenation does not currently support type '%s'.",
209221
TfLiteTypeGetName(output_type));
210222
return kTfLiteError;
211223
}
212224

213-
micro_context->DeallocateTempTfLiteTensor(output);
225+
micro_context->DeallocateTempTfLiteTensor(output_tensor);
214226

215227
return kTfLiteOk;
216228
}

tensorflow/lite/micro/kernels/conv.cc

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -45,15 +45,35 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
4545
TFLITE_DCHECK(node->user_data != nullptr);
4646
const auto& data = *(static_cast<const OpDataConv*>(node->user_data));
4747

48+
#ifdef USE_TFLM_COMPRESSION
49+
50+
MicroContext* micro_context = GetMicroContext(context);
51+
52+
const CompressionTensorData* weights_comp_td =
53+
micro_context->GetTensorCompressionData(node, kConvWeightsTensor);
54+
const CompressionTensorData* bias_comp_td =
55+
micro_context->GetTensorCompressionData(node, kConvBiasTensor);
56+
57+
#endif // USE_TFLM_COMPRESSION
58+
4859
switch (input->type) { // Already know in/out types are same.
4960
case kTfLiteFloat32: {
5061
tflite::reference_ops::Conv(
5162
ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input),
5263
tflite::micro::GetTensorData<float>(input),
5364
tflite::micro::GetTensorShape(filter),
65+
#ifdef USE_TFLM_COMPRESSION
66+
tflite::micro::GetTensorData<float>(micro_context, filter,
67+
weights_comp_td,
68+
data.weights_scratch_index),
69+
tflite::micro::GetTensorShape(bias),
70+
tflite::micro::GetOptionalTensorData<float>(
71+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
72+
#else // USE_TFLM_COMPRESSION
5473
tflite::micro::GetTensorData<float>(filter),
5574
tflite::micro::GetTensorShape(bias),
5675
tflite::micro::GetOptionalTensorData<float>(bias),
76+
#endif // USE_TFLM_COMPRESSION
5777
tflite::micro::GetTensorShape(output),
5878
tflite::micro::GetTensorData<float>(output),
5979
tflite::micro::GetTensorShape(nullptr), nullptr);
@@ -67,9 +87,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
6787
tflite::micro::GetTensorShape(input),
6888
tflite::micro::GetTensorData<int16_t>(input),
6989
tflite::micro::GetTensorShape(filter),
90+
#ifdef USE_TFLM_COMPRESSION
91+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
92+
weights_comp_td,
93+
data.weights_scratch_index),
94+
tflite::micro::GetTensorShape(bias),
95+
tflite::micro::GetOptionalTensorData<int32_t>(
96+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
97+
#else // USE_TFLM_COMPRESSION
7098
tflite::micro::GetTensorData<int8_t>(filter),
7199
tflite::micro::GetTensorShape(bias),
72100
tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
101+
#endif // USE_TFLM_COMPRESSION
73102
tflite::micro::GetTensorShape(output),
74103
tflite::micro::GetTensorData<int16_t>(output));
75104
} else if (bias->type == kTfLiteInt64) {
@@ -79,9 +108,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
79108
tflite::micro::GetTensorShape(input),
80109
tflite::micro::GetTensorData<int16_t>(input),
81110
tflite::micro::GetTensorShape(filter),
111+
#ifdef USE_TFLM_COMPRESSION
112+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
113+
weights_comp_td,
114+
data.weights_scratch_index),
115+
tflite::micro::GetTensorShape(bias),
116+
tflite::micro::GetTensorData<int64_t>(
117+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
118+
#else // USE_TFLM_COMPRESSION
82119
tflite::micro::GetTensorData<int8_t>(filter),
83120
tflite::micro::GetTensorShape(bias),
84-
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
121+
tflite::micro::GetTensorData<std::int64_t>(bias),
122+
#endif // USE_TFLM_COMPRESSION
85123
tflite::micro::GetTensorShape(output),
86124
tflite::micro::GetTensorData<int16_t>(output));
87125
} else {
@@ -119,9 +157,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
119157
tflite::micro::GetTensorShape(input),
120158
tflite::micro::GetTensorData<int8_t>(input),
121159
tflite::micro::GetTensorShape(filter),
160+
#ifdef USE_TFLM_COMPRESSION
161+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
162+
weights_comp_td,
163+
data.weights_scratch_index),
164+
tflite::micro::GetTensorShape(bias),
165+
tflite::micro::GetOptionalTensorData<int32_t>(
166+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
167+
#else // USE_TFLM_COMPRESSION
122168
tflite::micro::GetTensorData<int8_t>(filter),
123169
tflite::micro::GetTensorShape(bias),
124170
tflite::micro::GetOptionalTensorData<int32_t>(bias),
171+
#endif // USE_TFLM_COMPRESSION
125172
tflite::micro::GetTensorShape(output),
126173
tflite::micro::GetTensorData<int8_t>(output));
127174
break;

0 commit comments

Comments
 (0)