This commit is contained in:
jomjol
2022-08-28 19:52:51 +02:00
parent 338184712d
commit c9b7a5f84c
223 changed files with 13226 additions and 2342 deletions

View File

@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default: {
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
MicroPrintf("Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default: {
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
MicroPrintf("Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_RELU() {
return {/*init=*/ReluInit,
/*free=*/nullptr,
/*prepare=*/ReluPrepare,
/*invoke=*/ReluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
}
TfLiteRegistration Register_RELU6() {
return {/*init=*/Relu6Init,
/*free=*/nullptr,
/*prepare=*/Relu6Prepare,
/*invoke=*/Relu6Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
}
} // namespace tflite

View File

@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_ADD() {
return {/*init=*/AddInit,
/*free=*/nullptr,
/*prepare=*/AddPrepare,
/*invoke=*/AddEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
}
} // namespace tflite

View File

@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_ADD_N() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace arg_min_max
TfLiteRegistration Register_ARG_MAX() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/arg_min_max::ArgMaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
}
TfLiteRegistration Register_ARG_MIN() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/arg_min_max::ArgMinEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
}
} // namespace micro

View File

@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_ASSIGN_VARIABLE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_BROADCAST_ARGS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/BroadcastArgsPrepare,
/*invoke=*/BroadcastArgsEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
BroadcastArgsEval);
}
} // namespace tflite
} // namespace tflite

View File

@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_BROADCAST_TO() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/BroadcastToPrepare,
/*invoke=*/BroadcastToEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
BroadcastToEval);
}
} // namespace tflite
} // namespace tflite

View File

@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_CALL_ONCE() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CAST() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace ceil
TfLiteRegistration Register_CEIL() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/ceil::Prepare,
/*invoke=*/ceil::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
}
} // namespace micro

View File

@@ -108,14 +108,8 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
static TfLiteRegistration r = {/*init=*/CircularBufferInit,
/*free=*/nullptr,
/*prepare=*/CircularBufferPrepare,
/*invoke=*/CircularBufferEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
static TfLiteRegistration r = tflite::micro::RegisterOp(
CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
return &r;
}

View File

@@ -39,13 +39,12 @@ const int kCircularBufferCyclesMaxIndex = 0; // 'cycles_max'
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(-9);
TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
MicroContext * micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context-> AllocateTempInputTensor(node, kCircularBufferInputTensor);
TfLiteTensor* output =
micro_context-> AllocateTempOutputTensor(node, kCircularBufferOutputTensor);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kCircularBufferInputTensor);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
node, kCircularBufferOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
OpDataCircularBuffer* op_data =

View File

@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} // namespace comparisons
TfLiteRegistration Register_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::EqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::EqualEval);
}
TfLiteRegistration Register_NOT_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::NotEqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::NotEqualEval);
}
TfLiteRegistration Register_GREATER() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::GreaterEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::GreaterEval);
}
TfLiteRegistration Register_GREATER_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::GreaterEqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::GreaterEqualEval);
}
TfLiteRegistration Register_LESS() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::LessEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::LessEval);
}
TfLiteRegistration Register_LESS_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::LessEqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::LessEqualEval);
}
} // namespace micro

View File

@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input != nullptr);
int num_dimensions = NumDimensions(input);
if (num_dimensions > 4) {
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
TF_LITE_KERNEL_LOG(
context,
"Op Concatenation does not currently support num dimensions >4 "
"Op Concatenation does not currently support num dimensions > %d "
"Tensor has %d dimensions.",
num_dimensions);
RuntimeShape::kMaxSmallSize, num_dimensions);
return kTfLiteError;
}
micro_context->DeallocateTempTfLiteTensor(input);
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace concatenation
TfLiteRegistration Register_CONCATENATION() {
return {/*init=*/concatenation::Init,
/*free=*/nullptr,
/*prepare=*/concatenation::Prepare,
/*invoke=*/concatenation::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
concatenation::Eval);
}
} // namespace micro

View File

@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
break;
}
case kTfLiteInt16: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<std::int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
switch (bias->type) {
case kTfLiteInt32: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
case kTfLiteInt64: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default:
MicroPrintf("Bias type %s (%d) not supported.",
TfLiteTypeGetName(bias->type), bias->type);
return kTfLiteError;
}
break;
}
case kTfLiteInt8: {
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/ConvPrepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
}
} // namespace tflite

View File

@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
TfLiteRegistration registration, int16_t* output_data);
TfLiteStatus TestConvQuantizedPerChannel(
int* input_dims_data, const float* input_data, int16_t* input_quantized,
float input_scale, int input_zero_point, int* filter_dims_data,
const float* filter_data, int8_t* filter_data_quantized,
int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
float* bias_scales, int* bias_zero_points, int* output_dims_data,
const float* expected_output_data, int16_t* expected_output_data_quantized,
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
TfLiteRegistration registration, int16_t* output_data);
} // namespace testing
} // namespace tflite

View File

@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CUMSUM() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_DEPTH_TO_SPACE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/DepthwiseConvPrepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
// This is the most generic TfLiteRegistration. The actual supported types may
// still be target dependent. The only requirement is that every implementation
// (reference or optimized) must define this function.
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
#if defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int8 activations and int8 weights and uses the latency optimized
// implementations.
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int16 activations and int8 weights and uses the latency optimized
// implementations.
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
#else
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
return Register_DEPTHWISE_CONV_2D();
}
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
return Register_DEPTHWISE_CONV_2D();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_

View File

@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteUInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_DEQUANTIZE() {
return {/*init=*/DequantizeInit,
/*free=*/nullptr,
/*prepare=*/DequantizePrepare,
/*invoke=*/DequantizeEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
DequantizeEval);
}
} // namespace tflite

View File

@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteUInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
if (output->type == kTfLiteInt32) {

View File

@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return op_data;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = static_cast<OpData*>(node->user_data);
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
static TfLiteRegistration r = {/*init=*/Init,
/*free=*/Free,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}

View File

@@ -0,0 +1,208 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/div.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
struct OpDataDiv {
// Parameters used in the quantized paths where the output is 8bit
int32_t input1_zero_point;
int32_t input2_zero_point;
int32_t output_zero_point;
int32_t output_activation_min;
int32_t output_activation_max;
// Parameters used in all quantized paths
int32_t output_multiplier;
int output_shift;
};
TfLiteStatus CalculateOpDataDiv(TfLiteContext* context, TfLiteTensor* input1,
TfLiteTensor* input2, TfLiteTensor* output,
TfLiteDivParams* params, OpDataDiv* data) {
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, output->type);
if (output->type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
const double real_multiplier = static_cast<double>(
input1->params.scale / (input2->params.scale * output->params.scale));
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
&data->output_shift);
data->input1_zero_point = input1->params.zero_point;
data->input2_zero_point = input2->params.zero_point;
data->output_zero_point = output->params.zero_point;
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataDiv));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input1 =
micro_context->AllocateTempInputTensor(node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* input2 =
micro_context->AllocateTempInputTensor(node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
OpDataDiv* data = static_cast<OpDataDiv*>(node->user_data);
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
TF_LITE_ENSURE_STATUS(
CalculateOpDataDiv(context, input1, input2, output, params, data));
micro_context->DeallocateTempTfLiteTensor(input1);
micro_context->DeallocateTempTfLiteTensor(input2);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
const OpDataDiv* data, const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params = {};
#define TF_LITE_DIV(type, opname, data_type) \
data_type output_activation_min, output_activation_max; \
CalculateActivationRange(params->activation, &output_activation_min, \
&output_activation_max); \
SetActivationParams(output_activation_min, output_activation_max, \
&op_params); \
type::opname(op_params, tflite::micro::GetTensorShape(input1), \
tflite::micro::GetTensorData<data_type>(input1), \
tflite::micro::GetTensorShape(input2), \
tflite::micro::GetTensorData<data_type>(input2), \
tflite::micro::GetTensorShape(output), \
tflite::micro::GetTensorData<data_type>(output))
bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (requires_broadcast) {
TF_LITE_DIV(reference_ops, BroadcastDivSlow, float);
} else {
TF_LITE_DIV(reference_ops, Div, float);
}
#undef TF_LITE_DIV
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteDivParams* params, const OpDataDiv* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params = {};
#define TF_LITE_DIV(type, opname, dtype) \
type::opname(op_params, tflite::micro::GetTensorShape(input1), \
tflite::micro::GetTensorData<dtype>(input1), \
tflite::micro::GetTensorShape(input2), \
tflite::micro::GetTensorData<dtype>(input2), \
tflite::micro::GetTensorShape(output), \
tflite::micro::GetTensorData<dtype>(output))
if (input1->type == kTfLiteInt8 && input2->type == kTfLiteInt8 &&
output->type == kTfLiteInt8) {
SetActivationParams(data->output_activation_min,
data->output_activation_max, &op_params);
op_params.input1_offset = -data->input1_zero_point;
op_params.input2_offset = -data->input2_zero_point;
op_params.output_offset = data->output_zero_point;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (requires_broadcast) {
TF_LITE_DIV(reference_ops, BroadcastDivSlow, int8_t);
} else {
TF_LITE_DIV(reference_ops, Div, int8_t);
}
#undef TF_LITE_DIV
} else {
TF_LITE_KERNEL_LOG(
context, "Unsupported combination of input and output types in DIV.");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = static_cast<TfLiteDivParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
auto* data = static_cast<OpDataDiv*>(node->user_data);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
EvalDiv(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteInt8) {
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
input1, input2, output));
} else {
TF_LITE_KERNEL_LOG(context,
"DIV only supports FLOAT32, quantized INT8 "
"now, got type %s (%d).",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_DIV() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@ limitations under the License.
#include <cmath>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -27,6 +29,22 @@ namespace micro {
namespace elementwise {
namespace {
constexpr int kAbsNameId = 0;
constexpr int kRsrqtNameId = 1;
const int kElementwiseInputTensor = 0;
const int kElementwiseOutputTensor = 0;
struct OpDataAbsRsqrt {
int32_t multiplier;
int shift;
int input_offset;
int output_offset;
bool needs_rescale;
TfLiteQuantizationType input_quantization_type;
TfLiteType input_type;
};
bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32;
}
@@ -35,16 +53,40 @@ bool IsLogicalSupportedType(const TfLiteType type) {
return type == kTfLiteBool;
}
bool IsAbsSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
}
bool IsRsqrtSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8;
}
inline void SetAbsOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int* shift) {
QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
multiplier, shift);
}
inline void SetRsqrtOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int* shift) {
const double scale =
1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
QuantizeMultiplier(scale, multiplier, shift);
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
@@ -58,9 +100,79 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType, const int op_nameid>
TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
op_data->input_type = input->type;
// For int16 type input, we support both quantized and non-quantized
// evaluation.
if (op_nameid == kAbsNameId) {
op_data->input_quantization_type = input->quantization.type;
}
if (input->type == kTfLiteInt8 ||
(input->type == kTfLiteInt16 &&
input->quantization.type != kTfLiteNoQuantization)) {
TF_LITE_ENSURE_EQ(context, input->quantization.type,
kTfLiteAffineQuantization);
TF_LITE_ENSURE_EQ(context, output->quantization.type,
kTfLiteAffineQuantization);
const auto* input_params =
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
TF_LITE_ENSURE(context, input_params != nullptr);
TF_LITE_ENSURE(context, input_params->scale != nullptr);
TF_LITE_ENSURE(context, input_params->scale->size > 0);
TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
TF_LITE_ENSURE(context, output_params != nullptr);
TF_LITE_ENSURE(context, output_params->scale != nullptr);
TF_LITE_ENSURE(context, output_params->scale->size > 0);
TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
op_data->input_offset = input_params->zero_point->data[0];
op_data->output_offset = output_params->zero_point->data[0];
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
}
const float input_scale = input_params->scale->data[0];
const float output_scale = output_params->scale->data[0];
op_data->needs_rescale = input_scale != output_scale;
if (op_nameid == kAbsNameId && op_data->needs_rescale) {
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
} else if (op_nameid == kRsrqtNameId) {
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
}
}
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteType expected_type) {
inline TfLiteStatus EvalImplQuantized(
TfLiteContext* context, TfLiteNode* node,
T func(TfLiteContext*, TfLiteNode*, T),
TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
@@ -68,6 +180,34 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context,
validate_input_func(context, node, in_data[i]));
}
out_data[i] = func(context, node, in_data[i]);
}
return kTfLiteOk;
}
template <typename T>
inline T AbsHelper(T i) {
return std::abs(i);
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteStatus validate_input_func(T),
TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const size_t num_elements = ElementCount(*input->dims);
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
}
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
return EvalImpl<float>(context, node, float_func,
/*validate_input_func=*/nullptr, kTfLiteFloat32);
}
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
bool bool_func(bool)) {
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
return EvalImpl<bool>(context, node, bool_func,
/*validate_input_func=*/nullptr, kTfLiteBool);
}
void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
}
template <typename T>
inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
const int32_t value = std::abs(i - op_data->input_offset);
if (!op_data->needs_rescale) {
return static_cast<T>(
std::min(std::max(static_cast<long int>(value + op_data->output_offset),
static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
const int32_t output = tflite::MultiplyByQuantizedMultiplier(
value, op_data->multiplier, op_data->shift) +
op_data->output_offset;
return static_cast<T>(std::min(
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
template <typename T>
inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
const int32_t value = (i - op_data->input_offset);
const int32_t kShift = 20; // Shift to keep value integer.
if (value == 0) {
// Assume that any value close to 0 represents the max output value.
return static_cast<T>(kMax);
}
int32_t inv_sqrt_multiplier;
int inv_sqrt_shift;
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
&inv_sqrt_shift);
const int32_t data = tflite::MultiplyByQuantizedMultiplier(
static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
const int32_t output =
tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
op_data->shift - kShift) +
op_data->output_offset;
return static_cast<T>(std::min(
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
template <typename T>
TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
"Rsqrt is only defined for positive values");
return static_cast<TfLiteStatus>(kTfLiteOk);
}
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::abs);
OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
TfLiteType type = op_data->input_type;
TfLiteQuantizationType input_quantization_type =
op_data->input_quantization_type;
TfLiteStatus eval_result;
switch (type) {
case kTfLiteFloat32:
eval_result = EvalNumeric(context, node, std::abs);
break;
case kTfLiteInt8:
eval_result =
EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
/*validate_input_func=*/nullptr, type);
break;
case kTfLiteInt16:
eval_result =
input_quantization_type == kTfLiteNoQuantization
? EvalImpl<int16_t>(context, node, AbsHelper,
/*validate_input_func=*/nullptr, type)
: EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
/*validate_input_func=*/nullptr,
type);
break;
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
break;
}
return eval_result;
}
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
TfLiteType type = op_data->input_type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(
context, node, [](float f) { return 1.f / std::sqrt(f); },
/*validate_input_func=*/nullptr, type);
case kTfLiteInt8:
return EvalImplQuantized<int8_t>(context, node,
elementwise::RsqrtEvalQuantized,
elementwise::validate_input_func, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
@@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace elementwise
TfLiteRegistration Register_ABS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::AbsEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
elementwise::ElementWiseAbsRsqrtInit,
elementwise::PrepareAbsRsqrt<elementwise::IsAbsSupportedType,
elementwise::kAbsNameId>,
elementwise::AbsEval);
}
TfLiteRegistration Register_SIN() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::SinEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SinEval);
}
TfLiteRegistration Register_COS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::CosEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::CosEval);
}
TfLiteRegistration Register_LOG() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::LogEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval);
}
TfLiteRegistration Register_SQRT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::SqrtEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SqrtEval);
}
TfLiteRegistration Register_RSQRT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::RsqrtEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
elementwise::ElementWiseAbsRsqrtInit,
elementwise::PrepareAbsRsqrt<elementwise::IsRsqrtSupportedType,
elementwise::kRsrqtNameId>,
elementwise::RsqrtEval);
}
TfLiteRegistration Register_SQUARE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::SquareEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SquareEval);
}
TfLiteRegistration Register_LOGICAL_NOT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
/*invoke=*/elementwise::LogicalNotEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
elementwise::LogicalNotEval);
}
} // namespace micro
} // namespace ops
} // namespace tflite
} // namespace tflite

