Rolling 20220924

This commit is contained in:
jomjol
2022-09-24 21:24:50 +02:00
parent a1691a77cf
commit 68e57d5ec4
133 changed files with 5485 additions and 1810 deletions

View File

@@ -92,6 +92,7 @@ AllOpsResolver::AllOpsResolver() {
AddResizeNearestNeighbor();
AddRound();
AddRsqrt();
AddSelectV2();
AddShape();
AddSin();
AddSlice();
@@ -102,6 +103,7 @@ AllOpsResolver::AllOpsResolver() {
AddSplitV();
AddSqrt();
AddSquare();
AddSquaredDifference();
AddSqueeze();
AddStridedSlice();
AddSub();
@@ -110,6 +112,7 @@ AllOpsResolver::AllOpsResolver() {
AddTanh();
AddTranspose();
AddTransposeConv();
AddUnidirectionalSequenceLSTM();
AddUnpack();
AddVarHandle();
AddWhile();

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -59,6 +59,19 @@ TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params,
TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node);
// Generic must define registration function.
TfLiteRegistration Register_ADD();
#if defined(CMSIS_NN)
TfLiteRegistration Register_ADD_INT8();
TfLiteRegistration Register_ADD_INT16();
#else
// Fallback registration
inline TfLiteRegistration Register_ADD_INT8() { return Register_ADD(); }
inline TfLiteRegistration Register_ADD_INT16() { return Register_ADD(); }
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_ADD_H_

View File

@@ -121,8 +121,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
context, kTfLiteActNone, output, &data->output_activation_min,
&data->output_activation_max));
} else {
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
@@ -198,8 +198,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} else if (output->type == kTfLiteInt8) {
EvalAddNQuantized<int8_t>(context, node, output);
} else {
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -70,21 +70,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Only float32, uint8_t and int8_t are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf(
"Only float32, uint8_t and int8_t are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context,
"Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context, "Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(axis->type));
MicroPrintf("Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(axis->type));
return kTfLiteError;
}

View File

@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -90,7 +90,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}

View File

@@ -118,8 +118,8 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -210,8 +210,8 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -288,8 +288,8 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -366,8 +366,8 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -444,8 +444,8 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -522,8 +522,8 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
output_data);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteInt8 ||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
input_type == kTfLiteInt64);
input_type == kTfLiteInt64 || input_type == kTfLiteBool);
// Output type must match input type
TF_LITE_ENSURE_EQ(context, output_type, input_type);
@@ -149,8 +149,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int num_dimensions = NumDimensions(input);
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
TF_LITE_KERNEL_LOG(
context,
MicroPrintf(
"Op Concatenation does not currently support num dimensions > %d "
"Tensor has %d dimensions.",
RuntimeShape::kMaxSmallSize, num_dimensions);
@@ -168,6 +167,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
switch (output_type) { // Already know in/outtypes are same.
case kTfLiteBool:
case kTfLiteFloat32:
case kTfLiteInt16:
case kTfLiteInt32:
@@ -205,9 +205,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
TF_LITE_KERNEL_LOG(
context, "Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}
@@ -238,11 +237,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16:
EvalUnquantized<int16_t>(context, node);
break;
case kTfLiteBool:
EvalUnquantized<bool>(context, node);
break;
default:
TF_LITE_KERNEL_LOG(
context, "Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}

View File

@@ -123,7 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (axis < 0) axis += input_shape.DimensionsCount();
if (axis < 0 || axis >= input_shape.DimensionsCount()) {
TF_LITE_KERNEL_LOG(context, "CUMSUM Invalid axis: %d", axis);
MicroPrintf("CUMSUM Invalid axis: %d", axis);
return kTfLiteError;
}
@@ -156,9 +156,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} break;
default: {
TF_LITE_KERNEL_LOG(context,
"CUMSUM only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("CUMSUM only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}

View File

@@ -124,9 +124,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(
context, "DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
MicroPrintf("DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@@ -82,8 +82,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -162,8 +162,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
#undef TF_LITE_DIV
} else {
TF_LITE_KERNEL_LOG(
context, "Unsupported combination of input and output types in DIV.");
MicroPrintf("Unsupported combination of input and output types in DIV.");
return kTfLiteError;
}
@@ -189,10 +188,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
input1, input2, output));
} else {
TF_LITE_KERNEL_LOG(context,
"DIV only supports FLOAT32, quantized INT8 "
"now, got type %s (%d).",
TfLiteTypeGetName(output->type), output->type);
MicroPrintf(
"DIV only supports FLOAT32, quantized INT8 "
"now, got type %s (%d).",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}

View File

@@ -90,8 +90,8 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
@@ -112,8 +112,8 @@ TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
@@ -317,8 +317,8 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
type);
break;
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
MicroPrintf("Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
break;
}
@@ -355,8 +355,8 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
elementwise::validate_input_func, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
MicroPrintf("Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
@@ -426,4 +426,4 @@ TfLiteRegistration Register_LOGICAL_NOT() {
} // namespace micro
} // namespace ops
} // namespace tflite
} // namespace tflite

