Rolling 20220924

This commit is contained in:
jomjol
2022-09-24 21:24:50 +02:00
parent a1691a77cf
commit 68e57d5ec4
133 changed files with 5485 additions and 1810 deletions

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -59,6 +59,19 @@ TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params,
TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node);
// Generic must define registration function.
TfLiteRegistration Register_ADD();
#if defined(CMSIS_NN)
TfLiteRegistration Register_ADD_INT8();
TfLiteRegistration Register_ADD_INT16();
#else
// Fallback registration
inline TfLiteRegistration Register_ADD_INT8() { return Register_ADD(); }
inline TfLiteRegistration Register_ADD_INT16() { return Register_ADD(); }
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_ADD_H_

View File

@@ -121,8 +121,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
context, kTfLiteActNone, output, &data->output_activation_min,
&data->output_activation_max));
} else {
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
@@ -198,8 +198,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} else if (output->type == kTfLiteInt8) {
EvalAddNQuantized<int8_t>(context, node, output);
} else {
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -70,21 +70,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Only float32, uint8_t and int8_t are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf(
"Only float32, uint8_t and int8_t are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context,
"Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context, "Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(axis->type));
MicroPrintf("Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(axis->type));
return kTfLiteError;
}

View File

@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -90,7 +90,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}

View File

@@ -118,8 +118,8 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -210,8 +210,8 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -288,8 +288,8 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -366,8 +366,8 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -444,8 +444,8 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -522,8 +522,8 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteInt8 ||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
input_type == kTfLiteInt64);
input_type == kTfLiteInt64 || input_type == kTfLiteBool);
// Output type must match input type
TF_LITE_ENSURE_EQ(context, output_type, input_type);
@@ -149,8 +149,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int num_dimensions = NumDimensions(input);
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
TF_LITE_KERNEL_LOG(
context,
MicroPrintf(
"Op Concatenation does not currently support num dimensions > %d "
"Tensor has %d dimensions.",
RuntimeShape::kMaxSmallSize, num_dimensions);
@@ -168,6 +167,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
switch (output_type) { // Already know in/outtypes are same.
case kTfLiteBool:
case kTfLiteFloat32:
case kTfLiteInt16:
case kTfLiteInt32:
@@ -205,9 +205,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
TF_LITE_KERNEL_LOG(
context, "Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}
@@ -238,11 +237,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16:
EvalUnquantized<int16_t>(context, node);
break;
case kTfLiteBool:
EvalUnquantized<bool>(context, node);
break;
default:
TF_LITE_KERNEL_LOG(
context, "Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}

View File

@@ -123,7 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (axis < 0) axis += input_shape.DimensionsCount();
if (axis < 0 || axis >= input_shape.DimensionsCount()) {
TF_LITE_KERNEL_LOG(context, "CUMSUM Invalid axis: %d", axis);
MicroPrintf("CUMSUM Invalid axis: %d", axis);
return kTfLiteError;
}
@@ -156,9 +156,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} break;
default: {
TF_LITE_KERNEL_LOG(context,
"CUMSUM only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("CUMSUM only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}

View File

@@ -124,9 +124,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(
context, "DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@@ -82,8 +82,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -162,8 +162,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
#undef TF_LITE_DIV
} else {
TF_LITE_KERNEL_LOG(
context, "Unsupported combination of input and output types in DIV.");
MicroPrintf("Unsupported combination of input and output types in DIV.");
return kTfLiteError;
}
@@ -189,10 +188,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
input1, input2, output));
} else {
TF_LITE_KERNEL_LOG(context,
"DIV only supports FLOAT32, quantized INT8 "
"now, got type %s (%d).",
TfLiteTypeGetName(output->type), output->type);
MicroPrintf(
"DIV only supports FLOAT32, quantized INT8 "
"now, got type %s (%d).",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}

View File

@@ -90,8 +90,8 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
@@ -112,8 +112,8 @@ TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
@@ -317,8 +317,8 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
type);
break;
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
MicroPrintf("Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
break;
}
@@ -355,8 +355,8 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
elementwise::validate_input_func, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
MicroPrintf("Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
@@ -426,4 +426,4 @@ TfLiteRegistration Register_LOGICAL_NOT() {
} // namespace micro
} // namespace ops
} // namespace tflite
} // namespace tflite