View File

@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_ELU() {
return {/*init=*/EluInit,
/*free=*/nullptr,
/*prepare=*/EluPrepare,
/*invoke=*/EluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
}
} // namespace tflite

View File

@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_ADD() {
return {/*init=*/AddInit,
/*free=*/nullptr,
/*prepare=*/AddPrepare,
/*invoke=*/AddEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
}
} // namespace tflite

View File

@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if ESP_NN
if (input->type == kTfLiteInt8) {
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input->dims->data[3], 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output->dims->data[3], 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
conv_params_t conv_params = {
.in_offset = 0, .out_offset = 0,
.stride = {params.stride_width, params.stride_height},
.padding = {data->op_data.padding.width, data->op_data.padding.height},
.dilation = {0, 0}, .activation = {-128, 127}
};
int scratch_buf_size = esp_nn_get_conv_scratch_size(
input_width, input_height, input->dims->data[3],
output->dims->data[3], filter_width, filter_height);
&input_dims, &filter_dims, &output_dims, &conv_params);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
const int input_size = input_width * input_height * input_depth;
const int output_size = output_width * output_height * output_depth;
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input_depth, 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output_depth, 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
conv_params_t conv_params = {
.in_offset = input_offset, .out_offset = output_offset,
.stride = {stride_width, stride_height},
.padding = {pad_width, pad_height},
.dilation = {0, 0},
.activation = {activation_min, activation_max}
};
quant_data_t quant_data = {
.shift = data.op_data.per_channel_output_shift,
.mult = data.op_data.per_channel_output_multiplier
};
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
esp_nn_conv_s8(input_data + i_batch * input_size,
input_width, input_height, input_depth, input_offset,
pad_width, pad_height, stride_width, stride_height,
tflite::micro::GetTensorData<int8_t>(filter),
filter_width, filter_height,
esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorData<int32_t>(bias),
output_data + i_batch * output_size,
output_width, output_height, output_depth, output_offset,
data.op_data.per_channel_output_shift,
data.op_data.per_channel_output_multiplier,
activation_min, activation_max);
&output_dims, output_data + i_batch * output_size,
&conv_params, &quant_data);
}
} else {
reference_integer_ops::ConvPerChannel(
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
conv_total_time += esp_timer_get_time() - start_time;
long long time_this_instance = esp_timer_get_time() - start_time;
conv_total_time += time_this_instance;
//printf("time this instance: %llu\n", time_this_instance / 1000);
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
if (data.buffer_idx > -1) {
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
}
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input_depth, 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output_depth, 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
dw_conv_params_t conv_params = {
.in_offset = input_offset, .out_offset = output_offset,
.ch_mult = depth_multiplier,
.stride = {stride_width, stride_height},
.padding = {pad_width, pad_height}, .dilation = {0, 0},
.activation = {activation_min, activation_max}
};
quant_data_t quant_data = {
.shift = data.op_data.per_channel_output_shift,
.mult = data.op_data.per_channel_output_multiplier
};
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
input_height, input_depth, input_offset,
pad_width, pad_height,
stride_width, stride_height, depth_multiplier,
tflite::micro::GetTensorData<int8_t>(filter),
filter_width, filter_height,
esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorData<int32_t>(bias),
output_data + i_batch * output_size,
output_width, output_height, output_offset,
data.op_data.per_channel_output_shift,
data.op_data.per_channel_output_multiplier,
activation_min, activation_max);
&output_dims, output_data + i_batch * output_size,
&conv_params, &quant_data);
}
} else {
reference_integer_ops::DepthwiseConvPerChannel(
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if ESP_NN
if (input->type == kTfLiteInt8) {
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input->dims->data[3], 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output->dims->data[3], 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
dw_conv_params_t conv_params = {
.in_offset = 0, .out_offset = 0,
.ch_mult = params.depth_multiplier,
.stride = {params.stride_width, params.stride_height},
.padding = {data->op_data.padding.width, data->op_data.padding.height},
.dilation = {0, 0}, .activation = {-128, 127}
};
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
input_width, input_height, input->dims->data[3],
params.depth_multiplier, filter_width, filter_height);
&input_dims, &filter_dims, &output_dims, &conv_params);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
dc_total_time += esp_timer_get_time() - start_time;
long long time_this_instance = esp_timer_get_time() - start_time;
dc_total_time += time_this_instance;
// printf("time this instance: %llu\n", time_this_instance / 1000);
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_MUL() {
return {/*init=*/MulInit,
/*free=*/nullptr,
/*prepare=*/MulPrepare,
/*invoke=*/MulEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
}
} // namespace tflite

View File

@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} // namespace
TfLiteRegistration Register_AVERAGE_POOL_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PoolingPrepare,
/*invoke=*/AverageEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
}
TfLiteRegistration Register_MAX_POOL_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PoolingPrepare,
/*invoke=*/MaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
}
} // namespace tflite

View File

@@ -0,0 +1,208 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/kernels/softmax.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN
#include <esp_nn.h>
#endif
long long softmax_total_time = 0;
namespace tflite {
namespace {
// Softmax parameter data that persists in user_data
const int kInt16LUTArraySize = 513;
struct NodeData {
SoftmaxParams op_data;
#if ESP_NN
int buffer_idx;
#endif
};
static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(NodeData));
}
void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
TfLiteEvalTensor* output, const NodeData* data) {
if (input->type == kTfLiteInt8) {
if (output->type == kTfLiteInt16) {
tflite::reference_ops::Softmax(
data->op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else {
#if ESP_NN
const int32_t input_beta_multiplier = data->op_data.input_multiplier;
const int32_t input_beta_left_shift = data->op_data.input_left_shift;
const int diff_min = data->op_data.diff_min;
const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int8_t *in_ptr = tflite::micro::GetTensorData<int8_t>(input);
int8_t *out_ptr = tflite::micro::GetTensorData<int8_t>(output);
void *scratch_buf = NULL;
if (data->buffer_idx > -1) {
scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
}
esp_nn_set_softmax_scratch_buf(scratch_buf);
esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
input_beta_left_shift, diff_min, out_ptr);
#else
tflite::reference_ops::Softmax(
data->op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
}
} else {
tflite::reference_ops::SoftmaxInt16(
data->op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
}
}
static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TFLITE_DCHECK(node->user_data != nullptr);
NodeData data = *static_cast<NodeData*>(node->user_data);
long long start_time = esp_timer_get_time();
switch (input->type) {
case kTfLiteFloat32: {
tflite::reference_ops::Softmax(
data.op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
break;
case kTfLiteInt8:
case kTfLiteInt16: {
SoftmaxQuantized(context, input, output, &data);
}
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
softmax_total_time += esp_timer_get_time() - start_time;
return kTfLiteOk;
}
static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, node->user_data != nullptr);
NodeData* data = static_cast<NodeData*>(node->user_data);
// Only allocate LUTs for KTfLiteInt16 data type
if (input->type == kTfLiteInt16) {
void* raw_exp_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
data->op_data.exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
data->op_data.one_over_one_plus_x_lut =
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
}
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
}
// Populate LUT if required
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// exp LUT only used on negative values
// we consider exp(-10.0) is insignificant to accumulation
gen_lut<float, int16_t, int16_t>(
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
data->op_data.exp_lut);
gen_lut<float, int16_t, int16_t>(
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
1.0f, data->op_data.one_over_one_plus_x_lut);
data->op_data.zero_point = output->params.zero_point;
data->op_data.scale = output->params.scale;
}
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
auto ret_val =
CalculateSoftmaxParams(context, input, output, params, &data->op_data);
#if ESP_NN
if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
const int32_t input_width = input->dims->data[1];
const int32_t input_height = input->dims->data[2];
int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
input_height);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
}
}
#endif
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return ret_val;
}
} // namespace
TfLiteRegistration Register_SOFTMAX() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_EXP() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_EXPAND_DIMS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FILL() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace floor
TfLiteRegistration Register_FLOOR() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/floor::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
}
} // namespace micro

View File

@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FLOOR_DIV() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FLOOR_MOD() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
context, params->activation, input->type,
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
const int64_t* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
: nullptr;
tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias), bias_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default: {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
@@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
}
#endif
#if defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int16.
TfLiteRegistration Register_FULLY_CONNECTED_INT16();
#else
// Note that while this block gets used for both reference and optimized kernels
// that do not have any specialized implementations, the only goal here is to
// define fallback implementation that allow reference kernels to still be used
// from applications that call a more specific kernel variant.
inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
return Register_FULLY_CONNECTED();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_

View File

@@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_GATHER() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -131,7 +131,8 @@ TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
slice_size *= params->dims->data[i];
}
int remain_flat_size = ElementCount(*params->dims);
int params_flat_size = ElementCount(*params->dims);
int remain_flat_size = params_flat_size;
// Number of elements per dimension
int dims_to_count[MAX_INDICES_ND];
@@ -147,6 +148,9 @@ TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
IndicesT index = index_data[offset];
from_pos += index * dims_to_count[j];
}
if (from_pos < 0 || from_pos + slice_size > params_flat_size) {
return kTfLiteError;
}
std::memcpy(output_data + i * slice_size, param_data + from_pos,
sizeof(ParamsT) * slice_size);
}
@@ -158,12 +162,13 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
const TfLiteEvalTensor* params,
const TfLiteEvalTensor* indices,
TfLiteEvalTensor* output) {
TfLiteStatus status = kTfLiteError;
switch (params->type) {
case kTfLiteFloat32:
return GatherNd<float, IndicesT>(params, indices, output);
status = GatherNd<float, IndicesT>(params, indices, output);
break;
case kTfLiteInt8:
return GatherNd<int8_t, IndicesT>(params, indices, output);
status = GatherNd<int8_t, IndicesT>(params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
@@ -171,6 +176,10 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
TfLiteTypeGetName(params->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
}
return status;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -195,14 +204,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_GATHER_ND() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_HARD_SWISH() {
return {/*init=*/HardSwishInit,
/*free=*/nullptr,
/*prepare=*/tflite::HardSwishPrepare,
/*invoke=*/HardSwishEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare,
HardSwishEval);
}
} // namespace tflite

View File

@@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_IF() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
#include "tensorflow/lite/micro/test_helpers.h"
namespace tflite {
@@ -27,14 +27,21 @@ namespace micro {
constexpr int KernelRunner::kKernelRunnerBufferSize_;
uint8_t KernelRunner::kKernelRunnerBuffer_[];
void ClearBufferApi(TfLiteContext* context_) {
context_->GetScratchBuffer = nullptr;
context_->GetExternalContext = nullptr;
context_->AllocatePersistentBuffer = nullptr;
context_->RequestScratchBufferInArena = nullptr;
}
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
void* builtin_data)
void* builtin_data, TfLiteIntArray* intermediates)
: registration_(registration),
allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
kKernelRunnerBuffer_,
kKernelRunnerBufferSize_)),
allocator_(SingleArenaBufferAllocator::Create(GetMicroErrorReporter(),
kKernelRunnerBuffer_,
kKernelRunnerBufferSize_)),
mock_micro_graph_(allocator_),
fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
// Prepare TfLiteContext:
@@ -43,10 +50,8 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
context_.recommended_num_threads = 1;
context_.GetTensor = MicroContextGetTensor;
context_.GetEvalTensor = MicroContextGetEvalTensor;
tflite::micro::ClearBufferApi(&context_);
context_.AllocatePersistentBuffer = MicroContextAllocatePersistentBuffer;
context_.RequestScratchBufferInArena =
MicroContextRequestScratchBufferInArena;
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
context_.recommended_num_threads = 0;
@@ -54,6 +59,7 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
node_.inputs = inputs;
node_.outputs = outputs;
node_.builtin_data = builtin_data;
node_.intermediates = intermediates;
}
bool KernelRunner::ValidateTempBufferDeallocated() {
@@ -63,12 +69,19 @@ bool KernelRunner::ValidateTempBufferDeallocated() {
TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
size_t length) {
if (registration_.init) {
tflite::micro::ClearBufferApi(&context_);
context_.AllocatePersistentBuffer = MicroContextAllocatePersistentBuffer;
node_.user_data = registration_.init(&context_, init_data, length);
}
TF_LITE_ENSURE(&context_, ValidateTempBufferDeallocated());
if (registration_.prepare) {
tflite ::micro::ClearBufferApi(&context_);
context_.AllocatePersistentBuffer = MicroContextAllocatePersistentBuffer;
context_.RequestScratchBufferInArena =
MicroContextRequestScratchBufferInArena;
context_.GetExternalContext = MicroContextGetExternalContext;
TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
}
@@ -78,6 +91,9 @@ TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
}
TfLiteStatus KernelRunner::Invoke() {
tflite::micro::ClearBufferApi(&context_);
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
if (registration_.invoke == nullptr) {
MicroPrintf("TfLiteRegistration missing invoke function pointer!");
return kTfLiteError;

View File

@@ -18,9 +18,9 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/fake_micro_context.h"
#include "tensorflow/lite/micro/mock_micro_graph.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
namespace tflite {
namespace micro {
@@ -35,7 +35,8 @@ class KernelRunner {
public:
KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
int tensors_size, TfLiteIntArray* inputs,
TfLiteIntArray* outputs, void* builtin_data);
TfLiteIntArray* outputs, void* builtin_data,
TfLiteIntArray* intermediates = nullptr);
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
// exceptions will be DebugLog'd and returned as a status code.
@@ -64,7 +65,7 @@ class KernelRunner {
TfLiteNode node_ = {};
const TfLiteRegistration& registration_;
SimpleMemoryAllocator* allocator_;
SingleArenaBufferAllocator* allocator_;
MockMicroGraph mock_micro_graph_;
FakeMicroContext fake_micro_context_;
};

View File

@@ -36,6 +36,21 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
} // namespace
TfLiteRegistration RegisterOp(
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
return {/*init=*/init,
/*free=*/nullptr,
/*prepare=*/prepare,
/*invoke=*/invoke,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0,
/*registration_external=*/nullptr};
}
// Returns a mutable tensor for a given input index. is_variable must be checked
// during prepare when the full TfLiteTensor is available.
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,

View File

@@ -27,6 +27,11 @@ limitations under the License.
namespace tflite {
namespace micro {
TfLiteRegistration RegisterOp(
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
// Returns a mutable tensor for a given input index. is_variable must be checked
// during prepare when the full TfLiteTensor is available.
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
@@ -40,19 +45,33 @@ const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context,
TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
const TfLiteNode* node, int index);
// Returns data for a TfLiteEvalTensor struct.
// Returns data for a TfLiteEvalTensor struct that are expected to exist.
template <typename T>
T* GetTensorData(TfLiteEvalTensor* tensor) {
return tensor != nullptr ? reinterpret_cast<T*>(tensor->data.raw) : nullptr;
TFLITE_DCHECK(tensor != nullptr);
return reinterpret_cast<T*>(tensor->data.raw);
}
// Returns const data for a TfLiteEvalTensor struct.
// Returns const data for a TfLiteEvalTensor struct that are expected to exist.
template <typename T>
const T* GetTensorData(const TfLiteEvalTensor* tensor) {
TFLITE_DCHECK(tensor != nullptr);
return reinterpret_cast<const T*>(tensor->data.raw);
}
// Returns data for a TfLiteEvalTensor struct that could be null.
template <typename T>
T* GetOptionalTensorData(TfLiteEvalTensor* tensor) {
return tensor == nullptr ? nullptr : reinterpret_cast<T*>(tensor->data.raw);
}
// Returns const data for a TfLiteEvalTensor struct that could be null.
template <typename T>
const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
return tensor == nullptr ? nullptr
: reinterpret_cast<const T*>(tensor->data.raw);
}
// Returns the shape of a TfLiteEvalTensor struct.
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);

View File

@@ -136,14 +136,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_L2_POOL_2D() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/L2Prepare,
/*invoke=*/L2Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, L2Prepare, L2Eval);
}
} // namespace tflite