View File

@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -136,9 +135,8 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(
context, "ELU only supports float32 and int8 currently, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("ELU only supports float32 and int8 currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN

View File

@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN

View File

@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN

View File

@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -63,8 +64,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
static_cast<size_t>(flat_size),
tflite::micro::GetTensorData<float>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) currently not supported by Exp.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) currently not supported by Exp.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -31,8 +31,7 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
int32_t* axis_value) {
const int axis_dims = (tflite::GetTensorShape(axis)).DimensionsCount();
if (axis_dims > 1) {
TF_LITE_KERNEL_LOG(context, "Axis has only one element for Expand_Dims.",
axis_dims);
MicroPrintf("Axis has only one element for Expand_Dims.", axis_dims);
return kTfLiteError;
}
@@ -41,9 +40,8 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
*axis_value = axis_ptr[0];
return kTfLiteOk;
} else {
TF_LITE_KERNEL_LOG(context,
"Axis type %s (%d) not supported by Expand_Dims.",
TfLiteTypeGetName(axis->type), axis->type);
MicroPrintf("Axis type %s (%d) not supported by Expand_Dims.",
TfLiteTypeGetName(axis->type), axis->type);
return kTfLiteError;
}
}
@@ -99,8 +97,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
output->type = input->type;
if (IsDynamicTensor(axis)) {
TF_LITE_KERNEL_LOG(context,
"DynamicTensor is not yet supported by Expand_Dims.");
MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
return kTfLiteError;
}
TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
@@ -135,8 +132,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input), flat_size);
} break;
default:
TF_LITE_KERNEL_LOG(
context,
MicroPrintf(
"Expand_Dims only currently supports int8 and float32, got %d.",
input->type);
return kTfLiteError;

View File

@@ -53,9 +53,8 @@ TfLiteStatus EnsureEq(TfLiteContext* context, const TfLiteIntArray* array,
case kTfLiteInt64:
return EnsureEqImpl<int64_t>(context, array, tensor);
default:
TF_LITE_KERNEL_LOG(context,
"cannot compare int array to tensor of type %d.",
tensor->type);
MicroPrintf("cannot compare int array to tensor of type %d.",
tensor->type);
return kTfLiteError;
}
}
@@ -123,9 +122,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
FillImpl<int8_t>(value, output);
break;
default:
TF_LITE_KERNEL_LOG(
context, "Fill only currently supports float32 for input 1, got %d.",
TfLiteTypeGetName(value->type));
MicroPrintf("Fill only currently supports float32 for input 1, got %d.",
TfLiteTypeGetName(value->type));
return kTfLiteError;
}

View File

@@ -74,7 +74,7 @@ TfLiteStatus EvalFloorDiv(TfLiteContext* context,
// Validate the denominator.
for (int i = 0; i < tflite::ElementCount(*input2->dims); ++i) {
if (std::equal_to<T>()(denominator_data[i], 0)) {
TF_LITE_KERNEL_LOG(context, "Division by 0");
MicroPrintf("Division by 0");
return kTfLiteError;
}
}
@@ -113,8 +113,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalFloorDiv<float>(context, input1, input2, output);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_DIV.",
TfLiteTypeGetName(input1->type));
MicroPrintf("Type '%s' is not supported by FLOOR_DIV.",
TfLiteTypeGetName(input1->type));
return kTfLiteError;
}
}

View File

@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_MOD.",
TfLiteTypeGetName(input1->type));
MicroPrintf("Type '%s' is not supported by FLOOR_MOD.",
TfLiteTypeGetName(input1->type));
return kTfLiteError;
}
}

View File

@@ -141,8 +141,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
default: {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
}

View File

@@ -118,9 +118,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Positions of type '%s' are not supported by gather.",
TfLiteTypeGetName(coords->type));
MicroPrintf("Positions of type '%s' are not supported by gather.",
TfLiteTypeGetName(coords->type));
return kTfLiteError;
break;
}
@@ -134,8 +133,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8:
break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
break;
}
@@ -207,8 +206,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return Gather<int8_t, int32_t>(params, input, coords, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
break;
}

View File

@@ -47,9 +47,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Params of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
MicroPrintf("Params of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
break;
}
@@ -57,9 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
@@ -67,22 +65,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int indices_rank = NumDimensions(indices);
const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
if (params_rank < 1) {
TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
MicroPrintf("Params must be at least a vector.");
return kTfLiteError;
}
if (indices_rank < 1) {
TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
MicroPrintf("Indices must be at least a vector.");
return kTfLiteError;
}
if (indices_nd > params_rank) {
TF_LITE_KERNEL_LOG(
context, "Index innermost dimension length must be <= params rank.");
MicroPrintf("Index innermost dimension length must be <= params rank.");
return kTfLiteError;
}
if (indices_nd > MAX_INDICES_ND) {
TF_LITE_KERNEL_LOG(context,
"Index innermost dimension length must not exceed %d.",
MAX_INDICES_ND);
MicroPrintf("Index innermost dimension length must not exceed %d.",
MAX_INDICES_ND);
return kTfLiteError;
}
@@ -171,13 +167,12 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
status = GatherNd<int8_t, IndicesT>(params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Params type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
MicroPrintf("Params type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
MicroPrintf("gather_nd index out of bounds");
}
return status;
}
@@ -195,9 +190,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalGatherNd<int32_t>(context, params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
}

