mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-10 13:36:54 +03:00
Rolling 20220716_2
This commit is contained in:
@@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Only float32 is supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Only float32 is supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_RELU() {
|
||||
return {/*init=*/ReluInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/ReluPrepare,
|
||||
/*invoke=*/ReluEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_RELU6() {
|
||||
return {/*init=*/Relu6Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Relu6Prepare,
|
||||
/*invoke=*/Relu6Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_ADD() {
|
||||
return {/*init=*/AddInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/AddPrepare,
|
||||
/*invoke=*/AddEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_ADD_N() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace arg_min_max
|
||||
|
||||
TfLiteRegistration Register_ARG_MAX() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/arg_min_max::ArgMaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_ARG_MIN() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/arg_min_max::ArgMinEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_ASSIGN_VARIABLE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_BROADCAST_ARGS() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/BroadcastArgsPrepare,
|
||||
/*invoke=*/BroadcastArgsEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
|
||||
BroadcastArgsEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
|
||||
@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_BROADCAST_TO() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/BroadcastToPrepare,
|
||||
/*invoke=*/BroadcastToEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
|
||||
BroadcastToEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
|
||||
@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_CALL_ONCE() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_CAST() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace ceil
|
||||
|
||||
TfLiteRegistration Register_CEIL() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/ceil::Prepare,
|
||||
/*invoke=*/ceil::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -108,14 +108,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
||||
static TfLiteRegistration r = {/*init=*/CircularBufferInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/CircularBufferPrepare,
|
||||
/*invoke=*/CircularBufferEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
static TfLiteRegistration r = tflite::micro::RegisterOp(CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
||||
@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace comparisons
|
||||
|
||||
TfLiteRegistration Register_EQUAL() {
|
||||
return {/*init=*/comparisons::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/comparisons::Prepare,
|
||||
/*invoke=*/comparisons::EqualEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||
comparisons::EqualEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_NOT_EQUAL() {
|
||||
return {/*init=*/comparisons::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/comparisons::Prepare,
|
||||
/*invoke=*/comparisons::NotEqualEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||
comparisons::NotEqualEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_GREATER() {
|
||||
return {/*init=*/comparisons::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/comparisons::Prepare,
|
||||
/*invoke=*/comparisons::GreaterEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||
comparisons::GreaterEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_GREATER_EQUAL() {
|
||||
return {/*init=*/comparisons::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/comparisons::Prepare,
|
||||
/*invoke=*/comparisons::GreaterEqualEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||
comparisons::GreaterEqualEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_LESS() {
|
||||
return {/*init=*/comparisons::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/comparisons::Prepare,
|
||||
/*invoke=*/comparisons::LessEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||
comparisons::LessEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_LESS_EQUAL() {
|
||||
return {/*init=*/comparisons::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/comparisons::Prepare,
|
||||
/*invoke=*/comparisons::LessEqualEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||
comparisons::LessEqualEval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
int num_dimensions = NumDimensions(input);
|
||||
|
||||
if (num_dimensions > 4) {
|
||||
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"Op Concatenation does not currently support num dimensions >4 "
|
||||
"Op Concatenation does not currently support num dimensions > %d "
|
||||
"Tensor has %d dimensions.",
|
||||
num_dimensions);
|
||||
RuntimeShape::kMaxSmallSize, num_dimensions);
|
||||
return kTfLiteError;
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace concatenation
|
||||
|
||||
TfLiteRegistration Register_CONCATENATION() {
|
||||
return {/*init=*/concatenation::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/concatenation::Prepare,
|
||||
/*invoke=*/concatenation::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
|
||||
concatenation::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetOptionalTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output),
|
||||
tflite::micro::GetTensorShape(nullptr), nullptr);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
|
||||
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<std::int64_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
switch (bias->type) {
|
||||
case kTfLiteInt32: {
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
ConvParamsQuantized(params, data),
|
||||
data.per_channel_output_multiplier, data.per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt64: {
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
ConvParamsQuantized(params, data),
|
||||
data.per_channel_output_multiplier, data.per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MicroPrintf("Bias type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(bias->type), bias->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetOptionalTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_CONV_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/ConvPrepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
|
||||
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
|
||||
TfLiteRegistration registration, int16_t* output_data);
|
||||
|
||||
TfLiteStatus TestConvQuantizedPerChannel(
|
||||
int* input_dims_data, const float* input_data, int16_t* input_quantized,
|
||||
float input_scale, int input_zero_point, int* filter_dims_data,
|
||||
const float* filter_data, int8_t* filter_data_quantized,
|
||||
int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
|
||||
float* bias_scales, int* bias_zero_points, int* output_dims_data,
|
||||
const float* expected_output_data, int16_t* expected_output_data_quantized,
|
||||
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
|
||||
TfLiteRegistration registration, int16_t* output_data);
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
|
||||
@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_CUMSUM() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_DEPTH_TO_SPACE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetOptionalTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetOptionalTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/DepthwiseConvPrepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
|
||||
|
||||
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// This is the most generic TfLiteRegistration. The actual supported types may
|
||||
// still be target dependent. The only requirement is that every implementation
|
||||
// (reference or optimized) must define this function.
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int8 activations and int8 weights and uses the latency optimized
|
||||
// implementations.
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
|
||||
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int16 activations and int8 weights and uses the latency optimized
|
||||
// implementations.
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
|
||||
|
||||
#else
|
||||
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
|
||||
return Register_DEPTHWISE_CONV_2D();
|
||||
}
|
||||
|
||||
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
|
||||
return Register_DEPTHWISE_CONV_2D();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
|
||||
|
||||
@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
reference_ops::Dequantize(data->quantization_params,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
default:
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_DEQUANTIZE() {
|
||||
return {/*init=*/DequantizeInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/DequantizePrepare,
|
||||
/*invoke=*/DequantizeEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
|
||||
DequantizeEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteInt16 ||
|
||||
input->type == kTfLiteUInt8);
|
||||
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
|
||||
|
||||
if (output->type == kTfLiteInt32) {
|
||||
|
||||
@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return op_data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
|
||||
static TfLiteRegistration r = {/*init=*/Init,
|
||||
/*free=*/Free,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
@@ -27,6 +29,22 @@ namespace micro {
|
||||
namespace elementwise {
|
||||
namespace {
|
||||
|
||||
constexpr int kAbsNameId = 0;
|
||||
constexpr int kRsrqtNameId = 1;
|
||||
|
||||
const int kElementwiseInputTensor = 0;
|
||||
const int kElementwiseOutputTensor = 0;
|
||||
|
||||
struct OpDataAbsRsqrt {
|
||||
int32_t multiplier;
|
||||
int shift;
|
||||
int input_offset;
|
||||
int output_offset;
|
||||
bool needs_rescale;
|
||||
TfLiteQuantizationType input_quantization_type;
|
||||
TfLiteType input_type;
|
||||
};
|
||||
|
||||
bool IsNumericSupportedType(const TfLiteType type) {
|
||||
return type == kTfLiteFloat32;
|
||||
}
|
||||
@@ -35,16 +53,40 @@ bool IsLogicalSupportedType(const TfLiteType type) {
|
||||
return type == kTfLiteBool;
|
||||
}
|
||||
|
||||
bool IsAbsSupportedType(const TfLiteType type) {
|
||||
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
|
||||
}
|
||||
|
||||
bool IsRsqrtSupportedType(const TfLiteType type) {
|
||||
return type == kTfLiteFloat32 || type == kTfLiteInt8;
|
||||
}
|
||||
|
||||
inline void SetAbsOutputMultiplier(const float input_scale,
|
||||
const float output_scale,
|
||||
int32_t* multiplier, int* shift) {
|
||||
QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
|
||||
multiplier, shift);
|
||||
}
|
||||
|
||||
inline void SetRsqrtOutputMultiplier(const float input_scale,
|
||||
const float output_scale,
|
||||
int32_t* multiplier, int* shift) {
|
||||
const double scale =
|
||||
1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
|
||||
QuantizeMultiplier(scale, multiplier, shift);
|
||||
}
|
||||
|
||||
typedef bool (*IsSupportedType)(TfLiteType);
|
||||
template <IsSupportedType>
|
||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
@@ -58,9 +100,79 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
typedef bool (*IsSupportedType)(TfLiteType);
|
||||
template <IsSupportedType, const int op_nameid>
|
||||
TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
|
||||
op_data->input_type = input->type;
|
||||
|
||||
// For int16 type input, we support both quantized and non-quantized
|
||||
// evaluation.
|
||||
if (op_nameid == kAbsNameId) {
|
||||
op_data->input_quantization_type = input->quantization.type;
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteInt8 ||
|
||||
(input->type == kTfLiteInt16 &&
|
||||
input->quantization.type != kTfLiteNoQuantization)) {
|
||||
TF_LITE_ENSURE_EQ(context, input->quantization.type,
|
||||
kTfLiteAffineQuantization);
|
||||
TF_LITE_ENSURE_EQ(context, output->quantization.type,
|
||||
kTfLiteAffineQuantization);
|
||||
const auto* input_params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
|
||||
const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
output->quantization.params);
|
||||
TF_LITE_ENSURE(context, input_params != nullptr);
|
||||
TF_LITE_ENSURE(context, input_params->scale != nullptr);
|
||||
TF_LITE_ENSURE(context, input_params->scale->size > 0);
|
||||
TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
|
||||
TF_LITE_ENSURE(context, output_params != nullptr);
|
||||
TF_LITE_ENSURE(context, output_params->scale != nullptr);
|
||||
TF_LITE_ENSURE(context, output_params->scale->size > 0);
|
||||
TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
|
||||
op_data->input_offset = input_params->zero_point->data[0];
|
||||
op_data->output_offset = output_params->zero_point->data[0];
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
|
||||
TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
|
||||
}
|
||||
const float input_scale = input_params->scale->data[0];
|
||||
const float output_scale = output_params->scale->data[0];
|
||||
op_data->needs_rescale = input_scale != output_scale;
|
||||
if (op_nameid == kAbsNameId && op_data->needs_rescale) {
|
||||
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
|
||||
&op_data->shift);
|
||||
} else if (op_nameid == kRsrqtNameId) {
|
||||
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
|
||||
&op_data->shift);
|
||||
}
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
T func(T), TfLiteType expected_type) {
|
||||
inline TfLiteStatus EvalImplQuantized(
|
||||
TfLiteContext* context, TfLiteNode* node,
|
||||
T func(TfLiteContext*, TfLiteNode*, T),
|
||||
TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
|
||||
TfLiteType expected_type) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
||||
@@ -68,6 +180,34 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
const T* in_data = tflite::micro::GetTensorData<T>(input);
|
||||
T* out_data = tflite::micro::GetTensorData<T>(output);
|
||||
for (size_t i = 0; i < num_elements; ++i) {
|
||||
if (validate_input_func) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
validate_input_func(context, node, in_data[i]));
|
||||
}
|
||||
out_data[i] = func(context, node, in_data[i]);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T AbsHelper(T i) {
|
||||
return std::abs(i);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
T func(T), TfLiteStatus validate_input_func(T),
|
||||
TfLiteType expected_type) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
||||
const size_t num_elements = ElementCount(*input->dims);
|
||||
const T* in_data = tflite::micro::GetTensorData<T>(input);
|
||||
T* out_data = tflite::micro::GetTensorData<T>(output);
|
||||
for (size_t i = 0; i < num_elements; ++i) {
|
||||
if (validate_input_func) {
|
||||
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
|
||||
}
|
||||
out_data[i] = func(in_data[i]);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
|
||||
float float_func(float)) {
|
||||
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
|
||||
return EvalImpl<float>(context, node, float_func,
|
||||
/*validate_input_func=*/nullptr, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
bool bool_func(bool)) {
|
||||
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
|
||||
return EvalImpl<bool>(context, node, bool_func,
|
||||
/*validate_input_func=*/nullptr, kTfLiteBool);
|
||||
}
|
||||
|
||||
void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
|
||||
size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
|
||||
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||
const int kMin = std::numeric_limits<T>::min();
|
||||
const int kMax = std::numeric_limits<T>::max();
|
||||
|
||||
const int32_t value = std::abs(i - op_data->input_offset);
|
||||
if (!op_data->needs_rescale) {
|
||||
return static_cast<T>(
|
||||
std::min(std::max(static_cast<long int>(value + op_data->output_offset),
|
||||
static_cast<long int>(kMin)),
|
||||
static_cast<long int>(kMax)));
|
||||
}
|
||||
|
||||
const int32_t output = tflite::MultiplyByQuantizedMultiplier(
|
||||
value, op_data->multiplier, op_data->shift) +
|
||||
op_data->output_offset;
|
||||
return static_cast<T>(std::min(
|
||||
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
|
||||
static_cast<long int>(kMax)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
|
||||
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||
const int kMin = std::numeric_limits<T>::min();
|
||||
const int kMax = std::numeric_limits<T>::max();
|
||||
|
||||
const int32_t value = (i - op_data->input_offset);
|
||||
const int32_t kShift = 20; // Shift to keep value integer.
|
||||
if (value == 0) {
|
||||
// Assume that any value close to 0 represents the max output value.
|
||||
return static_cast<T>(kMax);
|
||||
}
|
||||
int32_t inv_sqrt_multiplier;
|
||||
int inv_sqrt_shift;
|
||||
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
|
||||
&inv_sqrt_shift);
|
||||
const int32_t data = tflite::MultiplyByQuantizedMultiplier(
|
||||
static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
|
||||
const int32_t output =
|
||||
tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
|
||||
op_data->shift - kShift) +
|
||||
op_data->output_offset;
|
||||
return static_cast<T>(std::min(
|
||||
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
|
||||
static_cast<long int>(kMax)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
|
||||
T i) {
|
||||
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
|
||||
"Rsqrt is only defined for positive values");
|
||||
return static_cast<TfLiteStatus>(kTfLiteOk);
|
||||
}
|
||||
|
||||
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, std::abs);
|
||||
OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
|
||||
TfLiteType type = op_data->input_type;
|
||||
TfLiteQuantizationType input_quantization_type =
|
||||
op_data->input_quantization_type;
|
||||
TfLiteStatus eval_result;
|
||||
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
eval_result = EvalNumeric(context, node, std::abs);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
eval_result =
|
||||
EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
|
||||
/*validate_input_func=*/nullptr, type);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
eval_result =
|
||||
input_quantization_type == kTfLiteNoQuantization
|
||||
? EvalImpl<int16_t>(context, node, AbsHelper,
|
||||
/*validate_input_func=*/nullptr, type)
|
||||
: EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
|
||||
/*validate_input_func=*/nullptr,
|
||||
type);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
|
||||
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||
TfLiteType type = op_data->input_type;
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalImpl<float>(
|
||||
context, node, [](float f) { return 1.f / std::sqrt(f); },
|
||||
/*validate_input_func=*/nullptr, type);
|
||||
case kTfLiteInt8:
|
||||
return EvalImplQuantized<int8_t>(context, node,
|
||||
elementwise::RsqrtEvalQuantized,
|
||||
elementwise::validate_input_func, type);
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace elementwise
|
||||
|
||||
TfLiteRegistration Register_ABS() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::AbsEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
elementwise::ElementWiseAbsRsqrtInit,
|
||||
elementwise::PrepareAbsRsqrt<elementwise::IsAbsSupportedType,
|
||||
elementwise::kAbsNameId>,
|
||||
elementwise::AbsEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_SIN() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::SinEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::SinEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_COS() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::CosEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::CosEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_LOG() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::LogEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::LogEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_SQRT() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::SqrtEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::SqrtEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_RSQRT() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::RsqrtEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
elementwise::ElementWiseAbsRsqrtInit,
|
||||
elementwise::PrepareAbsRsqrt<elementwise::IsRsqrtSupportedType,
|
||||
elementwise::kRsrqtNameId>,
|
||||
elementwise::RsqrtEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_SQUARE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::SquareEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::SquareEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_LOGICAL_NOT() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
|
||||
/*invoke=*/elementwise::LogicalNotEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
|
||||
elementwise::LogicalNotEval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_ELU() {
|
||||
return {/*init=*/EluInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/EluPrepare,
|
||||
/*invoke=*/EluEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_ADD() {
|
||||
return {/*init=*/AddInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/AddPrepare,
|
||||
/*invoke=*/AddEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
#if ESP_NN
|
||||
if (input->type == kTfLiteInt8) {
|
||||
data_dims_t input_dims = {
|
||||
.width = input_width, .height = input_height,
|
||||
.channels = input->dims->data[3], 1
|
||||
};
|
||||
data_dims_t output_dims = {
|
||||
.width = output_width, .height = output_height,
|
||||
.channels = output->dims->data[3], 1
|
||||
};
|
||||
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||
conv_params_t conv_params = {
|
||||
.in_offset = 0, .out_offset = 0,
|
||||
.stride = {params.stride_width, params.stride_height},
|
||||
.padding = {data->op_data.padding.width, data->op_data.padding.height},
|
||||
.dilation = {0, 0}, .activation = {-128, 127}
|
||||
};
|
||||
|
||||
int scratch_buf_size = esp_nn_get_conv_scratch_size(
|
||||
input_width, input_height, input->dims->data[3],
|
||||
output->dims->data[3], filter_width, filter_height);
|
||||
&input_dims, &filter_dims, &output_dims, &conv_params);
|
||||
if (scratch_buf_size > 0) {
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, scratch_buf_size, &data->buffer_idx));
|
||||
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
|
||||
const int input_size = input_width * input_height * input_depth;
|
||||
const int output_size = output_width * output_height * output_depth;
|
||||
|
||||
data_dims_t input_dims = {
|
||||
.width = input_width, .height = input_height,
|
||||
.channels = input_depth, 1
|
||||
};
|
||||
data_dims_t output_dims = {
|
||||
.width = output_width, .height = output_height,
|
||||
.channels = output_depth, 1
|
||||
};
|
||||
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||
conv_params_t conv_params = {
|
||||
.in_offset = input_offset, .out_offset = output_offset,
|
||||
.stride = {stride_width, stride_height},
|
||||
.padding = {pad_width, pad_height},
|
||||
.dilation = {0, 0},
|
||||
.activation = {activation_min, activation_max}
|
||||
};
|
||||
quant_data_t quant_data = {
|
||||
.shift = data.op_data.per_channel_output_shift,
|
||||
.mult = data.op_data.per_channel_output_multiplier
|
||||
};
|
||||
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
||||
esp_nn_conv_s8(input_data + i_batch * input_size,
|
||||
input_width, input_height, input_depth, input_offset,
|
||||
pad_width, pad_height, stride_width, stride_height,
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
filter_width, filter_height,
|
||||
esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
|
||||
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
output_data + i_batch * output_size,
|
||||
output_width, output_height, output_depth, output_offset,
|
||||
data.op_data.per_channel_output_shift,
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
activation_min, activation_max);
|
||||
&output_dims, output_data + i_batch * output_size,
|
||||
&conv_params, &quant_data);
|
||||
}
|
||||
} else {
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
conv_total_time += esp_timer_get_time() - start_time;
|
||||
long long time_this_instance = esp_timer_get_time() - start_time;
|
||||
conv_total_time += time_this_instance;
|
||||
//printf("time this instance: %llu\n", time_this_instance / 1000);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_CONV_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
if (data.buffer_idx > -1) {
|
||||
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
|
||||
}
|
||||
|
||||
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
|
||||
|
||||
data_dims_t input_dims = {
|
||||
.width = input_width, .height = input_height,
|
||||
.channels = input_depth, 1
|
||||
};
|
||||
data_dims_t output_dims = {
|
||||
.width = output_width, .height = output_height,
|
||||
.channels = output_depth, 1
|
||||
};
|
||||
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||
dw_conv_params_t conv_params = {
|
||||
.in_offset = input_offset, .out_offset = output_offset,
|
||||
.ch_mult = depth_multiplier,
|
||||
.stride = {stride_width, stride_height},
|
||||
.padding = {pad_width, pad_height}, .dilation = {0, 0},
|
||||
.activation = {activation_min, activation_max}
|
||||
};
|
||||
quant_data_t quant_data = {
|
||||
.shift = data.op_data.per_channel_output_shift,
|
||||
.mult = data.op_data.per_channel_output_multiplier
|
||||
};
|
||||
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
||||
esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
|
||||
input_height, input_depth, input_offset,
|
||||
pad_width, pad_height,
|
||||
stride_width, stride_height, depth_multiplier,
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
filter_width, filter_height,
|
||||
esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
|
||||
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
output_data + i_batch * output_size,
|
||||
output_width, output_height, output_offset,
|
||||
data.op_data.per_channel_output_shift,
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
activation_min, activation_max);
|
||||
&output_dims, output_data + i_batch * output_size,
|
||||
&conv_params, &quant_data);
|
||||
}
|
||||
} else {
|
||||
reference_integer_ops::DepthwiseConvPerChannel(
|
||||
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
#if ESP_NN
|
||||
if (input->type == kTfLiteInt8) {
|
||||
data_dims_t input_dims = {
|
||||
.width = input_width, .height = input_height,
|
||||
.channels = input->dims->data[3], 1
|
||||
};
|
||||
data_dims_t output_dims = {
|
||||
.width = output_width, .height = output_height,
|
||||
.channels = output->dims->data[3], 1
|
||||
};
|
||||
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||
dw_conv_params_t conv_params = {
|
||||
.in_offset = 0, .out_offset = 0,
|
||||
.ch_mult = params.depth_multiplier,
|
||||
.stride = {params.stride_width, params.stride_height},
|
||||
.padding = {data->op_data.padding.width, data->op_data.padding.height},
|
||||
.dilation = {0, 0}, .activation = {-128, 127}
|
||||
};
|
||||
|
||||
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
|
||||
input_width, input_height, input->dims->data[3],
|
||||
params.depth_multiplier, filter_width, filter_height);
|
||||
&input_dims, &filter_dims, &output_dims, &conv_params);
|
||||
if (scratch_buf_size > 0) {
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, scratch_buf_size, &data->buffer_idx));
|
||||
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
dc_total_time += esp_timer_get_time() - start_time;
|
||||
long long time_this_instance = esp_timer_get_time() - start_time;
|
||||
dc_total_time += time_this_instance;
|
||||
// printf("time this instance: %llu\n", time_this_instance / 1000);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MUL() {
|
||||
return {/*init=*/MulInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/MulPrepare,
|
||||
/*invoke=*/MulEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_AVERAGE_POOL_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PoolingPrepare,
|
||||
/*invoke=*/AverageEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MAX_POOL_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PoolingPrepare,
|
||||
/*invoke=*/MaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
long long softmax_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
// Softmax parameter data that persists in user_data
|
||||
const int kInt16LUTArraySize = 513;
|
||||
|
||||
struct NodeData {
|
||||
SoftmaxParams op_data;
|
||||
#if ESP_NN
|
||||
int buffer_idx;
|
||||
#endif
|
||||
};
|
||||
|
||||
static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(NodeData));
|
||||
}
|
||||
|
||||
void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
|
||||
TfLiteEvalTensor* output, const NodeData* data) {
|
||||
if (input->type == kTfLiteInt8) {
|
||||
if (output->type == kTfLiteInt16) {
|
||||
tflite::reference_ops::Softmax(
|
||||
data->op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
#if ESP_NN
|
||||
const int32_t input_beta_multiplier = data->op_data.input_multiplier;
|
||||
const int32_t input_beta_left_shift = data->op_data.input_left_shift;
|
||||
const int diff_min = data->op_data.diff_min;
|
||||
const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
|
||||
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
const int depth =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
const int8_t *in_ptr = tflite::micro::GetTensorData<int8_t>(input);
|
||||
int8_t *out_ptr = tflite::micro::GetTensorData<int8_t>(output);
|
||||
void *scratch_buf = NULL;
|
||||
if (data->buffer_idx > -1) {
|
||||
scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
|
||||
}
|
||||
esp_nn_set_softmax_scratch_buf(scratch_buf);
|
||||
esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
|
||||
input_beta_left_shift, diff_min, out_ptr);
|
||||
#else
|
||||
tflite::reference_ops::Softmax(
|
||||
data->op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
tflite::reference_ops::SoftmaxInt16(
|
||||
data->op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
}
|
||||
}
|
||||
|
||||
static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
NodeData data = *static_cast<NodeData*>(node->user_data);
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
tflite::reference_ops::Softmax(
|
||||
data.op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt16: {
|
||||
SoftmaxQuantized(context, input, output, &data);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
softmax_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
NodeData* data = static_cast<NodeData*>(node->user_data);
|
||||
// Only allocate LUTs for KTfLiteInt16 data type
|
||||
if (input->type == kTfLiteInt16) {
|
||||
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||
data->op_data.exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||
data->op_data.one_over_one_plus_x_lut =
|
||||
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
}
|
||||
|
||||
// Populate LUT if required
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
// exp LUT only used on negative values
|
||||
// we consider exp(-10.0) is insignificant to accumulation
|
||||
gen_lut<float, int16_t, int16_t>(
|
||||
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
|
||||
data->op_data.exp_lut);
|
||||
gen_lut<float, int16_t, int16_t>(
|
||||
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
|
||||
1.0f, data->op_data.one_over_one_plus_x_lut);
|
||||
data->op_data.zero_point = output->params.zero_point;
|
||||
data->op_data.scale = output->params.scale;
|
||||
}
|
||||
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
auto ret_val =
|
||||
CalculateSoftmaxParams(context, input, output, params, &data->op_data);
|
||||
|
||||
#if ESP_NN
|
||||
if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
|
||||
const int32_t input_width = input->dims->data[1];
|
||||
const int32_t input_height = input->dims->data[2];
|
||||
int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
|
||||
input_height);
|
||||
if (scratch_buf_size > 0) {
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, scratch_buf_size, &data->buffer_idx));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return ret_val;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SOFTMAX() {
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_EXP() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_EXPAND_DIMS() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FILL() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace floor
|
||||
|
||||
TfLiteRegistration Register_FLOOR() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/floor::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FLOOR_DIV() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FLOOR_MOD() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
|
||||
node, kFullyConnectedOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
|
||||
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
|
||||
context, params->activation, input->type,
|
||||
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
|
||||
case kTfLiteInt16: {
|
||||
const int64_t* bias_data =
|
||||
nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
|
||||
: nullptr;
|
||||
|
||||
tflite::reference_integer_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias), bias_data,
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
@@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int16.
|
||||
TfLiteRegistration Register_FULLY_CONNECTED_INT16();
|
||||
|
||||
#else
|
||||
// Note that while this block gets used for both reference and optimized kernels
|
||||
// that do not have any specialized implementations, the only goal here is to
|
||||
// define fallback implementation that allow reference kernels to still be used
|
||||
// from applications that call a more specific kernel variant.
|
||||
|
||||
inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
|
||||
return Register_FULLY_CONNECTED();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||
|
||||
@@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_GATHER() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -195,14 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_GATHER_ND() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_HARD_SWISH() {
|
||||
return {/*init=*/HardSwishInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/tflite::HardSwishPrepare,
|
||||
/*invoke=*/HardSwishEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare,
|
||||
HardSwishEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_IF() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -15,9 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
|
||||
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||
#include "tensorflow/lite/micro/micro_arena_constants.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
|
||||
namespace tflite {
|
||||
@@ -30,7 +30,7 @@ uint8_t KernelRunner::kKernelRunnerBuffer_[];
|
||||
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||
TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
||||
void* builtin_data)
|
||||
void* builtin_data, TfLiteIntArray* intermediates)
|
||||
: registration_(registration),
|
||||
allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
|
||||
kKernelRunnerBuffer_,
|
||||
@@ -54,6 +54,7 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||
node_.inputs = inputs;
|
||||
node_.outputs = outputs;
|
||||
node_.builtin_data = builtin_data;
|
||||
node_.intermediates = intermediates;
|
||||
}
|
||||
|
||||
bool KernelRunner::ValidateTempBufferDeallocated() {
|
||||
|
||||
@@ -18,9 +18,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||
#include "tensorflow/lite/micro/fake_micro_context.h"
|
||||
#include "tensorflow/lite/micro/mock_micro_graph.h"
|
||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -35,7 +35,8 @@ class KernelRunner {
|
||||
public:
|
||||
KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
|
||||
int tensors_size, TfLiteIntArray* inputs,
|
||||
TfLiteIntArray* outputs, void* builtin_data);
|
||||
TfLiteIntArray* outputs, void* builtin_data,
|
||||
TfLiteIntArray* intermediates = nullptr);
|
||||
|
||||
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
|
||||
// exceptions will be DebugLog'd and returned as a status code.
|
||||
|
||||
@@ -36,6 +36,21 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration RegisterOp(
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
|
||||
return {/*init=*/init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/prepare,
|
||||
/*invoke=*/invoke,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0,
|
||||
/*registration_external=*/nullptr};
|
||||
}
|
||||
|
||||
// Returns a mutable tensor for a given input index. is_variable must be checked
|
||||
// during prepare when the full TfLiteTensor is available.
|
||||
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
|
||||
|
||||
@@ -27,6 +27,11 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
|
||||
TfLiteRegistration RegisterOp(
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
|
||||
|
||||
// Returns a mutable tensor for a given input index. is_variable must be checked
|
||||
// during prepare when the full TfLiteTensor is available.
|
||||
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
|
||||
@@ -40,19 +45,33 @@ const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context,
|
||||
TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index);
|
||||
|
||||
// Returns data for a TfLiteEvalTensor struct.
|
||||
// Returns data for a TfLiteEvalTensor struct that are expected to exist.
|
||||
template <typename T>
|
||||
T* GetTensorData(TfLiteEvalTensor* tensor) {
|
||||
return tensor != nullptr ? reinterpret_cast<T*>(tensor->data.raw) : nullptr;
|
||||
TFLITE_DCHECK(tensor != nullptr);
|
||||
return reinterpret_cast<T*>(tensor->data.raw);
|
||||
}
|
||||
|
||||
// Returns const data for a TfLiteEvalTensor struct.
|
||||
// Returns const data for a TfLiteEvalTensor struct that are expected to exist.
|
||||
template <typename T>
|
||||
const T* GetTensorData(const TfLiteEvalTensor* tensor) {
|
||||
TFLITE_DCHECK(tensor != nullptr);
|
||||
return reinterpret_cast<const T*>(tensor->data.raw);
|
||||
}
|
||||
|
||||
// Returns data for a TfLiteEvalTensor struct that could be null.
|
||||
template <typename T>
|
||||
T* GetOptionalTensorData(TfLiteEvalTensor* tensor) {
|
||||
return tensor == nullptr ? nullptr : reinterpret_cast<T*>(tensor->data.raw);
|
||||
}
|
||||
|
||||
// Returns const data for a TfLiteEvalTensor struct that could be null.
|
||||
template <typename T>
|
||||
const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
|
||||
return tensor == nullptr ? nullptr
|
||||
: reinterpret_cast<const T*>(tensor->data.raw);
|
||||
}
|
||||
|
||||
// Returns the shape of a TfLiteEvalTensor struct.
|
||||
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
|
||||
|
||||
|
||||
@@ -136,14 +136,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_L2_POOL_2D() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/L2Prepare,
|
||||
/*invoke=*/L2Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, L2Prepare, L2Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -137,14 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace l2norm
|
||||
|
||||
TfLiteRegistration Register_L2NORM_REF() {
|
||||
return {/*init=*/l2norm::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/l2norm::Prepare,
|
||||
/*invoke=*/l2norm::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(l2norm::Init, l2norm::Prepare, l2norm::Eval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }
|
||||
|
||||
@@ -88,14 +88,8 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_LEAKY_RELU() {
|
||||
return {/*init=*/LeakyReluInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/LeakyReluPrepare,
|
||||
/*invoke=*/LeakyReluEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(LeakyReluInit, LeakyReluPrepare,
|
||||
LeakyReluEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -142,14 +142,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_LOG_SOFTMAX() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/LogSoftmaxPrepare,
|
||||
/*invoke=*/LogSoftmaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, LogSoftmaxPrepare, LogSoftmaxEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -34,29 +34,11 @@ TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_LOGICAL_OR() {
|
||||
// Init, Free, Prepare, Eval are satisfying the Interface required by
|
||||
// TfLiteRegistration.
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/LogicalOrEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, LogicalOrEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_LOGICAL_AND() {
|
||||
// Init, Free, Prepare, Eval are satisfying the Interface required by
|
||||
// TfLiteRegistration.
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/LogicalAndEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, LogicalAndEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -106,13 +106,6 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_LOGISTIC() {
|
||||
return {/*init=*/LogisticInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/LogisticPrepare,
|
||||
/*invoke=*/LogisticEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(LogisticInit, LogisticPrepare, LogisticEval);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,250 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Pamameters for integer LSTM.
|
||||
// Consider split this into two Integer Parameters if more fields are added.
|
||||
struct IntegerLstmParameter {
|
||||
int32_t effective_input_to_input_scale_a;
|
||||
int32_t effective_input_to_input_scale_b;
|
||||
int32_t effective_recurrent_to_input_scale_a;
|
||||
int32_t effective_recurrent_to_input_scale_b;
|
||||
int32_t effective_cell_to_input_scale_a;
|
||||
int32_t effective_cell_to_input_scale_b;
|
||||
int32_t effective_input_to_forget_scale_a;
|
||||
int32_t effective_input_to_forget_scale_b;
|
||||
int32_t effective_recurrent_to_forget_scale_a;
|
||||
int32_t effective_recurrent_to_forget_scale_b;
|
||||
int32_t effective_cell_to_forget_scale_a;
|
||||
int32_t effective_cell_to_forget_scale_b;
|
||||
int32_t effective_input_to_cell_scale_a;
|
||||
int32_t effective_input_to_cell_scale_b;
|
||||
int32_t effective_recurrent_to_cell_scale_a;
|
||||
int32_t effective_recurrent_to_cell_scale_b;
|
||||
int32_t effective_input_to_output_scale_a;
|
||||
int32_t effective_input_to_output_scale_b;
|
||||
int32_t effective_recurrent_to_output_scale_a;
|
||||
int32_t effective_recurrent_to_output_scale_b;
|
||||
int32_t effective_cell_to_output_scale_a;
|
||||
int32_t effective_cell_to_output_scale_b;
|
||||
int32_t effective_proj_scale_a;
|
||||
int32_t effective_proj_scale_b;
|
||||
int32_t effective_hidden_scale_a;
|
||||
int32_t effective_hidden_scale_b;
|
||||
int32_t layer_norm_input_scale_a;
|
||||
int32_t layer_norm_input_scale_b;
|
||||
int32_t layer_norm_forget_scale_a;
|
||||
int32_t layer_norm_forget_scale_b;
|
||||
int32_t layer_norm_cell_scale_a;
|
||||
int32_t layer_norm_cell_scale_b;
|
||||
int32_t layer_norm_output_scale_a;
|
||||
int32_t layer_norm_output_scale_b;
|
||||
// Quantized clip value for cell and projection. Zero value means no clipping.
|
||||
int16_t quantized_cell_clip;
|
||||
int8_t quantized_proj_clip;
|
||||
int32_t hidden_zp;
|
||||
int32_t cell_scale;
|
||||
|
||||
int32_t input_variance_guard;
|
||||
int32_t forget_variance_guard;
|
||||
int32_t cell_variance_guard;
|
||||
int32_t output_variance_guard;
|
||||
|
||||
// Pre-calculate bias + zero_point * weight.
|
||||
int32_t* input_to_forget_effective_bias;
|
||||
int32_t* recurrent_to_forget_effective_bias;
|
||||
int32_t* input_to_cell_effective_bias;
|
||||
int32_t* recurrent_to_cell_effective_bias;
|
||||
int32_t* input_to_output_effective_bias;
|
||||
int32_t* recurrent_to_output_effective_bias;
|
||||
int32_t* input_to_input_effective_bias;
|
||||
int32_t* recurrent_to_input_effective_bias;
|
||||
int32_t* projection_effective_bias;
|
||||
|
||||
// Scale and zero point for intermediate tensors.
|
||||
// Used only in the 8x8_8 case.
|
||||
int32_t intermediate_scale_a[8];
|
||||
int32_t intermediate_scale_b[8];
|
||||
int32_t intermediate_zp[12];
|
||||
};
|
||||
|
||||
// Scales for hybrid op with integer inputs and float weights
|
||||
struct HybridLstmScales {
|
||||
float input_to_input_weights_scale;
|
||||
float input_to_forget_weights_scale;
|
||||
float input_to_cell_weights_scale;
|
||||
float input_to_output_weights_scale;
|
||||
float aux_input_to_input_weights_scale;
|
||||
float aux_input_to_forget_weights_scale;
|
||||
float aux_input_to_cell_weights_scale;
|
||||
float aux_input_to_output_weights_scale;
|
||||
float recurrent_to_input_weights_scale;
|
||||
float recurrent_to_forget_weights_scale;
|
||||
float recurrent_to_cell_weights_scale;
|
||||
float recurrent_to_output_weights_scale;
|
||||
float cell_to_input_weights_scale;
|
||||
float cell_to_forget_weights_scale;
|
||||
float cell_to_output_weights_scale;
|
||||
float projection_weights_scale;
|
||||
};
|
||||
|
||||
TfLiteStatus EvalFloatLstm(
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* input_to_input_weights,
|
||||
const TfLiteEvalTensor* input_to_forget_weights,
|
||||
const TfLiteEvalTensor* input_to_cell_weights,
|
||||
const TfLiteEvalTensor* input_to_output_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||
const TfLiteEvalTensor* cell_to_input_weights,
|
||||
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||
const TfLiteEvalTensor* cell_to_output_weights,
|
||||
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* aux_input,
|
||||
const TfLiteEvalTensor* aux_input_to_input_weights,
|
||||
const TfLiteEvalTensor* aux_input_to_forget_weights,
|
||||
const TfLiteEvalTensor* aux_input_to_cell_weights,
|
||||
const TfLiteEvalTensor* aux_input_to_output_weights,
|
||||
const TfLiteEvalTensor* input_gate_bias,
|
||||
const TfLiteEvalTensor* forget_gate_bias,
|
||||
const TfLiteEvalTensor* cell_gate_bias,
|
||||
const TfLiteEvalTensor* output_gate_bias,
|
||||
const TfLiteEvalTensor* projection_weights,
|
||||
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||
bool forward_sequence, bool time_major, int output_offset,
|
||||
float* scratch_buffer, TfLiteEvalTensor* output_state,
|
||||
TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
|
||||
|
||||
TfLiteStatus EvalHybridLstm(
|
||||
const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* input_to_input_weights,
|
||||
const TfLiteEvalTensor* input_to_input_weights_ledger,
|
||||
const TfLiteEvalTensor* input_to_forget_weights,
|
||||
const TfLiteEvalTensor* input_to_forget_weights_ledger,
|
||||
const TfLiteEvalTensor* input_to_cell_weights,
|
||||
const TfLiteEvalTensor* input_to_cell_weights_ledger,
|
||||
const TfLiteEvalTensor* input_to_output_weights,
|
||||
const TfLiteEvalTensor* input_to_output_weights_ledger,
|
||||
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
|
||||
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
|
||||
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
|
||||
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
|
||||
const TfLiteEvalTensor* cell_to_input_weights,
|
||||
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||
const TfLiteEvalTensor* cell_to_output_weights,
|
||||
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* aux_input,
|
||||
const TfLiteEvalTensor* aux_input_to_input_weights,
|
||||
const TfLiteEvalTensor* aux_input_to_forget_weights,
|
||||
const TfLiteEvalTensor* aux_input_to_cell_weights,
|
||||
const TfLiteEvalTensor* aux_input_to_output_weights,
|
||||
const TfLiteEvalTensor* input_gate_bias,
|
||||
const TfLiteEvalTensor* forget_gate_bias,
|
||||
const TfLiteEvalTensor* cell_gate_bias,
|
||||
const TfLiteEvalTensor* output_gate_bias,
|
||||
const TfLiteEvalTensor* projection_weights,
|
||||
const TfLiteEvalTensor* projection_weights_ledger,
|
||||
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||
bool forward_sequence, bool time_major, int output_offset,
|
||||
float* scratch_buffer, float* input_sf, float* aux_input_sf,
|
||||
float* output_state_sf, float* prod_scaling_factors,
|
||||
float* recovered_cell_weights, int8_t* input_quantized,
|
||||
int8_t* aux_input_quantized, int8_t* output_state_quantized,
|
||||
int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
|
||||
TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
|
||||
TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
|
||||
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
|
||||
bool* compute_row_sums);
|
||||
|
||||
TfLiteStatus EvalInteger8x8_16Lstm(
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* input_to_input_weights,
|
||||
const TfLiteEvalTensor* input_to_forget_weights,
|
||||
const TfLiteEvalTensor* input_to_cell_weights,
|
||||
const TfLiteEvalTensor* input_to_output_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||
const TfLiteEvalTensor* cell_to_input_weights,
|
||||
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||
const TfLiteEvalTensor* cell_to_output_weights,
|
||||
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* input_gate_bias,
|
||||
const TfLiteEvalTensor* forget_gate_bias,
|
||||
const TfLiteEvalTensor* cell_gate_bias,
|
||||
const TfLiteEvalTensor* output_gate_bias,
|
||||
const TfLiteEvalTensor* projection_weights,
|
||||
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||
bool forward_sequence, bool time_major,
|
||||
const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
|
||||
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
|
||||
TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
|
||||
int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5);
|
||||
|
||||
TfLiteStatus EvalInteger8x8_8Lstm(
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* input_to_input_weights,
|
||||
const TfLiteEvalTensor* input_to_forget_weights,
|
||||
const TfLiteEvalTensor* input_to_cell_weights,
|
||||
const TfLiteEvalTensor* input_to_output_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||
const TfLiteEvalTensor* cell_to_input_weights,
|
||||
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||
const TfLiteEvalTensor* cell_to_output_weights,
|
||||
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||
const TfLiteEvalTensor* input_gate_bias,
|
||||
const TfLiteEvalTensor* forget_gate_bias,
|
||||
const TfLiteEvalTensor* cell_gate_bias,
|
||||
const TfLiteEvalTensor* output_gate_bias,
|
||||
const TfLiteEvalTensor* projection_weights,
|
||||
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
|
||||
TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
|
||||
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7);
|
||||
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
|
||||
@@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Input Tensors of size {n_batch, n_input}
|
||||
constexpr int kLstmInputTensor = 0;
|
||||
|
||||
// Input weight tensors of size: {n_cell, n_input}
|
||||
constexpr int kLstmInputToInputWeightsTensor = 1; // Optional
|
||||
constexpr int kLstmInputToForgetWeightsTensor = 2;
|
||||
constexpr int kLstmInputToCellWeightsTensor = 3;
|
||||
constexpr int kLstmInputToOutputWeightsTensor = 4;
|
||||
|
||||
// Recurrent weight tensors of size {n_cell, n_output}
|
||||
constexpr int kLstmRecurrentToInputWeightsTensor = 5; // Optional
|
||||
constexpr int kLstmRecurrentToForgetWeightsTensor = 6;
|
||||
constexpr int kLstmRecurrentToCellWeightsTensor = 7;
|
||||
constexpr int kLstmRecurrentToOutputWeightsTensor = 8;
|
||||
|
||||
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
|
||||
constexpr int kLstmCellToInputWeightsTensor = 9; // Optional
|
||||
constexpr int kLstmCellToForgetWeightsTensor = 10; // Optional
|
||||
constexpr int kLstmCellToOutputWeightsTensor = 11; // Optional
|
||||
|
||||
// Gates bias tensors of size {n_cell}
|
||||
constexpr int kLstmInputGateBiasTensor = 12; // Optional
|
||||
constexpr int kLstmForgetGateBiasTensor = 13;
|
||||
constexpr int kLstmCellGateBiasTensor = 14;
|
||||
constexpr int kLstmOutputGateBiasTensor = 15;
|
||||
|
||||
// Projection weight tensor of size {n_output, n_cell}
|
||||
constexpr int kLstmProjectionWeightsTensor = 16; // Optional
|
||||
// Projection bias tensor of size {n_output}
|
||||
constexpr int kLstmProjectionBiasTensor = 17; // Optional
|
||||
|
||||
// These state tensors are defined as variable tensors, and will be modified by
|
||||
// this op.
|
||||
constexpr int kLstmOutputStateTensor = 18;
|
||||
constexpr int kLstmCellStateTensor = 19;
|
||||
|
||||
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
|
||||
// matrix.
|
||||
constexpr int kLstmInputLayerNormCoefficientsTensor = 20; // Optional
|
||||
constexpr int kLstmForgetLayerNormCoefficientsTensor = 21; // Optional
|
||||
constexpr int kLstmCellLayerNormCoefficientsTensor = 22; // Optional
|
||||
constexpr int kLstmOutputLayerNormCoefficientsTensor = 23; // Optional
|
||||
|
||||
// Output tensors.
|
||||
constexpr int kLstmOutputTensor = 0;
|
||||
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
|
||||
@@ -115,29 +115,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace maximum_minimum
|
||||
|
||||
TfLiteRegistration Register_MAXIMUM() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/
|
||||
maximum_minimum::Eval<maximum_minimum::kReference,
|
||||
maximum_minimum::MaximumOp>,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, nullptr,
|
||||
maximum_minimum::Eval<maximum_minimum::kReference,
|
||||
maximum_minimum::MaximumOp>);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MINIMUM() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/
|
||||
maximum_minimum::Eval<maximum_minimum::kReference,
|
||||
maximum_minimum::MinimumOp>,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
nullptr, nullptr,
|
||||
maximum_minimum::Eval<maximum_minimum::kReference,
|
||||
maximum_minimum::MinimumOp>);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -76,11 +76,14 @@ TfLiteRegistration Register_SHAPE();
|
||||
TfLiteRegistration Register_SLICE();
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
||||
TfLiteRegistration Register_SPACE_TO_DEPTH();
|
||||
TfLiteRegistration Register_SQUARED_DIFFERENCE();
|
||||
TfLiteRegistration Register_SQUEEZE();
|
||||
TfLiteRegistration Register_SUB();
|
||||
TfLiteRegistration Register_SVDF();
|
||||
TfLiteRegistration Register_TRANSPOSE();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV();
|
||||
// TODO(b/230666079): resolve conflict with xtensa implementation
|
||||
TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
|
||||
TfLiteRegistration Register_VAR_HANDLE();
|
||||
TfLiteRegistration Register_WHILE();
|
||||
TfLiteRegistration Register_ZEROS_LIKE();
|
||||
@@ -103,14 +106,12 @@ TfLiteRegistration Register_LESS_EQUAL();
|
||||
TfLiteRegistration Register_LOG();
|
||||
TfLiteRegistration Register_LOGICAL_NOT();
|
||||
TfLiteRegistration Register_MAXIMUM();
|
||||
TfLiteRegistration Register_MEAN();
|
||||
TfLiteRegistration Register_MINIMUM();
|
||||
TfLiteRegistration Register_NEG();
|
||||
TfLiteRegistration Register_NOT_EQUAL();
|
||||
TfLiteRegistration Register_PACK();
|
||||
TfLiteRegistration Register_PAD();
|
||||
TfLiteRegistration Register_PADV2();
|
||||
TfLiteRegistration Register_REDUCE_MAX();
|
||||
TfLiteRegistration Register_RESHAPE();
|
||||
TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR();
|
||||
TfLiteRegistration Register_ROUND();
|
||||
@@ -121,7 +122,6 @@ TfLiteRegistration Register_SPLIT_V();
|
||||
TfLiteRegistration Register_SQRT();
|
||||
TfLiteRegistration Register_SQUARE();
|
||||
TfLiteRegistration Register_STRIDED_SLICE();
|
||||
TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
|
||||
TfLiteRegistration Register_UNPACK();
|
||||
TfLiteRegistration Register_L2_NORMALIZATION();
|
||||
TfLiteRegistration Register_TANH();
|
||||
|
||||
@@ -0,0 +1,809 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "fixedpoint/fixedpoint.h" // from @gemmlowp
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro_tensor_utils {
|
||||
|
||||
namespace {
|
||||
const int32_t kInt16Max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t kInt16Min = std::numeric_limits<int16_t>::min();
|
||||
} // namespace
|
||||
|
||||
void PortableSymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min_value,
|
||||
float* max_value, float* scaling_factor) {
|
||||
auto minmax = std::minmax_element(values, values + size);
|
||||
*min_value = *minmax.first;
|
||||
*max_value = *minmax.second;
|
||||
|
||||
PortableSymmetricQuantizeFloats(values, size, quantized_values, *min_value,
|
||||
*max_value, scaling_factor);
|
||||
}
|
||||
|
||||
void PortableSymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float min_value,
|
||||
float max_value, float* scaling_factor) {
|
||||
const int32_t kScale = 127;
|
||||
const float range = std::max(std::abs(min_value), std::abs(max_value));
|
||||
if (range == 0) {
|
||||
memset(quantized_values, 0, size * sizeof(int8_t));
|
||||
*scaling_factor = 1;
|
||||
return;
|
||||
}
|
||||
*scaling_factor = range / kScale;
|
||||
const float scaling_factor_inv = kScale / range;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const int32_t quantized_value =
|
||||
static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
|
||||
// Clamp: just in case some odd numeric offset.
|
||||
quantized_values[i] = static_cast<int8_t>(
|
||||
std::min(kScale, std::max(-kScale, quantized_value)));
|
||||
}
|
||||
}
|
||||
|
||||
void PortableAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float* scaling_factor, int32_t* offset) {
|
||||
const int32_t kMinScale = -128;
|
||||
const int32_t kMaxScale = 127;
|
||||
const double qmin_double = kMinScale;
|
||||
const double qmax_double = kMaxScale;
|
||||
const auto minmax = std::minmax_element(values, values + size);
|
||||
const double rmin = static_cast<double>(std::min(0.0f, *minmax.first));
|
||||
const double rmax = static_cast<double>(std::max(0.0f, *minmax.second));
|
||||
if (rmin == rmax) {
|
||||
memset(quantized_values, 0, size * sizeof(int8_t));
|
||||
*scaling_factor = 1;
|
||||
*offset = 0;
|
||||
return;
|
||||
} else {
|
||||
double scale = (rmax - rmin) / (qmax_double - qmin_double);
|
||||
const double zero_point_from_min = qmin_double - rmin / scale;
|
||||
const double zero_point_from_max = qmax_double - rmax / scale;
|
||||
const double zero_point_from_min_error =
|
||||
std::abs(qmin_double) + std::abs(rmin / scale);
|
||||
const double zero_point_from_max_error =
|
||||
std::abs(qmax_double) + std::abs(rmax / scale);
|
||||
const double zero_point_double =
|
||||
zero_point_from_min_error < zero_point_from_max_error
|
||||
? zero_point_from_min
|
||||
: zero_point_from_max;
|
||||
int8_t nudged_zero_point = 0;
|
||||
if (zero_point_double <= qmin_double) {
|
||||
nudged_zero_point = kMinScale;
|
||||
} else if (zero_point_double >= qmax_double) {
|
||||
nudged_zero_point = kMaxScale;
|
||||
} else {
|
||||
nudged_zero_point = static_cast<int8_t>(round(zero_point_double));
|
||||
}
|
||||
*scaling_factor = scale;
|
||||
*offset = nudged_zero_point;
|
||||
}
|
||||
const float scaling_factor_inv = 1.0f / *scaling_factor;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const int32_t quantized_value = static_cast<int32_t>(
|
||||
TfLiteRound(*offset + values[i] * scaling_factor_inv));
|
||||
quantized_values[i] =
|
||||
std::min(kMaxScale, std::max(kMinScale, quantized_value));
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
|
||||
int m_rows, int m_cols,
|
||||
const float* vector,
|
||||
int n_batch, float* result) {
|
||||
float* result_in_batch = result;
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
const float* matrix_ptr = matrix;
|
||||
for (int r = 0; r < m_rows; r++) {
|
||||
float dot_prod = 0.0f;
|
||||
const float* vector_in_batch = vector + b * m_cols;
|
||||
for (int c = 0; c < m_cols; c++) {
|
||||
dot_prod += *matrix_ptr++ * *vector_in_batch++;
|
||||
}
|
||||
*result_in_batch += dot_prod;
|
||||
++result_in_batch;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result) {
|
||||
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
|
||||
const float batch_scaling_factor = scaling_factors[batch];
|
||||
// Get the address of the first row.
|
||||
const int8_t* row_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; ++row) {
|
||||
// Initialize the dot product sum for the row to 0.
|
||||
int32_t dotprod = 0;
|
||||
// TODO(b/230666277): remove this
|
||||
#if defined(__GNUC__)
|
||||
// Prefetch the row to cache.
|
||||
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
||||
3 /* temporal locality */);
|
||||
#endif
|
||||
for (int col = 0; col < m_cols; ++col, ++row_ptr) {
|
||||
dotprod += (*row_ptr) * (vectors[col]);
|
||||
} // for col
|
||||
*result += dotprod * batch_scaling_factor;
|
||||
++result;
|
||||
} // for row
|
||||
} // for batch
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context) {
|
||||
if (input_offset == nullptr) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
|
||||
return;
|
||||
}
|
||||
if (!compute_row_sums || *compute_row_sums) {
|
||||
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
||||
if (compute_row_sums) {
|
||||
*compute_row_sums = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
|
||||
const float batch_scaling_factor = scaling_factors[batch];
|
||||
const int32_t batch_offset = input_offset[batch];
|
||||
const int8_t* row_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; ++row) {
|
||||
int32_t dotprod = 0;
|
||||
float scale = batch_scaling_factor;
|
||||
if (per_channel_scale) {
|
||||
scale *= per_channel_scale[row];
|
||||
}
|
||||
#if defined(__GNUC__)
|
||||
// Prefetch the row to cache.
|
||||
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
||||
3 /* temporal locality */);
|
||||
#endif
|
||||
for (int col = 0; col < m_cols; ++col, ++row_ptr) {
|
||||
dotprod += (*row_ptr) * vectors[col];
|
||||
} // for col
|
||||
dotprod -= row_sums[row] * batch_offset;
|
||||
*result += dotprod * scale;
|
||||
++result;
|
||||
} // for row
|
||||
} // for batch
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
const int kBlockSize = 4;
|
||||
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
|
||||
for (int batch = 0; batch < n_batch; batch++) {
|
||||
const float* matrix_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; row++) {
|
||||
float dot_prod = 0.0f;
|
||||
const float* vector_in_batch = vector + batch * m_cols;
|
||||
for (int i = segments[row]; i < segments[row + 1]; i++) {
|
||||
const int block_start_index = indices[i] * kBlockSize;
|
||||
const float* vector_block_in_batch_ptr =
|
||||
vector_in_batch + block_start_index;
|
||||
for (int c = 0; c < kBlockSize; c++) {
|
||||
dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
|
||||
}
|
||||
}
|
||||
result[batch * m_rows + row] += dot_prod;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
int8_t* __restrict__ result) {
|
||||
const int kBlockSize = 16;
|
||||
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
const int8_t* matrix_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; ++row) {
|
||||
int32_t dot_prod = 0;
|
||||
const int8_t* vector_in_batch = vector + batch * m_cols;
|
||||
for (int i = segments[row]; i < segments[row + 1]; ++i) {
|
||||
const int block_start_index = indices[i] * kBlockSize;
|
||||
const int8_t* vector_block_in_batch_ptr =
|
||||
vector_in_batch + block_start_index;
|
||||
for (int c = 0; c < kBlockSize; c++) {
|
||||
dot_prod += *matrix_ptr * *vector_block_in_batch_ptr++;
|
||||
dot_prod += *matrix_ptr++ * input_offset;
|
||||
}
|
||||
}
|
||||
const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0;
|
||||
dot_prod = MultiplyByQuantizedMultiplier(dot_prod + bias_value,
|
||||
output_multiplier, output_shift);
|
||||
dot_prod += output_offset;
|
||||
result[batch * m_rows + row] =
|
||||
static_cast<int8_t>(ActivationFunctionWithMinMax(
|
||||
dot_prod, output_activation_min, output_activation_max));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
float* __restrict__ result) {
|
||||
const int kBlockSize = 16;
|
||||
TFLITE_DCHECK_EQ( // NOLINT
|
||||
m_cols % kBlockSize, 0);
|
||||
for (int batch = 0; batch < n_batch; batch++) {
|
||||
const float* matrix_ptr = matrix;
|
||||
const uint8_t* ledger_ptr = ledger;
|
||||
for (int row = 0; row < m_rows; row++) {
|
||||
float dot_prod = 0.0f;
|
||||
int num_nonzero_blocks = *ledger_ptr++;
|
||||
if (num_nonzero_blocks > 0) {
|
||||
const float* vector_in_batch = vector + batch * m_cols;
|
||||
for (int i = 0; i < num_nonzero_blocks; i++) {
|
||||
const int block_start_index = *ledger_ptr++ * kBlockSize;
|
||||
const float* vector_block_in_batch_ptr =
|
||||
vector_in_batch + block_start_index;
|
||||
for (int c = 0; c < kBlockSize; c++) {
|
||||
dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
result[batch * m_rows + row] += dot_prod;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||
const int m_cols, const int8_t* __restrict__ vectors,
|
||||
const float* scaling_factors, int n_batch, float* __restrict__ result) {
|
||||
static const int kBlockSize = 16;
|
||||
TFLITE_DCHECK_EQ( // NOLINT
|
||||
m_cols % kBlockSize, 0);
|
||||
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
|
||||
const float batch_scaling_factor = scaling_factors[batch];
|
||||
const uint8_t* ledger_ptr = ledger;
|
||||
// Get the address of the first row.
|
||||
const int8_t* row_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; ++row) {
|
||||
// Initialize the dot product sum for the row to 0.
|
||||
int32_t dotprod = 0;
|
||||
#if defined(__GNUC__)
|
||||
// Prefetch the row to cache.
|
||||
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
||||
3 /* temporal locality */);
|
||||
#endif
|
||||
int num_nonzero_blocks = *ledger_ptr++;
|
||||
for (int i = 0; i < num_nonzero_blocks; i++) {
|
||||
const int block_start_index = *ledger_ptr++ * kBlockSize;
|
||||
const int8_t* vector_block_ptr = vectors + block_start_index;
|
||||
for (int c = 0; c < kBlockSize; c++) {
|
||||
dotprod += (*row_ptr++) * (*vector_block_ptr++);
|
||||
} // for block
|
||||
} // for num_nonzero_blocks
|
||||
result[batch * m_rows + row] += dotprod * batch_scaling_factor;
|
||||
} // for row
|
||||
} // for batch
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PortableMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
T* output) {
|
||||
const int16_t output_max = std::numeric_limits<T>::max();
|
||||
const int16_t output_min = std::numeric_limits<T>::min();
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int row = 0; row < n_output; ++row) {
|
||||
int32_t acc = bias[row];
|
||||
for (int col = 0; col < n_input; ++col) {
|
||||
int8_t input_val = input[batch * n_input + col];
|
||||
int8_t weights_val = input_to_gate_weights[row * n_input + col];
|
||||
acc += input_val * weights_val;
|
||||
}
|
||||
acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
|
||||
acc += output_zp;
|
||||
acc += output[batch * n_output + row];
|
||||
if (acc > output_max) {
|
||||
acc = output_max;
|
||||
}
|
||||
if (acc < output_min) {
|
||||
acc = output_min;
|
||||
}
|
||||
output[batch * n_output + row] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int16_t* output, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
|
||||
n_output, output_zp, output);
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int8_t* output, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
|
||||
n_output, output_zp, output);
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiply(const int8_t* input,
|
||||
int32_t input_zeropoint,
|
||||
const int8_t* input_to_gate_weights,
|
||||
int32_t input_to_gate_effective_scale_a,
|
||||
int32_t input_to_gate_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_input,
|
||||
int32_t n_cell, int8_t* gate_output,
|
||||
int8_t gate_output_zp) {
|
||||
const int32_t int8_max = std::numeric_limits<int8_t>::max();
|
||||
const int32_t int8_min = std::numeric_limits<int8_t>::min();
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int row = 0; row < n_cell; ++row) {
|
||||
int32_t acc = 0;
|
||||
for (int col = 0; col < n_input; ++col) {
|
||||
int32_t input_val = input[batch * n_input + col];
|
||||
int8_t weights_val = input_to_gate_weights[row * n_input + col];
|
||||
acc += (input_val - input_zeropoint) * weights_val;
|
||||
}
|
||||
acc = MultiplyByQuantizedMultiplier(acc, input_to_gate_effective_scale_a,
|
||||
input_to_gate_effective_scale_b);
|
||||
acc += gate_output_zp;
|
||||
if (acc > int8_max) {
|
||||
acc = int8_max;
|
||||
}
|
||||
if (acc < int8_min) {
|
||||
acc = int8_min;
|
||||
}
|
||||
gate_output[batch * n_cell + row] = static_cast<int8_t>(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiply(
|
||||
const int16_t* hidden, const int8_t* hidden_to_output_weights,
|
||||
int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
|
||||
const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
|
||||
int32_t n_output, int32_t output_zp, int8_t* proj_output) {
|
||||
const int16_t int8_max = std::numeric_limits<int8_t>::max();
|
||||
const int16_t int8_min = std::numeric_limits<int8_t>::min();
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int row = 0; row < n_output; ++row) {
|
||||
int64_t acc = gate_bias[row];
|
||||
for (int col = 0; col < n_hidden; ++col) {
|
||||
int16_t input_val = hidden[batch * n_hidden + col];
|
||||
int8_t weights_val = hidden_to_output_weights[row * n_hidden + col];
|
||||
int64_t curr = acc;
|
||||
acc += input_val * weights_val;
|
||||
if (input_val * weights_val > 0 && acc < curr) {
|
||||
acc = std::numeric_limits<int32_t>::max();
|
||||
}
|
||||
if (input_val * weights_val < 0 && acc > curr) {
|
||||
acc = std::numeric_limits<int32_t>::min();
|
||||
}
|
||||
}
|
||||
acc = MultiplyByQuantizedMultiplier(acc, proj_effective_scale_a,
|
||||
proj_effective_scale_b);
|
||||
acc += output_zp;
|
||||
if (acc > int8_max) {
|
||||
acc = int8_max;
|
||||
}
|
||||
if (acc < int8_min) {
|
||||
acc = int8_min;
|
||||
}
|
||||
proj_output[batch * n_output + row] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableApplyLayerNorm(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
const int32_t* bias, int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, int32_t variance_limit,
|
||||
int n_batch, int n_input, int16_t* output) {
|
||||
// The square of std::pow(2, 10), which is the extra factor that makes sure
|
||||
// normalized values has enough resolution.
|
||||
static const int kTwoToPower20 = 1 << 20;
|
||||
for (int i = 0; i < n_batch; ++i) {
|
||||
int64_t sum = 0;
|
||||
int64_t sum_sq = 0;
|
||||
for (int j = 0; j < n_input; ++j) {
|
||||
const int32_t index = i * n_input + j;
|
||||
int32_t val = static_cast<int32_t>(input[index]);
|
||||
sum += val;
|
||||
sum_sq += val * val;
|
||||
}
|
||||
int32_t mean =
|
||||
static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
|
||||
// TODO(b/173994730): Avoids overflow but only works for POT n_input.
|
||||
int32_t temp = kTwoToPower20 / n_input;
|
||||
int64_t variance =
|
||||
sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
|
||||
int32_t variance2 = static_cast<int32_t>(variance / kTwoToPower20);
|
||||
if (variance2 < 1) {
|
||||
variance2 = variance_limit;
|
||||
}
|
||||
int32_t stddev_inverse_a;
|
||||
int stddev_inverse_b;
|
||||
GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1,
|
||||
&stddev_inverse_a, &stddev_inverse_b);
|
||||
|
||||
for (int j = 0; j < n_input; ++j) {
|
||||
const int32_t index = i * n_input + j;
|
||||
int32_t val = static_cast<int32_t>(input[index]);
|
||||
int32_t shifted = 1024 * val - mean;
|
||||
int32_t rescaled = MultiplyByQuantizedMultiplier(
|
||||
shifted, stddev_inverse_a, stddev_inverse_b);
|
||||
int64_t val3 = rescaled * layer_norm_weights[j] + bias[j];
|
||||
int32_t val4 =
|
||||
static_cast<int32_t>((val3 > 0 ? val3 + 512 : val3 - 512) / 1024);
|
||||
int32_t val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a,
|
||||
layer_norm_scale_b + 12);
|
||||
val5 = std::min(std::max(kInt16Min, val5), kInt16Max);
|
||||
output[index] = static_cast<int16_t>(val5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableApplyLayerNormFloat(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b,
|
||||
const int32_t* bias, int n_batch, int n_input,
|
||||
int16_t* output) {
|
||||
const int32_t int16_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t int16_min = std::numeric_limits<int16_t>::min();
|
||||
const float layer_norm_scale =
|
||||
layer_norm_scale_a *
|
||||
std::pow(2.0, static_cast<double>(layer_norm_scale_b - 31));
|
||||
const float bias_scale =
|
||||
static_cast<float>(std::pow(2.0, -10)) * layer_norm_scale;
|
||||
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
float sum = 0.0f;
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
const float value = static_cast<float>(input[index]);
|
||||
sum += value;
|
||||
sum_sq += value * value;
|
||||
}
|
||||
const float mean = sum / n_input;
|
||||
float stddev_inv = 0.0f;
|
||||
const float variance = sum_sq / n_input - mean * mean;
|
||||
if (variance == 0) {
|
||||
stddev_inv = 1.0f / std::sqrt(1e-8f);
|
||||
} else {
|
||||
stddev_inv = 1.0f / std::sqrt(variance);
|
||||
}
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
const float normalized_value =
|
||||
(static_cast<float>(input[index]) - mean) * stddev_inv;
|
||||
const float weighted_normalized_value =
|
||||
normalized_value * layer_norm_weights[i] * layer_norm_scale +
|
||||
bias[i] * bias_scale;
|
||||
const int32_t quant_output = static_cast<int32_t>(round(
|
||||
weighted_normalized_value * static_cast<float>(std::pow(2, 12))));
|
||||
output[index] = std::min(int16_max, std::max(int16_min, quant_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
|
||||
int32_t scalar, int32_t n_row,
|
||||
int32_t n_col, int32_t* output) {
|
||||
for (int i = 0; i < n_row; ++i) {
|
||||
int32_t row_sum = 0;
|
||||
for (int j = 0; j < n_col; ++j) {
|
||||
row_sum += *matrix++;
|
||||
}
|
||||
output[i] += row_sum * scalar;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int c = 0; c < n_input; c++) {
|
||||
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
|
||||
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
||||
const int index = batch * n_input + c;
|
||||
F3 sigmoid_input = F3::FromRaw(input[index]);
|
||||
F0 sigmoid_output = gemmlowp::logistic(sigmoid_input);
|
||||
output[index] = sigmoid_output.raw();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output) {
|
||||
const int32_t int16_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t int16_min = std::numeric_limits<int16_t>::min();
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
const float float_input =
|
||||
input[index] * static_cast<float>(std::pow(2, -12));
|
||||
const float float_output = 1.0f / (1.0f + std::exp(-float_input));
|
||||
const int32_t quant_output = static_cast<int32_t>(
|
||||
float_output * static_cast<float>(std::pow(2, 15)));
|
||||
const int32_t quant_output_clamped =
|
||||
std::min(int16_max, std::max(int16_min, quant_output));
|
||||
output[index] = static_cast<int16_t>(quant_output_clamped);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int IntegerBits>
|
||||
void PortableApplyTanhImpl(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output) {
|
||||
using FX = gemmlowp::FixedPoint<std::int16_t, IntegerBits>;
|
||||
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
FX tanh_input = FX::FromRaw(input[index]);
|
||||
F0 tanh_output = gemmlowp::tanh(tanh_input);
|
||||
output[index] = tanh_output.raw();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
|
||||
int32_t n_batch, int32_t n_input, int16_t* output) {
|
||||
if (integer_bits > 6) {
|
||||
TFLITE_ASSERT_FALSE;
|
||||
}
|
||||
#define DISPATCH_TANH(i) \
|
||||
case i: \
|
||||
PortableApplyTanhImpl<i>(input, n_batch, n_input, output); \
|
||||
break;
|
||||
switch (integer_bits) {
|
||||
DISPATCH_TANH(0);
|
||||
DISPATCH_TANH(1);
|
||||
DISPATCH_TANH(2);
|
||||
DISPATCH_TANH(3);
|
||||
DISPATCH_TANH(4);
|
||||
DISPATCH_TANH(5);
|
||||
DISPATCH_TANH(6);
|
||||
default:
|
||||
return;
|
||||
}
|
||||
#undef DISPATCH_TANH
|
||||
}
|
||||
|
||||
void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int32_t integer_bits,
|
||||
int16_t* output) {
|
||||
const int32_t int16_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t int16_min = std::numeric_limits<int16_t>::min();
|
||||
const double two = 2.0;
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
const float float_input =
|
||||
input[index] * std::pow(two, static_cast<double>(integer_bits));
|
||||
const float float_output = std::tanh(float_input);
|
||||
const int32_t quant_output = static_cast<int32_t>(
|
||||
float_output * static_cast<float>(std::pow(2, 15)));
|
||||
const int32_t quant_output_clamped =
|
||||
std::min(int16_max, std::max(int16_min, quant_output));
|
||||
output[index] = static_cast<int16_t>(quant_output_clamped);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int shift, int16_t* output) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
const int16_t a = input_1[index];
|
||||
const int16_t b = input_2[index];
|
||||
const int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
|
||||
output[index] =
|
||||
static_cast<int16_t>(gemmlowp::RoundingDivideByPOT(value, shift));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int32_t multiplier, int32_t shift, int32_t n_batch,
|
||||
int32_t n_input, int32_t output_zp, int8_t* output) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
const int16_t a = input_1[index];
|
||||
const int16_t b = input_2[index];
|
||||
int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
|
||||
value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
|
||||
value -= output_zp;
|
||||
value = std::min(std::max(static_cast<int32_t>(-128), value),
|
||||
static_cast<int32_t>(127));
|
||||
|
||||
output[index] = static_cast<int8_t>(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int16_t* output) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
int32_t sum = input_1[index] + input_2[index];
|
||||
const int32_t sum_clamped = std::min(kInt16Max, std::max(kInt16Min, sum));
|
||||
output[index] = static_cast<int16_t>(sum_clamped);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size) {
|
||||
float result = 0.0;
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result += *vector1++ * *vector2++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline int32_t VectorVectorDotProduct(const int16_t* vector1,
|
||||
const int16_t* vector2, int v_size) {
|
||||
int32_t result = 0;
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result += *vector1++ * *vector2++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1,
|
||||
const int16_t* vector2,
|
||||
int v_size, int n_batch,
|
||||
int32_t* result) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
result[b] = VectorVectorDotProduct(vector1, vector2, v_size);
|
||||
vector1 += v_size;
|
||||
vector2 += v_size;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableVectorBatchVectorCwiseProductAccumulate(
|
||||
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
|
||||
int32_t multiplier, int shift, int16_t* result) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
int32_t prod = vector[v] * *batch_vector++;
|
||||
prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
|
||||
int32_t output = prod + *result;
|
||||
output = std::max(std::min(static_cast<int32_t>(32767), output),
|
||||
static_cast<int32_t>(-32768));
|
||||
*result++ = output;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSub1Vector(const float* vector, int v_size, float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
*result++ = 1.0f - *vector++;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result) {
|
||||
static const int16_t kOne = 32767;
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
*result++ = kOne - *vector++;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
|
||||
const float scale, float* result) {
|
||||
for (int v = 0; v < v_size; ++v) {
|
||||
*result++ = scale * *vector++;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMeanStddevNormalization(const float* __restrict__ input_vector,
|
||||
float* __restrict__ output_vector,
|
||||
int v_size, int n_batch) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < v_size; ++i) {
|
||||
sum += input_vector[i];
|
||||
}
|
||||
const float mean = sum / v_size;
|
||||
float sum_diff_sq = 0.0f;
|
||||
for (int i = 0; i < v_size; ++i) {
|
||||
const float diff = input_vector[i] - mean;
|
||||
sum_diff_sq += diff * diff;
|
||||
}
|
||||
const float variance = sum_diff_sq / v_size;
|
||||
constexpr float kNormalizationConstant = 1e-8f;
|
||||
const float stddev_inv =
|
||||
1.0f / std::sqrt(variance + kNormalizationConstant);
|
||||
for (int i = 0; i < v_size; ++i) {
|
||||
output_vector[i] = (input_vector[i] - mean) * stddev_inv;
|
||||
}
|
||||
input_vector += v_size;
|
||||
output_vector += v_size;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||
const int8_t* recurrent, int8_t recurrent_zp,
|
||||
int32_t input_effective_scale_a,
|
||||
int32_t input_effective_scale_b,
|
||||
int32_t recurrent_effective_scale_a,
|
||||
int32_t recurrent_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_cell,
|
||||
int16_t* output) {
|
||||
const int32_t int16_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t int16_min = std::numeric_limits<int16_t>::min();
|
||||
for (int i = 0; i < n_batch * n_cell; ++i) {
|
||||
int32_t x = static_cast<int32_t>(input[i]) - static_cast<int32_t>(input_zp);
|
||||
int32_t h =
|
||||
static_cast<int32_t>(recurrent[i]) - static_cast<int32_t>(recurrent_zp);
|
||||
int32_t x_scaled = MultiplyByQuantizedMultiplier(x, input_effective_scale_a,
|
||||
input_effective_scale_b);
|
||||
int32_t h_scaled = MultiplyByQuantizedMultiplier(
|
||||
h, recurrent_effective_scale_a, recurrent_effective_scale_b);
|
||||
int32_t y = h_scaled + x_scaled;
|
||||
if (y > int16_max) {
|
||||
y = int16_max;
|
||||
}
|
||||
if (y < int16_min) {
|
||||
y = int16_min;
|
||||
}
|
||||
output[i] = static_cast<int16_t>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace micro_tensor_utils
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,874 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file and the associated .cc file is branched from
|
||||
// tensorflow/lite/kernels/internal/reference/portable_tensor_utils*
|
||||
// TFLM needs to create its own because the original files are coupled with
|
||||
// the tensor_utils module, which we cannot reuse due to its use of the
|
||||
// Eigen library.
|
||||
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define __restrict__ __restrict
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Not all backends support CpuBackendContext usage, so forward declare to avoid
|
||||
// pulling in its implementation.
|
||||
// TODO(b/230666277): consider removing this since micro does not utilize it
|
||||
class CpuBackendContext;
|
||||
|
||||
namespace micro_tensor_utils {
|
||||
|
||||
template <typename T>
|
||||
inline bool PortableIsZeroVector(const T* vector, int v_size) {
|
||||
for (int i = 0; i < v_size; ++i) {
|
||||
if (vector[i] != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PortableSymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min_value,
|
||||
float* max_value, float* scaling_factor);
|
||||
|
||||
void PortableSymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float min_value,
|
||||
float max_value, float* scaling_factor);
|
||||
|
||||
void PortableAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float* scaling_factor, int32_t* offset);
|
||||
|
||||
// Multiply a matrix by a batch vector, and store results in a batch-size
|
||||
// vector.
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
|
||||
int m_rows, int m_cols,
|
||||
const float* vector,
|
||||
int n_batch, float* result);
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result);
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context);
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vector, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
CpuBackendContext* context);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
float* __restrict__ result);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
int8_t* __restrict__ result);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||
const int m_cols, const int8_t* __restrict__ vectors,
|
||||
const float* scaling_factors, int n_batch, float* __restrict__ result);
|
||||
|
||||
// Dot product of two vectors.
|
||||
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size);
|
||||
|
||||
void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1,
|
||||
const int16_t* vector2,
|
||||
int v_size, int n_batch,
|
||||
int32_t* result);
|
||||
|
||||
void PortableVectorBatchVectorCwiseProductAccumulate(
|
||||
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
|
||||
int32_t multiplier, int shift, int16_t* result);
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int16_t* output, CpuBackendContext* context);
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int8_t* output, CpuBackendContext* context);
|
||||
|
||||
void PortableMatrixBatchVectorMultiply(const int8_t* input,
|
||||
int32_t input_zeropoint,
|
||||
const int8_t* input_to_gate_weights,
|
||||
int32_t input_to_gate_effective_scale_a,
|
||||
int32_t input_to_gate_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_input,
|
||||
int32_t n_cell, int8_t* gate_output,
|
||||
int8_t gate_output_zp);
|
||||
|
||||
void PortableMatrixBatchVectorMultiply(
|
||||
const int16_t* hidden, const int8_t* hidden_to_output_weights,
|
||||
int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
|
||||
const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
|
||||
int32_t n_output, int32_t output_zp, int8_t* proj_output);
|
||||
|
||||
void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
|
||||
int32_t scalar, int32_t n_row,
|
||||
int32_t n_col, int32_t* output);
|
||||
|
||||
void PortableApplyLayerNorm(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
const int32_t* bias, int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, int32_t variance_limit,
|
||||
int n_batch, int n_input, int16_t* output);
|
||||
|
||||
void PortableApplyLayerNormFloat(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b,
|
||||
const int32_t* bias, int n_batch, int n_input,
|
||||
int16_t* output);
|
||||
|
||||
void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output);
|
||||
|
||||
void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output);
|
||||
|
||||
void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
|
||||
int32_t n_batch, int32_t n_input, int16_t* output);
|
||||
|
||||
void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int32_t integer_bits,
|
||||
int16_t* output);
|
||||
|
||||
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int shift, int16_t* output);
|
||||
|
||||
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int32_t multiplier, int32_t shift, int32_t n_batch,
|
||||
int32_t n_input, int32_t output_zp, int8_t* output);
|
||||
|
||||
void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int16_t* output);
|
||||
|
||||
template <typename T>
|
||||
inline void PortableCwiseClipping(T* vector, const int v_size,
|
||||
const T& clipping_value) {
|
||||
for (int i = 0; i < v_size; i++) {
|
||||
vector[i] = std::max(std::min(clipping_value, vector[i]),
|
||||
static_cast<T>(-clipping_value));
|
||||
}
|
||||
}
|
||||
|
||||
// Batch vector initialization with another vector.
|
||||
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
|
||||
int n_batch, float* batch_vector);
|
||||
|
||||
// Compute "1.0f - elements of vector" (used in CIFG).
|
||||
void PortableSub1Vector(const float* vector, int v_size, float* result);
|
||||
|
||||
void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
||||
|
||||
// Multiply all elements of vector with a scalar.
|
||||
void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result);
|
||||
|
||||
// Reduce-sum on a vector:
|
||||
// input_vector: pointer to input vector.
|
||||
// output_vector: pointer to vector.
|
||||
// output_size: output vector size.
|
||||
// reduction_size: number of consecutive elements from input vector which are
|
||||
// added to get one element of output.
|
||||
template <typename INPUT, typename OUTPUT>
|
||||
inline void PortableReductionSumVector(const INPUT* input_vector,
|
||||
OUTPUT* output_vector, int output_size,
|
||||
int reduction_size) {
|
||||
for (int o = 0; o < output_size; o++) {
|
||||
OUTPUT result = 0;
|
||||
for (int r = 0; r < reduction_size; r++) {
|
||||
result += input_vector[r];
|
||||
}
|
||||
output_vector[o] = result;
|
||||
input_vector += reduction_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm for each batch.
|
||||
void PortableMeanStddevNormalization(const float* __restrict__ input_vector,
|
||||
float* __restrict__ output_vector,
|
||||
int v_size, int n_batch);
|
||||
|
||||
// Saturate Add.
|
||||
void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||
const int8_t* recurrent, int8_t recurrent_zp,
|
||||
int32_t input_effective_scale_a,
|
||||
int32_t input_effective_scale_b,
|
||||
int32_t recurrent_effective_scale_a,
|
||||
int32_t recurrent_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_cell,
|
||||
int16_t* output);
|
||||
|
||||
// Add another vector for each batch in the batch vector.
|
||||
template <typename T>
|
||||
inline void VectorBatchVectorAdd(const T* vector, int v_size, int n_batch,
|
||||
T* batch_vector) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
for (int i = 0; i < v_size; ++i) {
|
||||
batch_vector[i] += vector[i];
|
||||
}
|
||||
batch_vector += v_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Cwise product of two vectors.
|
||||
template <typename T>
|
||||
inline void VectorVectorCwiseProduct(const T* vector1, const T* vector2,
|
||||
int v_size, T* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
*result++ = *vector1++ * *vector2++;
|
||||
}
|
||||
}
|
||||
|
||||
// Cwise product of a vector and a batch-vector.
|
||||
template <typename T>
|
||||
inline void VectorBatchVectorCwiseProduct(const T* vector, int v_size,
|
||||
const T* batch_vector, int n_batch,
|
||||
T* result) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
VectorVectorCwiseProduct(vector, batch_vector, v_size, result);
|
||||
// Update the pointers.
|
||||
result += v_size;
|
||||
batch_vector += v_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce-sum on a float input vector:
|
||||
// input_vector: float pointer to input vector.
|
||||
// output_vector: float pointer to vector.
|
||||
// output_size: output vector size.
|
||||
// reduction_size: number of consecutive elements from input vector which are
|
||||
// added to get one element of output.
|
||||
inline void ReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
// Same as above but input/output is 32 bit integer.
|
||||
inline void ReductionSumVector(const int32_t* input_vector,
|
||||
int32_t* output_vector, int output_size,
|
||||
int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
// Same as above but input is 8 bit integer.
|
||||
inline void ReductionSumVector(const int8_t* input_vector,
|
||||
int32_t* output_vector, int output_size,
|
||||
int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
// Cwise product and accumulate of two vectors. Since it's a MAC operation, the
|
||||
// assumption here is that result array is initialized to valid values.
|
||||
template <typename T>
|
||||
inline void VectorVectorCwiseProductAccumulate(const T* __restrict__ vector1,
|
||||
const T* __restrict__ vector2,
|
||||
int v_size,
|
||||
T* __restrict__ result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
*result++ += *vector1++ * *vector2++;
|
||||
}
|
||||
}
|
||||
|
||||
// Batch vector initialization with another vector.
|
||||
template <typename T>
|
||||
inline void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch,
|
||||
T* batch_vector) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(vector, v_size, batch_vector + b * v_size);
|
||||
}
|
||||
}
|
||||
|
||||
inline void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min,
|
||||
float* max, float* scaling_factor) {
|
||||
PortableSymmetricQuantizeFloats(values, size, quantized_values, min, max,
|
||||
scaling_factor);
|
||||
}
|
||||
|
||||
inline void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float min_value,
|
||||
float max_value, float* scaling_factor) {
|
||||
PortableSymmetricQuantizeFloats(values, size, quantized_values, min_value,
|
||||
max_value, scaling_factor);
|
||||
}
|
||||
|
||||
inline void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float* scaling_factor, int32_t* offset) {
|
||||
PortableAsymmetricQuantizeFloats(values, size, quantized_values,
|
||||
scaling_factor, offset);
|
||||
}
|
||||
|
||||
// Helper function to quantize floats.
|
||||
// float_data_ptr input float vectors
|
||||
// n_batch number of input vectors
|
||||
// n_data size of a single input vector
|
||||
// quantized_data_ptr (out) vector with quantized data
|
||||
// scaling_factors (out) scaling factors (one per vector)
|
||||
// zero_points (out) zero points (one per vector)
|
||||
// do_asymmetric controls if the quantization should be asymmetric.
|
||||
inline void BatchQuantizeFloats(const float* float_data_ptr, int n_batch,
|
||||
int n_data, int8_t* quantized_data_ptr,
|
||||
float* scaling_factors, int32_t* zero_points,
|
||||
bool do_asymmetric) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_data;
|
||||
if (do_asymmetric) {
|
||||
AsymmetricQuantizeFloats(float_data_ptr + offset, n_data,
|
||||
quantized_data_ptr + offset, &scaling_factors[b],
|
||||
&zero_points[b]);
|
||||
} else {
|
||||
float unused_min, unused_max;
|
||||
SymmetricQuantizeFloats(float_data_ptr + offset, n_data,
|
||||
quantized_data_ptr + offset, &unused_min,
|
||||
&unused_max, &scaling_factors[b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all entries of a vector are zero for float.
|
||||
inline bool IsZeroVector(const float* vector, int v_size) {
|
||||
return PortableIsZeroVector(vector, v_size);
|
||||
}
|
||||
|
||||
// Check if all entries of a vector are zero for int8_t.
|
||||
inline bool IsZeroVector(const int8_t* vector, int v_size) {
|
||||
return PortableIsZeroVector(vector, v_size);
|
||||
}
|
||||
|
||||
// Apply Layer Normalization (https://arxiv.org/abs/1607.06450) to a Quantized
|
||||
// vector.
|
||||
// Parameters:
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - layer_norm_weights: the quantized layer normalization weights.
|
||||
// - bias: the bias for the layer normalization.
|
||||
// - layer_norm_scale_a: multiplier for scale factor.
|
||||
// - layer_norm_scale_b: shift for scale factor.
|
||||
// - variance_limit: the guard to make sure the inverse does not overflow.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 16 bit output
|
||||
inline void ApplyLayerNorm(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
const int32_t* bias, int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, int32_t variance_limit,
|
||||
int n_batch, int n_input, int16_t* output) {
|
||||
PortableApplyLayerNorm(input, layer_norm_weights, bias, layer_norm_scale_a,
|
||||
layer_norm_scale_b, variance_limit, n_batch, n_input,
|
||||
output);
|
||||
}
|
||||
|
||||
// Same as above but the internal calculation is done in float.
|
||||
inline void ApplyLayerNormFloat(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, const int32_t* bias,
|
||||
int n_batch, int n_input, int16_t* output) {
|
||||
PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
|
||||
layer_norm_scale_b, bias, n_batch, n_input,
|
||||
output);
|
||||
}
|
||||
|
||||
// Apply Sigmoid to a quantized vector.
|
||||
// Parameters:
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 16 bit output
|
||||
// The input is in Q3.12 format and the output is in Q0.15 format.
|
||||
inline void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int16_t* output) {
|
||||
PortableApplySigmoid(input, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
// Same as above but the internal calcualtion is float.
|
||||
inline void ApplySigmoidFloat(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output) {
|
||||
PortableApplySigmoidFloat(input, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
// Apply Tanh to a quantized vector.
|
||||
// Parameters:
|
||||
// - integer_bits: the integer bits of the input.
|
||||
// Currently supports 0, 1, 2, 3, 4, 5, 6.
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 16 bit output
|
||||
// The input is in Qm.15-m format and the output is in Q0.15 format.
|
||||
inline void ApplyTanh(int32_t integer_bits, const int16_t* input,
|
||||
int32_t n_batch, int32_t n_input, int16_t* output) {
|
||||
PortableApplyTanh(integer_bits, input, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
|
||||
// - Input has 2^(integer_bits) as scale.
|
||||
// - Output has Q0.15 as scale.
|
||||
inline void ApplyTanhFloat(const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int32_t integer_bits,
|
||||
int16_t* output) {
|
||||
PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
|
||||
}
|
||||
|
||||
// Element-wise multiplication of two quantized vectors.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - shift: the shift needed to produce the output.
|
||||
// - output: the 16 bit output of size n_batch * n_input.
|
||||
// Output does not need to be initialized.
|
||||
inline void CwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int shift, int16_t* output) {
|
||||
PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
|
||||
}
|
||||
|
||||
// Element-wise multiplication of two quantized vectors with rescaling.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - multiplier: the multiplier part of scale.
|
||||
// - shift: the shift part of scale.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 8 bit output of size n_batch * n_input.
|
||||
// - output_zp: the zero point of output.
|
||||
// Output does not need to be initialized.
|
||||
// Multiplier ("m") and shift ("s") are connected to scale ("s") with s = m *
|
||||
// 2^(s - 31).
|
||||
inline void CwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int32_t multiplier, int32_t shift, int32_t n_batch,
|
||||
int32_t n_input, int32_t output_zp, int8_t* output) {
|
||||
PortableCwiseMul(input_1, input_2, multiplier, shift, n_batch, n_input,
|
||||
output_zp, output);
|
||||
}
|
||||
|
||||
// Element-wise in-place clipping of a vector. Overloaded for float, int16_t,
|
||||
// int8_t. Parameters:
|
||||
// - vector: vector of size v_size.
|
||||
// - v_size: the size of the vector.
|
||||
// - clipping_value: the value used for clipping.
|
||||
inline void CwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
inline void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
inline void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
// Element-wise saturating addition of two quantized vectors without rescaling.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 8 bit output of size n_batch * n_input.
|
||||
// Output does not need to be initialized.
|
||||
inline void CwiseAdd(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int16_t* output) {
|
||||
PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
inline void MeanStddevNormalization(const float* input_vector,
|
||||
float* output_vector, int v_size,
|
||||
int n_batch) {
|
||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||
}
|
||||
|
||||
inline void Sub1Vector(const float* vector, int v_size, float* result) {
|
||||
PortableSub1Vector(vector, v_size, result);
|
||||
}
|
||||
|
||||
inline void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
|
||||
PortableSub1Vector(vector, v_size, result);
|
||||
}
|
||||
|
||||
// Multiply all elements of vector with a scalar.
|
||||
inline void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result) {
|
||||
PortableVectorScalarMultiply(vector, v_size, scale, result);
|
||||
}
|
||||
|
||||
// Saturate Add with rescale on both inputs.
|
||||
inline void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||
const int8_t* recurrent, int8_t recurrent_zp,
|
||||
int32_t input_effective_scale_a,
|
||||
int32_t input_effective_scale_b,
|
||||
int32_t recurrent_effective_scale_a,
|
||||
int32_t recurrent_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_cell,
|
||||
int16_t* output) {
|
||||
PortableTwoGateSaturatingAdd(
|
||||
input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
|
||||
input_effective_scale_b, recurrent_effective_scale_a,
|
||||
recurrent_effective_scale_b, n_batch, n_cell, output);
|
||||
}
|
||||
|
||||
// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
|
||||
// dimension composed by input vectors independent from each other). The result
|
||||
// of the multiplication is accumulated to the passed result buffer.
|
||||
// More specifically, for a matrix M of shape [n, i] and a batched-vector
|
||||
// of shape [i, batch] it will first compute the product of shape [n, batch].
|
||||
// This product will be accumulated to the result buffer.
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
|
||||
int m_cols, const float* vector,
|
||||
int n_batch, float* result) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
n_batch, result);
|
||||
}
|
||||
|
||||
// Same as the function above, but the matrix is a sparse tensor with block
|
||||
// pattern 1x4.
|
||||
// This function assumes that m_cols is a multiple of the block size (4 in this
|
||||
// case) so that there's no incomplete block.
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vector, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
scaling_factors, n_batch, result);
|
||||
}
|
||||
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
|
||||
context);
|
||||
}
|
||||
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vector, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
scaling_factors, n_batch, result);
|
||||
}
|
||||
|
||||
// Same as the function above, but the matrix is a sparse tensor with block
|
||||
// pattern 1x4.
|
||||
// This function assumes that m_cols is a multiple of the block size (4 in this
|
||||
// case) so that there's no incomplete block.
|
||||
inline void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
matrix, segments, indices, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
// Same as the function above, but the matrix is stored in block compressed
|
||||
// sparse row format with block pattern 1x16 which consists of two arrays:
|
||||
// 1. A matrix array stores non-zero blocks of the matrix in row major.
|
||||
// 2. A ledger array stores nrows groups, one group per row. Each group starts
|
||||
// with an integer representing the number of non-zero blocks for the
|
||||
// corresponding row and follows with column indexes of the first element
|
||||
// of each non-zero block.
|
||||
// This function assumes that
|
||||
// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
|
||||
// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
|
||||
inline void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, ledger, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
// Same as the function above, but the matrix is a sparse tensor with block
|
||||
// pattern 1x16.
|
||||
// This function assumes that m_cols is a multiple of the block size (16 in this
|
||||
// case) so that there's no incomplete block. Also, it assumes all offsets of
|
||||
// input, output and filter are zero.
|
||||
inline void SparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
int8_t* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch,
|
||||
input_offset, output_multiplier, output_shift, output_offset,
|
||||
output_activation_min, output_activation_max, result);
|
||||
}
|
||||
|
||||
// Same as the function above, but the matrix is stored in block compressed
|
||||
// sparse row format with block pattern 1x16 which consists of two arrays:
|
||||
// 1. A matrix array stores non-zero blocks of the matrix in row major.
|
||||
// 2. A ledger array stores nrows groups, one group per row. Each group starts
|
||||
// with an integer representing the number of non-zero blocks for the
|
||||
// corresponding row followed by column index of the first element of
|
||||
// each non-zero block.
|
||||
// This function assumes that
|
||||
// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
|
||||
// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
|
||||
inline void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||
const int m_cols, const int8_t* __restrict__ vectors,
|
||||
const float* scaling_factors, int n_batch, float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch,
|
||||
result);
|
||||
}
|
||||
|
||||
// Same as the above 8, 8, 8 integer matmul except for the presence of zero
|
||||
// point and non-accumulative.
|
||||
// TODO(b/148688698): remove this function by folding zero point calculation in
|
||||
// prepare() function.
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int16_t* output, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
|
||||
n_output, output_zp, scratch, output, context);
|
||||
}
|
||||
|
||||
// Same as above but has 16 bit and 8 bit input and 8 bit output.
|
||||
// Used in projection when hidden is 16bit.
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int8_t* output, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
|
||||
n_output, output_zp, scratch, output, context);
|
||||
}
|
||||
|
||||
// Same as the function above, but provides separate scaling factor for the
|
||||
// matrix and the vectors. The scaling factors are multiplied in the
|
||||
// scaling_factor_scratch buffer.
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float matrix_scaling_factor,
|
||||
const float* vector_scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, float* scaling_factor_scratch,
|
||||
CpuBackendContext* context) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
scaling_factor_scratch[b] =
|
||||
vector_scaling_factors[b] * matrix_scaling_factor;
|
||||
}
|
||||
MatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||
scaling_factor_scratch, n_batch, result,
|
||||
per_channel_scale, input_offset, scratch,
|
||||
row_sums, compute_row_sums, context);
|
||||
}
|
||||
|
||||
// Multiplies a matrix with a scalar and reduce the result on each row to a
|
||||
// scalar.
|
||||
// Parameters:
|
||||
// - matrix: matrix of size n_row * n_col
|
||||
// - scalar: the scalar that is multiplied to each element in the matrix
|
||||
// - n_row: the row count of the matrix
|
||||
// - n_col: the column count of the matrix
|
||||
// - output: the 32bit output
|
||||
// Note: We do not need saturation because the int8 * int8 is safe from overflow
|
||||
// in (2^31-1) / (2^14) = 131072, which is bigger than the n_row. Non-zero
|
||||
// initial output value is not exceptionally large.
|
||||
inline void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
|
||||
int32_t n_row, int32_t n_col,
|
||||
int32_t* output) {
|
||||
PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output);
|
||||
}
|
||||
|
||||
// Same as the above 8, 8, 8 integer matmul except for the presence of zero
|
||||
// point and non-accumulative.
|
||||
// TODO(b/148688698): remove this function by folding zero point calculation in
|
||||
// prepare() function.
|
||||
inline void MatrixBatchVectorMultiply(const int8_t* input,
|
||||
int32_t input_zeropoint,
|
||||
const int8_t* input_to_gate_weights,
|
||||
int32_t input_to_gate_effective_scale_a,
|
||||
int32_t input_to_gate_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_input,
|
||||
int32_t n_cell, int8_t* gate_output,
|
||||
int8_t gate_output_zp) {
|
||||
PortableMatrixBatchVectorMultiply(
|
||||
input, input_zeropoint, input_to_gate_weights,
|
||||
input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
|
||||
n_input, n_cell, gate_output, gate_output_zp);
|
||||
}
|
||||
|
||||
// Same as above but has 16 bit and 8 bit input and 8 bit output.
|
||||
// Used in projection when hidden is 16bit.
|
||||
inline void MatrixBatchVectorMultiply(const int16_t* hidden,
|
||||
const int8_t* hidden_to_output_weights,
|
||||
int32_t proj_effective_scale_a,
|
||||
int32_t proj_effective_scale_b,
|
||||
const int32_t* gate_bias, int32_t n_batch,
|
||||
int32_t n_hidden, int32_t n_output,
|
||||
int32_t output_zp, int8_t* proj_output) {
|
||||
PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
|
||||
proj_effective_scale_a,
|
||||
proj_effective_scale_b, gate_bias, n_batch,
|
||||
n_hidden, n_output, output_zp, proj_output);
|
||||
}
|
||||
|
||||
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
|
||||
// operation, the assumption here is that result array is initialized to valid
|
||||
// values.
|
||||
template <typename T>
|
||||
inline void VectorBatchVectorCwiseProductAccumulate(const T* vector, int v_size,
|
||||
const T* batch_vector,
|
||||
int n_batch, T* result) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
VectorVectorCwiseProductAccumulate(vector, batch_vector, v_size, result);
|
||||
// Update the pointers.
|
||||
result += v_size;
|
||||
batch_vector += v_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, but inputs are 16bit integer and output is 16bit integer.
|
||||
inline void VectorBatchVectorCwiseProductAccumulate(
|
||||
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
|
||||
int32_t multiplier, int shift, int16_t* result) {
|
||||
PortableVectorBatchVectorCwiseProductAccumulate(
|
||||
vector, v_size, batch_vector, n_batch, multiplier, shift, result);
|
||||
}
|
||||
|
||||
// Apply Rectified Linear to elements of a vector.
|
||||
inline void ApplyReluToVector(const float* vector, int v_size, float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = std::max(0.0f, vector[v]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Rectified Linear 1 (cap to [-1;1]) to elements of a vector
|
||||
inline void ApplyRelu1ToVector(const float* vector, int v_size, float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = std::max(-1.0f, std::min(vector[v], 1.0f));
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Rectified Linear 6 (cap to [0;6]) to elements of a vector
|
||||
inline void ApplyRelu6ToVector(const float* vector, int v_size, float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = std::max(0.0f, std::min(vector[v], 6.0f));
|
||||
}
|
||||
}
|
||||
|
||||
// Apply tanh to elements of a vector
|
||||
inline void ApplyTanhToVector(const float* vector, int v_size, float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = std::tanh(vector[v]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply signbit to elements of a vector
|
||||
inline void ApplySignbitToVector(const float* vector, int v_size,
|
||||
float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = std::signbit(vector[v]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply sigmoid to elements of a vector.
|
||||
inline void ApplySigmoidToVector(const float* vector, int v_size,
|
||||
float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = 1.0f / (1.0f + std::exp(-vector[v]));
|
||||
}
|
||||
}
|
||||
|
||||
// Apply appropriate activation function to elements of a vector.
|
||||
inline void ApplyActivationToVector(const float* vector, int v_size,
|
||||
TfLiteFusedActivation activation,
|
||||
float* result) {
|
||||
switch (activation) {
|
||||
case kTfLiteActNone:
|
||||
return;
|
||||
case kTfLiteActRelu:
|
||||
return ApplyReluToVector(vector, v_size, result);
|
||||
case kTfLiteActReluN1To1:
|
||||
return ApplyRelu1ToVector(vector, v_size, result);
|
||||
case kTfLiteActRelu6:
|
||||
return ApplyRelu6ToVector(vector, v_size, result);
|
||||
case kTfLiteActTanh:
|
||||
return ApplyTanhToVector(vector, v_size, result);
|
||||
case kTfLiteActSignBit:
|
||||
return ApplySignbitToVector(vector, v_size, result);
|
||||
case kTfLiteActSigmoid:
|
||||
return ApplySigmoidToVector(vector, v_size, result);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace micro_tensor_utils
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_
|
||||
@@ -209,14 +209,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_MIRROR_PAD() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -61,14 +61,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MUL() {
|
||||
return {/*init=*/MulInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/MulPrepare,
|
||||
/*invoke=*/MulEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -51,14 +51,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace neg
|
||||
|
||||
TfLiteRegistration Register_NEG() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/neg::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, neg::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace pack
|
||||
|
||||
TfLiteRegistration Register_PACK() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/pack::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, pack::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -223,26 +223,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace pad
|
||||
|
||||
TfLiteRegistration Register_PAD() {
|
||||
return {/*init=*/pad::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/pad::Prepare,
|
||||
/*invoke=*/pad::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval);
|
||||
}
|
||||
|
||||
// Also register Pad as PadV2.
|
||||
TfLiteRegistration Register_PADV2() {
|
||||
return {/*init=*/pad::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/pad::Prepare,
|
||||
/*invoke=*/pad::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -88,25 +88,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_AVERAGE_POOL_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PoolingPrepare,
|
||||
/*invoke=*/AverageEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MAX_POOL_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PoolingPrepare,
|
||||
/*invoke=*/MaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -69,14 +69,7 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_PRELU() {
|
||||
return {/*init=*/PreluInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PreluPrepare,
|
||||
/*invoke=*/PreluEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(PreluInit, PreluPrepare, PreluEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -34,14 +34,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_QUANTIZE() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PrepareQuantizeReference,
|
||||
/*invoke=*/EvalQuantizeReference,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, PrepareQuantizeReference,
|
||||
EvalQuantizeReference);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -53,15 +53,19 @@ TfLiteStatus PrepareQuantizeReference(TfLiteContext* context,
|
||||
TF_LITE_ENSURE(context, affine_quantization->scale);
|
||||
TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteInt32 ||
|
||||
input->type == kTfLiteInt16 || input->type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE(
|
||||
context, input->type == kTfLiteFloat32 || input->type == kTfLiteInt32 ||
|
||||
input->type == kTfLiteInt16 || input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteUInt8);
|
||||
TF_LITE_ENSURE(context, output->type == kTfLiteInt8 ||
|
||||
output->type == kTfLiteInt16 ||
|
||||
output->type == kTfLiteInt32);
|
||||
output->type == kTfLiteInt32 ||
|
||||
output->type == kTfLiteUInt8);
|
||||
|
||||
if ((input->type == kTfLiteInt16 && output->type == kTfLiteInt8) ||
|
||||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt8) ||
|
||||
(input->type == kTfLiteInt8 && output->type == kTfLiteUInt8) ||
|
||||
(input->type == kTfLiteUInt8 && output->type == kTfLiteInt8) ||
|
||||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt16) ||
|
||||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt32) ||
|
||||
(input->type == kTfLiteInt16 && output->type == kTfLiteInt16) ||
|
||||
@@ -109,9 +113,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (input->type == kTfLiteInt32) {
|
||||
@@ -132,9 +136,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (input->type == kTfLiteInt16) {
|
||||
@@ -162,9 +166,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int32_t>(output));
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (input->type == kTfLiteInt8) {
|
||||
@@ -179,6 +183,13 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->input_zero_point, data->quantization_params.zero_point,
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
reference_ops::Requantize(
|
||||
tflite::micro::GetTensorData<int8_t>(input), size,
|
||||
data->requantize_output_multiplier, data->requantize_output_shift,
|
||||
data->input_zero_point, data->quantization_params.zero_point,
|
||||
tflite::micro::GetTensorData<uint8_t>(output));
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
reference_ops::Requantize(
|
||||
tflite::micro::GetTensorData<int8_t>(input), size,
|
||||
@@ -194,15 +205,31 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int32_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (input->type == kTfLiteUInt8) {
|
||||
size_t size = ElementCount(*input->dims);
|
||||
switch (output->type) {
|
||||
case kTfLiteInt8:
|
||||
reference_ops::Requantize(
|
||||
tflite::micro::GetTensorData<uint8_t>(input), size,
|
||||
data->requantize_output_multiplier, data->requantize_output_shift,
|
||||
data->input_zero_point, data->quantization_params.zero_point,
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -81,14 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_READ_VARIABLE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -23,331 +23,41 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/reduce.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace reduce {
|
||||
|
||||
constexpr int kMaxNumberOfAxis = 4;
|
||||
constexpr int kMaxNumberOfReducedAxis = 2;
|
||||
|
||||
struct OpData {
|
||||
int32_t multiplier;
|
||||
int shift;
|
||||
int temp_buffer_idx;
|
||||
int resolved_axis_idx;
|
||||
int input_zp;
|
||||
float input_scale;
|
||||
int output_zp;
|
||||
float output_scale;
|
||||
int num_output_elements;
|
||||
};
|
||||
|
||||
void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// Inputs Tensor (dtype depends on quantization):
|
||||
// [0] = Input
|
||||
// [1] = Axis
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
|
||||
// Outputs Tensor (dtype depends on quantization):
|
||||
// [0] = Output
|
||||
|
||||
// Validate number of inputs and outputs
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
// Validate axis type
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
const double real_multiplier = static_cast<double>(input->params.scale) /
|
||||
static_cast<double>(output->params.scale);
|
||||
QuantizeMultiplier(real_multiplier, &data->multiplier, &data->shift);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
return kTfLiteOk;
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpDataReduce));
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
|
||||
|
||||
op_data->input_scale = input->params.scale;
|
||||
op_data->output_scale = output->params.scale;
|
||||
op_data->num_output_elements = NumElements(output);
|
||||
|
||||
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
|
||||
&op_data->temp_buffer_idx);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
|
||||
&op_data->resolved_axis_idx);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
return kTfLiteOk;
|
||||
return PrepareMaxHelper(context, node,
|
||||
static_cast<OpDataReduce*>(node->user_data));
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
|
||||
const double real_multiplier = static_cast<double>(input->params.scale) /
|
||||
static_cast<double>(output->params.scale);
|
||||
QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
|
||||
}
|
||||
|
||||
int output_size = NumElements(output);
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
|
||||
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
|
||||
&op_data->temp_buffer_idx);
|
||||
op_data->input_zp = input->params.zero_point;
|
||||
op_data->input_scale = input->params.scale;
|
||||
op_data->output_zp = output->params.zero_point;
|
||||
op_data->output_scale = output->params.scale;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
|
||||
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void ResolveAxis(const int* axis_data, int axis_count,
|
||||
tflite::MeanParams* op_params) {
|
||||
int i = 0;
|
||||
for (; i < axis_count; ++i) {
|
||||
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
|
||||
}
|
||||
for (; i < 4; ++i) {
|
||||
op_params->axis[i] = 1;
|
||||
}
|
||||
op_params->axis_count = axis_count;
|
||||
return PrepareMeanOrSumHelper(context, node,
|
||||
static_cast<OpDataReduce*>(node->user_data));
|
||||
}
|
||||
|
||||
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TfLiteReducerParams* params =
|
||||
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
int num_axis = static_cast<int>(ElementCount(*axis->dims));
|
||||
int temp_index[kMaxNumberOfAxis];
|
||||
int resolved_axis[kMaxNumberOfReducedAxis];
|
||||
|
||||
tflite::MeanParams op_params;
|
||||
ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
|
||||
|
||||
// Special case mean implementation exists for 4D mean across axes 1 and 2.
|
||||
bool special_case_4d_axes_1_and_2 =
|
||||
input->dims->size == 4 && op_params.axis_count == 2 &&
|
||||
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
||||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
|
||||
if (params->keep_dims && special_case_4d_axes_1_and_2) {
|
||||
reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
} else {
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::Mean(
|
||||
tflite::micro::GetTensorData<float>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<float>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_index, resolved_axis,
|
||||
tflite::micro::GetTensorData<float>(output)));
|
||||
}
|
||||
} break;
|
||||
case kTfLiteInt8: {
|
||||
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
|
||||
if (params->keep_dims && special_case_4d_axes_1_and_2) {
|
||||
reference_integer_ops::Mean(
|
||||
op_params, op_data->multiplier, op_data->shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
|
||||
} else if (op_data->input_zp == op_data->output_zp &&
|
||||
op_data->input_scale == op_data->output_scale) {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::Mean(
|
||||
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_index, resolved_axis, temp_buffer));
|
||||
} else {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum(
|
||||
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
|
||||
op_data->input_scale, input->dims->data, input->dims->size,
|
||||
tflite::micro::GetTensorData<int8_t>(output),
|
||||
op_data->output_zp, op_data->output_scale, output->dims->data,
|
||||
output->dims->size, tflite::micro::GetTensorData<int>(axis),
|
||||
num_axis, params->keep_dims, temp_index, resolved_axis,
|
||||
temp_buffer, false));
|
||||
}
|
||||
} break;
|
||||
case kTfLiteInt16: {
|
||||
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
|
||||
if (params->keep_dims && special_case_4d_axes_1_and_2) {
|
||||
reference_integer_ops::Mean(
|
||||
op_params, op_data->multiplier, op_data->shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
|
||||
} else if (op_data->input_zp == op_data->output_zp &&
|
||||
op_data->input_scale == op_data->output_scale) {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
|
||||
input->dims->data, input->dims->size,
|
||||
tflite::micro::GetTensorData<int16_t>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis),
|
||||
num_axis, params->keep_dims, temp_index,
|
||||
resolved_axis, temp_buffer));
|
||||
} else {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum(
|
||||
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
|
||||
op_data->input_scale, input->dims->data, input->dims->size,
|
||||
tflite::micro::GetTensorData<int16_t>(output),
|
||||
op_data->output_zp, op_data->output_scale, output->dims->data,
|
||||
output->dims->size, tflite::micro::GetTensorData<int>(axis),
|
||||
num_axis, params->keep_dims, temp_index, resolved_axis,
|
||||
temp_buffer, false));
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_ENSURE_MSG(context, false,
|
||||
"Currently, only float32, int8 or uint8 input type "
|
||||
"is supported.");
|
||||
}
|
||||
return kTfLiteOk;
|
||||
return EvalMeanHelper(context, node,
|
||||
static_cast<OpDataReduce*>(node->user_data));
|
||||
}
|
||||
|
||||
TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TfLiteReducerParams* params =
|
||||
static_cast<TfLiteReducerParams*>(node->builtin_data);
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
// Interpret an axis tensor with null dimensions as a scalar
|
||||
int num_axis = static_cast<int>(ElementCount(*axis->dims));
|
||||
int* temp_buffer = static_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
int* resolved_axis = static_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::ReduceGeneric<float>(
|
||||
tflite::micro::GetTensorData<float>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<float>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_buffer, resolved_axis,
|
||||
std::numeric_limits<float>::lowest(),
|
||||
[](const float current, const float in) -> float {
|
||||
return (in > current) ? in : current;
|
||||
}));
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
|
||||
static_cast<double>(op_data->output_scale));
|
||||
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::ReduceGeneric<int8_t>(
|
||||
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_buffer, resolved_axis,
|
||||
std::numeric_limits<int8_t>::lowest(),
|
||||
[](const int8_t current, const int8_t in) -> int8_t {
|
||||
return (in > current) ? in : current;
|
||||
}));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only float32 and int8 types are supported.\n");
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
OpDataReduce* op_data = static_cast<OpDataReduce*>(node->user_data);
|
||||
return EvalMaxHelper(context, node, op_data);
|
||||
}
|
||||
|
||||
} // namespace reduce
|
||||
|
||||
TfLiteRegistration Register_MEAN() {
|
||||
return {/*init=*/reduce::InitReduce,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/reduce::PrepareMeanOrSum,
|
||||
/*invoke=*/reduce::EvalMean,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalMean);
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_REDUCE_MAX() {
|
||||
return {/*init=*/reduce::InitReduce,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/reduce::PrepareMax,
|
||||
/*invoke=*/reduce::EvalMax,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(InitReduce, PrepareMax, EvalMax);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
extern const int kMaxNumberOfAxis;
|
||||
extern const int kMaxNumberOfReducedAxis;
|
||||
|
||||
struct OpDataReduce {
|
||||
int32_t multiplier;
|
||||
int shift;
|
||||
int temp_buffer_idx;
|
||||
int resolved_axis_idx;
|
||||
int input_zp;
|
||||
float input_scale;
|
||||
int output_zp;
|
||||
float output_scale;
|
||||
int num_output_elements;
|
||||
};
|
||||
|
||||
TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data);
|
||||
|
||||
TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data);
|
||||
|
||||
TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data);
|
||||
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data);
|
||||
|
||||
void ReduceResolveAxis(const int* axis_data, int axis_count,
|
||||
MeanParams* op_params);
|
||||
|
||||
TfLiteRegistration Register_MEAN();
|
||||
TfLiteRegistration Register_REDUCE_MAX();
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_
|
||||
@@ -0,0 +1,311 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reduce.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/reduce.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
const int kMaxNumberOfAxis = 4;
|
||||
const int kMaxNumberOfReducedAxis = 2;
|
||||
|
||||
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
|
||||
int32_t* multiplier, int* shift) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// Inputs Tensor (dtype depends on quantization):
|
||||
// [0] = Input
|
||||
// [1] = Axis
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
|
||||
// Outputs Tensor (dtype depends on quantization):
|
||||
// [0] = Output
|
||||
|
||||
// Validate number of inputs and outputs
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
// Validate axis type
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
const double real_multiplier = static_cast<double>(input->params.scale) /
|
||||
static_cast<double>(output->params.scale);
|
||||
QuantizeMultiplier(real_multiplier, multiplier, shift);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data) {
|
||||
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node, &op_data->multiplier,
|
||||
&op_data->shift));
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
|
||||
|
||||
op_data->input_scale = input->params.scale;
|
||||
op_data->output_scale = output->params.scale;
|
||||
op_data->num_output_elements = NumElements(output);
|
||||
|
||||
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
|
||||
&op_data->temp_buffer_idx);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
|
||||
&op_data->resolved_axis_idx);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
|
||||
const double real_multiplier = static_cast<double>(input->params.scale) /
|
||||
static_cast<double>(output->params.scale);
|
||||
QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
|
||||
}
|
||||
|
||||
int output_size = NumElements(output);
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
|
||||
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
|
||||
&op_data->temp_buffer_idx);
|
||||
op_data->input_zp = input->params.zero_point;
|
||||
op_data->input_scale = input->params.scale;
|
||||
op_data->output_zp = output->params.zero_point;
|
||||
op_data->output_scale = output->params.scale;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
PrepareSimple(context, node, &(op_data->multiplier), &(op_data->shift)));
|
||||
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void ResolveAxis(const int* axis_data, int axis_count,
|
||||
tflite::MeanParams* op_params) {
|
||||
int i = 0;
|
||||
for (; i < axis_count; ++i) {
|
||||
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
|
||||
}
|
||||
for (; i < 4; ++i) {
|
||||
op_params->axis[i] = 1;
|
||||
}
|
||||
op_params->axis_count = axis_count;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TfLiteReducerParams* params =
|
||||
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
|
||||
|
||||
int num_axis = static_cast<int>(ElementCount(*axis->dims));
|
||||
int temp_index[kMaxNumberOfAxis];
|
||||
int resolved_axis[kMaxNumberOfReducedAxis];
|
||||
|
||||
tflite::MeanParams op_params;
|
||||
ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
|
||||
|
||||
// Special case mean implementation exists for 4D mean across axes 1 and 2.
|
||||
bool special_case_4d_axes_1_and_2 =
|
||||
input->dims->size == 4 && op_params.axis_count == 2 &&
|
||||
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
||||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
|
||||
if (params->keep_dims && special_case_4d_axes_1_and_2) {
|
||||
reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
} else {
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::Mean(
|
||||
tflite::micro::GetTensorData<float>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<float>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_index, resolved_axis,
|
||||
tflite::micro::GetTensorData<float>(output)));
|
||||
}
|
||||
} break;
|
||||
case kTfLiteInt8: {
|
||||
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
|
||||
if (params->keep_dims && special_case_4d_axes_1_and_2) {
|
||||
reference_integer_ops::Mean(
|
||||
op_params, op_data->multiplier, op_data->shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
|
||||
} else if (op_data->input_zp == op_data->output_zp &&
|
||||
op_data->input_scale == op_data->output_scale) {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::Mean(
|
||||
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_index, resolved_axis, temp_buffer));
|
||||
} else {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum(
|
||||
tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
|
||||
op_data->input_scale, input->dims->data, input->dims->size,
|
||||
tflite::micro::GetTensorData<int8_t>(output),
|
||||
op_data->output_zp, op_data->output_scale, output->dims->data,
|
||||
output->dims->size, tflite::micro::GetTensorData<int>(axis),
|
||||
num_axis, params->keep_dims, temp_index, resolved_axis,
|
||||
temp_buffer, false));
|
||||
}
|
||||
} break;
|
||||
case kTfLiteInt16: {
|
||||
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
|
||||
if (params->keep_dims && special_case_4d_axes_1_and_2) {
|
||||
reference_integer_ops::Mean(
|
||||
op_params, op_data->multiplier, op_data->shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
|
||||
} else if (op_data->input_zp == op_data->output_zp &&
|
||||
op_data->input_scale == op_data->output_scale) {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
|
||||
input->dims->data, input->dims->size,
|
||||
tflite::micro::GetTensorData<int16_t>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis),
|
||||
num_axis, params->keep_dims, temp_index,
|
||||
resolved_axis, temp_buffer));
|
||||
} else {
|
||||
int32_t* temp_buffer = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum(
|
||||
tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
|
||||
op_data->input_scale, input->dims->data, input->dims->size,
|
||||
tflite::micro::GetTensorData<int16_t>(output),
|
||||
op_data->output_zp, op_data->output_scale, output->dims->data,
|
||||
output->dims->size, tflite::micro::GetTensorData<int>(axis),
|
||||
num_axis, params->keep_dims, temp_index, resolved_axis,
|
||||
temp_buffer, false));
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_ENSURE_MSG(context, false,
|
||||
"Currently, only float32, int8 or uint8 input type "
|
||||
"is supported.");
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
|
||||
OpDataReduce* op_data) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TfLiteReducerParams* params =
|
||||
static_cast<TfLiteReducerParams*>(node->builtin_data);
|
||||
|
||||
// Interpret an axis tensor with null dimensions as a scalar
|
||||
int num_axis = static_cast<int>(ElementCount(*axis->dims));
|
||||
int* temp_buffer = static_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
|
||||
int* resolved_axis = static_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::ReduceGeneric<float>(
|
||||
tflite::micro::GetTensorData<float>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<float>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_buffer, resolved_axis,
|
||||
std::numeric_limits<float>::lowest(),
|
||||
[](const float current, const float in) -> float {
|
||||
return (in > current) ? in : current;
|
||||
}));
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
|
||||
static_cast<double>(op_data->output_scale));
|
||||
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::ReduceGeneric<int8_t>(
|
||||
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
|
||||
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
|
||||
output->dims->data, output->dims->size,
|
||||
tflite::micro::GetTensorData<int>(axis), num_axis,
|
||||
params->keep_dims, temp_buffer, resolved_axis,
|
||||
std::numeric_limits<int8_t>::lowest(),
|
||||
[](const int8_t current, const int8_t in) -> int8_t {
|
||||
return (in > current) ? in : current;
|
||||
}));
|
||||
break;
|
||||
default:
|
||||
MicroPrintf("Only float32 and int8 types are supported.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -110,14 +110,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace reshape
|
||||
|
||||
TfLiteRegistration Register_RESHAPE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/reshape::Prepare,
|
||||
/*invoke=*/reshape::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, reshape::Prepare, reshape::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -111,14 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_RESIZE_BILINEAR() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -117,14 +117,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace resize_nearest_neighbor
|
||||
|
||||
TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/resize_nearest_neighbor::Prepare,
|
||||
/*invoke=*/resize_nearest_neighbor::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, resize_nearest_neighbor::Prepare,
|
||||
resize_nearest_neighbor::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -68,14 +68,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace round
|
||||
|
||||
TfLiteRegistration Register_ROUND() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/round::Prepare,
|
||||
/*invoke=*/round::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, round::Prepare, round::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -60,14 +60,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SHAPE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -151,14 +151,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SLICE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -83,14 +83,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SOFTMAX() {
|
||||
return {/*init=*/SoftmaxInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/SoftmaxPrepare,
|
||||
/*invoke=*/SoftmaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(SoftmaxInit, SoftmaxPrepare, SoftmaxEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -23,6 +23,13 @@ namespace tflite {
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length);
|
||||
|
||||
// Common helper function to SoftmaxPrepare.
|
||||
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
const TfLiteSoftmaxParams* params,
|
||||
SoftmaxParams* op_data);
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// This is the most generic TfLiteRegistration. The actual supported types may
|
||||
@@ -30,7 +37,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||
// (reference or optimized) must define this function.
|
||||
TfLiteRegistration Register_SOFTMAX();
|
||||
|
||||
#if defined(XTENSA)
|
||||
#if defined(XTENSA) || defined(CMSIS_NN)
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int8 input and int16 output.
|
||||
TfLiteRegistration Register_SOFTMAX_INT8_INT16();
|
||||
@@ -40,6 +47,23 @@ inline TfLiteRegistration Register_SOFTMAX_INT8_INT16() {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int8 input/output and uses the latency optimized implementations.
|
||||
TfLiteRegistration Register_SOFTMAX_INT8();
|
||||
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int16 input/output and uses the latency optimized implementations.
|
||||
TfLiteRegistration Register_SOFTMAX_INT16();
|
||||
|
||||
#else
|
||||
inline TfLiteRegistration Register_SOFTMAX_INT8() { return Register_SOFTMAX(); }
|
||||
|
||||
inline TfLiteRegistration Register_SOFTMAX_INT16() {
|
||||
return Register_SOFTMAX();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SOFTMAX_H_
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -28,11 +28,59 @@ namespace {
|
||||
// Softmax parameter data that persists in user_data
|
||||
const int kInt16LUTArraySize = 513;
|
||||
|
||||
TfLiteStatus InitializeLutForInt16(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
SoftmaxParams* op_data) {
|
||||
// Only allocate LUTs for KTfLiteInt16 data type
|
||||
if (input->type == kTfLiteInt16) {
|
||||
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||
op_data->one_over_one_plus_x_lut =
|
||||
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
}
|
||||
|
||||
// Populate LUT if required
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
// exp LUT only used on negative values
|
||||
// we consider exp(-10.0) is insignificant to accumulation
|
||||
gen_lut<float, int16_t, int16_t>(
|
||||
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
|
||||
op_data->exp_lut);
|
||||
gen_lut<float, int16_t, int16_t>(
|
||||
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
|
||||
1.0f, op_data->one_over_one_plus_x_lut);
|
||||
op_data->zero_point = output->params.zero_point;
|
||||
op_data->scale = output->params.scale;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
const TfLiteSoftmaxParams* params,
|
||||
SoftmaxParams* op_data) {
|
||||
if (InitializeLutForInt16(context, input, output, op_data) != kTfLiteOk) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
@@ -83,8 +131,6 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
|
||||
@@ -103,40 +149,6 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
|
||||
// Only allocate LUTs for KTfLiteInt16 data type
|
||||
if (input->type == kTfLiteInt16) {
|
||||
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||
op_data->one_over_one_plus_x_lut =
|
||||
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
}
|
||||
|
||||
// Populate LUT if required
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
// exp LUT only used on negative values
|
||||
// we consider exp(-10.0) is insignificant to accumulation
|
||||
gen_lut<float, int16_t, int16_t>(
|
||||
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
|
||||
op_data->exp_lut);
|
||||
gen_lut<float, int16_t, int16_t>(
|
||||
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
|
||||
1.0f, op_data->one_over_one_plus_x_lut);
|
||||
op_data->zero_point = output->params.zero_point;
|
||||
op_data->scale = output->params.scale;
|
||||
}
|
||||
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
auto ret_val =
|
||||
|
||||
@@ -114,14 +114,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SPACE_TO_DEPTH() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -120,14 +120,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace split
|
||||
|
||||
TfLiteRegistration Register_SPLIT() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/split::Prepare,
|
||||
/*invoke=*/split::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, split::Prepare, split::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace split_v
|
||||
|
||||
TfLiteRegistration Register_SPLIT_V() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/split_v::Prepare,
|
||||
/*invoke=*/split_v::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, split_v::Prepare, split_v::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
constexpr int kInputTensor1 = 0;
|
||||
constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
struct OpData {
|
||||
bool requires_broadcast;
|
||||
ArithmeticParams arithmetic_params;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T SquaredDifference(T input1, T input2) {
|
||||
const T difference = input1 - input2;
|
||||
return difference * difference;
|
||||
}
|
||||
|
||||
void* SquaredDifferenceInit(TfLiteContext* context, const char* buffer,
|
||||
size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
}
|
||||
|
||||
TfLiteStatus SquaredDifferencePrepare(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
data->requires_broadcast = false;
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
output->type = input2->type;
|
||||
|
||||
// Ensure the quantization parameters are equivalent.
|
||||
if (input1->type == kTfLiteInt8) {
|
||||
const auto& input1_quantization_params = input1->params;
|
||||
const auto& input2_quantization_params = input2->params;
|
||||
const auto& output_quantization_params = output->params;
|
||||
const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
|
||||
const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
|
||||
TF_LITE_ENSURE(context,
|
||||
input1_quantization_params.zero_point >= integer_type_min);
|
||||
TF_LITE_ENSURE(context,
|
||||
input1_quantization_params.zero_point <= integer_type_max);
|
||||
TF_LITE_ENSURE(context,
|
||||
input2_quantization_params.zero_point >= integer_type_min);
|
||||
TF_LITE_ENSURE(context,
|
||||
input2_quantization_params.zero_point <= integer_type_max);
|
||||
TF_LITE_ENSURE(context,
|
||||
output_quantization_params.zero_point >= integer_type_min);
|
||||
TF_LITE_ENSURE(context,
|
||||
output_quantization_params.zero_point <= integer_type_max);
|
||||
data->arithmetic_params.input1_offset =
|
||||
-input1_quantization_params.zero_point;
|
||||
data->arithmetic_params.input2_offset =
|
||||
-input2_quantization_params.zero_point;
|
||||
data->arithmetic_params.output_offset =
|
||||
output_quantization_params.zero_point;
|
||||
|
||||
// shift to make integer for scales.
|
||||
// 7 is selected so that maximum shifted result 255^2 * (1 << (7 * 2 ))
|
||||
// does not overflow signed 32-bit integer
|
||||
data->arithmetic_params.left_shift = 7;
|
||||
const double twice_max_input_scale =
|
||||
2.0 * static_cast<double>(std::max(input1_quantization_params.scale,
|
||||
input2_quantization_params.scale));
|
||||
const double real_input1_multiplier =
|
||||
static_cast<double>(input1_quantization_params.scale) /
|
||||
twice_max_input_scale;
|
||||
double real_input2_multiplier =
|
||||
static_cast<double>(input2_quantization_params.scale) /
|
||||
twice_max_input_scale;
|
||||
const double real_output_multiplier =
|
||||
(twice_max_input_scale * twice_max_input_scale) /
|
||||
static_cast<double>((1 << data->arithmetic_params.left_shift * 2) *
|
||||
output_quantization_params.scale);
|
||||
QuantizeMultiplierSmallerThanOneExp(
|
||||
real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
|
||||
&data->arithmetic_params.input1_shift);
|
||||
QuantizeMultiplierSmallerThanOneExp(
|
||||
real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
|
||||
&data->arithmetic_params.input2_shift);
|
||||
QuantizeMultiplierSmallerThanOneExp(
|
||||
real_output_multiplier, &data->arithmetic_params.output_multiplier,
|
||||
&data->arithmetic_params.output_shift);
|
||||
data->arithmetic_params.quantized_activation_min =
|
||||
std::numeric_limits<int8_t>::min();
|
||||
data->arithmetic_params.quantized_activation_max =
|
||||
std::numeric_limits<int8_t>::max();
|
||||
}
|
||||
|
||||
data->requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
inline int8_t SquaredDifference(int8_t x, int8_t y,
|
||||
const ArithmeticParams& params) {
|
||||
const int32_t input1_val = params.input1_offset + x;
|
||||
const int32_t input2_val = params.input2_offset + y;
|
||||
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
|
||||
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
|
||||
const int32_t scaled_input1_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input1_val, params.input1_multiplier, params.input1_shift);
|
||||
const int32_t scaled_input2_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input2_val, params.input2_multiplier, params.input2_shift);
|
||||
const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
|
||||
|
||||
// Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
|
||||
const int32_t squared_raw_diff = raw_diff * raw_diff;
|
||||
const int32_t raw_output =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
squared_raw_diff, params.output_multiplier, params.output_shift) +
|
||||
params.output_offset;
|
||||
const int32_t clamped_output =
|
||||
std::min(params.quantized_activation_max,
|
||||
std::max(params.quantized_activation_min, raw_output));
|
||||
return static_cast<int8_t>(clamped_output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData* data,
|
||||
const TfLiteEvalTensor* input1,
|
||||
const TfLiteEvalTensor* input2,
|
||||
TfLiteEvalTensor* output) {
|
||||
const auto* op_data = static_cast<const OpData*>(node->user_data);
|
||||
if (data->requires_broadcast) {
|
||||
reference_integer_ops::BroadcastBinaryFunction4DSlow(
|
||||
op_data->arithmetic_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<T>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<T>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<T>(output),
|
||||
reference_integer_ops::CheckArithmeticParams, SquaredDifference);
|
||||
} else {
|
||||
const int flat_size = tflite::micro::GetTensorShape(input1).FlatSize();
|
||||
reference_integer_ops::ElementWise(
|
||||
flat_size, op_data->arithmetic_params,
|
||||
tflite::micro::GetTensorData<int8_t>(input1),
|
||||
tflite::micro::GetTensorData<int8_t>(input2),
|
||||
tflite::micro::GetTensorData<int8_t>(output),
|
||||
reference_integer_ops::CheckArithmeticParams, SquaredDifference);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData* data, const TfLiteEvalTensor* input1,
|
||||
const TfLiteEvalTensor* input2,
|
||||
TfLiteEvalTensor* output) {
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
|
||||
tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<T>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<T>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<T>(output), SquaredDifference<T>);
|
||||
} else {
|
||||
reference_ops::BinaryFunction<T, T, T>(
|
||||
tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<T>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<T>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<T>(output), SquaredDifference<T>);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus SquaredDifferenceEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteEvalTensor* input1 =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor1);
|
||||
const TfLiteEvalTensor* input2 =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor2);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
EvalSquaredDifference<float>(context, node, data, input1, input2, output);
|
||||
} else if (output->type == kTfLiteInt32) {
|
||||
EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
|
||||
} else if (output->type == kTfLiteInt8) {
|
||||
EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
|
||||
output);
|
||||
} else {
|
||||
MicroPrintf(
|
||||
"SquaredDifference only supports FLOAT32, INT32 and INT8 now, got %d.",
|
||||
output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SQUARED_DIFFERENCE() {
|
||||
return tflite::micro::RegisterOp(
|
||||
SquaredDifferenceInit, SquaredDifferencePrepare, SquaredDifferenceEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -111,14 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SQUEEZE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -193,14 +193,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace strided_slice
|
||||
|
||||
TfLiteRegistration Register_STRIDED_SLICE() {
|
||||
return {/*init=*/strided_slice::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/strided_slice::Prepare,
|
||||
/*invoke=*/strided_slice::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(strided_slice::Init, strided_slice::Prepare,
|
||||
strided_slice::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -162,14 +162,7 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_SUB() {
|
||||
return {/*init=*/SubInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/SubPrepare,
|
||||
/*invoke=*/SubEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(SubInit, SubPrepare, SubEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -100,14 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SVDF() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PrepareSvdf,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, PrepareSvdf, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -195,14 +195,8 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace activations
|
||||
|
||||
TfLiteRegistration Register_TANH() {
|
||||
return {/*init=*/activations::TanhInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/activations::TanhPrepare,
|
||||
/*invoke=*/activations::TanhEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(
|
||||
activations::TanhInit, activations::TanhPrepare, activations::TanhEval);
|
||||
}
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
|
||||
@@ -116,13 +116,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_TRANSPOSE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
@@ -266,7 +266,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetOptionalTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output),
|
||||
tflite::micro::GetTensorShape(nullptr), nullptr);
|
||||
@@ -282,7 +282,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetOptionalTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output),
|
||||
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
|
||||
@@ -293,7 +293,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->GetScratchBuffer(context, data.scratch_buffer_index));
|
||||
// TODO(b/192090531): Remove this once all 8x16 transpose conv models use
|
||||
// 64-bit biases.
|
||||
if (bias->type == kTfLiteInt16) {
|
||||
if (bias != nullptr && bias->type == kTfLiteInt16) {
|
||||
std::int64_t* bias_converted_buffer =
|
||||
static_cast<int64_t*>(context->GetScratchBuffer(
|
||||
context, data.bias_converted_buffer_index));
|
||||
@@ -319,7 +319,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<std::int64_t>(bias),
|
||||
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output),
|
||||
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
|
||||
@@ -337,14 +337,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,244 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
||||
// TODO(b/230666079) enable below tests for xtensa when the xtensa
|
||||
// kernel is reconciled with reference kernel
|
||||
#if !defined(XTENSA)
|
||||
|
||||
typedef struct LstmIntegerTestConfig {
|
||||
const int n_batch;
|
||||
const int n_input;
|
||||
const int n_cell;
|
||||
const int n_output;
|
||||
const int sequence_length;
|
||||
const bool time_major;
|
||||
const bool use_cifg;
|
||||
const bool use_peephole;
|
||||
const bool use_projection_weights;
|
||||
const bool use_projection_bias;
|
||||
const bool use_layer_norm;
|
||||
const bool use_8x8_8_implementation;
|
||||
float intermediate_scale[5][2];
|
||||
int intermediate_zp[5][2];
|
||||
TfLiteAffineQuantization* intermediate_qparam;
|
||||
|
||||
const float* input;
|
||||
int8_t* input_quant;
|
||||
|
||||
const float* input_to_input_weights;
|
||||
int8_t* lstm_i2i_quant;
|
||||
const float* input_to_forget_weights;
|
||||
int8_t* lstm_i2f_quant;
|
||||
const float* input_to_cell_weights;
|
||||
int8_t* lstm_i2c_quant;
|
||||
const float* input_to_output_weights;
|
||||
int8_t* lstm_i2o_quant;
|
||||
|
||||
const float* recurrent_to_input_weights;
|
||||
int8_t* lstm_r2i_quant;
|
||||
const float* recurrent_to_forget_weights;
|
||||
int8_t* lstm_r2f_quant;
|
||||
const float* recurrent_to_cell_weights;
|
||||
int8_t* lstm_r2c_quant;
|
||||
const float* recurrent_to_output_weights;
|
||||
int8_t* lstm_r2o_quant;
|
||||
|
||||
const float* cell_to_input_weights;
|
||||
int16_t* lstm_c2i_quant;
|
||||
const float* cell_to_forget_weights;
|
||||
int16_t* lstm_c2f_quant;
|
||||
const float* cell_to_output_weights;
|
||||
int16_t* lstm_c2o_quant;
|
||||
|
||||
const float* input_gate_bias;
|
||||
int32_t* lstm_igate_bias_quant;
|
||||
const float* forget_gate_bias;
|
||||
int32_t* lstm_fgate_bias_quant;
|
||||
const float* cell_gate_bias;
|
||||
int32_t* lstm_cgate_bias_quant;
|
||||
const float* output_gate_bias;
|
||||
int32_t* lstm_ogate_bias_quant;
|
||||
|
||||
const float* projection_weights;
|
||||
int8_t* lstm_proj_w_quant;
|
||||
const float* projection_bias;
|
||||
int32_t* projection_bias_quant;
|
||||
|
||||
int16_t* output_state;
|
||||
int16_t* cell_state;
|
||||
|
||||
const float* input_layer_norm_coefficients;
|
||||
int16_t* lstm_input_layer_norm_coeff_quant;
|
||||
const float* forget_layer_norm_coefficients;
|
||||
int16_t* lstm_forget_layer_norm_coeff_quant;
|
||||
const float* cell_layer_norm_coefficients;
|
||||
int16_t* lstm_cell_layer_norm_coeff_quant;
|
||||
const float* output_layer_norm_coefficients;
|
||||
int16_t* lstm_output_layer_norm_coeff_quant;
|
||||
|
||||
int8_t* output;
|
||||
const int8_t* expected_output;
|
||||
|
||||
bool asymmetric_quantize_inputs;
|
||||
const float ranges[25][2];
|
||||
} LstmIntegerTestConfig;
|
||||
|
||||
typedef struct LstmFloatTestConfig {
|
||||
const int n_batch;
|
||||
const int n_input;
|
||||
const int n_cell;
|
||||
const int n_output;
|
||||
const int sequence_length;
|
||||
const bool time_major;
|
||||
const bool use_cifg;
|
||||
const bool use_peephole;
|
||||
const bool use_projection_weights;
|
||||
const bool use_projection_bias;
|
||||
const bool use_layer_norm;
|
||||
const float cell_clip;
|
||||
const float proj_clip;
|
||||
|
||||
const float* input_original;
|
||||
float* input;
|
||||
|
||||
const float* input_to_input_weights;
|
||||
const float* input_to_forget_weights;
|
||||
const float* input_to_cell_weights;
|
||||
const float* input_to_output_weights;
|
||||
|
||||
const float* recurrent_to_input_weights;
|
||||
const float* recurrent_to_forget_weights;
|
||||
const float* recurrent_to_cell_weights;
|
||||
const float* recurrent_to_output_weights;
|
||||
|
||||
const float* cell_to_input_weights;
|
||||
const float* cell_to_forget_weights;
|
||||
const float* cell_to_output_weights;
|
||||
|
||||
const float* input_gate_bias;
|
||||
const float* forget_gate_bias;
|
||||
const float* cell_gate_bias;
|
||||
const float* output_gate_bias;
|
||||
|
||||
const float* projection_weights;
|
||||
const float* projection_bias;
|
||||
|
||||
float* output_state;
|
||||
float* cell_state;
|
||||
|
||||
const float* input_layer_norm_coefficients;
|
||||
const float* forget_layer_norm_coefficients;
|
||||
const float* cell_layer_norm_coefficients;
|
||||
const float* output_layer_norm_coefficients;
|
||||
|
||||
float* output;
|
||||
const float* expected_output_original;
|
||||
float* expected_output;
|
||||
} LstmFloatTestConfig;
|
||||
|
||||
typedef struct LstmWeightQuantizationBuffers {
|
||||
int8_t* lstm_i2i_quant;
|
||||
float* lstm_i2i_scale;
|
||||
int* lstm_i2i_zp;
|
||||
TfLiteAffineQuantization* lstm_i2i_qparam;
|
||||
|
||||
int8_t* lstm_i2f_quant;
|
||||
float* lstm_i2f_scale;
|
||||
int* lstm_i2f_zp;
|
||||
TfLiteAffineQuantization* lstm_i2f_qparam;
|
||||
|
||||
int8_t* lstm_i2c_quant;
|
||||
float* lstm_i2c_scale;
|
||||
int* lstm_i2c_zp;
|
||||
TfLiteAffineQuantization* lstm_i2c_qparam;
|
||||
|
||||
int8_t* lstm_i2o_quant;
|
||||
float* lstm_i2o_scale;
|
||||
int* lstm_i2o_zp;
|
||||
TfLiteAffineQuantization* lstm_i2o_qparam;
|
||||
|
||||
int8_t* lstm_r2i_quant;
|
||||
float* lstm_r2i_scale;
|
||||
int* lstm_r2i_zp;
|
||||
TfLiteAffineQuantization* lstm_r2i_qparam;
|
||||
|
||||
int8_t* lstm_r2f_quant;
|
||||
float* lstm_r2f_scale;
|
||||
int* lstm_r2f_zp;
|
||||
TfLiteAffineQuantization* lstm_r2f_qparam;
|
||||
|
||||
int8_t* lstm_r2c_quant;
|
||||
float* lstm_r2c_scale;
|
||||
int* lstm_r2c_zp;
|
||||
TfLiteAffineQuantization* lstm_r2c_qparam;
|
||||
|
||||
int8_t* lstm_r2o_quant;
|
||||
float* lstm_r2o_scale;
|
||||
int* lstm_r2o_zp;
|
||||
TfLiteAffineQuantization* lstm_r2o_qparam;
|
||||
|
||||
int8_t* lstm_c2i_quant;
|
||||
float* lstm_c2i_scale;
|
||||
int* lstm_c2i_zp;
|
||||
TfLiteAffineQuantization* lstm_c2i_qparam;
|
||||
|
||||
int8_t* lstm_c2f_quant;
|
||||
float* lstm_c2f_scale;
|
||||
int* lstm_c2f_zp;
|
||||
TfLiteAffineQuantization* lstm_c2f_qparam;
|
||||
|
||||
int8_t* lstm_c2o_quant;
|
||||
float* lstm_c2o_scale;
|
||||
int* lstm_c2o_zp;
|
||||
TfLiteAffineQuantization* lstm_c2o_qparam;
|
||||
|
||||
int8_t* lstm_proj_w_quant;
|
||||
float* lstm_proj_w_scale;
|
||||
int* lstm_proj_w_zp;
|
||||
TfLiteAffineQuantization* lstm_proj_w_qparam;
|
||||
} LstmWeightQuantizationBuffers;
|
||||
|
||||
extern LstmIntegerTestConfig lstm_integer_no_peephole_config;
|
||||
|
||||
extern LstmIntegerTestConfig lstm_integer_peephole_config;
|
||||
|
||||
extern LstmFloatTestConfig lstm_no_cifg_no_peephole_no_proj_config;
|
||||
|
||||
extern LstmFloatTestConfig lstm_cifg_peephole_no_proj_config;
|
||||
|
||||
extern LstmFloatTestConfig lstm_no_cifg_peephole_proj_config;
|
||||
|
||||
extern LstmFloatTestConfig lstm_no_cifg_peephole_proj_bias_config;
|
||||
|
||||
extern LstmWeightQuantizationBuffers lstm_no_cifg_no_peephole_no_proj_buffers;
|
||||
|
||||
extern LstmWeightQuantizationBuffers lstm_cifg_peephole_no_proj_buffers;
|
||||
|
||||
extern LstmWeightQuantizationBuffers lstm_no_cifg_peephole_proj_buffers;
|
||||
|
||||
extern LstmFloatTestConfig cifg_peephole_no_proj_config_layer_norm;
|
||||
|
||||
#endif // !defined(XTENSA)
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_
|
||||
@@ -103,14 +103,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace unpack
|
||||
|
||||
TfLiteRegistration Register_UNPACK() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/unpack::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(nullptr, nullptr, unpack::Eval);
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -87,14 +87,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_VAR_HANDLE() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -127,14 +127,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_WHILE() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user