View File

@@ -137,14 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace l2norm
TfLiteRegistration Register_L2NORM_REF() {
return {/*init=*/l2norm::Init,
/*free=*/nullptr,
/*prepare=*/l2norm::Prepare,
/*invoke=*/l2norm::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(l2norm::Init, l2norm::Prepare, l2norm::Eval);
}
TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }

View File

@@ -88,14 +88,8 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_LEAKY_RELU() {
return {/*init=*/LeakyReluInit,
/*free=*/nullptr,
/*prepare=*/LeakyReluPrepare,
/*invoke=*/LeakyReluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(LeakyReluInit, LeakyReluPrepare,
LeakyReluEval);
}
} // namespace tflite

View File

@@ -142,14 +142,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_LOG_SOFTMAX() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/LogSoftmaxPrepare,
/*invoke=*/LogSoftmaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, LogSoftmaxPrepare, LogSoftmaxEval);
}
} // namespace tflite

View File

@@ -34,29 +34,11 @@ TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_LOGICAL_OR() {
// Init, Free, Prepare, Eval are satisfying the Interface required by
// TfLiteRegistration.
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/LogicalOrEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, LogicalOrEval);
}
TfLiteRegistration Register_LOGICAL_AND() {
// Init, Free, Prepare, Eval are satisfying the Interface required by
// TfLiteRegistration.
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/LogicalAndEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, LogicalAndEval);
}
} // namespace tflite

View File

@@ -106,13 +106,6 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_LOGISTIC() {
return {/*init=*/LogisticInit,
/*free=*/nullptr,
/*prepare=*/LogisticPrepare,
/*invoke=*/LogisticEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(LogisticInit, LogisticPrepare, LogisticEval);
}
} // namespace tflite

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,250 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
#include <cstdint>
#include <memory>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
namespace tflite {
// Pamameters for integer LSTM.
// Consider split this into two Integer Parameters if more fields are added.
struct IntegerLstmParameter {
int32_t effective_input_to_input_scale_a;
int32_t effective_input_to_input_scale_b;
int32_t effective_recurrent_to_input_scale_a;
int32_t effective_recurrent_to_input_scale_b;
int32_t effective_cell_to_input_scale_a;
int32_t effective_cell_to_input_scale_b;
int32_t effective_input_to_forget_scale_a;
int32_t effective_input_to_forget_scale_b;
int32_t effective_recurrent_to_forget_scale_a;
int32_t effective_recurrent_to_forget_scale_b;
int32_t effective_cell_to_forget_scale_a;
int32_t effective_cell_to_forget_scale_b;
int32_t effective_input_to_cell_scale_a;
int32_t effective_input_to_cell_scale_b;
int32_t effective_recurrent_to_cell_scale_a;
int32_t effective_recurrent_to_cell_scale_b;
int32_t effective_input_to_output_scale_a;
int32_t effective_input_to_output_scale_b;
int32_t effective_recurrent_to_output_scale_a;
int32_t effective_recurrent_to_output_scale_b;
int32_t effective_cell_to_output_scale_a;
int32_t effective_cell_to_output_scale_b;
int32_t effective_proj_scale_a;
int32_t effective_proj_scale_b;
int32_t effective_hidden_scale_a;
int32_t effective_hidden_scale_b;
int32_t layer_norm_input_scale_a;
int32_t layer_norm_input_scale_b;
int32_t layer_norm_forget_scale_a;
int32_t layer_norm_forget_scale_b;
int32_t layer_norm_cell_scale_a;
int32_t layer_norm_cell_scale_b;
int32_t layer_norm_output_scale_a;
int32_t layer_norm_output_scale_b;
// Quantized clip value for cell and projection. Zero value means no clipping.
int16_t quantized_cell_clip;
int8_t quantized_proj_clip;
int32_t hidden_zp;
int32_t cell_scale;
int32_t input_variance_guard;
int32_t forget_variance_guard;
int32_t cell_variance_guard;
int32_t output_variance_guard;
// Pre-calculate bias + zero_point * weight.
int32_t* input_to_forget_effective_bias;
int32_t* recurrent_to_forget_effective_bias;
int32_t* input_to_cell_effective_bias;
int32_t* recurrent_to_cell_effective_bias;
int32_t* input_to_output_effective_bias;
int32_t* recurrent_to_output_effective_bias;
int32_t* input_to_input_effective_bias;
int32_t* recurrent_to_input_effective_bias;
int32_t* projection_effective_bias;
// Scale and zero point for intermediate tensors.
// Used only in the 8x8_8 case.
int32_t intermediate_scale_a[8];
int32_t intermediate_scale_b[8];
int32_t intermediate_zp[12];
};
// Scales for hybrid op with integer inputs and float weights
struct HybridLstmScales {
float input_to_input_weights_scale;
float input_to_forget_weights_scale;
float input_to_cell_weights_scale;
float input_to_output_weights_scale;
float aux_input_to_input_weights_scale;
float aux_input_to_forget_weights_scale;
float aux_input_to_cell_weights_scale;
float aux_input_to_output_weights_scale;
float recurrent_to_input_weights_scale;
float recurrent_to_forget_weights_scale;
float recurrent_to_cell_weights_scale;
float recurrent_to_output_weights_scale;
float cell_to_input_weights_scale;
float cell_to_forget_weights_scale;
float cell_to_output_weights_scale;
float projection_weights_scale;
};
TfLiteStatus EvalFloatLstm(
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* input_to_input_weights,
const TfLiteEvalTensor* input_to_forget_weights,
const TfLiteEvalTensor* input_to_cell_weights,
const TfLiteEvalTensor* input_to_output_weights,
const TfLiteEvalTensor* recurrent_to_input_weights,
const TfLiteEvalTensor* recurrent_to_forget_weights,
const TfLiteEvalTensor* recurrent_to_cell_weights,
const TfLiteEvalTensor* recurrent_to_output_weights,
const TfLiteEvalTensor* cell_to_input_weights,
const TfLiteEvalTensor* cell_to_forget_weights,
const TfLiteEvalTensor* cell_to_output_weights,
const TfLiteEvalTensor* input_layer_norm_coefficients,
const TfLiteEvalTensor* forget_layer_norm_coefficients,
const TfLiteEvalTensor* cell_layer_norm_coefficients,
const TfLiteEvalTensor* output_layer_norm_coefficients,
const TfLiteEvalTensor* aux_input,
const TfLiteEvalTensor* aux_input_to_input_weights,
const TfLiteEvalTensor* aux_input_to_forget_weights,
const TfLiteEvalTensor* aux_input_to_cell_weights,
const TfLiteEvalTensor* aux_input_to_output_weights,
const TfLiteEvalTensor* input_gate_bias,
const TfLiteEvalTensor* forget_gate_bias,
const TfLiteEvalTensor* cell_gate_bias,
const TfLiteEvalTensor* output_gate_bias,
const TfLiteEvalTensor* projection_weights,
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major, int output_offset,
float* scratch_buffer, TfLiteEvalTensor* output_state,
TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
TfLiteStatus EvalHybridLstm(
const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* input_to_input_weights,
const TfLiteEvalTensor* input_to_input_weights_ledger,
const TfLiteEvalTensor* input_to_forget_weights,
const TfLiteEvalTensor* input_to_forget_weights_ledger,
const TfLiteEvalTensor* input_to_cell_weights,
const TfLiteEvalTensor* input_to_cell_weights_ledger,
const TfLiteEvalTensor* input_to_output_weights,
const TfLiteEvalTensor* input_to_output_weights_ledger,
const TfLiteEvalTensor* recurrent_to_input_weights,
const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
const TfLiteEvalTensor* recurrent_to_forget_weights,
const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
const TfLiteEvalTensor* recurrent_to_cell_weights,
const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
const TfLiteEvalTensor* recurrent_to_output_weights,
const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
const TfLiteEvalTensor* cell_to_input_weights,
const TfLiteEvalTensor* cell_to_forget_weights,
const TfLiteEvalTensor* cell_to_output_weights,
const TfLiteEvalTensor* input_layer_norm_coefficients,
const TfLiteEvalTensor* forget_layer_norm_coefficients,
const TfLiteEvalTensor* cell_layer_norm_coefficients,
const TfLiteEvalTensor* output_layer_norm_coefficients,
const TfLiteEvalTensor* aux_input,
const TfLiteEvalTensor* aux_input_to_input_weights,
const TfLiteEvalTensor* aux_input_to_forget_weights,
const TfLiteEvalTensor* aux_input_to_cell_weights,
const TfLiteEvalTensor* aux_input_to_output_weights,
const TfLiteEvalTensor* input_gate_bias,
const TfLiteEvalTensor* forget_gate_bias,
const TfLiteEvalTensor* cell_gate_bias,
const TfLiteEvalTensor* output_gate_bias,
const TfLiteEvalTensor* projection_weights,
const TfLiteEvalTensor* projection_weights_ledger,
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major, int output_offset,
float* scratch_buffer, float* input_sf, float* aux_input_sf,
float* output_state_sf, float* prod_scaling_factors,
float* recovered_cell_weights, int8_t* input_quantized,
int8_t* aux_input_quantized, int8_t* output_state_quantized,
int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
bool* compute_row_sums);
TfLiteStatus EvalInteger8x8_16Lstm(
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* input_to_input_weights,
const TfLiteEvalTensor* input_to_forget_weights,
const TfLiteEvalTensor* input_to_cell_weights,
const TfLiteEvalTensor* input_to_output_weights,
const TfLiteEvalTensor* recurrent_to_input_weights,
const TfLiteEvalTensor* recurrent_to_forget_weights,
const TfLiteEvalTensor* recurrent_to_cell_weights,
const TfLiteEvalTensor* recurrent_to_output_weights,
const TfLiteEvalTensor* cell_to_input_weights,
const TfLiteEvalTensor* cell_to_forget_weights,
const TfLiteEvalTensor* cell_to_output_weights,
const TfLiteEvalTensor* input_layer_norm_coefficients,
const TfLiteEvalTensor* forget_layer_norm_coefficients,
const TfLiteEvalTensor* cell_layer_norm_coefficients,
const TfLiteEvalTensor* output_layer_norm_coefficients,
const TfLiteEvalTensor* input_gate_bias,
const TfLiteEvalTensor* forget_gate_bias,
const TfLiteEvalTensor* cell_gate_bias,
const TfLiteEvalTensor* output_gate_bias,
const TfLiteEvalTensor* projection_weights,
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major,
const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5);
TfLiteStatus EvalInteger8x8_8Lstm(
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* input_to_input_weights,
const TfLiteEvalTensor* input_to_forget_weights,
const TfLiteEvalTensor* input_to_cell_weights,
const TfLiteEvalTensor* input_to_output_weights,
const TfLiteEvalTensor* recurrent_to_input_weights,
const TfLiteEvalTensor* recurrent_to_forget_weights,
const TfLiteEvalTensor* recurrent_to_cell_weights,
const TfLiteEvalTensor* recurrent_to_output_weights,
const TfLiteEvalTensor* cell_to_input_weights,
const TfLiteEvalTensor* cell_to_forget_weights,
const TfLiteEvalTensor* cell_to_output_weights,
const TfLiteEvalTensor* input_layer_norm_coefficients,
const TfLiteEvalTensor* forget_layer_norm_coefficients,
const TfLiteEvalTensor* cell_layer_norm_coefficients,
const TfLiteEvalTensor* output_layer_norm_coefficients,
const TfLiteEvalTensor* input_gate_bias,
const TfLiteEvalTensor* forget_gate_bias,
const TfLiteEvalTensor* cell_gate_bias,
const TfLiteEvalTensor* output_gate_bias,
const TfLiteEvalTensor* projection_weights,
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_

View File

@@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
namespace tflite {
// Input Tensors of size {n_batch, n_input}
constexpr int kLstmInputTensor = 0;
// Input weight tensors of size: {n_cell, n_input}
constexpr int kLstmInputToInputWeightsTensor = 1; // Optional
constexpr int kLstmInputToForgetWeightsTensor = 2;
constexpr int kLstmInputToCellWeightsTensor = 3;
constexpr int kLstmInputToOutputWeightsTensor = 4;
// Recurrent weight tensors of size {n_cell, n_output}
constexpr int kLstmRecurrentToInputWeightsTensor = 5; // Optional
constexpr int kLstmRecurrentToForgetWeightsTensor = 6;
constexpr int kLstmRecurrentToCellWeightsTensor = 7;
constexpr int kLstmRecurrentToOutputWeightsTensor = 8;
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
constexpr int kLstmCellToInputWeightsTensor = 9; // Optional
constexpr int kLstmCellToForgetWeightsTensor = 10; // Optional
constexpr int kLstmCellToOutputWeightsTensor = 11; // Optional
// Gates bias tensors of size {n_cell}
constexpr int kLstmInputGateBiasTensor = 12; // Optional
constexpr int kLstmForgetGateBiasTensor = 13;
constexpr int kLstmCellGateBiasTensor = 14;
constexpr int kLstmOutputGateBiasTensor = 15;
// Projection weight tensor of size {n_output, n_cell}
constexpr int kLstmProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kLstmProjectionBiasTensor = 17; // Optional
// These state tensors are defined as variable tensors, and will be modified by
// this op.
constexpr int kLstmOutputStateTensor = 18;
constexpr int kLstmCellStateTensor = 19;
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
// matrix.
constexpr int kLstmInputLayerNormCoefficientsTensor = 20; // Optional
constexpr int kLstmForgetLayerNormCoefficientsTensor = 21; // Optional
constexpr int kLstmCellLayerNormCoefficientsTensor = 22; // Optional
constexpr int kLstmOutputLayerNormCoefficientsTensor = 23; // Optional
// Output tensors.
constexpr int kLstmOutputTensor = 0;
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_

View File

@@ -115,29 +115,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace maximum_minimum
TfLiteRegistration Register_MAXIMUM() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/
maximum_minimum::Eval<maximum_minimum::kReference,
maximum_minimum::MaximumOp>,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, nullptr,
maximum_minimum::Eval<maximum_minimum::kReference,
maximum_minimum::MaximumOp>);
}
TfLiteRegistration Register_MINIMUM() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/
maximum_minimum::Eval<maximum_minimum::kReference,
maximum_minimum::MinimumOp>,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, nullptr,
maximum_minimum::Eval<maximum_minimum::kReference,
maximum_minimum::MinimumOp>);
}
} // namespace micro

View File