View File

@@ -106,5 +106,17 @@ TfLiteStatus KernelRunner::Invoke() {
return kTfLiteOk;
}
TfLiteStatus KernelRunner::Free() {
tflite::micro::ClearBufferApi(&context_);
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
if (registration_.free == nullptr) {
MicroPrintf("TfLiteRegistration missing free function pointer!");
return kTfLiteError;
}
registration_.free(&context_, node_.user_data);
return kTfLiteOk;
}
} // namespace micro
} // namespace tflite
} // namespace tflite

View File

@@ -48,6 +48,11 @@ class KernelRunner {
// passed into the constructor of this class.
TfLiteStatus Invoke();
// Calls Free on a given TfLiteRegistration pointer(if it's implemented).
// After successful Free, kTfLiteOk status will be returned. If Free is not
// implemented for a given kernel kTfLiteError will be returned.
TfLiteStatus Free();
// Returns a pointer to the internal MockMicroGraph which KernelRunner uses
// to stub out MicroGraph methods and track invocations on each subgraph.
MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; }

View File

@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace micro {
@@ -39,9 +40,10 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
TfLiteRegistration RegisterOp(
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
void (*free)(TfLiteContext* context, void* buffer)) {
return {/*init=*/init,
/*free=*/nullptr,
/*free=*/free,
/*prepare=*/prepare,
/*invoke=*/invoke,
/*profiling_string=*/nullptr,
@@ -160,6 +162,46 @@ TfLiteStatus CopyOpInputsToOpOutputs(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
// Args:
// 1. int8_t tensor_data - int8_t buffer of unknown size who's data you'd
// like
// to print
// 2. int n_btyes - a small int representing number of bytes you want to
// print
// to debug output. It should always be <= tensor_data's size.
// 3. prefix - optional message you'd like to print before printing bytes
//
// Purpose:
// Function takes in paramaters above and prints n_bytes bytes from the
// tensor_data buffer. This can be use to debug the output of a model and it's
// op.
void PrintNBytes(const int8_t* tensor_data, int n_bytes, const char* prefix) {
if (prefix != nullptr) {
MicroPrintf("%s", prefix);
}
for (int i = 0; i < n_bytes; ++i) {
MicroPrintf(" %x", tensor_data[i]);
}
MicroPrintf("\n");
}
// same as the PrintNBytes above but the buffer needs to be extracted out of the
// TfLiteEvalTensor*
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
const char* prefix) {
const int8_t* tensor_data = tflite::micro::GetTensorData<int8_t>(tensor);
PrintNBytes(tensor_data, n_bytes, prefix);
}
// same as the PrintNBytes above but the buffer needs to be extracted out of the
// TfLiteEvalTensor*
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes, const char* prefix) {
const int8_t* tensor_data = tflite::GetTensorData<int8_t>(tensor);
PrintNBytes(tensor_data, n_bytes, prefix);
}
TfLiteStatus CopyOpInputsToSubgraphInputs(TfLiteContext* context,
TfLiteNode* node,
MicroGraph* graph_info,

View File

@@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace micro {
@@ -30,7 +32,20 @@ namespace micro {
TfLiteRegistration RegisterOp(
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
void (*free)(TfLiteContext* context, void* buffer) = nullptr);
// Prints out n bytes in a int8_t buffer as hex
void PrintNBytes(const int8_t* tensor_data, int n_bytes,
const char* prefix = nullptr);
// Prints out the the n bytes in a TfLiteEvalTensor as hex
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
const char* prefix = nullptr);
// Prints out the the n bytes in a TfLiteTensor as hex
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes,
const char* prefix = nullptr);
// Returns a mutable tensor for a given input index. is_variable must be checked
// during prepare when the full TfLiteTensor is available.

View File

@@ -125,9 +125,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
L2EvalFloat(*params, *input, &op_params, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"L2_POOL_2D only supports float32 currently, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("L2_POOL_2D only supports float32 currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -126,8 +126,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
TfLiteTypeGetName(output->type));
MicroPrintf("Output type is %s, requires float.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@@ -132,9 +132,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context,
"LOG_SOFTMAX only supports float32, int8, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("LOG_SOFTMAX only supports float32, int8, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -530,11 +532,20 @@ void CalculateLstmGateInteger8x8_16(
// Apply activation
switch (activation) {
case kTfLiteActSigmoid:
micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
break;
case kTfLiteActTanh:
micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
reference_integer_ops::Logistic(
0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
n_batch * n_cell /*NumElements(input->dims)*/,
gate /* tflite::micro::GetTensorData<int16_t>(input) */,
gate /*tflite::micro::GetTensorData<int16_t>(output) */);
break;
case kTfLiteActTanh: {
int32_t dims_data = n_batch * n_cell;
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
reference_integer_ops::Tanh(0, 0, tanh_inp_shape, gate, tanh_inp_shape,
gate);
} break;
default:
// Only Sigmoid or Tanh is used.
TFLITE_ASSERT_FALSE;
@@ -599,7 +610,7 @@ void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
// - scratch1: scratch area of size n_batch*n_cell
// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
void CalculateLstmOutputInteger8x8_16(
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
int n_batch, int n_cell, int n_output, int16_t* cell_state,
int32_t cell_state_scale, const int16_t* output_gate,
int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
const int8_t* projection_weights, int32_t proj_scale_a,
@@ -607,8 +618,23 @@ void CalculateLstmOutputInteger8x8_16(
int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
// Note: unlike float/hybrid, the activation is always Tanh.
micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
n_cell, scratch0);
{
int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3;
int32_t dims_data = n_batch * n_cell;
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
int32_t i;
tanh_input_left_shift = -tanh_input_left_shift;
for (i = 0; i < dims_data; i++) {
cell_state[i] = cell_state[i] >> tanh_input_left_shift;
}
tanh_input_left_shift = 0;
}
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
reference_integer_ops::Tanh(0, tanh_input_left_shift, tanh_inp_shape,
cell_state, tanh_inp_shape, scratch0);
}
micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
hidden_scale_b, n_batch, n_cell, hidden_zp,
scratch1);

View File

@@ -98,15 +98,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLiteOperation<int64_t, OpType>(context, node, op_context);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Type %s (%d) is not supported by Maximum/Minimum.",
TfLiteTypeGetName(op_context.output->type),
op_context.output->type);
MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.",
TfLiteTypeGetName(op_context.output->type),
op_context.output->type);
return kTfLiteError;
}
} else {
TF_LITE_KERNEL_LOG(context,
"Kernel type not supported by Maximum/Minimum.");
MicroPrintf("Kernel type not supported by Maximum/Minimum.");
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -72,6 +72,7 @@ TfLiteRegistration Register_READ_VARIABLE();
TfLiteRegistration Register_RELU();
TfLiteRegistration Register_RELU6();
TfLiteRegistration Register_RESIZE_BILINEAR();
TfLiteRegistration Register_SELECT_V2();
TfLiteRegistration Register_SHAPE();
TfLiteRegistration Register_SLICE();
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
@@ -79,6 +80,7 @@ TfLiteRegistration Register_SPACE_TO_DEPTH();
TfLiteRegistration Register_SQUARED_DIFFERENCE();
TfLiteRegistration Register_SQUEEZE();
TfLiteRegistration Register_SUB();
TfLiteRegistration Register_SUM();
TfLiteRegistration Register_SVDF();
TfLiteRegistration Register_TRANSPOSE();
TfLiteRegistration Register_TRANSPOSE_CONV();

View File

@@ -663,7 +663,7 @@ void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
const int16_t b = input_2[index];
int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
value -= output_zp;
value += output_zp;
value = std::min(std::max(static_cast<int32_t>(-128), value),
static_cast<int32_t>(127));

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -60,6 +60,15 @@ void EvalMulFloatReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output);
// Generic must define registration function.
TfLiteRegistration Register_MUL();
#if defined(CMSIS_NN)
TfLiteRegistration Register_MUL_INT8();
#else
// Fallback registration
inline TfLiteRegistration Register_MUL_INT8() { return Register_MUL(); }
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_MUL_H_

View File

@@ -41,8 +41,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<float>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
data->axis);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
TfLiteTypeGetName(output->type));
MicroPrintf("Type '%s' is not supported by pack.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}

View File

@@ -213,8 +213,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type %s not currently supported by Pad.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -45,8 +45,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
AveragePoolingEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
MicroPrintf("Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
@@ -73,8 +73,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
MaxPoolingEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
namespace tflite {
@@ -66,6 +67,19 @@ void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input,
TfLiteEvalTensor* output);
#if defined(CMSIS_NN)
TfLiteRegistration Register_AVERAGE_POOL_2D_INT8();
TfLiteRegistration Register_MAX_POOL_2D_INT8();
#else
inline TfLiteRegistration Register_AVERAGE_POOL_2D_INT8() {
return tflite::Register_AVERAGE_POOL_2D();
}
inline TfLiteRegistration Register_MAX_POOL_2D_INT8() {
return tflite::Register_MAX_POOL_2D();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_POOLING_H_

View File

@@ -61,9 +61,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
TF_LITE_KERNEL_LOG(
context, "Only float32 and uint8_t are supported currently, got %d.",
TfLiteTypeGetName(input->type));
MicroPrintf("Only float32 and uint8_t are supported currently, got %d.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -28,7 +28,7 @@ limitations under the License.
namespace tflite {
const int kMaxNumberOfAxis = 4;
const int kMaxNumberOfAxis = 5;
const int kMaxNumberOfReducedAxis = 2;
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,

View File

@@ -55,8 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
if (params->half_pixel_centers && params->align_corners) {
TF_LITE_KERNEL_LOG(
context, "If half_pixel_centers is True, align_corners must be False.");
MicroPrintf("If half_pixel_centers is True, align_corners must be False.");
return kTfLiteError;
}
@@ -100,8 +99,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float or int8.",
output->type);
MicroPrintf("Output type is %d, requires float or int8.", output->type);
return kTfLiteError;
}

View File

@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace ops {
@@ -55,7 +54,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output->type = input->type;
if (!IsConstantTensor(size)) {
TF_LITE_KERNEL_LOG(context, "Dynamic tensors are unsupported in tfmicro.");
MicroPrintf("Dynamic tensors are unsupported in tfmicro.");
return kTfLiteError;
}

View File

@@ -0,0 +1,196 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/select.h"
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
constexpr int kInputTensorCondition = 0;
constexpr int kInputTensorX = 1;
constexpr int kInputTensorY = 2;
constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
// True if input condition is scalar or input condition has rank one and
// matches the first dimension of other inputs.
bool has_low_rank_input_condition;
};
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* data = static_cast<OpData*>(
context->AllocatePersistentBuffer(context, sizeof(OpData)));
data->requires_broadcast = false;
data->has_low_rank_input_condition = false;
return data;
}
TfLiteStatus CheckBroadcastShape(TfLiteContext* context,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
const TfLiteTensor* input3,
const TfLiteIntArray* output_shape) {
const int dims1 = NumDimensions(input1);
const int dims2 = NumDimensions(input2);
const int dims3 = NumDimensions(input3);
const int out_dims = std::max(std::max(dims1, dims2), dims3);
TF_LITE_ENSURE_EQ(context, out_dims, output_shape->size);
for (int i = 0; i < out_dims; ++i) {
const int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
const int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
const int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
const int min_value = std::min(std::min(d1, d2), d3);
int max_value = std::max(std::max(d1, d2), d3);
// If one dimention is 0, others must be 0 or 1.
if (min_value == 0) max_value = 0;
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
!(d3 == 1 || d3 == max_value)) {
MicroPrintf("Given shapes are not broadcastable.");
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, output_shape->data[out_dims - i - 1], max_value);
}
return kTfLiteOk;
}
TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input_condition =
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
TfLiteTensor* input_x =
micro_context->AllocateTempInputTensor(node, kInputTensorX);
TfLiteTensor* input_y =
micro_context->AllocateTempInputTensor(node, kInputTensorY);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
// Input must be bool.
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type);
output->type = input_x->type;
// Respect the original output shape when there are mixed shapes to represent
// a scalar data.
if (GetTensorShape(input_condition).FlatSize() == 1 &&
GetTensorShape(input_x).FlatSize() == 1 &&
GetTensorShape(input_y).FlatSize() == 1 &&
GetTensorShape(output).FlatSize() == 1) {
return kTfLiteOk;
}
bool same_shape = HaveSameShapes(input_condition, input_x) &&
HaveSameShapes(input_x, input_y);
if (!same_shape) {
TF_LITE_ENSURE_OK(
context, CheckBroadcastShape(context, input_condition, input_x, input_y,
output->dims));
data->requires_broadcast = true;
}
micro_context->DeallocateTempTfLiteTensor(input_condition);
micro_context->DeallocateTempTfLiteTensor(input_x);
micro_context->DeallocateTempTfLiteTensor(input_y);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input_condition =
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
TfLiteTensor* input_x =
micro_context->AllocateTempInputTensor(node, kInputTensorX);
TfLiteTensor* input_y =
micro_context->AllocateTempInputTensor(node, kInputTensorY);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
#define TF_LITE_SELECT(type, op) \
reference_ops::op(GetTensorShape(input_condition), \
GetTensorData<bool>(input_condition), \
GetTensorShape(input_x), GetTensorData<type>(input_x), \
GetTensorShape(input_y), GetTensorData<type>(input_y), \
GetTensorShape(output), GetTensorData<type>(output));
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
case kTfLiteFloat32: \
TF_LITE_SELECT(float, op); \
break; \
case kTfLiteInt8: \
TF_LITE_SELECT(int8_t, op); \
break; \
case kTfLiteInt16: \
TF_LITE_SELECT(int16_t, op); \
break; \
default: \
MicroPrintf("Does not support type other than %s, but got %s", \
"int8|int16|float32", TfLiteTypeGetName(type)); \
return kTfLiteError; \
}
if (data->has_low_rank_input_condition) {
MicroPrintf("Not yet implemented.");
return kTfLiteError;
} else if (data->requires_broadcast) {
TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow);
} else {
TF_LITE_SWITCH(input_x->type, Select);
}
#undef TF_LITE_SELECT
#undef TF_LITE_SWITCH
micro_context->DeallocateTempTfLiteTensor(input_condition);
micro_context->DeallocateTempTfLiteTensor(input_x);
micro_context->DeallocateTempTfLiteTensor(input_y);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
// SelectV2 op selects values of 'x' if the corresponding value of 'condition'
// is true or the value of 'y' if false. There are valid condition input sizes:
//
// 1. Either the same shape (in which case the select is elementwise), or
// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
TfLiteRegistration Register_SELECT_V2() {
return tflite::micro::RegisterOp(tflite::SelectInit, tflite::SelectPrepare,
tflite::SelectEval);
}
} // namespace tflite