View File

@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -136,9 +135,8 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(
context, "ELU only supports float32 and int8 currently, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("ELU only supports float32 and int8 currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN

View File

@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN

View File

@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN

View File

@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -63,8 +64,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
static_cast<size_t>(flat_size),
tflite::micro::GetTensorData<float>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) currently not supported by Exp.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) currently not supported by Exp.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -31,8 +31,7 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
int32_t* axis_value) {
const int axis_dims = (tflite::GetTensorShape(axis)).DimensionsCount();
if (axis_dims > 1) {
TF_LITE_KERNEL_LOG(context, "Axis has only one element for Expand_Dims.",
axis_dims);
MicroPrintf("Axis has only one element for Expand_Dims.", axis_dims);
return kTfLiteError;
}
@@ -41,9 +40,8 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
*axis_value = axis_ptr[0];
return kTfLiteOk;
} else {
TF_LITE_KERNEL_LOG(context,
"Axis type %s (%d) not supported by Expand_Dims.",
TfLiteTypeGetName(axis->type), axis->type);
MicroPrintf("Axis type %s (%d) not supported by Expand_Dims.",
TfLiteTypeGetName(axis->type), axis->type);
return kTfLiteError;
}
}
@@ -99,8 +97,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
output->type = input->type;
if (IsDynamicTensor(axis)) {
TF_LITE_KERNEL_LOG(context,
"DynamicTensor is not yet supported by Expand_Dims.");
MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
return kTfLiteError;
}
TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
@@ -135,8 +132,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input), flat_size);
} break;
default:
TF_LITE_KERNEL_LOG(
context,
MicroPrintf(
"Expand_Dims only currently supports int8 and float32, got %d.",
input->type);
return kTfLiteError;

View File

@@ -53,9 +53,8 @@ TfLiteStatus EnsureEq(TfLiteContext* context, const TfLiteIntArray* array,
case kTfLiteInt64:
return EnsureEqImpl<int64_t>(context, array, tensor);
default:
TF_LITE_KERNEL_LOG(context,
"cannot compare int array to tensor of type %d.",
tensor->type);
MicroPrintf("cannot compare int array to tensor of type %d.",
tensor->type);
return kTfLiteError;
}
}
@@ -123,9 +122,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
FillImpl<int8_t>(value, output);
break;
default:
TF_LITE_KERNEL_LOG(
context, "Fill only currently supports float32 for input 1, got %d.",
TfLiteTypeGetName(value->type));
MicroPrintf("Fill only currently supports float32 for input 1, got %d.",
TfLiteTypeGetName(value->type));
return kTfLiteError;
}

View File

@@ -74,7 +74,7 @@ TfLiteStatus EvalFloorDiv(TfLiteContext* context,
// Validate the denominator.
for (int i = 0; i < tflite::ElementCount(*input2->dims); ++i) {
if (std::equal_to<T>()(denominator_data[i], 0)) {
TF_LITE_KERNEL_LOG(context, "Division by 0");
MicroPrintf("Division by 0");
return kTfLiteError;
}
}
@@ -113,8 +113,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalFloorDiv<float>(context, input1, input2, output);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_DIV.",
TfLiteTypeGetName(input1->type));
MicroPrintf("Type '%s' is not supported by FLOOR_DIV.",
TfLiteTypeGetName(input1->type));
return kTfLiteError;
}
}

View File

@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_MOD.",
TfLiteTypeGetName(input1->type));
MicroPrintf("Type '%s' is not supported by FLOOR_MOD.",
TfLiteTypeGetName(input1->type));
return kTfLiteError;
}
}

View File

@@ -141,8 +141,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
default: {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
}

View File