@@ -76,11 +76,14 @@ TfLiteRegistration Register_SHAPE();
TfLiteRegistration Register_SLICE();
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
TfLiteRegistration Register_SPACE_TO_DEPTH();
TfLiteRegistration Register_SQUARED_DIFFERENCE();
TfLiteRegistration Register_SQUEEZE();
TfLiteRegistration Register_SUB();
TfLiteRegistration Register_SVDF();
TfLiteRegistration Register_TRANSPOSE();
TfLiteRegistration Register_TRANSPOSE_CONV();
// TODO(b/230666079): resolve conflict with xtensa implementation
TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration Register_VAR_HANDLE();
TfLiteRegistration Register_WHILE();
TfLiteRegistration Register_ZEROS_LIKE();
@@ -103,14 +106,12 @@ TfLiteRegistration Register_LESS_EQUAL();
TfLiteRegistration Register_LOG();
TfLiteRegistration Register_LOGICAL_NOT();
TfLiteRegistration Register_MAXIMUM();
TfLiteRegistration Register_MEAN();
TfLiteRegistration Register_MINIMUM();
TfLiteRegistration Register_NEG();
TfLiteRegistration Register_NOT_EQUAL();
TfLiteRegistration Register_PACK();
TfLiteRegistration Register_PAD();
TfLiteRegistration Register_PADV2();
TfLiteRegistration Register_REDUCE_MAX();
TfLiteRegistration Register_RESHAPE();
TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR();
TfLiteRegistration Register_ROUND();
@@ -121,7 +122,6 @@ TfLiteRegistration Register_SPLIT_V();
TfLiteRegistration Register_SQRT();
TfLiteRegistration Register_SQUARE();
TfLiteRegistration Register_STRIDED_SLICE();
TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration Register_UNPACK();
TfLiteRegistration Register_L2_NORMALIZATION();
TfLiteRegistration Register_TANH();

View File

@@ -0,0 +1,809 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
#include <utility>
#include "fixedpoint/fixedpoint.h" // from @gemmlowp
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace micro_tensor_utils {
namespace {
const int32_t kInt16Max = std::numeric_limits<int16_t>::max();
const int32_t kInt16Min = std::numeric_limits<int16_t>::min();
} // namespace
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min_value,
float* max_value, float* scaling_factor) {
auto minmax = std::minmax_element(values, values + size);
*min_value = *minmax.first;
*max_value = *minmax.second;
PortableSymmetricQuantizeFloats(values, size, quantized_values, *min_value,
*max_value, scaling_factor);
}
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float min_value,
float max_value, float* scaling_factor) {
const int32_t kScale = 127;
const float range = std::max(std::abs(min_value), std::abs(max_value));
if (range == 0) {
memset(quantized_values, 0, size * sizeof(int8_t));
*scaling_factor = 1;
return;
}
*scaling_factor = range / kScale;
const float scaling_factor_inv = kScale / range;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value =
static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
// Clamp: just in case some odd numeric offset.
quantized_values[i] = static_cast<int8_t>(
std::min(kScale, std::max(-kScale, quantized_value)));
}
}
void PortableAsymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values,
float* scaling_factor, int32_t* offset) {
const int32_t kMinScale = -128;
const int32_t kMaxScale = 127;
const double qmin_double = kMinScale;
const double qmax_double = kMaxScale;
const auto minmax = std::minmax_element(values, values + size);
const double rmin = static_cast<double>(std::min(0.0f, *minmax.first));
const double rmax = static_cast<double>(std::max(0.0f, *minmax.second));
if (rmin == rmax) {
memset(quantized_values, 0, size * sizeof(int8_t));
*scaling_factor = 1;
*offset = 0;
return;
} else {
double scale = (rmax - rmin) / (qmax_double - qmin_double);
const double zero_point_from_min = qmin_double - rmin / scale;
const double zero_point_from_max = qmax_double - rmax / scale;
const double zero_point_from_min_error =
std::abs(qmin_double) + std::abs(rmin / scale);
const double zero_point_from_max_error =
std::abs(qmax_double) + std::abs(rmax / scale);
const double zero_point_double =
zero_point_from_min_error < zero_point_from_max_error
? zero_point_from_min
: zero_point_from_max;
int8_t nudged_zero_point = 0;
if (zero_point_double <= qmin_double) {
nudged_zero_point = kMinScale;
} else if (zero_point_double >= qmax_double) {
nudged_zero_point = kMaxScale;
} else {
nudged_zero_point = static_cast<int8_t>(round(zero_point_double));
}
*scaling_factor = scale;
*offset = nudged_zero_point;
}
const float scaling_factor_inv = 1.0f / *scaling_factor;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value = static_cast<int32_t>(
TfLiteRound(*offset + values[i] * scaling_factor_inv));
quantized_values[i] =
std::min(kMaxScale, std::max(kMinScale, quantized_value));
}
}
void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
int m_rows, int m_cols,
const float* vector,
int n_batch, float* result) {
float* result_in_batch = result;
for (int b = 0; b < n_batch; b++) {
const float* matrix_ptr = matrix;
for (int r = 0; r < m_rows; r++) {
float dot_prod = 0.0f;
const float* vector_in_batch = vector + b * m_cols;
for (int c = 0; c < m_cols; c++) {
dot_prod += *matrix_ptr++ * *vector_in_batch++;
}
*result_in_batch += dot_prod;
++result_in_batch;
}
}
}
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result) {
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
const float batch_scaling_factor = scaling_factors[batch];
// Get the address of the first row.
const int8_t* row_ptr = matrix;
for (int row = 0; row < m_rows; ++row) {
// Initialize the dot product sum for the row to 0.
int32_t dotprod = 0;
// TODO(b/230666277): remove this
#if defined(__GNUC__)
// Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */);
#endif
for (int col = 0; col < m_cols; ++col, ++row_ptr) {
dotprod += (*row_ptr) * (vectors[col]);
} // for col
*result += dotprod * batch_scaling_factor;
++result;
} // for row
} // for batch
}
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) {
if (input_offset == nullptr) {
PortableMatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
return;
}
if (!compute_row_sums || *compute_row_sums) {
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
if (compute_row_sums) {
*compute_row_sums = false;
}
}
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
const float batch_scaling_factor = scaling_factors[batch];
const int32_t batch_offset = input_offset[batch];
const int8_t* row_ptr = matrix;
for (int row = 0; row < m_rows; ++row) {
int32_t dotprod = 0;
float scale = batch_scaling_factor;
if (per_channel_scale) {
scale *= per_channel_scale[row];
}
#if defined(__GNUC__)
// Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */);
#endif
for (int col = 0; col < m_cols; ++col, ++row_ptr) {
dotprod += (*row_ptr) * vectors[col];
} // for col
dotprod -= row_sums[row] * batch_offset;
*result += dotprod * scale;
++result;
} // for row
} // for batch
}
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
const int kBlockSize = 4;
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; batch++) {
const float* matrix_ptr = matrix;
for (int row = 0; row < m_rows; row++) {
float dot_prod = 0.0f;
const float* vector_in_batch = vector + batch * m_cols;
for (int i = segments[row]; i < segments[row + 1]; i++) {
const int block_start_index = indices[i] * kBlockSize;
const float* vector_block_in_batch_ptr =
vector_in_batch + block_start_index;
for (int c = 0; c < kBlockSize; c++) {
dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
}
}
result[batch * m_rows + row] += dot_prod;
}
}
}
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
const int32_t output_shift, const int32_t output_offset,
const int32_t output_activation_min, const int32_t output_activation_max,
int8_t* __restrict__ result) {
const int kBlockSize = 16;
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; ++batch) {
const int8_t* matrix_ptr = matrix;
for (int row = 0; row < m_rows; ++row) {
int32_t dot_prod = 0;
const int8_t* vector_in_batch = vector + batch * m_cols;
for (int i = segments[row]; i < segments[row + 1]; ++i) {
const int block_start_index = indices[i] * kBlockSize;
const int8_t* vector_block_in_batch_ptr =
vector_in_batch + block_start_index;
for (int c = 0; c < kBlockSize; c++) {
dot_prod += *matrix_ptr * *vector_block_in_batch_ptr++;
dot_prod += *matrix_ptr++ * input_offset;
}
}
const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0;
dot_prod = MultiplyByQuantizedMultiplier(dot_prod + bias_value,
output_multiplier, output_shift);
dot_prod += output_offset;
result[batch * m_rows + row] =
static_cast<int8_t>(ActivationFunctionWithMinMax(
dot_prod, output_activation_min, output_activation_max));
}
}
}
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
float* __restrict__ result) {
const int kBlockSize = 16;
TFLITE_DCHECK_EQ( // NOLINT
m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; batch++) {
const float* matrix_ptr = matrix;
const uint8_t* ledger_ptr = ledger;
for (int row = 0; row < m_rows; row++) {
float dot_prod = 0.0f;
int num_nonzero_blocks = *ledger_ptr++;
if (num_nonzero_blocks > 0) {
const float* vector_in_batch = vector + batch * m_cols;
for (int i = 0; i < num_nonzero_blocks; i++) {
const int block_start_index = *ledger_ptr++ * kBlockSize;
const float* vector_block_in_batch_ptr =
vector_in_batch + block_start_index;
for (int c = 0; c < kBlockSize; c++) {
dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
}
}
}
result[batch * m_rows + row] += dot_prod;
}
}
}
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
const float* scaling_factors, int n_batch, float* __restrict__ result) {
static const int kBlockSize = 16;
TFLITE_DCHECK_EQ( // NOLINT
m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
const float batch_scaling_factor = scaling_factors[batch];
const uint8_t* ledger_ptr = ledger;
// Get the address of the first row.
const int8_t* row_ptr = matrix;
for (int row = 0; row < m_rows; ++row) {
// Initialize the dot product sum for the row to 0.
int32_t dotprod = 0;
#if defined(__GNUC__)
// Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */);
#endif
int num_nonzero_blocks = *ledger_ptr++;
for (int i = 0; i < num_nonzero_blocks; i++) {
const int block_start_index = *ledger_ptr++ * kBlockSize;
const int8_t* vector_block_ptr = vectors + block_start_index;
for (int c = 0; c < kBlockSize; c++) {
dotprod += (*row_ptr++) * (*vector_block_ptr++);
} // for block
} // for num_nonzero_blocks
result[batch * m_rows + row] += dotprod * batch_scaling_factor;
} // for row
} // for batch
}
template <typename T>
void PortableMatrixBatchVectorMultiplyAccumulateImpl(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
T* output) {
const int16_t output_max = std::numeric_limits<T>::max();
const int16_t output_min = std::numeric_limits<T>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int row = 0; row < n_output; ++row) {
int32_t acc = bias[row];
for (int col = 0; col < n_input; ++col) {
int8_t input_val = input[batch * n_input + col];
int8_t weights_val = input_to_gate_weights[row * n_input + col];
acc += input_val * weights_val;
}
acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
acc += output_zp;
acc += output[batch * n_output + row];
if (acc > output_max) {
acc = output_max;
}
if (acc < output_min) {
acc = output_min;
}
output[batch * n_output + row] = static_cast<T>(acc);
}
}
}
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int16_t* output, CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulateImpl(
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
n_output, output_zp, output);
}
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int8_t* output, CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulateImpl(
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
n_output, output_zp, output);
}
void PortableMatrixBatchVectorMultiply(const int8_t* input,
int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input,
int32_t n_cell, int8_t* gate_output,
int8_t gate_output_zp) {
const int32_t int8_max = std::numeric_limits<int8_t>::max();
const int32_t int8_min = std::numeric_limits<int8_t>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int row = 0; row < n_cell; ++row) {
int32_t acc = 0;
for (int col = 0; col < n_input; ++col) {
int32_t input_val = input[batch * n_input + col];
int8_t weights_val = input_to_gate_weights[row * n_input + col];
acc += (input_val - input_zeropoint) * weights_val;
}
acc = MultiplyByQuantizedMultiplier(acc, input_to_gate_effective_scale_a,
input_to_gate_effective_scale_b);
acc += gate_output_zp;
if (acc > int8_max) {
acc = int8_max;
}
if (acc < int8_min) {
acc = int8_min;
}
gate_output[batch * n_cell + row] = static_cast<int8_t>(acc);
}
}
}
void PortableMatrixBatchVectorMultiply(
const int16_t* hidden, const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
int32_t n_output, int32_t output_zp, int8_t* proj_output) {
const int16_t int8_max = std::numeric_limits<int8_t>::max();
const int16_t int8_min = std::numeric_limits<int8_t>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int row = 0; row < n_output; ++row) {
int64_t acc = gate_bias[row];
for (int col = 0; col < n_hidden; ++col) {
int16_t input_val = hidden[batch * n_hidden + col];
int8_t weights_val = hidden_to_output_weights[row * n_hidden + col];
int64_t curr = acc;
acc += input_val * weights_val;
if (input_val * weights_val > 0 && acc < curr) {
acc = std::numeric_limits<int32_t>::max();
}
if (input_val * weights_val < 0 && acc > curr) {
acc = std::numeric_limits<int32_t>::min();
}
}
acc = MultiplyByQuantizedMultiplier(acc, proj_effective_scale_a,
proj_effective_scale_b);
acc += output_zp;
if (acc > int8_max) {
acc = int8_max;
}
if (acc < int8_min) {
acc = int8_min;
}
proj_output[batch * n_output + row] = acc;
}
}
}
void PortableApplyLayerNorm(const int16_t* input,
const int16_t* layer_norm_weights,
const int32_t* bias, int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b, int32_t variance_limit,
int n_batch, int n_input, int16_t* output) {
// The square of std::pow(2, 10), which is the extra factor that makes sure
// normalized values has enough resolution.
static const int kTwoToPower20 = 1 << 20;
for (int i = 0; i < n_batch; ++i) {
int64_t sum = 0;
int64_t sum_sq = 0;
for (int j = 0; j < n_input; ++j) {
const int32_t index = i * n_input + j;
int32_t val = static_cast<int32_t>(input[index]);
sum += val;
sum_sq += val * val;
}
int32_t mean =
static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
// TODO(b/173994730): Avoids overflow but only works for POT n_input.
int32_t temp = kTwoToPower20 / n_input;
int64_t variance =
sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
int32_t variance2 = static_cast<int32_t>(variance / kTwoToPower20);
if (variance2 < 1) {
variance2 = variance_limit;
}
int32_t stddev_inverse_a;
int stddev_inverse_b;
GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1,
&stddev_inverse_a, &stddev_inverse_b);
for (int j = 0; j < n_input; ++j) {
const int32_t index = i * n_input + j;
int32_t val = static_cast<int32_t>(input[index]);
int32_t shifted = 1024 * val - mean;
int32_t rescaled = MultiplyByQuantizedMultiplier(
shifted, stddev_inverse_a, stddev_inverse_b);
int64_t val3 = rescaled * layer_norm_weights[j] + bias[j];
int32_t val4 =
static_cast<int32_t>((val3 > 0 ? val3 + 512 : val3 - 512) / 1024);
int32_t val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a,
layer_norm_scale_b + 12);
val5 = std::min(std::max(kInt16Min, val5), kInt16Max);
output[index] = static_cast<int16_t>(val5);
}
}
}
void PortableApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16_t>::max();
const int32_t int16_min = std::numeric_limits<int16_t>::min();
const float layer_norm_scale =
layer_norm_scale_a *
std::pow(2.0, static_cast<double>(layer_norm_scale_b - 31));
const float bias_scale =
static_cast<float>(std::pow(2.0, -10)) * layer_norm_scale;
for (int batch = 0; batch < n_batch; ++batch) {
float sum = 0.0f;
float sum_sq = 0.0f;
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float value = static_cast<float>(input[index]);
sum += value;
sum_sq += value * value;
}
const float mean = sum / n_input;
float stddev_inv = 0.0f;
const float variance = sum_sq / n_input - mean * mean;
if (variance == 0) {
stddev_inv = 1.0f / std::sqrt(1e-8f);
} else {
stddev_inv = 1.0f / std::sqrt(variance);
}
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float normalized_value =
(static_cast<float>(input[index]) - mean) * stddev_inv;
const float weighted_normalized_value =
normalized_value * layer_norm_weights[i] * layer_norm_scale +
bias[i] * bias_scale;
const int32_t quant_output = static_cast<int32_t>(round(
weighted_normalized_value * static_cast<float>(std::pow(2, 12))));
output[index] = std::min(int16_max, std::max(int16_min, quant_output));
}
}
}
void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
int32_t scalar, int32_t n_row,
int32_t n_col, int32_t* output) {
for (int i = 0; i < n_row; ++i) {
int32_t row_sum = 0;
for (int j = 0; j < n_col; ++j) {
row_sum += *matrix++;
}
output[i] += row_sum * scalar;
}
}
void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
for (int batch = 0; batch < n_batch; ++batch) {
for (int c = 0; c < n_input; c++) {
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
const int index = batch * n_input + c;
F3 sigmoid_input = F3::FromRaw(input[index]);
F0 sigmoid_output = gemmlowp::logistic(sigmoid_input);
output[index] = sigmoid_output.raw();
}
}
}
void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16_t>::max();
const int32_t int16_min = std::numeric_limits<int16_t>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float float_input =
input[index] * static_cast<float>(std::pow(2, -12));
const float float_output = 1.0f / (1.0f + std::exp(-float_input));
const int32_t quant_output = static_cast<int32_t>(
float_output * static_cast<float>(std::pow(2, 15)));
const int32_t quant_output_clamped =
std::min(int16_max, std::max(int16_min, quant_output));
output[index] = static_cast<int16_t>(quant_output_clamped);
}
}
}
template <int IntegerBits>
void PortableApplyTanhImpl(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
using FX = gemmlowp::FixedPoint<std::int16_t, IntegerBits>;
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
FX tanh_input = FX::FromRaw(input[index]);
F0 tanh_output = gemmlowp::tanh(tanh_input);
output[index] = tanh_output.raw();
}
}
}
void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
int32_t n_batch, int32_t n_input, int16_t* output) {
if (integer_bits > 6) {
TFLITE_ASSERT_FALSE;
}
#define DISPATCH_TANH(i) \
case i: \
PortableApplyTanhImpl<i>(input, n_batch, n_input, output); \
break;
switch (integer_bits) {
DISPATCH_TANH(0);
DISPATCH_TANH(1);
DISPATCH_TANH(2);
DISPATCH_TANH(3);
DISPATCH_TANH(4);
DISPATCH_TANH(5);
DISPATCH_TANH(6);
default:
return;
}
#undef DISPATCH_TANH
}
void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int32_t integer_bits,
int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16_t>::max();
const int32_t int16_min = std::numeric_limits<int16_t>::min();
const double two = 2.0;
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float float_input =
input[index] * std::pow(two, static_cast<double>(integer_bits));
const float float_output = std::tanh(float_input);
const int32_t quant_output = static_cast<int32_t>(
float_output * static_cast<float>(std::pow(2, 15)));
const int32_t quant_output_clamped =
std::min(int16_max, std::max(int16_min, quant_output));
output[index] = static_cast<int16_t>(quant_output_clamped);
}
}
}
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int shift, int16_t* output) {
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const int16_t a = input_1[index];
const int16_t b = input_2[index];
const int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
output[index] =
static_cast<int16_t>(gemmlowp::RoundingDivideByPOT(value, shift));
}
}
}
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
int32_t multiplier, int32_t shift, int32_t n_batch,
int32_t n_input, int32_t output_zp, int8_t* output) {
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const int16_t a = input_1[index];
const int16_t b = input_2[index];
int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
value -= output_zp;
value = std::min(std::max(static_cast<int32_t>(-128), value),
static_cast<int32_t>(127));
output[index] = static_cast<int8_t>(value);
}
}
}
void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int16_t* output) {
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
int32_t sum = input_1[index] + input_2[index];
const int32_t sum_clamped = std::min(kInt16Max, std::max(kInt16Min, sum));
output[index] = static_cast<int16_t>(sum_clamped);
}
}
}
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size) {
float result = 0.0;
for (int v = 0; v < v_size; v++) {
result += *vector1++ * *vector2++;
}
return result;
}
namespace {
inline int32_t VectorVectorDotProduct(const int16_t* vector1,
const int16_t* vector2, int v_size) {
int32_t result = 0;
for (int v = 0; v < v_size; v++) {
result += *vector1++ * *vector2++;
}
return result;
}
} // namespace
void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1,
const int16_t* vector2,
int v_size, int n_batch,
int32_t* result) {
for (int b = 0; b < n_batch; b++) {
result[b] = VectorVectorDotProduct(vector1, vector2, v_size);
vector1 += v_size;
vector2 += v_size;
}
}
void PortableVectorBatchVectorCwiseProductAccumulate(
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
int32_t multiplier, int shift, int16_t* result) {
for (int b = 0; b < n_batch; b++) {
for (int v = 0; v < v_size; v++) {
int32_t prod = vector[v] * *batch_vector++;
prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
int32_t output = prod + *result;
output = std::max(std::min(static_cast<int32_t>(32767), output),
static_cast<int32_t>(-32768));
*result++ = output;
}
}
}
void PortableSub1Vector(const float* vector, int v_size, float* result) {
for (int v = 0; v < v_size; v++) {
*result++ = 1.0f - *vector++;
}
}
void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result) {
static const int16_t kOne = 32767;
for (int v = 0; v < v_size; v++) {
*result++ = kOne - *vector++;
}
}
void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
const float scale, float* result) {
for (int v = 0; v < v_size; ++v) {
*result++ = scale * *vector++;
}
}
void PortableMeanStddevNormalization(const float* __restrict__ input_vector,
float* __restrict__ output_vector,
int v_size, int n_batch) {
for (int batch = 0; batch < n_batch; ++batch) {
float sum = 0.0f;
for (int i = 0; i < v_size; ++i) {
sum += input_vector[i];
}
const float mean = sum / v_size;
float sum_diff_sq = 0.0f;
for (int i = 0; i < v_size; ++i) {
const float diff = input_vector[i] - mean;
sum_diff_sq += diff * diff;
}
const float variance = sum_diff_sq / v_size;
constexpr float kNormalizationConstant = 1e-8f;
const float stddev_inv =
1.0f / std::sqrt(variance + kNormalizationConstant);
for (int i = 0; i < v_size; ++i) {
output_vector[i] = (input_vector[i] - mean) * stddev_inv;
}
input_vector += v_size;
output_vector += v_size;
}
}
void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b,
int32_t n_batch, int32_t n_cell,
int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16_t>::max();
const int32_t int16_min = std::numeric_limits<int16_t>::min();
for (int i = 0; i < n_batch * n_cell; ++i) {
int32_t x = static_cast<int32_t>(input[i]) - static_cast<int32_t>(input_zp);
int32_t h =
static_cast<int32_t>(recurrent[i]) - static_cast<int32_t>(recurrent_zp);
int32_t x_scaled = MultiplyByQuantizedMultiplier(x, input_effective_scale_a,
input_effective_scale_b);
int32_t h_scaled = MultiplyByQuantizedMultiplier(
h, recurrent_effective_scale_a, recurrent_effective_scale_b);
int32_t y = h_scaled + x_scaled;
if (y > int16_max) {
y = int16_max;
}
if (y < int16_min) {
y = int16_min;
}
output[i] = static_cast<int16_t>(y);
}
}
} // namespace micro_tensor_utils
} // namespace tflite