View File

@@ -47,8 +47,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (output->type != kTfLiteInt32) {
TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
MicroPrintf("Output type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
} else {
ExtractShape(input, tflite::micro::GetTensorData<int32_t>(output));

View File

@@ -106,8 +106,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetBeginAndSizeVectors<int64_t>(input->dims->size, begin, size,
op_params.begin, op_params.size);
} else {
TF_LITE_KERNEL_LOG(context, "Begin tensor type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Begin tensor type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}

View File

@@ -75,8 +75,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
}

View File

@@ -104,8 +104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -109,9 +109,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(
context, "SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(input->type));
MicroPrintf("SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}

View File

@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return SplitImpl<int32_t>(context, node, input, axis_value);
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type %s currently not supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -90,8 +90,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
if (input->type == kTfLiteString) {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}

View File

@@ -183,9 +183,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
break;
case kTfLiteBool:
reference_ops::StridedSlice(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<bool>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -82,7 +82,7 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node);
// (reference or optimized) must define this function.
TfLiteRegistration Register_SVDF();
#if defined(HEXAGON)
#if defined(HEXAGON) || defined(CMSIS_NN)
TfLiteRegistration Register_SVDF_INT8();
#else

View File

@@ -185,9 +185,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type), context);
return kTfLiteError;
}
}