@@ -118,9 +118,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Positions of type '%s' are not supported by gather.",
TfLiteTypeGetName(coords->type));
MicroPrintf("Positions of type '%s' are not supported by gather.",
TfLiteTypeGetName(coords->type));
return kTfLiteError;
break;
}
@@ -134,8 +133,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8:
break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
break;
}
@@ -207,8 +206,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return Gather<int8_t, int32_t>(params, input, coords, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
break;
}

View File

@@ -47,9 +47,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Params of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
MicroPrintf("Params of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
break;
}
@@ -57,9 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
@@ -67,22 +65,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int indices_rank = NumDimensions(indices);
const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
if (params_rank < 1) {
TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
MicroPrintf("Params must be at least a vector.");
return kTfLiteError;
}
if (indices_rank < 1) {
TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
MicroPrintf("Indices must be at least a vector.");
return kTfLiteError;
}
if (indices_nd > params_rank) {
TF_LITE_KERNEL_LOG(
context, "Index innermost dimension length must be <= params rank.");
MicroPrintf("Index innermost dimension length must be <= params rank.");
return kTfLiteError;
}
if (indices_nd > MAX_INDICES_ND) {
TF_LITE_KERNEL_LOG(context,
"Index innermost dimension length must not exceed %d.",
MAX_INDICES_ND);
MicroPrintf("Index innermost dimension length must not exceed %d.",
MAX_INDICES_ND);
return kTfLiteError;
}
@@ -171,13 +167,12 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
status = GatherNd<int8_t, IndicesT>(params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Params type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
MicroPrintf("Params type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
MicroPrintf("gather_nd index out of bounds");
}
return status;
}
@@ -195,9 +190,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalGatherNd<int32_t>(context, params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
}

View File

@@ -106,5 +106,17 @@ TfLiteStatus KernelRunner::Invoke() {
return kTfLiteOk;
}
TfLiteStatus KernelRunner::Free() {
tflite::micro::ClearBufferApi(&context_);
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
if (registration_.free == nullptr) {
MicroPrintf("TfLiteRegistration missing free function pointer!");
return kTfLiteError;
}
registration_.free(&context_, node_.user_data);
return kTfLiteOk;
}
} // namespace micro
} // namespace tflite
} // namespace tflite

View File

@@ -48,6 +48,11 @@ class KernelRunner {
// passed into the constructor of this class.
TfLiteStatus Invoke();
// Calls Free on a given TfLiteRegistration pointer(if it's implemented).
// After successful Free, kTfLiteOk status will be returned. If Free is not
// implemented for a given kernel kTfLiteError will be returned.
TfLiteStatus Free();
// Returns a pointer to the internal MockMicroGraph which KernelRunner uses
// to stub out MicroGraph methods and track invocations on each subgraph.
MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; }

View File