View File

@@ -0,0 +1,874 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file and the associated .cc file is branched from
// tensorflow/lite/kernels/internal/reference/portable_tensor_utils*
// TFLM needs to create its own because the original files are coupled with
// the tensor_utils module, which we cannot reuse due to its use of the
// Eigen library.
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_
#include <algorithm>
#include <cmath>
#include <cstdint>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
#endif
namespace tflite {
// Not all backends support CpuBackendContext usage, so forward declare to avoid
// pulling in its implementation.
// TODO(b/230666277): consider removing this since micro does not utilize it
class CpuBackendContext;
namespace micro_tensor_utils {
template <typename T>
inline bool PortableIsZeroVector(const T* vector, int v_size) {
for (int i = 0; i < v_size; ++i) {
if (vector[i] != 0) {
return false;
}
}
return true;
}
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min_value,
float* max_value, float* scaling_factor);
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float min_value,
float max_value, float* scaling_factor);
void PortableAsymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values,
float* scaling_factor, int32_t* offset);
// Multiply a matrix by a batch vector, and store results in a batch-size
// vector.
void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
int m_rows, int m_cols,
const float* vector,
int n_batch, float* result);
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result);
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context);
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vector, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
CpuBackendContext* context);
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
float* __restrict__ result);
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
const int32_t output_shift, const int32_t output_offset,
const int32_t output_activation_min, const int32_t output_activation_max,
int8_t* __restrict__ result);
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
const float* scaling_factors, int n_batch, float* __restrict__ result);
// Dot product of two vectors.
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size);
void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1,
const int16_t* vector2,
int v_size, int n_batch,
int32_t* result);
void PortableVectorBatchVectorCwiseProductAccumulate(
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
int32_t multiplier, int shift, int16_t* result);
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int16_t* output, CpuBackendContext* context);
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int8_t* output, CpuBackendContext* context);
void PortableMatrixBatchVectorMultiply(const int8_t* input,
int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input,
int32_t n_cell, int8_t* gate_output,
int8_t gate_output_zp);
void PortableMatrixBatchVectorMultiply(
const int16_t* hidden, const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
int32_t n_output, int32_t output_zp, int8_t* proj_output);
void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
int32_t scalar, int32_t n_row,
int32_t n_col, int32_t* output);
void PortableApplyLayerNorm(const int16_t* input,
const int16_t* layer_norm_weights,
const int32_t* bias, int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b, int32_t variance_limit,
int n_batch, int n_input, int16_t* output);
void PortableApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output);
void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
int32_t n_batch, int32_t n_input, int16_t* output);
void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int32_t integer_bits,
int16_t* output);
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int shift, int16_t* output);
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
int32_t multiplier, int32_t shift, int32_t n_batch,
int32_t n_input, int32_t output_zp, int8_t* output);
void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int16_t* output);
template <typename T>
inline void PortableCwiseClipping(T* vector, const int v_size,
const T& clipping_value) {
for (int i = 0; i < v_size; i++) {
vector[i] = std::max(std::min(clipping_value, vector[i]),
static_cast<T>(-clipping_value));
}
}
// Batch vector initialization with another vector.
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
// Compute "1.0f - elements of vector" (used in CIFG).
void PortableSub1Vector(const float* vector, int v_size, float* result);
void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result);
// Multiply all elements of vector with a scalar.
void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
float* result);
// Reduce-sum on a vector:
// input_vector: pointer to input vector.
// output_vector: pointer to vector.
// output_size: output vector size.
// reduction_size: number of consecutive elements from input vector which are
// added to get one element of output.
template <typename INPUT, typename OUTPUT>
inline void PortableReductionSumVector(const INPUT* input_vector,
OUTPUT* output_vector, int output_size,
int reduction_size) {
for (int o = 0; o < output_size; o++) {
OUTPUT result = 0;
for (int r = 0; r < reduction_size; r++) {
result += input_vector[r];
}
output_vector[o] = result;
input_vector += reduction_size;
}
}
// Layer norm for each batch.
void PortableMeanStddevNormalization(const float* __restrict__ input_vector,
float* __restrict__ output_vector,
int v_size, int n_batch);
// Saturate Add.
void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b,
int32_t n_batch, int32_t n_cell,
int16_t* output);
// Add another vector for each batch in the batch vector.
template <typename T>
inline void VectorBatchVectorAdd(const T* vector, int v_size, int n_batch,
T* batch_vector) {
for (int b = 0; b < n_batch; b++) {
for (int i = 0; i < v_size; ++i) {
batch_vector[i] += vector[i];
}
batch_vector += v_size;
}
}
// Cwise product of two vectors.
template <typename T>
inline void VectorVectorCwiseProduct(const T* vector1, const T* vector2,
int v_size, T* result) {
for (int v = 0; v < v_size; v++) {
*result++ = *vector1++ * *vector2++;
}
}
// Cwise product of a vector and a batch-vector.
template <typename T>
inline void VectorBatchVectorCwiseProduct(const T* vector, int v_size,
const T* batch_vector, int n_batch,
T* result) {
for (int b = 0; b < n_batch; b++) {
VectorVectorCwiseProduct(vector, batch_vector, v_size, result);
// Update the pointers.
result += v_size;
batch_vector += v_size;
}
}
// Reduce-sum on a float input vector:
// input_vector: float pointer to input vector.
// output_vector: float pointer to vector.
// output_size: output vector size.
// reduction_size: number of consecutive elements from input vector which are
// added to get one element of output.
inline void ReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size) {
PortableReductionSumVector(input_vector, output_vector, output_size,
reduction_size);
}
// Same as above but input/output is 32 bit integer.
inline void ReductionSumVector(const int32_t* input_vector,
int32_t* output_vector, int output_size,
int reduction_size) {
PortableReductionSumVector(input_vector, output_vector, output_size,
reduction_size);
}
// Same as above but input is 8 bit integer.
inline void ReductionSumVector(const int8_t* input_vector,
int32_t* output_vector, int output_size,
int reduction_size) {
PortableReductionSumVector(input_vector, output_vector, output_size,
reduction_size);
}
// Cwise product and accumulate of two vectors. Since it's a MAC operation, the
// assumption here is that result array is initialized to valid values.
template <typename T>
inline void VectorVectorCwiseProductAccumulate(const T* __restrict__ vector1,
const T* __restrict__ vector2,
int v_size,
T* __restrict__ result) {
for (int v = 0; v < v_size; v++) {
*result++ += *vector1++ * *vector2++;
}
}
// Batch vector initialization with another vector.
template <typename T>
inline void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch,
T* batch_vector) {
for (int b = 0; b < n_batch; b++) {
std::copy_n(vector, v_size, batch_vector + b * v_size);
}
}
inline void SymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
float* max, float* scaling_factor) {
PortableSymmetricQuantizeFloats(values, size, quantized_values, min, max,
scaling_factor);
}
inline void SymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float min_value,
float max_value, float* scaling_factor) {
PortableSymmetricQuantizeFloats(values, size, quantized_values, min_value,
max_value, scaling_factor);
}
inline void AsymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values,
float* scaling_factor, int32_t* offset) {
PortableAsymmetricQuantizeFloats(values, size, quantized_values,
scaling_factor, offset);
}
// Helper function to quantize floats.
// float_data_ptr input float vectors
// n_batch number of input vectors
// n_data size of a single input vector
// quantized_data_ptr (out) vector with quantized data
// scaling_factors (out) scaling factors (one per vector)
// zero_points (out) zero points (one per vector)
// do_asymmetric controls if the quantization should be asymmetric.
inline void BatchQuantizeFloats(const float* float_data_ptr, int n_batch,
int n_data, int8_t* quantized_data_ptr,
float* scaling_factors, int32_t* zero_points,
bool do_asymmetric) {
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_data;
if (do_asymmetric) {
AsymmetricQuantizeFloats(float_data_ptr + offset, n_data,
quantized_data_ptr + offset, &scaling_factors[b],
&zero_points[b]);
} else {
float unused_min, unused_max;
SymmetricQuantizeFloats(float_data_ptr + offset, n_data,
quantized_data_ptr + offset, &unused_min,
&unused_max, &scaling_factors[b]);
}
}
}
// Check if all entries of a vector are zero for float.
inline bool IsZeroVector(const float* vector, int v_size) {
return PortableIsZeroVector(vector, v_size);
}
// Check if all entries of a vector are zero for int8_t.
inline bool IsZeroVector(const int8_t* vector, int v_size) {
return PortableIsZeroVector(vector, v_size);
}
// Apply Layer Normalization (https://arxiv.org/abs/1607.06450) to a Quantized
// vector.
// Parameters:
// - input: batch vector of size n_batch * n_input; 16 bit.
// - layer_norm_weights: the quantized layer normalization weights.
// - bias: the bias for the layer normalization.
// - layer_norm_scale_a: multiplier for scale factor.
// - layer_norm_scale_b: shift for scale factor.
// - variance_limit: the guard to make sure the inverse does not overflow.
// - n_batch: the number of batches.
// - n_input: the size for input and output.
// - output: the 16 bit output
inline void ApplyLayerNorm(const int16_t* input,
const int16_t* layer_norm_weights,
const int32_t* bias, int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b, int32_t variance_limit,
int n_batch, int n_input, int16_t* output) {
PortableApplyLayerNorm(input, layer_norm_weights, bias, layer_norm_scale_a,
layer_norm_scale_b, variance_limit, n_batch, n_input,
output);
}
// Same as above but the internal calculation is done in float.
inline void ApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b, const int32_t* bias,
int n_batch, int n_input, int16_t* output) {
PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
layer_norm_scale_b, bias, n_batch, n_input,
output);
}
// Apply Sigmoid to a quantized vector.
// Parameters:
// - input: batch vector of size n_batch * n_input; 16 bit.
// - n_batch: the number of batches.
// - n_input: the size for input and output.
// - output: the 16 bit output
// The input is in Q3.12 format and the output is in Q0.15 format.
inline void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
PortableApplySigmoid(input, n_batch, n_input, output);
}
// Same as above but the internal calcualtion is float.
inline void ApplySigmoidFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
PortableApplySigmoidFloat(input, n_batch, n_input, output);
}
// Apply Tanh to a quantized vector.
// Parameters:
// - integer_bits: the integer bits of the input.
// Currently supports 0, 1, 2, 3, 4, 5, 6.
// - input: batch vector of size n_batch * n_input; 16 bit.
// - n_batch: the number of batches.
// - n_input: the size for input and output.
// - output: the 16 bit output
// The input is in Qm.15-m format and the output is in Q0.15 format.
inline void ApplyTanh(int32_t integer_bits, const int16_t* input,
int32_t n_batch, int32_t n_input, int16_t* output) {
PortableApplyTanh(integer_bits, input, n_batch, n_input, output);
}
// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
// - Input has 2^(integer_bits) as scale.
// - Output has Q0.15 as scale.
inline void ApplyTanhFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int32_t integer_bits,
int16_t* output) {
PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
}
// Element-wise multiplication of two quantized vectors.
// Parameters:
// - input_1: batch vector of size n_batch * n_input; 16 bit.
// - input_2: batch vector of size n_batch * n_input; 16 bit.
// - n_batch: the number of batches.
// - n_input: the size for input and output.
// - shift: the shift needed to produce the output.
// - output: the 16 bit output of size n_batch * n_input.
// Output does not need to be initialized.
inline void CwiseMul(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int shift, int16_t* output) {
PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
}
// Element-wise multiplication of two quantized vectors with rescaling.
// Parameters:
// - input_1: batch vector of size n_batch * n_input; 16 bit.
// - input_2: batch vector of size n_batch * n_input; 16 bit.
// - multiplier: the multiplier part of scale.
// - shift: the shift part of scale.
// - n_batch: the number of batches.
// - n_input: the size for input and output.
// - output: the 8 bit output of size n_batch * n_input.
// - output_zp: the zero point of output.
// Output does not need to be initialized.
// Multiplier ("m") and shift ("s") are connected to scale ("s") with s = m *
// 2^(s - 31).
inline void CwiseMul(const int16_t* input_1, const int16_t* input_2,
int32_t multiplier, int32_t shift, int32_t n_batch,
int32_t n_input, int32_t output_zp, int8_t* output) {
PortableCwiseMul(input_1, input_2, multiplier, shift, n_batch, n_input,
output_zp, output);
}
// Element-wise in-place clipping of a vector. Overloaded for float, int16_t,
// int8_t. Parameters:
// - vector: vector of size v_size.
// - v_size: the size of the vector.
// - clipping_value: the value used for clipping.
inline void CwiseClipping(float* vector, const int v_size,
const float clipping_value) {
PortableCwiseClipping(vector, v_size, clipping_value);
}
inline void CwiseClipping(int16_t* vector, const int v_size,
const int16_t clipping_value) {
PortableCwiseClipping(vector, v_size, clipping_value);
}
inline void CwiseClipping(int8_t* vector, const int v_size,
const int8_t clipping_value) {
PortableCwiseClipping(vector, v_size, clipping_value);
}
// Element-wise saturating addition of two quantized vectors without rescaling.
// Parameters:
// - input_1: batch vector of size n_batch * n_input; 16 bit.
// - input_2: batch vector of size n_batch * n_input; 16 bit.
// - n_batch: the number of batches.
// - n_input: the size for input and output.
// - output: the 8 bit output of size n_batch * n_input.
// Output does not need to be initialized.
inline void CwiseAdd(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int16_t* output) {
PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
}
inline void MeanStddevNormalization(const float* input_vector,
float* output_vector, int v_size,
int n_batch) {
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
}
inline void Sub1Vector(const float* vector, int v_size, float* result) {
PortableSub1Vector(vector, v_size, result);
}
inline void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
PortableSub1Vector(vector, v_size, result);
}
// Multiply all elements of vector with a scalar.
inline void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
float* result) {
PortableVectorScalarMultiply(vector, v_size, scale, result);
}
// Saturate Add with rescale on both inputs.
inline void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b,
int32_t n_batch, int32_t n_cell,
int16_t* output) {
PortableTwoGateSaturatingAdd(
input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
input_effective_scale_b, recurrent_effective_scale_a,
recurrent_effective_scale_b, n_batch, n_cell, output);
}
// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
// dimension composed by input vectors independent from each other). The result
// of the multiplication is accumulated to the passed result buffer.
// More specifically, for a matrix M of shape [n, i] and a batched-vector
// of shape [i, batch] it will first compute the product of shape [n, batch].
// This product will be accumulated to the result buffer.
inline void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
int m_cols, const float* vector,
int n_batch, float* result) {
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
n_batch, result);
}
// Same as the function above, but the matrix is a sparse tensor with block
// pattern 1x4.
// This function assumes that m_cols is a multiple of the block size (4 in this
// case) so that there's no incomplete block.
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vector, const float* scaling_factors,
int n_batch, float* __restrict__ result) {
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
scaling_factors, n_batch, result);
}
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
context);
}
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vector, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
scaling_factors, n_batch, result);
}
// Same as the function above, but the matrix is a sparse tensor with block
// pattern 1x4.
// This function assumes that m_cols is a multiple of the block size (4 in this
// case) so that there's no incomplete block.
inline void SparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
matrix, segments, indices, m_rows, m_cols, vector, n_batch, result);
}
// Same as the function above, but the matrix is stored in block compressed
// sparse row format with block pattern 1x16 which consists of two arrays:
// 1. A matrix array stores non-zero blocks of the matrix in row major.
// 2. A ledger array stores nrows groups, one group per row. Each group starts
// with an integer representing the number of non-zero blocks for the
// corresponding row and follows with column indexes of the first element
// of each non-zero block.
// This function assumes that
// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
inline void SparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
float* __restrict__ result) {
PortableSparseMatrixBatchVectorMultiplyAccumulate(
matrix, ledger, m_rows, m_cols, vector, n_batch, result);
}
// Same as the function above, but the matrix is a sparse tensor with block
// pattern 1x16.
// This function assumes that m_cols is a multiple of the block size (16 in this
// case) so that there's no incomplete block. Also, it assumes all offsets of
// input, output and filter are zero.
inline void SparseMatrixBatchVectorMultiplyAccumulate1x16(
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
const int32_t output_shift, const int32_t output_offset,
const int32_t output_activation_min, const int32_t output_activation_max,
int8_t* __restrict__ result) {
PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch,
input_offset, output_multiplier, output_shift, output_offset,
output_activation_min, output_activation_max, result);
}
// Same as the function above, but the matrix is stored in block compressed
// sparse row format with block pattern 1x16 which consists of two arrays:
// 1. A matrix array stores non-zero blocks of the matrix in row major.
// 2. A ledger array stores nrows groups, one group per row. Each group starts
// with an integer representing the number of non-zero blocks for the
// corresponding row followed by column index of the first element of
// each non-zero block.
// This function assumes that
// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
inline void SparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
const float* scaling_factors, int n_batch, float* __restrict__ result) {
PortableSparseMatrixBatchVectorMultiplyAccumulate(
matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch,
result);
}
// Same as the above 8, 8, 8 integer matmul except for the presence of zero
// point and non-accumulative.
// TODO(b/148688698): remove this function by folding zero point calculation in
// prepare() function.
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int16_t* output, CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulate(
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
n_output, output_zp, scratch, output, context);
}
// Same as above but has 16 bit and 8 bit input and 8 bit output.
// Used in projection when hidden is 16bit.
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int8_t* output, CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulate(
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
n_output, output_zp, scratch, output, context);
}
// Same as the function above, but provides separate scaling factor for the
// matrix and the vectors. The scaling factors are multiplied in the
// scaling_factor_scratch buffer.
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float matrix_scaling_factor,
const float* vector_scaling_factors, int n_batch,
float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, float* scaling_factor_scratch,
CpuBackendContext* context) {
for (int b = 0; b < n_batch; ++b) {
scaling_factor_scratch[b] =
vector_scaling_factors[b] * matrix_scaling_factor;
}
MatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factor_scratch, n_batch, result,
per_channel_scale, input_offset, scratch,
row_sums, compute_row_sums, context);
}
// Multiplies a matrix with a scalar and reduce the result on each row to a
// scalar.
// Parameters:
// - matrix: matrix of size n_row * n_col
// - scalar: the scalar that is multiplied to each element in the matrix
// - n_row: the row count of the matrix
// - n_col: the column count of the matrix
// - output: the 32bit output
// Note: We do not need saturation because the int8 * int8 is safe from overflow
// in (2^31-1) / (2^14) = 131072, which is bigger than the n_row. Non-zero
// initial output value is not exceptionally large.
inline void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
int32_t n_row, int32_t n_col,
int32_t* output) {
PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output);
}
// Same as the above 8, 8, 8 integer matmul except for the presence of zero
// point and non-accumulative.
// TODO(b/148688698): remove this function by folding zero point calculation in
// prepare() function.
inline void MatrixBatchVectorMultiply(const int8_t* input,
int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input,
int32_t n_cell, int8_t* gate_output,
int8_t gate_output_zp) {
PortableMatrixBatchVectorMultiply(
input, input_zeropoint, input_to_gate_weights,
input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
n_input, n_cell, gate_output, gate_output_zp);
}
// Same as above but has 16 bit and 8 bit input and 8 bit output.
// Used in projection when hidden is 16bit.
inline void MatrixBatchVectorMultiply(const int16_t* hidden,
const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a,
int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch,
int32_t n_hidden, int32_t n_output,
int32_t output_zp, int8_t* proj_output) {
PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
proj_effective_scale_a,
proj_effective_scale_b, gate_bias, n_batch,
n_hidden, n_output, output_zp, proj_output);
}
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
template <typename T>
inline void VectorBatchVectorCwiseProductAccumulate(const T* vector, int v_size,
const T* batch_vector,
int n_batch, T* result) {
for (int b = 0; b < n_batch; b++) {
VectorVectorCwiseProductAccumulate(vector, batch_vector, v_size, result);
// Update the pointers.
result += v_size;
batch_vector += v_size;
}
}
// Same as above, but inputs are 16bit integer and output is 16bit integer.
inline void VectorBatchVectorCwiseProductAccumulate(
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
int32_t multiplier, int shift, int16_t* result) {
PortableVectorBatchVectorCwiseProductAccumulate(
vector, v_size, batch_vector, n_batch, multiplier, shift, result);
}
// Apply Rectified Linear to elements of a vector.
inline void ApplyReluToVector(const float* vector, int v_size, float* result) {
for (int v = 0; v < v_size; v++) {
result[v] = std::max(0.0f, vector[v]);
}
}
// Apply Rectified Linear 1 (cap to [-1;1]) to elements of a vector
inline void ApplyRelu1ToVector(const float* vector, int v_size, float* result) {
for (int v = 0; v < v_size; v++) {
result[v] = std::max(-1.0f, std::min(vector[v], 1.0f));
}
}
// Apply Rectified Linear 6 (cap to [0;6]) to elements of a vector
inline void ApplyRelu6ToVector(const float* vector, int v_size, float* result) {
for (int v = 0; v < v_size; v++) {
result[v] = std::max(0.0f, std::min(vector[v], 6.0f));
}
}
// Apply tanh to elements of a vector
inline void ApplyTanhToVector(const float* vector, int v_size, float* result) {
for (int v = 0; v < v_size; v++) {
result[v] = std::tanh(vector[v]);
}
}
// Apply signbit to elements of a vector
inline void ApplySignbitToVector(const float* vector, int v_size,
float* result) {
for (int v = 0; v < v_size; v++) {
result[v] = std::signbit(vector[v]);
}
}
// Apply sigmoid to elements of a vector.
inline void ApplySigmoidToVector(const float* vector, int v_size,
float* result) {
for (int v = 0; v < v_size; v++) {
result[v] = 1.0f / (1.0f + std::exp(-vector[v]));
}
}
// Apply appropriate activation function to elements of a vector.
inline void ApplyActivationToVector(const float* vector, int v_size,
TfLiteFusedActivation activation,
float* result) {
switch (activation) {
case kTfLiteActNone:
return;
case kTfLiteActRelu:
return ApplyReluToVector(vector, v_size, result);
case kTfLiteActReluN1To1:
return ApplyRelu1ToVector(vector, v_size, result);
case kTfLiteActRelu6:
return ApplyRelu6ToVector(vector, v_size, result);
case kTfLiteActTanh:
return ApplyTanhToVector(vector, v_size, result);
case kTfLiteActSignBit:
return ApplySignbitToVector(vector, v_size, result);
case kTfLiteActSigmoid:
return ApplySigmoidToVector(vector, v_size, result);
}
}
} // namespace micro_tensor_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_