View File

@@ -103,10 +103,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context,
"Type %s is currently not supported by Transpose. "
"Only float32 and int8 is supported",
TfLiteTypeGetName(input->type));
MicroPrintf(
"Type %s is currently not supported by Transpose. "
"Only float32 and int8 is supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}

View File

@@ -327,8 +327,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -24,7 +24,7 @@ namespace testing {
// kernel is reconciled with reference kernel
#if !defined(XTENSA)
typedef struct LstmIntegerTestConfig {
struct LstmIntegerTestConfig {
const int n_batch;
const int n_input;
const int n_cell;
@@ -100,9 +100,9 @@ typedef struct LstmIntegerTestConfig {
bool asymmetric_quantize_inputs;
const float ranges[25][2];
} LstmIntegerTestConfig;
};
typedef struct LstmFloatTestConfig {
struct LstmFloatTestConfig {
const int n_batch;
const int n_input;
const int n_cell;
@@ -153,9 +153,9 @@ typedef struct LstmFloatTestConfig {
float* output;
const float* expected_output_original;
float* expected_output;
} LstmFloatTestConfig;
};
typedef struct LstmWeightQuantizationBuffers {
struct LstmWeightQuantizationBuffers {
int8_t* lstm_i2i_quant;
float* lstm_i2i_scale;
int* lstm_i2i_zp;
@@ -215,7 +215,7 @@ typedef struct LstmWeightQuantizationBuffers {
float* lstm_proj_w_scale;
int* lstm_proj_w_zp;
TfLiteAffineQuantization* lstm_proj_w_qparam;
} LstmWeightQuantizationBuffers;
};
extern LstmIntegerTestConfig lstm_integer_no_peephole_config;

View File

@@ -91,8 +91,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
MicroPrintf("Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@@ -70,10 +70,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
resetZeros(tflite::micro::GetTensorData<float>(output), flat_size);
break;
default:
TF_LITE_KERNEL_LOG(context,
"ZerosLike only currently supports int64, int32, "
"and float32, got %d.",
input->type);
MicroPrintf(
"ZerosLike only currently supports int64, int32, "
"and float32, got %d.",
input->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@@ -323,8 +323,12 @@ TfLiteStatus AllocationInfoBuilder::GetOfflinePlannedOffsets(
if (model_->metadata()) {
for (size_t i = 0; i < model_->metadata()->size(); ++i) {
auto metadata = model_->metadata()->Get(i);
if (strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
strlen(kOfflineMemAllocMetadata)) == 0) {
const size_t metadata_name_size = (size_t)metadata->name()->size();
if ((strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
std::min(metadata_name_size,
strlen(kOfflineMemAllocMetadata))) == 0) &&
metadata_name_size == strlen(kOfflineMemAllocMetadata)) {
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
model_->buffers();
auto* buffer = (*buffers)[metadata->buffer()];

View File

@@ -509,14 +509,15 @@ TfLiteStatus MicroAllocator::FinishModelAllocation(
return kTfLiteError;
}
// Allocate scratch buffer metadata and buffers for variable tensors.
// Allocate scratch buffer metadata.
TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles(
scratch_buffer_handles, scratch_buffer_request_count_));
// Allocate buffers for variable tensors.
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
subgraph_idx++) {
const SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx);
TFLITE_DCHECK(subgraph != nullptr);
TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles(
scratch_buffer_handles, scratch_buffer_request_count_));
TF_LITE_ENSURE_STATUS(AllocateVariables(
subgraph, subgraph_allocations[subgraph_idx].tensors));
}

View File

@@ -54,7 +54,7 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
// of a sequential, array of ScratchBufferHandle allocations in the tail
// section. These allocations are indexed by the request API defined in the
// TfLiteContext struct.
typedef struct {
struct ScratchBufferRequest {
// Number of bytes required by the buffer. The actual allocated size might be
// greater than `bytes` due to buffer alignment.
size_t bytes;
@@ -63,29 +63,29 @@ typedef struct {
// have `before` = node_idx and `after` = node_idx.
int node_idx;
int subgraph_idx;
} ScratchBufferRequest;
};
} // namespace internal
typedef struct {
struct NodeAndRegistration {
TfLiteNode node;
const TfLiteRegistration* registration;
} NodeAndRegistration;
};
// Holds a pointer to a buffer for a scratch buffer requested by a kernel during
// the model prepare stage. This struct is allocated in-place and allows for
// quick pointer-indexed lookup for speed during model inference.
typedef struct {
struct ScratchBufferHandle {
// Pointer to location of the scratch buffer:
uint8_t* data;
} ScratchBufferHandle;
};
// Stores all per-subgraph allocations. This includes the node and registration
// array, tensor list and scratch buffer handles for each subgraph.
typedef struct {
// array, and tensor list for each subgraph.
struct SubgraphAllocations {
NodeAndRegistration* node_and_registrations;
TfLiteEvalTensor* tensors;
} SubgraphAllocations;
};
// Allocator responsible for allocating memory for all intermediate tensors
// necessary to invoke a model.

View File

@@ -317,7 +317,17 @@ TfLiteTensor* MicroInterpreter::output(size_t index) {
}
return output_tensors_[index];
}
// Repurposing free subgraphs to reset state for some ops for now
// will reset api is made. See b/220940833#comment25 for more context.
TfLiteStatus MicroInterpreter::Reset() {
TfLiteStatus status = graph_.FreeSubgraphs();
if (status != kTfLiteOk) {
return status;
}
return graph_.ResetVariableTensors();
}
// TODO: remove this API completely in favor of MicroInterpreter::Reset
TfLiteStatus MicroInterpreter::ResetVariableTensors() {
return graph_.ResetVariableTensors();
}

View File

@@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/lite/portable_type_to_tflitetype.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Copied from tensorflow/lite/version.h to avoid a dependency chain into
/// Copied from tensorflow/lite/version.h to avoid a dependency chain into
// tensorflow/core.
#define TFLITE_SCHEMA_VERSION (3)
@@ -116,6 +116,11 @@ class MicroInterpreter {
return nullptr;
}
// Reset the state to be what you would expect when the interpreter is first
// created. i.e. after Init and Prepare is called for the very first time.
TfLiteStatus Reset();
// TODO(b/244457206): remove this in favor of Reset()
// Reset all variable tensors to the default value.
TfLiteStatus ResetVariableTensors();

View File

@@ -24,11 +24,13 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/kernels/add.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
#include "tensorflow/lite/micro/kernels/ethosu.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/kernels/pooling.h"
#include "tensorflow/lite/micro/kernels/reduce.h"
#include "tensorflow/lite/micro/kernels/softmax.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
@@ -140,9 +142,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::Register_ASSIGN_VARIABLE(), ParseAssignVariable);
}
TfLiteStatus AddAveragePool2D() {
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
tflite::Register_AVERAGE_POOL_2D(), ParsePool);
TfLiteStatus AddAveragePool2D(
const TfLiteRegistration& registration = Register_AVERAGE_POOL_2D()) {
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, registration, ParsePool);
}
TfLiteStatus AddBatchToSpaceNd() {
@@ -363,9 +365,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
}
TfLiteStatus AddMaxPool2D() {
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
tflite::Register_MAX_POOL_2D(), ParsePool);
TfLiteStatus AddMaxPool2D(
const TfLiteRegistration& registration = Register_MAX_POOL_2D()) {
return AddBuiltin(BuiltinOperator_MAX_POOL_2D, registration, ParsePool);
}
TfLiteStatus AddMirrorPad() {
@@ -382,8 +384,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::ops::micro::Register_MINIMUM(), ParseMinimum);
}
TfLiteStatus AddMul() {
return AddBuiltin(BuiltinOperator_MUL, tflite::Register_MUL(), ParseMul);
TfLiteStatus AddMul(const TfLiteRegistration& registration = Register_MUL()) {
return AddBuiltin(BuiltinOperator_MUL, registration, ParseMul);
}
TfLiteStatus AddNeg() {
@@ -466,6 +468,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
}
TfLiteStatus AddSelectV2() {
return AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2(),
ParseSelectV2);
}
TfLiteStatus AddShape() {
return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape);
}
@@ -519,6 +526,12 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::ops::micro::Register_SQUARE(), ParseSquare);
}
TfLiteStatus AddSquaredDifference() {
return AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE,
tflite::Register_SQUARED_DIFFERENCE(),
ParseSquaredDifference);
}
TfLiteStatus AddStridedSlice() {
return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
tflite::ops::micro::Register_STRIDED_SLICE(),

View File

@@ -16,6 +16,7 @@ limitations under the License.
#include <cinttypes>
#include <cstdint>
#include <cstring>
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
@@ -67,4 +68,48 @@ void MicroProfiler::LogCsv() const {
#endif
}
void MicroProfiler::LogTicksPerTagCsv() {
#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
MicroPrintf(
"\"Unique Tag\",\"Total ticks across all events with that tag.\"");
int total_ticks = 0;
for (int i = 0; i < num_events_; ++i) {
uint32_t ticks = end_ticks_[i] - start_ticks_[i];
TFLITE_DCHECK(tags_[i] != nullptr);
int position = FindExistingOrNextPosition(tags_[i]);
TFLITE_DCHECK(position >= 0);
total_ticks_per_tag[position].tag = tags_[i];
total_ticks_per_tag[position].ticks =
total_ticks_per_tag[position].ticks + ticks;
total_ticks += ticks;
}
for (int i = 0; i < num_events_; ++i) {
TicksPerTag each_tag_entry = total_ticks_per_tag[i];
if (each_tag_entry.tag == nullptr) {
break;
}
MicroPrintf("%s, %d", each_tag_entry.tag, each_tag_entry.ticks);
}
MicroPrintf("total number of ticks, %d", total_ticks);
#endif
}
// This method finds a particular array element in the total_ticks_per_tag array
// with the matching tag_name passed in the method. If it can find a
// matching array element that has the same tag_name, then it will return the
// position of the matching element. But if it unable to find a matching element
// with the given tag_name, it will return the next available empty position
// from the array.
int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) {
int pos = 0;
for (; pos < num_events_; pos++) {
TicksPerTag each_tag_entry = total_ticks_per_tag[pos];
if (each_tag_entry.tag == nullptr ||
strcmp(each_tag_entry.tag, tag_name) == 0) {
return pos;
}
}
return pos < num_events_ ? pos : -1;
}
} // namespace tflite