@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace micro {
@@ -39,9 +40,10 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
TfLiteRegistration RegisterOp(
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
void (*free)(TfLiteContext* context, void* buffer)) {
return {/*init=*/init,
/*free=*/nullptr,
/*free=*/free,
/*prepare=*/prepare,
/*invoke=*/invoke,
/*profiling_string=*/nullptr,
@@ -160,6 +162,46 @@ TfLiteStatus CopyOpInputsToOpOutputs(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
// Args:
// 1. int8_t tensor_data - int8_t buffer of unknown size who's data you'd
// like
// to print
// 2. int n_btyes - a small int representing number of bytes you want to
// print
// to debug output. It should always be <= tensor_data's size.
// 3. prefix - optional message you'd like to print before printing bytes
//
// Purpose:
// Function takes in paramaters above and prints n_bytes bytes from the
// tensor_data buffer. This can be use to debug the output of a model and it's
// op.
void PrintNBytes(const int8_t* tensor_data, int n_bytes, const char* prefix) {
if (prefix != nullptr) {
MicroPrintf("%s", prefix);
}
for (int i = 0; i < n_bytes; ++i) {
MicroPrintf(" %x", tensor_data[i]);
}
MicroPrintf("\n");
}
// same as the PrintNBytes above but the buffer needs to be extracted out of the
// TfLiteEvalTensor*
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
const char* prefix) {
const int8_t* tensor_data = tflite::micro::GetTensorData<int8_t>(tensor);
PrintNBytes(tensor_data, n_bytes, prefix);
}
// same as the PrintNBytes above but the buffer needs to be extracted out of the
// TfLiteEvalTensor*
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes, const char* prefix) {
const int8_t* tensor_data = tflite::GetTensorData<int8_t>(tensor);
PrintNBytes(tensor_data, n_bytes, prefix);
}
TfLiteStatus CopyOpInputsToSubgraphInputs(TfLiteContext* context,
TfLiteNode* node,
MicroGraph* graph_info,

View File

@@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace micro {
@@ -30,7 +32,20 @@ namespace micro {
TfLiteRegistration RegisterOp(
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
void (*free)(TfLiteContext* context, void* buffer) = nullptr);
// Prints out n bytes in a int8_t buffer as hex
void PrintNBytes(const int8_t* tensor_data, int n_bytes,
const char* prefix = nullptr);
// Prints out the the n bytes in a TfLiteEvalTensor as hex
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
const char* prefix = nullptr);
// Prints out the the n bytes in a TfLiteTensor as hex
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes,
const char* prefix = nullptr);
// Returns a mutable tensor for a given input index. is_variable must be checked
// during prepare when the full TfLiteTensor is available.

View File

@@ -125,9 +125,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
L2EvalFloat(*params, *input, &op_params, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"L2_POOL_2D only supports float32 currently, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("L2_POOL_2D only supports float32 currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -126,8 +126,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
TfLiteTypeGetName(output->type));
MicroPrintf("Output type is %s, requires float.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@@ -132,9 +132,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context,
"LOG_SOFTMAX only supports float32, int8, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("LOG_SOFTMAX only supports float32, int8, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -530,11 +532,20 @@ void CalculateLstmGateInteger8x8_16(
// Apply activation
switch (activation) {
case kTfLiteActSigmoid:
micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
break;
case kTfLiteActTanh:
micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
reference_integer_ops::Logistic(
0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
n_batch * n_cell /*NumElements(input->dims)*/,
gate /* tflite::micro::GetTensorData<int16_t>(input) */,
gate /*tflite::micro::GetTensorData<int16_t>(output) */);
break;
case kTfLiteActTanh: {
int32_t dims_data = n_batch * n_cell;
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
reference_integer_ops::Tanh(0, 0, tanh_inp_shape, gate, tanh_inp_shape,
gate);
} break;
default:
// Only Sigmoid or Tanh is used.
TFLITE_ASSERT_FALSE;
@@ -599,7 +610,7 @@ void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
// - scratch1: scratch area of size n_batch*n_cell
// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
void CalculateLstmOutputInteger8x8_16(
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
int n_batch, int n_cell, int n_output, int16_t* cell_state,
int32_t cell_state_scale, const int16_t* output_gate,
int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
const int8_t* projection_weights, int32_t proj_scale_a,
@@ -607,8 +618,23 @@ void CalculateLstmOutputInteger8x8_16(
int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
// Note: unlike float/hybrid, the activation is always Tanh.
micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
n_cell, scratch0);
{
int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3;
int32_t dims_data = n_batch * n_cell;
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
int32_t i;
tanh_input_left_shift = -tanh_input_left_shift;
for (i = 0; i < dims_data; i++) {
cell_state[i] = cell_state[i] >> tanh_input_left_shift;
}
tanh_input_left_shift = 0;
}
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
reference_integer_ops::Tanh(0, tanh_input_left_shift, tanh_inp_shape,
cell_state, tanh_inp_shape, scratch0);
}
micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
hidden_scale_b, n_batch, n_cell, hidden_zp,
scratch1);

View File

@@ -98,15 +98,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLiteOperation<int64_t, OpType>(context, node, op_context);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Type %s (%d) is not supported by Maximum/Minimum.",
TfLiteTypeGetName(op_context.output->type),
op_context.output->type);
MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.",
TfLiteTypeGetName(op_context.output->type),
op_context.output->type);
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context,
"Kernel type not supported by Maximum/Minimum.");
MicroPrintf("Kernel type not supported by Maximum/Minimum.");
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -72,6 +72,7 @@ TfLiteRegistration Register_READ_VARIABLE();
TfLiteRegistration Register_RELU();
TfLiteRegistration Register_RELU6();
TfLiteRegistration Register_RESIZE_BILINEAR();
TfLiteRegistration Register_SELECT_V2();
TfLiteRegistration Register_SHAPE();
TfLiteRegistration Register_SLICE();
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
@@ -79,6 +80,7 @@ TfLiteRegistration Register_SPACE_TO_DEPTH();
TfLiteRegistration Register_SQUARED_DIFFERENCE();
TfLiteRegistration Register_SQUEEZE();
TfLiteRegistration Register_SUB();
TfLiteRegistration Register_SUM();
TfLiteRegistration Register_SVDF();
TfLiteRegistration Register_TRANSPOSE();
TfLiteRegistration Register_TRANSPOSE_CONV();

View File

@@ -663,7 +663,7 @@ void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
const int16_t b = input_2[index];
int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
value -= output_zp;
value += output_zp;
value = std::min(std::max(static_cast<int32_t>(-128), value),
static_cast<int32_t>(127));

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -60,6 +60,15 @@ void EvalMulFloatReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output);
// Generic must define registration function.
TfLiteRegistration Register_MUL();
#if defined(CMSIS_NN)
TfLiteRegistration Register_MUL_INT8();
#else
// Fallback registration
inline TfLiteRegistration Register_MUL_INT8() { return Register_MUL(); }
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_MUL_H_

View File

@@ -41,8 +41,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<float>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
data->axis);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
TfLiteTypeGetName(output->type));
MicroPrintf("Type '%s' is not supported by pack.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}

View File

@@ -213,8 +213,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type %s not currently supported by Pad.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -45,8 +45,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
AveragePoolingEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
MicroPrintf("Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
@@ -73,8 +73,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
MaxPoolingEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
namespace tflite {
@@ -66,6 +67,19 @@ void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input,
TfLiteEvalTensor* output);
#if defined(CMSIS_NN)
TfLiteRegistration Register_AVERAGE_POOL_2D_INT8();
TfLiteRegistration Register_MAX_POOL_2D_INT8();
#else
inline TfLiteRegistration Register_AVERAGE_POOL_2D_INT8() {
return tflite::Register_AVERAGE_POOL_2D();
}
inline TfLiteRegistration Register_MAX_POOL_2D_INT8() {
return tflite::Register_MAX_POOL_2D();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_POOLING_H_

View File

@@ -61,9 +61,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
TF_LITE_KERNEL_LOG(
context, "Only float32 and uint8_t are supported currently, got %d.",
TfLiteTypeGetName(input->type));
MicroPrintf("Only float32 and uint8_t are supported currently, got %d.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -28,7 +28,7 @@ limitations under the License.
namespace tflite {
const int kMaxNumberOfAxis = 4;
const int kMaxNumberOfAxis = 5;
const int kMaxNumberOfReducedAxis = 2;
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,

View File

@@ -55,8 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
if (params->half_pixel_centers && params->align_corners) {
TF_LITE_KERNEL_LOG(
context, "If half_pixel_centers is True, align_corners must be False.");
MicroPrintf("If half_pixel_centers is True, align_corners must be False.");
return kTfLiteError;
}
@@ -100,8 +99,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float or int8.",
output->type);
MicroPrintf("Output type is %d, requires float or int8.", output->type);
return kTfLiteError;
}

View File

@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace ops {
@@ -55,7 +54,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output->type = input->type;
if (!IsConstantTensor(size)) {
TF_LITE_KERNEL_LOG(context, "Dynamic tensors are unsupported in tfmicro.");
MicroPrintf("Dynamic tensors are unsupported in tfmicro.");
return kTfLiteError;
}

View File

@@ -0,0 +1,196 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/select.h"
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
constexpr int kInputTensorCondition = 0;
constexpr int kInputTensorX = 1;
constexpr int kInputTensorY = 2;
constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
// True if input condition is scalar or input condition has rank one and
// matches the first dimension of other inputs.
bool has_low_rank_input_condition;
};
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* data = static_cast<OpData*>(
context->AllocatePersistentBuffer(context, sizeof(OpData)));
data->requires_broadcast = false;
data->has_low_rank_input_condition = false;
return data;
}
TfLiteStatus CheckBroadcastShape(TfLiteContext* context,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
const TfLiteTensor* input3,
const TfLiteIntArray* output_shape) {
const int dims1 = NumDimensions(input1);
const int dims2 = NumDimensions(input2);
const int dims3 = NumDimensions(input3);
const int out_dims = std::max(std::max(dims1, dims2), dims3);
TF_LITE_ENSURE_EQ(context, out_dims, output_shape->size);
for (int i = 0; i < out_dims; ++i) {
const int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
const int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
const int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
const int min_value = std::min(std::min(d1, d2), d3);
int max_value = std::max(std::max(d1, d2), d3);
// If one dimention is 0, others must be 0 or 1.
if (min_value == 0) max_value = 0;
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
!(d3 == 1 || d3 == max_value)) {
MicroPrintf("Given shapes are not broadcastable.");
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, output_shape->data[out_dims - i - 1], max_value);
}
return kTfLiteOk;
}
TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input_condition =
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
TfLiteTensor* input_x =
micro_context->AllocateTempInputTensor(node, kInputTensorX);
TfLiteTensor* input_y =
micro_context->AllocateTempInputTensor(node, kInputTensorY);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
// Input must be bool.
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type);
output->type = input_x->type;
// Respect the original output shape when there are mixed shapes to represent
// a scalar data.
if (GetTensorShape(input_condition).FlatSize() == 1 &&
GetTensorShape(input_x).FlatSize() == 1 &&
GetTensorShape(input_y).FlatSize() == 1 &&
GetTensorShape(output).FlatSize() == 1) {
return kTfLiteOk;
}
bool same_shape = HaveSameShapes(input_condition, input_x) &&
HaveSameShapes(input_x, input_y);
if (!same_shape) {
TF_LITE_ENSURE_OK(
context, CheckBroadcastShape(context, input_condition, input_x, input_y,
output->dims));
data->requires_broadcast = true;
}
micro_context->DeallocateTempTfLiteTensor(input_condition);
micro_context->DeallocateTempTfLiteTensor(input_x);
micro_context->DeallocateTempTfLiteTensor(input_y);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input_condition =
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
TfLiteTensor* input_x =
micro_context->AllocateTempInputTensor(node, kInputTensorX);
TfLiteTensor* input_y =
micro_context->AllocateTempInputTensor(node, kInputTensorY);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
#define TF_LITE_SELECT(type, op) \
reference_ops::op(GetTensorShape(input_condition), \
GetTensorData<bool>(input_condition), \
GetTensorShape(input_x), GetTensorData<type>(input_x), \
GetTensorShape(input_y), GetTensorData<type>(input_y), \
GetTensorShape(output), GetTensorData<type>(output));
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
case kTfLiteFloat32: \
TF_LITE_SELECT(float, op); \
break; \
case kTfLiteInt8: \
TF_LITE_SELECT(int8_t, op); \
break; \
case kTfLiteInt16: \
TF_LITE_SELECT(int16_t, op); \
break; \
default: \
MicroPrintf("Does not support type other than %s, but got %s", \
"int8|int16|float32", TfLiteTypeGetName(type)); \
return kTfLiteError; \
}
if (data->has_low_rank_input_condition) {
MicroPrintf("Not yet implemented.");
return kTfLiteError;
} else if (data->requires_broadcast) {
TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow);
} else {
TF_LITE_SWITCH(input_x->type, Select);
}
#undef TF_LITE_SELECT
#undef TF_LITE_SWITCH
micro_context->DeallocateTempTfLiteTensor(input_condition);
micro_context->DeallocateTempTfLiteTensor(input_x);
micro_context->DeallocateTempTfLiteTensor(input_y);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
// SelectV2 op selects values of 'x' if the corresponding value of 'condition'
// is true or the value of 'y' if false. There are valid condition input sizes:
//
// 1. Either the same shape (in which case the select is elementwise), or
// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
TfLiteRegistration Register_SELECT_V2() {
return tflite::micro::RegisterOp(tflite::SelectInit, tflite::SelectPrepare,
tflite::SelectEval);
}
} // namespace tflite