View File

@@ -209,14 +209,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_MIRROR_PAD() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -61,14 +61,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_MUL() {
return {/*init=*/MulInit,
/*free=*/nullptr,
/*prepare=*/MulPrepare,
/*invoke=*/MulEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
}
} // namespace tflite

View File

@@ -51,14 +51,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace neg
TfLiteRegistration Register_NEG() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/neg::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, neg::Eval);
}
} // namespace micro

View File

@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace pack
TfLiteRegistration Register_PACK() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/pack::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, pack::Eval);
}
} // namespace micro

View File

@@ -223,26 +223,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace pad
TfLiteRegistration Register_PAD() {
return {/*init=*/pad::Init,
/*free=*/nullptr,
/*prepare=*/pad::Prepare,
/*invoke=*/pad::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval);
}
// Also register Pad as PadV2.
TfLiteRegistration Register_PADV2() {
return {/*init=*/pad::Init,
/*free=*/nullptr,
/*prepare=*/pad::Prepare,
/*invoke=*/pad::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval);
}
} // namespace micro

View File

@@ -88,25 +88,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} // namespace
TfLiteRegistration Register_AVERAGE_POOL_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PoolingPrepare,
/*invoke=*/AverageEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
}
TfLiteRegistration Register_MAX_POOL_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PoolingPrepare,
/*invoke=*/MaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
}
} // namespace tflite

View File

@@ -69,14 +69,7 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_PRELU() {
return {/*init=*/PreluInit,
/*free=*/nullptr,
/*prepare=*/PreluPrepare,
/*invoke=*/PreluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(PreluInit, PreluPrepare, PreluEval);
}
} // namespace tflite

View File

@@ -34,14 +34,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} // namespace
TfLiteRegistration Register_QUANTIZE() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PrepareQuantizeReference,
/*invoke=*/EvalQuantizeReference,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PrepareQuantizeReference,
EvalQuantizeReference);
}
} // namespace tflite

View File