View File

@@ -61,6 +61,11 @@ class MicroProfiler {
// Separated Value) form.
void LogCsv() const;
// Prints total ticks for each unique tag in CSV format.
// Output will have one row for each unique tag along with the
// total ticks summed across all events with that particular tag.
void LogTicksPerTagCsv();
private:
// Maximum number of events that this class can keep track of. If we call
// AddEvent more than kMaxEvents number of times, then the oldest event's
@@ -72,6 +77,17 @@ class MicroProfiler {
uint32_t end_ticks_[kMaxEvents];
int num_events_ = 0;
struct TicksPerTag {
const char* tag;
uint32_t ticks;
};
// In practice, the number of tags will be much lower than the number of
// events. But it is theoretically possible that each event to be unique and
// hence we allow total_ticks_per_tag to have kMaxEvents entries.
TicksPerTag total_ticks_per_tag[kMaxEvents] = {};
int FindExistingOrNextPosition(const char* tag_name);
TF_LITE_REMOVE_VIRTUAL_DELETE;
};

View File

@@ -163,10 +163,12 @@ TfLiteStatus RecordingMicroAllocator::AllocateNodeAndRegistrations(
TfLiteStatus status =
MicroAllocator::AllocateNodeAndRegistrations(model, subgraph_allocations);
RecordAllocationUsage(allocations,
recorded_node_and_registration_array_data_);
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
subgraph_idx++) {
RecordAllocationUsage(allocations,
recorded_node_and_registration_array_data_);
// The allocation count in SingleArenaBufferAllocator will only be 1. To
// provide better logging, decrement by 1 and add in the actual number of
// operators used in the graph: The allocation for this recording will
@@ -176,8 +178,12 @@ TfLiteStatus RecordingMicroAllocator::AllocateNodeAndRegistrations(
// potential for fragmentation, manually adjust the accounting by
// decrementing by 1 and adding the actual number of nodes used in the
// graph:
recorded_node_and_registration_array_data_.count +=
model->subgraphs()->Get(subgraph_idx)->operators()->size() - 1;
if (model->subgraphs()->Get(subgraph_idx)->operators()) {
recorded_node_and_registration_array_data_.count +=
model->subgraphs()->Get(subgraph_idx)->operators()->size() - 1;
} else {
recorded_node_and_registration_array_data_.count -= 1;
}
}
return status;
}
@@ -188,9 +194,11 @@ TfLiteStatus RecordingMicroAllocator::AllocateTfLiteEvalTensors(
TfLiteStatus status =
MicroAllocator::AllocateTfLiteEvalTensors(model, subgraph_allocations);
RecordAllocationUsage(allocations, recorded_tflite_eval_tensor_data_);
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
subgraph_idx++) {
RecordAllocationUsage(allocations, recorded_tflite_eval_tensor_data_);
// The allocation for this recording will always be 1. This is because the
// parent class mallocs one large allocation for the number of tensors in
// the graph (e.g. sizeof(TfLiteEvalTensor) * num_tensors). To prevent extra