View File

@@ -47,8 +47,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (output->type != kTfLiteInt32) {
TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
MicroPrintf("Output type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
} else {
ExtractShape(input, tflite::micro::GetTensorData<int32_t>(output));

View File

@@ -106,8 +106,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetBeginAndSizeVectors<int64_t>(input->dims->size, begin, size,
op_params.begin, op_params.size);
} else {
TF_LITE_KERNEL_LOG(context, "Begin tensor type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Begin tensor type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}

View File

@@ -75,8 +75,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
}

View File

@@ -104,8 +104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -109,9 +109,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(
context, "SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}

View File

@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return SplitImpl<int32_t>(context, node, input, axis_value);
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type %s currently not supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -90,8 +90,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
if (input->type == kTfLiteString) {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}

View File

@@ -183,9 +183,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
break;
case kTfLiteBool:
reference_ops::StridedSlice(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<bool>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -82,7 +82,7 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node);
// (reference or optimized) must define this function.
TfLiteRegistration Register_SVDF();
#if defined(HEXAGON)
#if defined(HEXAGON) || defined(CMSIS_NN)
TfLiteRegistration Register_SVDF_INT8();
#else

View File

@@ -185,9 +185,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type), context);
return kTfLiteError;
}
}

View File

@@ -103,10 +103,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context,
"Type %s is currently not supported by Transpose. "
"Only float32 and int8 is supported",
TfLiteTypeGetName(input->type));
MicroPrintf(
"Type %s is currently not supported by Transpose. "
"Only float32 and int8 is supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}

View File

@@ -327,8 +327,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -24,7 +24,7 @@ namespace testing {
// kernel is reconciled with reference kernel
#if !defined(XTENSA)
typedef struct LstmIntegerTestConfig {
struct LstmIntegerTestConfig {
const int n_batch;
const int n_input;
const int n_cell;
@@ -100,9 +100,9 @@ typedef struct LstmIntegerTestConfig {
bool asymmetric_quantize_inputs;
const float ranges[25][2];
} LstmIntegerTestConfig;
};
typedef struct LstmFloatTestConfig {
struct LstmFloatTestConfig {
const int n_batch;
const int n_input;
const int n_cell;
@@ -153,9 +153,9 @@ typedef struct LstmFloatTestConfig {
float* output;
const float* expected_output_original;
float* expected_output;
} LstmFloatTestConfig;
};
typedef struct LstmWeightQuantizationBuffers {
struct LstmWeightQuantizationBuffers {
int8_t* lstm_i2i_quant;
float* lstm_i2i_scale;
int* lstm_i2i_zp;
@@ -215,7 +215,7 @@ typedef struct LstmWeightQuantizationBuffers {
float* lstm_proj_w_scale;
int* lstm_proj_w_zp;
TfLiteAffineQuantization* lstm_proj_w_qparam;
} LstmWeightQuantizationBuffers;
};
extern LstmIntegerTestConfig lstm_integer_no_peephole_config;

View File

@@ -91,8 +91,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -70,10 +70,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
resetZeros(tflite::micro::GetTensorData<float>(output), flat_size);
break;
default:
TF_LITE_KERNEL_LOG(context,
"ZerosLike only currently supports int64, int32, "
"and float32, got %d.",
input->type);
MicroPrintf(
"ZerosLike only currently supports int64, int32, "
"and float32, got %d.",
input->type);
return kTfLiteError;
}
return kTfLiteOk;