@@ -53,15 +53,19 @@ TfLiteStatus PrepareQuantizeReference(TfLiteContext* context,
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt32 ||
input->type == kTfLiteInt16 || input->type == kTfLiteInt8);
TF_LITE_ENSURE(
context, input->type == kTfLiteFloat32 || input->type == kTfLiteInt32 ||
input->type == kTfLiteInt16 || input->type == kTfLiteInt8 ||
input->type == kTfLiteUInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteInt8 ||
output->type == kTfLiteInt16 ||
output->type == kTfLiteInt32);
output->type == kTfLiteInt32 ||
output->type == kTfLiteUInt8);
if ((input->type == kTfLiteInt16 && output->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && output->type == kTfLiteUInt8) ||
(input->type == kTfLiteUInt8 && output->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt16) ||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt32) ||
(input->type == kTfLiteInt16 && output->type == kTfLiteInt16) ||
@@ -109,9 +113,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else if (input->type == kTfLiteInt32) {
@@ -132,9 +136,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int16_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else if (input->type == kTfLiteInt16) {
@@ -162,9 +166,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int32_t>(output));
return kTfLiteOk;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else if (input->type == kTfLiteInt8) {
@@ -179,6 +183,13 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
data->input_zero_point, data->quantization_params.zero_point,
tflite::micro::GetTensorData<int8_t>(output));
break;
case kTfLiteUInt8:
reference_ops::Requantize(
tflite::micro::GetTensorData<int8_t>(input), size,
data->requantize_output_multiplier, data->requantize_output_shift,
data->input_zero_point, data->quantization_params.zero_point,
tflite::micro::GetTensorData<uint8_t>(output));
break;
case kTfLiteInt16:
reference_ops::Requantize(
tflite::micro::GetTensorData<int8_t>(input), size,
@@ -194,15 +205,31 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int32_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else if (input->type == kTfLiteUInt8) {
size_t size = ElementCount(*input->dims);
switch (output->type) {
case kTfLiteInt8:
reference_ops::Requantize(
tflite::micro::GetTensorData<uint8_t>(input), size,
data->requantize_output_multiplier, data->requantize_output_shift,
data->input_zero_point, data->quantization_params.zero_point,
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@@ -81,14 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_READ_VARIABLE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,331 +23,50 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/reduce.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace ops {
namespace micro {
namespace reduce {
constexpr int kMaxNumberOfAxis = 4;
constexpr int kMaxNumberOfReducedAxis = 2;
struct OpData {
int32_t multiplier;
int shift;
int temp_buffer_idx;
int resolved_axis_idx;
int input_zp;
float input_scale;
int output_zp;
float output_scale;
int num_output_elements;
};
void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
// Inputs Tensor (dtype depends on quantization):
// [0] = Input
// [1] = Axis
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
// Outputs Tensor (dtype depends on quantization):
// [0] = Output
// Validate number of inputs and outputs
TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Validate axis type
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
TF_LITE_ENSURE(context, axis != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
if (input->type == kTfLiteInt8) {
OpData* data = static_cast<OpData*>(node->user_data);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
const double real_multiplier = static_cast<double>(input->params.scale) /
static_cast<double>(output->params.scale);
QuantizeMultiplier(real_multiplier, &data->multiplier, &data->shift);
micro_context->DeallocateTempTfLiteTensor(output);
}
micro_context->DeallocateTempTfLiteTensor(axis);
micro_context->DeallocateTempTfLiteTensor(input);
return kTfLiteOk;
return context->AllocatePersistentBuffer(context, sizeof(OpDataReduce));
}
TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
MicroContext* micro_context = GetMicroContext(context);
OpData* op_data = static_cast<OpData*>(node->user_data);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
op_data->input_scale = input->params.scale;
op_data->output_scale = output->params.scale;
op_data->num_output_elements = NumElements(output);
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
&op_data->temp_buffer_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
&op_data->resolved_axis_idx);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(axis);
return kTfLiteOk;
return PrepareMaxHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
const double real_multiplier = static_cast<double>(input->params.scale) /
static_cast<double>(output->params.scale);
QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
}
int output_size = NumElements(output);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
&op_data->temp_buffer_idx);
op_data->input_zp = input->params.zero_point;
op_data->input_scale = input->params.scale;
op_data->output_zp = output->params.zero_point;
op_data->output_scale = output->params.scale;
}
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
void ResolveAxis(const int* axis_data, int axis_count,
tflite::MeanParams* op_params) {
int i = 0;
for (; i < axis_count; ++i) {
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
}
for (; i < 4; ++i) {
op_params->axis[i] = 1;
}
op_params->axis_count = axis_count;
return PrepareMeanOrSumHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TfLiteReducerParams* params =
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];
tflite::MeanParams op_params;
ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
// Special case mean implementation exists for 4D mean across axes 1 and 2.
bool special_case_4d_axes_1_and_2 =
input->dims->size == 4 && op_params.axis_count == 2 &&
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
switch (input->type) {
case kTfLiteFloat32: {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
TF_LITE_ENSURE(
context,
reference_ops::Mean(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis,
tflite::micro::GetTensorData<float>(output)));
}
} break;
case kTfLiteInt8: {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_integer_ops::Mean(
op_params, op_data->multiplier, op_data->shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
} else if (op_data->input_zp == op_data->output_zp &&
op_data->input_scale == op_data->output_scale) {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::Mean(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, temp_buffer));
} else {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
op_data->input_scale, input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int8_t>(output),
op_data->output_zp, op_data->output_scale, output->dims->data,
output->dims->size, tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index, resolved_axis,
temp_buffer, false));
}
} break;
case kTfLiteInt16: {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_integer_ops::Mean(
op_params, op_data->multiplier, op_data->shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
} else if (op_data->input_zp == op_data->output_zp &&
op_data->input_scale == op_data->output_scale) {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int16_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index,
resolved_axis, temp_buffer));
} else {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
op_data->input_scale, input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int16_t>(output),
op_data->output_zp, op_data->output_scale, output->dims->data,
output->dims->size, tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index, resolved_axis,
temp_buffer, false));
}
} break;
default:
TF_LITE_ENSURE_MSG(context, false,
"Currently, only float32, int8 or uint8 input type "
"is supported.");
}
return kTfLiteOk;
return EvalMeanHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}
TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);
OpData* op_data = static_cast<OpData*>(node->user_data);
// Interpret an axis tensor with null dimensions as a scalar
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<float>::lowest(),
[](const float current, const float in) -> float {
return (in > current) ? in : current;
}));
break;
case kTfLiteInt8:
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
static_cast<double>(op_data->output_scale));
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<int8_t>::lowest(),
[](const int8_t current, const int8_t in) -> int8_t {
return (in > current) ? in : current;
}));
break;
default:
TF_LITE_KERNEL_LOG(context,
"Only float32 and int8 types are supported.\n");
return kTfLiteError;
}
return kTfLiteOk;
OpDataReduce* op_data = static_cast<OpDataReduce*>(node->user_data);
return EvalMaxHelper(context, node, op_data);
}
} // namespace reduce
TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
return EvalSumHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}
TfLiteRegistration Register_MEAN() {
return {/*init=*/reduce::InitReduce,
/*free=*/nullptr,
/*prepare=*/reduce::PrepareMeanOrSum,
/*invoke=*/reduce::EvalMean,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalMean);
}
TfLiteRegistration Register_REDUCE_MAX() {
return {/*init=*/reduce::InitReduce,
/*free=*/nullptr,
/*prepare=*/reduce::PrepareMax,
/*invoke=*/reduce::EvalMax,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(InitReduce, PrepareMax, EvalMax);
}
TfLiteRegistration Register_SUM() {
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalSum);
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@@ -0,0 +1,64 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_
#include <cstdint>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
extern const int kMaxNumberOfAxis;
extern const int kMaxNumberOfReducedAxis;
struct OpDataReduce {
int32_t multiplier;
int shift;
int temp_buffer_idx;
int resolved_axis_idx;
int input_zp;
float input_scale;
int output_zp;
float output_scale;
int num_output_elements;
};
TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
void ReduceResolveAxis(const int* axis_data, int axis_count,
MeanParams* op_params);
TfLiteRegistration Register_MEAN();
TfLiteRegistration Register_REDUCE_MAX();
TfLiteRegistration Register_SUM();
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_

View File

@@ -0,0 +1,374 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
#include "tensorflow/lite/kernels/internal/reference/reduce.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/reduce.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
const int kMaxNumberOfAxis = 4;
const int kMaxNumberOfReducedAxis = 2;
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
int32_t* multiplier, int* shift) {
MicroContext* micro_context = GetMicroContext(context);
// Inputs Tensor (dtype depends on quantization):
// [0] = Input
// [1] = Axis
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
// Outputs Tensor (dtype depends on quantization):
// [0] = Output
// Validate number of inputs and outputs
TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Validate axis type
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
TF_LITE_ENSURE(context, axis != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
if (input->type == kTfLiteInt8) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
const double real_multiplier = static_cast<double>(input->params.scale) /
static_cast<double>(output->params.scale);
QuantizeMultiplier(real_multiplier, multiplier, shift);
micro_context->DeallocateTempTfLiteTensor(output);
}
micro_context->DeallocateTempTfLiteTensor(axis);
micro_context->DeallocateTempTfLiteTensor(input);
return kTfLiteOk;
}
TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node, &op_data->multiplier,
&op_data->shift));
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
op_data->input_scale = input->params.scale;
op_data->output_scale = output->params.scale;
op_data->num_output_elements = NumElements(output);
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
&op_data->temp_buffer_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
&op_data->resolved_axis_idx);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(axis);
return kTfLiteOk;
}
TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
const double real_multiplier = static_cast<double>(input->params.scale) /
static_cast<double>(output->params.scale);
QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
}
int output_size = NumElements(output);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
&op_data->temp_buffer_idx);
op_data->input_zp = input->params.zero_point;
op_data->input_scale = input->params.scale;
op_data->output_zp = output->params.zero_point;
op_data->output_scale = output->params.scale;
}
TF_LITE_ENSURE_OK(
context,
PrepareSimple(context, node, &(op_data->multiplier), &(op_data->shift)));
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
void ResolveAxis(const int* axis_data, int axis_count,
tflite::MeanParams* op_params) {
int i = 0;
for (; i < axis_count; ++i) {
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
}
for (; i < 4; ++i) {
op_params->axis[i] = 1;
}
op_params->axis_count = axis_count;
}
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TfLiteReducerParams* params =
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];
tflite::MeanParams op_params;
ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
// Special case mean implementation exists for 4D mean across axes 1 and 2.
bool special_case_4d_axes_1_and_2 =
input->dims->size == 4 && op_params.axis_count == 2 &&
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
switch (input->type) {
case kTfLiteFloat32: {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
TF_LITE_ENSURE(
context,
reference_ops::Mean(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis,
tflite::micro::GetTensorData<float>(output)));
}
} break;
case kTfLiteInt8: {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_integer_ops::Mean(
op_params, op_data->multiplier, op_data->shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
} else if (op_data->input_zp == op_data->output_zp &&
op_data->input_scale == op_data->output_scale) {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::Mean(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, temp_buffer));
} else {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
op_data->input_scale, input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int8_t>(output),
op_data->output_zp, op_data->output_scale, output->dims->data,
output->dims->size, tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index, resolved_axis,
temp_buffer, false));
}
} break;
case kTfLiteInt16: {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_integer_ops::Mean(
op_params, op_data->multiplier, op_data->shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
} else if (op_data->input_zp == op_data->output_zp &&
op_data->input_scale == op_data->output_scale) {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int16_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index,
resolved_axis, temp_buffer));
} else {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
op_data->input_scale, input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int16_t>(output),
op_data->output_zp, op_data->output_scale, output->dims->data,
output->dims->size, tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index, resolved_axis,
temp_buffer, false));
}
} break;
default:
TF_LITE_ENSURE_MSG(context, false,
"Currently, only float32, int8 or int16 input type "
"is supported.");
}
return kTfLiteOk;
}
TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);
// Interpret an axis tensor with null dimensions as a scalar
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<float>::lowest(),
[](const float current, const float in) -> float {
return (in > current) ? in : current;
}));
break;
case kTfLiteInt8:
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
static_cast<double>(op_data->output_scale));
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<int8_t>::lowest(),
[](const int8_t current, const int8_t in) -> int8_t {
return (in > current) ? in : current;
}));
break;
default:
MicroPrintf("Only float32 and int8 types are supported.");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);
// Interpret an axis tensor with null dimensions as a scalar.
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];
switch (input->type) {
case kTfLiteFloat32: {
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, /*init_value=*/0.f,
[](const float current, const float in) -> float {
return in + current;
}));
} break;
case kTfLiteInt8: {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
op_data->input_scale, input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp,
op_data->output_scale, output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, temp_buffer,
/*compute_sum=*/true));
} break;
case kTfLiteInt16: {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
op_data->input_scale, input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp,
op_data->output_scale, output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, temp_buffer,
/*compute_sum=*/true));
} break;
default:
MicroPrintf("Only float32, int8, and int16 types are supported.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace tflite

View File

@@ -110,14 +110,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace reshape
TfLiteRegistration Register_RESHAPE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/reshape::Prepare,
/*invoke=*/reshape::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, reshape::Prepare, reshape::Eval);
}
} // namespace micro

View File

@@ -111,14 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_RESIZE_BILINEAR() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -117,14 +117,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace resize_nearest_neighbor
TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/resize_nearest_neighbor::Prepare,
/*invoke=*/resize_nearest_neighbor::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, resize_nearest_neighbor::Prepare,
resize_nearest_neighbor::Eval);
}
} // namespace micro

View File

@@ -68,14 +68,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace round
TfLiteRegistration Register_ROUND() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/round::Prepare,
/*invoke=*/round::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, round::Prepare, round::Eval);
}
} // namespace micro

View File

@@ -60,14 +60,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_SHAPE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -151,14 +151,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_SLICE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -83,14 +83,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_SOFTMAX() {
return {/*init=*/SoftmaxInit,
/*free=*/nullptr,
/*prepare=*/SoftmaxPrepare,
/*invoke=*/SoftmaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(SoftmaxInit, SoftmaxPrepare, SoftmaxEval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,6 +23,13 @@ namespace tflite {
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length);
// Common helper function to SoftmaxPrepare.
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
const TfLiteSoftmaxParams* params,
SoftmaxParams* op_data);
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node);
// This is the most generic TfLiteRegistration. The actual supported types may
@@ -30,7 +37,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node);
// (reference or optimized) must define this function.
TfLiteRegistration Register_SOFTMAX();
#if defined(XTENSA)
#if defined(XTENSA) || defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int8 input and int16 output.
TfLiteRegistration Register_SOFTMAX_INT8_INT16();
@@ -40,6 +47,23 @@ inline TfLiteRegistration Register_SOFTMAX_INT8_INT16() {
}
#endif
#if defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int8 input/output and uses the latency optimized implementations.
TfLiteRegistration Register_SOFTMAX_INT8();
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int16 input/output and uses the latency optimized implementations.
TfLiteRegistration Register_SOFTMAX_INT16();
#else
inline TfLiteRegistration Register_SOFTMAX_INT8() { return Register_SOFTMAX(); }
inline TfLiteRegistration Register_SOFTMAX_INT16() {
return Register_SOFTMAX();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SOFTMAX_H_

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -28,11 +28,59 @@ namespace {
// Softmax parameter data that persists in user_data
const int kInt16LUTArraySize = 513;
TfLiteStatus InitializeLutForInt16(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
SoftmaxParams* op_data) {
// Only allocate LUTs for KTfLiteInt16 data type
if (input->type == kTfLiteInt16) {
void* raw_exp_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
op_data->one_over_one_plus_x_lut =
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
}
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
}
// Populate LUT if required
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// exp LUT only used on negative values
// we consider exp(-10.0) is insignificant to accumulation
gen_lut<float, int16_t, int16_t>(
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
op_data->exp_lut);
gen_lut<float, int16_t, int16_t>(
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
1.0f, op_data->one_over_one_plus_x_lut);
op_data->zero_point = output->params.zero_point;
op_data->scale = output->params.scale;
}
return kTfLiteOk;
}
} // namespace
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
const TfLiteSoftmaxParams* params,
SoftmaxParams* op_data) {
if (InitializeLutForInt16(context, input, output, op_data) != kTfLiteOk) {
return kTfLiteError;
}
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -83,8 +131,6 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
return kTfLiteOk;
}
} // namespace
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
@@ -103,40 +149,6 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->user_data != nullptr);
SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
// Only allocate LUTs for KTfLiteInt16 data type
if (input->type == kTfLiteInt16) {
void* raw_exp_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
op_data->one_over_one_plus_x_lut =
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
}
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
}
// Populate LUT if required
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// exp LUT only used on negative values
// we consider exp(-10.0) is insignificant to accumulation
gen_lut<float, int16_t, int16_t>(
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
op_data->exp_lut);
gen_lut<float, int16_t, int16_t>(
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
1.0f, op_data->one_over_one_plus_x_lut);
op_data->zero_point = output->params.zero_point;
op_data->scale = output->params.scale;
}
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
auto ret_val =

