Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tensorflow/lite/micro/kernels/comparisons.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -286,6 +286,19 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input2), output_shape,
output_data);
break;
case kTfLiteInt16:
requires_broadcast
? reference_ops::Broadcast4DSlowGreaterWithScaling(
data->params, input1_shape,
tflite::micro::GetTensorData<int16_t>(input1), input2_shape,
tflite::micro::GetTensorData<int16_t>(input2), output_shape,
output_data)
: reference_ops::GreaterWithScaling(
data->params, input1_shape,
tflite::micro::GetTensorData<int16_t>(input1), input2_shape,
tflite::micro::GetTensorData<int16_t>(input2), output_shape,
output_data);
break;
default:
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
Expand Down
54 changes: 53 additions & 1 deletion tensorflow/lite/micro/kernels/comparisons_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -126,6 +126,29 @@ void TestComparisonQuantizedInt8(const TFLMRegistration& registration,
TestComparison(registration, tensors, expected_output_data, output_data);
}

void TestComparisonQuantizedInt16(const TFLMRegistration& registration,
int* input1_dims_data, float* input1_data,
int16_t* input1_quantized, float input1_scale,
int input1_zero_point, int* input2_dims_data,
float* input2_data, int16_t* input2_quantized,
float input2_scale, int input2_zero_point,
bool* expected_output_data,
int* output_dims_data, bool* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);

TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input1_data, input1_quantized, input1_dims,
input1_scale, input1_zero_point),
CreateQuantizedTensor(input2_data, input2_quantized, input2_dims,
input2_scale, input2_zero_point),
CreateTensor(output_data, output_dims),
};

TestComparison(registration, tensors, expected_output_data, output_data);
}

} // namespace
} // namespace testing
} // namespace tflite
Expand Down Expand Up @@ -656,6 +679,35 @@ TF_LITE_MICRO_TEST(GreaterQuantizedInt8WithBroadcast) {
}
}

TF_LITE_MICRO_TEST(GreaterQuantizedInt16WithBroadcast) {
const int num_shapes = 4;
const int max_shape_size = 5;
int test_shapes[num_shapes][max_shape_size] = {
{1, 6}, {2, 2, 3}, {3, 2, 1, 3}, {4, 1, 3, 1, 2}};

for (int i = 0; i < num_shapes; ++i) {
int* input1_dim = test_shapes[i];
int input2_dim[] = {1, 1};
float input1_data[] = {20, -2, -71, 8, 11, 20};
float input2_data[] = {8};

bool expected_data[] = {true, false, false, false, true, true};
int* expected_dim = input1_dim;

const float input1_scale = 0.5;
const int input1_zero_point = -9;
int16_t input1_quantized[6];
int16_t input2_quantized[6];

bool output_data[6];
tflite::testing::TestComparisonQuantizedInt16(
tflite::Register_GREATER(), input1_dim, input1_data, input1_quantized,
input1_scale, input1_zero_point, input2_dim, input2_data,
input2_quantized, input1_scale, input1_zero_point, expected_data,
expected_dim, output_data);
}
}

TF_LITE_MICRO_TEST(GreaterEqualQuantizedInt8WithBroadcast) {
const int num_shapes = 4;
const int max_shape_size = 5;
Expand Down
106 changes: 89 additions & 17 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -238,25 +238,97 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
if (bias == nullptr || bias->type == kTfLiteInt32) {
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
data.per_channel_output_multiplier,
reinterpret_cast<const int*>(
data.per_channel_output_shift),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output))
: tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else if (bias->type == kTfLiteInt64) {
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
data.per_channel_output_multiplier,
reinterpret_cast<const int*>(
data.per_channel_output_shift),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output))
: tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
}
break;
}
default: {
Expand Down
13 changes: 9 additions & 4 deletions tensorflow/lite/micro/kernels/fully_connected_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,9 +95,14 @@ TfLiteStatus CalculateOpDataFullyConnected(
filter->quantization.params);
const int per_channel_quantization_size = affine_quantization->scale->size;

// Currently only Int8 is supported for per channel quantization.
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 && filter->type != kTfLiteInt4);
// Currently only Int8/Int16 are supported for per channel quantization.
TF_LITE_ENSURE(
context,
(input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) ||
(input->type == kTfLiteInt16 && filter->type != kTfLiteInt4));

TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
per_channel_quantization_size);

TF_LITE_ENSURE_EQ(
context, per_channel_quantization_size,
Expand Down
Loading
Loading