View File

@@ -114,14 +114,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_SPACE_TO_BATCH_ND() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_SPACE_TO_DEPTH() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -120,14 +120,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace split
TfLiteRegistration Register_SPLIT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/split::Prepare,
/*invoke=*/split::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, split::Prepare, split::Eval);
}
} // namespace micro

View File

@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace split_v
TfLiteRegistration Register_SPLIT_V() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/split_v::Prepare,
/*invoke=*/split_v::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, split_v::Prepare, split_v::Eval);
}
} // namespace micro

View File

@@ -0,0 +1,247 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
ArithmeticParams arithmetic_params;
};
template <typename T>
T SquaredDifference(T input1, T input2) {
const T difference = input1 - input2;
return difference * difference;
}
void* SquaredDifferenceInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus SquaredDifferencePrepare(TfLiteContext* context,
TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
data->requires_broadcast = false;
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input1 =
micro_context->AllocateTempInputTensor(node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* input2 =
micro_context->AllocateTempInputTensor(node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
// Ensure the quantization parameters are equivalent.
if (input1->type == kTfLiteInt8) {
const auto& input1_quantization_params = input1->params;
const auto& input2_quantization_params = input2->params;
const auto& output_quantization_params = output->params;
const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
TF_LITE_ENSURE(context,
input1_quantization_params.zero_point >= integer_type_min);
TF_LITE_ENSURE(context,
input1_quantization_params.zero_point <= integer_type_max);
TF_LITE_ENSURE(context,
input2_quantization_params.zero_point >= integer_type_min);
TF_LITE_ENSURE(context,
input2_quantization_params.zero_point <= integer_type_max);
TF_LITE_ENSURE(context,
output_quantization_params.zero_point >= integer_type_min);
TF_LITE_ENSURE(context,
output_quantization_params.zero_point <= integer_type_max);
data->arithmetic_params.input1_offset =
-input1_quantization_params.zero_point;
data->arithmetic_params.input2_offset =
-input2_quantization_params.zero_point;
data->arithmetic_params.output_offset =
output_quantization_params.zero_point;
// shift to make integer for scales.
// 7 is selected so that maximum shifted result 255^2 * (1 << (7 * 2 ))
// does not overflow signed 32-bit integer
data->arithmetic_params.left_shift = 7;
const double twice_max_input_scale =
2.0 * static_cast<double>(std::max(input1_quantization_params.scale,
input2_quantization_params.scale));
const double real_input1_multiplier =
static_cast<double>(input1_quantization_params.scale) /
twice_max_input_scale;
double real_input2_multiplier =
static_cast<double>(input2_quantization_params.scale) /
twice_max_input_scale;
const double real_output_multiplier =
(twice_max_input_scale * twice_max_input_scale) /
static_cast<double>((1 << data->arithmetic_params.left_shift * 2) *
output_quantization_params.scale);
QuantizeMultiplierSmallerThanOneExp(
real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
&data->arithmetic_params.input1_shift);
QuantizeMultiplierSmallerThanOneExp(
real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
&data->arithmetic_params.input2_shift);
QuantizeMultiplierSmallerThanOneExp(
real_output_multiplier, &data->arithmetic_params.output_multiplier,
&data->arithmetic_params.output_shift);
data->arithmetic_params.quantized_activation_min =
std::numeric_limits<int8_t>::min();
data->arithmetic_params.quantized_activation_max =
std::numeric_limits<int8_t>::max();
}
data->requires_broadcast = !HaveSameShapes(input1, input2);
micro_context->DeallocateTempTfLiteTensor(input1);
micro_context->DeallocateTempTfLiteTensor(input2);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
inline int8_t SquaredDifference(int8_t x, int8_t y,
const ArithmeticParams& params) {
const int32_t input1_val = params.input1_offset + x;
const int32_t input2_val = params.input2_offset + y;
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
const int32_t scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32_t scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
// Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
const int32_t squared_raw_diff = raw_diff * raw_diff;
const int32_t raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
squared_raw_diff, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
return static_cast<int8_t>(clamped_output);
}
template <typename T>
void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
const OpData* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
const auto* op_data = static_cast<const OpData*>(node->user_data);
if (data->requires_broadcast) {
reference_integer_ops::BroadcastBinaryFunction4DSlow(
op_data->arithmetic_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<T>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<T>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<T>(output),
reference_integer_ops::CheckArithmeticParams, SquaredDifference);
} else {
const int flat_size = tflite::micro::GetTensorShape(input1).FlatSize();
reference_integer_ops::ElementWise(
flat_size, op_data->arithmetic_params,
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorData<int8_t>(output),
reference_integer_ops::CheckArithmeticParams, SquaredDifference);
}
}
template <typename T>
void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
const OpData* data, const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<T>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<T>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<T>(output), SquaredDifference<T>);
} else {
reference_ops::BinaryFunction<T, T, T>(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<T>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<T>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<T>(output), SquaredDifference<T>);
}
}
TfLiteStatus SquaredDifferenceEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
EvalSquaredDifference<float>(context, node, data, input1, input2, output);
} else if (output->type == kTfLiteInt32) {
EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
} else if (output->type == kTfLiteInt8) {
EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
output);
} else {
MicroPrintf(
"SquaredDifference only supports FLOAT32, INT32 and INT8 now, got %d.",
output->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_SQUARED_DIFFERENCE() {
return tflite::micro::RegisterOp(
SquaredDifferenceInit, SquaredDifferencePrepare, SquaredDifferenceEval);
}
} // namespace tflite

View File

@@ -111,14 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_SQUEEZE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -193,14 +193,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace strided_slice
TfLiteRegistration Register_STRIDED_SLICE() {
return {/*init=*/strided_slice::Init,
/*free=*/nullptr,
/*prepare=*/strided_slice::Prepare,
/*invoke=*/strided_slice::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(strided_slice::Init, strided_slice::Prepare,
strided_slice::Eval);
}
} // namespace micro

View File

@@ -162,14 +162,7 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_SUB() {
return {/*init=*/SubInit,
/*free=*/nullptr,
/*prepare=*/SubPrepare,
/*invoke=*/SubEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(SubInit, SubPrepare, SubEval);
}
} // namespace tflite

View File

@@ -100,14 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_SVDF() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PrepareSvdf,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PrepareSvdf, Eval);
}
} // namespace tflite

View File

@@ -195,14 +195,8 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace activations
TfLiteRegistration Register_TANH() {
return {/*init=*/activations::TanhInit,
/*free=*/nullptr,
/*prepare=*/activations::TanhPrepare,
/*invoke=*/activations::TanhEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
activations::TanhInit, activations::TanhPrepare, activations::TanhEval);
}
} // namespace micro
} // namespace ops

View File

@@ -116,13 +116,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_TRANSPOSE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -266,7 +266,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
@@ -282,7 +282,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
@@ -293,7 +293,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
context->GetScratchBuffer(context, data.scratch_buffer_index));
// TODO(b/192090531): Remove this once all 8x16 transpose conv models use
// 64-bit biases.
if (bias->type == kTfLiteInt16) {
if (bias != nullptr && bias->type == kTfLiteInt16) {
std::int64_t* bias_converted_buffer =
static_cast<int64_t*>(context->GetScratchBuffer(
context, data.bias_converted_buffer_index));
@@ -319,7 +319,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<std::int64_t>(bias),
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
@@ -337,14 +337,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_TRANSPOSE_CONV() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -0,0 +1,244 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace testing {
// TODO(b/230666079) enable below tests for xtensa when the xtensa
// kernel is reconciled with reference kernel
#if !defined(XTENSA)
typedef struct LstmIntegerTestConfig {
const int n_batch;
const int n_input;
const int n_cell;
const int n_output;
const int sequence_length;
const bool time_major;
const bool use_cifg;
const bool use_peephole;
const bool use_projection_weights;
const bool use_projection_bias;
const bool use_layer_norm;
const bool use_8x8_8_implementation;
float intermediate_scale[5][2];
int intermediate_zp[5][2];
TfLiteAffineQuantization* intermediate_qparam;
const float* input;
int8_t* input_quant;
const float* input_to_input_weights;
int8_t* lstm_i2i_quant;
const float* input_to_forget_weights;
int8_t* lstm_i2f_quant;
const float* input_to_cell_weights;
int8_t* lstm_i2c_quant;
const float* input_to_output_weights;
int8_t* lstm_i2o_quant;
const float* recurrent_to_input_weights;
int8_t* lstm_r2i_quant;
const float* recurrent_to_forget_weights;
int8_t* lstm_r2f_quant;
const float* recurrent_to_cell_weights;
int8_t* lstm_r2c_quant;
const float* recurrent_to_output_weights;
int8_t* lstm_r2o_quant;
const float* cell_to_input_weights;
int16_t* lstm_c2i_quant;
const float* cell_to_forget_weights;
int16_t* lstm_c2f_quant;
const float* cell_to_output_weights;
int16_t* lstm_c2o_quant;
const float* input_gate_bias;
int32_t* lstm_igate_bias_quant;
const float* forget_gate_bias;
int32_t* lstm_fgate_bias_quant;
const float* cell_gate_bias;
int32_t* lstm_cgate_bias_quant;
const float* output_gate_bias;
int32_t* lstm_ogate_bias_quant;
const float* projection_weights;
int8_t* lstm_proj_w_quant;
const float* projection_bias;
int32_t* projection_bias_quant;
int16_t* output_state;
int16_t* cell_state;
const float* input_layer_norm_coefficients;
int16_t* lstm_input_layer_norm_coeff_quant;
const float* forget_layer_norm_coefficients;
int16_t* lstm_forget_layer_norm_coeff_quant;
const float* cell_layer_norm_coefficients;
int16_t* lstm_cell_layer_norm_coeff_quant;
const float* output_layer_norm_coefficients;
int16_t* lstm_output_layer_norm_coeff_quant;
int8_t* output;
const int8_t* expected_output;
bool asymmetric_quantize_inputs;
const float ranges[25][2];
} LstmIntegerTestConfig;
typedef struct LstmFloatTestConfig {
const int n_batch;
const int n_input;
const int n_cell;
const int n_output;
const int sequence_length;
const bool time_major;
const bool use_cifg;
const bool use_peephole;
const bool use_projection_weights;
const bool use_projection_bias;
const bool use_layer_norm;
const float cell_clip;
const float proj_clip;
const float* input_original;
float* input;
const float* input_to_input_weights;
const float* input_to_forget_weights;
const float* input_to_cell_weights;
const float* input_to_output_weights;
const float* recurrent_to_input_weights;
const float* recurrent_to_forget_weights;
const float* recurrent_to_cell_weights;
const float* recurrent_to_output_weights;
const float* cell_to_input_weights;
const float* cell_to_forget_weights;
const float* cell_to_output_weights;
const float* input_gate_bias;
const float* forget_gate_bias;
const float* cell_gate_bias;
const float* output_gate_bias;
const float* projection_weights;
const float* projection_bias;
float* output_state;
float* cell_state;
const float* input_layer_norm_coefficients;
const float* forget_layer_norm_coefficients;
const float* cell_layer_norm_coefficients;
const float* output_layer_norm_coefficients;
float* output;
const float* expected_output_original;
float* expected_output;
} LstmFloatTestConfig;
typedef struct LstmWeightQuantizationBuffers {
int8_t* lstm_i2i_quant;
float* lstm_i2i_scale;
int* lstm_i2i_zp;
TfLiteAffineQuantization* lstm_i2i_qparam;
int8_t* lstm_i2f_quant;
float* lstm_i2f_scale;
int* lstm_i2f_zp;
TfLiteAffineQuantization* lstm_i2f_qparam;
int8_t* lstm_i2c_quant;
float* lstm_i2c_scale;
int* lstm_i2c_zp;
TfLiteAffineQuantization* lstm_i2c_qparam;
int8_t* lstm_i2o_quant;
float* lstm_i2o_scale;
int* lstm_i2o_zp;
TfLiteAffineQuantization* lstm_i2o_qparam;
int8_t* lstm_r2i_quant;
float* lstm_r2i_scale;
int* lstm_r2i_zp;
TfLiteAffineQuantization* lstm_r2i_qparam;
int8_t* lstm_r2f_quant;
float* lstm_r2f_scale;
int* lstm_r2f_zp;
TfLiteAffineQuantization* lstm_r2f_qparam;
int8_t* lstm_r2c_quant;
float* lstm_r2c_scale;
int* lstm_r2c_zp;
TfLiteAffineQuantization* lstm_r2c_qparam;
int8_t* lstm_r2o_quant;
float* lstm_r2o_scale;
int* lstm_r2o_zp;
TfLiteAffineQuantization* lstm_r2o_qparam;
int8_t* lstm_c2i_quant;
float* lstm_c2i_scale;
int* lstm_c2i_zp;
TfLiteAffineQuantization* lstm_c2i_qparam;
int8_t* lstm_c2f_quant;
float* lstm_c2f_scale;
int* lstm_c2f_zp;
TfLiteAffineQuantization* lstm_c2f_qparam;
int8_t* lstm_c2o_quant;
float* lstm_c2o_scale;
int* lstm_c2o_zp;
TfLiteAffineQuantization* lstm_c2o_qparam;
int8_t* lstm_proj_w_quant;
float* lstm_proj_w_scale;
int* lstm_proj_w_zp;
TfLiteAffineQuantization* lstm_proj_w_qparam;
} LstmWeightQuantizationBuffers;
extern LstmIntegerTestConfig lstm_integer_no_peephole_config;
extern LstmIntegerTestConfig lstm_integer_peephole_config;
extern LstmFloatTestConfig lstm_no_cifg_no_peephole_no_proj_config;
extern LstmFloatTestConfig lstm_cifg_peephole_no_proj_config;
extern LstmFloatTestConfig lstm_no_cifg_peephole_proj_config;
extern LstmFloatTestConfig lstm_no_cifg_peephole_proj_bias_config;
extern LstmWeightQuantizationBuffers lstm_no_cifg_no_peephole_no_proj_buffers;
extern LstmWeightQuantizationBuffers lstm_cifg_peephole_no_proj_buffers;
extern LstmWeightQuantizationBuffers lstm_no_cifg_peephole_proj_buffers;
extern LstmFloatTestConfig cifg_peephole_no_proj_config_layer_norm;
#endif // !defined(XTENSA)
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_

View File

@@ -103,14 +103,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace unpack
TfLiteRegistration Register_UNPACK() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/unpack::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, unpack::Eval);
}
} // namespace micro

Some files were not shown because too many files have changed in this diff Show More