mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-10 21:46:55 +03:00
Rolling 20220924
This commit is contained in:
@@ -92,6 +92,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddResizeNearestNeighbor();
|
||||
AddRound();
|
||||
AddRsqrt();
|
||||
AddSelectV2();
|
||||
AddShape();
|
||||
AddSin();
|
||||
AddSlice();
|
||||
@@ -102,6 +103,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddSplitV();
|
||||
AddSqrt();
|
||||
AddSquare();
|
||||
AddSquaredDifference();
|
||||
AddSqueeze();
|
||||
AddStridedSlice();
|
||||
AddSub();
|
||||
@@ -110,6 +112,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddTanh();
|
||||
AddTranspose();
|
||||
AddTransposeConv();
|
||||
AddUnidirectionalSequenceLSTM();
|
||||
AddUnpack();
|
||||
AddVarHandle();
|
||||
AddWhile();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -59,6 +59,19 @@ TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params,
|
||||
|
||||
TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// Generic must define registration function.
|
||||
TfLiteRegistration Register_ADD();
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_ADD_INT8();
|
||||
|
||||
TfLiteRegistration Register_ADD_INT16();
|
||||
#else
|
||||
// Fallback registration
|
||||
inline TfLiteRegistration Register_ADD_INT8() { return Register_ADD(); }
|
||||
|
||||
inline TfLiteRegistration Register_ADD_INT16() { return Register_ADD(); }
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_ADD_H_
|
||||
|
||||
@@ -121,8 +121,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
context, kTfLiteActNone, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -198,8 +198,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} else if (output->type == kTfLiteInt8) {
|
||||
EvalAddNQuantized<int8_t>(context, node, output);
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -70,21 +70,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only float32, uint8_t and int8_t are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf(
|
||||
"Only float32, uint8_t and int8_t are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(axis->type));
|
||||
MicroPrintf("Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(axis->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -90,7 +90,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -118,8 +118,8 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -210,8 +210,8 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -288,8 +288,8 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -366,8 +366,8 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -444,8 +444,8 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -522,8 +522,8 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input_type == kTfLiteFloat32 || input_type == kTfLiteInt8 ||
|
||||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
|
||||
input_type == kTfLiteInt64);
|
||||
input_type == kTfLiteInt64 || input_type == kTfLiteBool);
|
||||
|
||||
// Output type must match input type
|
||||
TF_LITE_ENSURE_EQ(context, output_type, input_type);
|
||||
@@ -149,8 +149,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
int num_dimensions = NumDimensions(input);
|
||||
|
||||
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
MicroPrintf(
|
||||
"Op Concatenation does not currently support num dimensions > %d "
|
||||
"Tensor has %d dimensions.",
|
||||
RuntimeShape::kMaxSmallSize, num_dimensions);
|
||||
@@ -168,6 +167,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
switch (output_type) { // Already know in/outtypes are same.
|
||||
case kTfLiteBool:
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteInt16:
|
||||
case kTfLiteInt32:
|
||||
@@ -205,9 +205,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -238,11 +237,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt16:
|
||||
EvalUnquantized<int16_t>(context, node);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
EvalUnquantized<bool>(context, node);
|
||||
break;
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (axis < 0) axis += input_shape.DimensionsCount();
|
||||
|
||||
if (axis < 0 || axis >= input_shape.DimensionsCount()) {
|
||||
TF_LITE_KERNEL_LOG(context, "CUMSUM Invalid axis: %d", axis);
|
||||
MicroPrintf("CUMSUM Invalid axis: %d", axis);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -156,9 +156,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} break;
|
||||
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"CUMSUM only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("CUMSUM only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,9 +124,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -162,8 +162,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
#undef TF_LITE_DIV
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Unsupported combination of input and output types in DIV.");
|
||||
MicroPrintf("Unsupported combination of input and output types in DIV.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -189,10 +188,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
|
||||
input1, input2, output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"DIV only supports FLOAT32, quantized INT8 "
|
||||
"now, got type %s (%d).",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
MicroPrintf(
|
||||
"DIV only supports FLOAT32, quantized INT8 "
|
||||
"now, got type %s (%d).",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -90,8 +90,8 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -112,8 +112,8 @@ TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -317,8 +317,8 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
type);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
MicroPrintf("Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -355,8 +355,8 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
elementwise::validate_input_func, type);
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
MicroPrintf("Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -426,4 +426,4 @@ TfLiteRegistration Register_LOGICAL_NOT() {
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
|
||||
@@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@@ -136,9 +135,8 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "ELU only supports float32 and int8 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("ELU only supports float32 and int8 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
|
||||
@@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
|
||||
@@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
|
||||
@@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@@ -63,8 +64,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
static_cast<size_t>(flat_size),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) currently not supported by Exp.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) currently not supported by Exp.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -31,8 +31,7 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
|
||||
int32_t* axis_value) {
|
||||
const int axis_dims = (tflite::GetTensorShape(axis)).DimensionsCount();
|
||||
if (axis_dims > 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "Axis has only one element for Expand_Dims.",
|
||||
axis_dims);
|
||||
MicroPrintf("Axis has only one element for Expand_Dims.", axis_dims);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -41,9 +40,8 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
|
||||
*axis_value = axis_ptr[0];
|
||||
return kTfLiteOk;
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Axis type %s (%d) not supported by Expand_Dims.",
|
||||
TfLiteTypeGetName(axis->type), axis->type);
|
||||
MicroPrintf("Axis type %s (%d) not supported by Expand_Dims.",
|
||||
TfLiteTypeGetName(axis->type), axis->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -99,8 +97,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
output->type = input->type;
|
||||
if (IsDynamicTensor(axis)) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"DynamicTensor is not yet supported by Expand_Dims.");
|
||||
MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
|
||||
@@ -135,8 +132,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(input), flat_size);
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
MicroPrintf(
|
||||
"Expand_Dims only currently supports int8 and float32, got %d.",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
|
||||
@@ -53,9 +53,8 @@ TfLiteStatus EnsureEq(TfLiteContext* context, const TfLiteIntArray* array,
|
||||
case kTfLiteInt64:
|
||||
return EnsureEqImpl<int64_t>(context, array, tensor);
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"cannot compare int array to tensor of type %d.",
|
||||
tensor->type);
|
||||
MicroPrintf("cannot compare int array to tensor of type %d.",
|
||||
tensor->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -123,9 +122,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
FillImpl<int8_t>(value, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Fill only currently supports float32 for input 1, got %d.",
|
||||
TfLiteTypeGetName(value->type));
|
||||
MicroPrintf("Fill only currently supports float32 for input 1, got %d.",
|
||||
TfLiteTypeGetName(value->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ TfLiteStatus EvalFloorDiv(TfLiteContext* context,
|
||||
// Validate the denominator.
|
||||
for (int i = 0; i < tflite::ElementCount(*input2->dims); ++i) {
|
||||
if (std::equal_to<T>()(denominator_data[i], 0)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Division by 0");
|
||||
MicroPrintf("Division by 0");
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -113,8 +113,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalFloorDiv<float>(context, input1, input2, output);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_DIV.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
MicroPrintf("Type '%s' is not supported by FLOOR_DIV.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_MOD.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
MicroPrintf("Type '%s' is not supported by FLOOR_MOD.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,8 +141,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +118,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Positions of type '%s' are not supported by gather.",
|
||||
TfLiteTypeGetName(coords->type));
|
||||
MicroPrintf("Positions of type '%s' are not supported by gather.",
|
||||
TfLiteTypeGetName(coords->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -134,8 +133,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -207,8 +206,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Gather<int8_t, int32_t>(params, input, coords, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -47,9 +47,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Params of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
MicroPrintf("Params of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -57,9 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -67,22 +65,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int indices_rank = NumDimensions(indices);
|
||||
const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
|
||||
if (params_rank < 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
|
||||
MicroPrintf("Params must be at least a vector.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices_rank < 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
|
||||
MicroPrintf("Indices must be at least a vector.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices_nd > params_rank) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Index innermost dimension length must be <= params rank.");
|
||||
MicroPrintf("Index innermost dimension length must be <= params rank.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices_nd > MAX_INDICES_ND) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Index innermost dimension length must not exceed %d.",
|
||||
MAX_INDICES_ND);
|
||||
MicroPrintf("Index innermost dimension length must not exceed %d.",
|
||||
MAX_INDICES_ND);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -171,13 +167,12 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
|
||||
status = GatherNd<int8_t, IndicesT>(params, indices, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Params type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
MicroPrintf("Params type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (status != kTfLiteOk) {
|
||||
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
|
||||
MicroPrintf("gather_nd index out of bounds");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
@@ -195,9 +190,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalGatherNd<int32_t>(context, params, indices, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,5 +106,17 @@ TfLiteStatus KernelRunner::Invoke() {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus KernelRunner::Free() {
|
||||
tflite::micro::ClearBufferApi(&context_);
|
||||
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
|
||||
|
||||
if (registration_.free == nullptr) {
|
||||
MicroPrintf("TfLiteRegistration missing free function pointer!");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
registration_.free(&context_, node_.user_data);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace micro
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
@@ -48,6 +48,11 @@ class KernelRunner {
|
||||
// passed into the constructor of this class.
|
||||
TfLiteStatus Invoke();
|
||||
|
||||
// Calls Free on a given TfLiteRegistration pointer(if it's implemented).
|
||||
// After successful Free, kTfLiteOk status will be returned. If Free is not
|
||||
// implemented for a given kernel kTfLiteError will be returned.
|
||||
TfLiteStatus Free();
|
||||
|
||||
// Returns a pointer to the internal MockMicroGraph which KernelRunner uses
|
||||
// to stub out MicroGraph methods and track invocations on each subgraph.
|
||||
MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; }
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -39,9 +40,10 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
|
||||
TfLiteRegistration RegisterOp(
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
|
||||
void (*free)(TfLiteContext* context, void* buffer)) {
|
||||
return {/*init=*/init,
|
||||
/*free=*/nullptr,
|
||||
/*free=*/free,
|
||||
/*prepare=*/prepare,
|
||||
/*invoke=*/invoke,
|
||||
/*profiling_string=*/nullptr,
|
||||
@@ -160,6 +162,46 @@ TfLiteStatus CopyOpInputsToOpOutputs(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Args:
|
||||
// 1. int8_t tensor_data - int8_t buffer of unknown size who's data you'd
|
||||
// like
|
||||
// to print
|
||||
// 2. int n_btyes - a small int representing number of bytes you want to
|
||||
// print
|
||||
// to debug output. It should always be <= tensor_data's size.
|
||||
// 3. prefix - optional message you'd like to print before printing bytes
|
||||
//
|
||||
// Purpose:
|
||||
// Function takes in paramaters above and prints n_bytes bytes from the
|
||||
// tensor_data buffer. This can be use to debug the output of a model and it's
|
||||
// op.
|
||||
|
||||
void PrintNBytes(const int8_t* tensor_data, int n_bytes, const char* prefix) {
|
||||
if (prefix != nullptr) {
|
||||
MicroPrintf("%s", prefix);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_bytes; ++i) {
|
||||
MicroPrintf(" %x", tensor_data[i]);
|
||||
}
|
||||
MicroPrintf("\n");
|
||||
}
|
||||
|
||||
// same as the PrintNBytes above but the buffer needs to be extracted out of the
|
||||
// TfLiteEvalTensor*
|
||||
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
|
||||
const char* prefix) {
|
||||
const int8_t* tensor_data = tflite::micro::GetTensorData<int8_t>(tensor);
|
||||
PrintNBytes(tensor_data, n_bytes, prefix);
|
||||
}
|
||||
|
||||
// same as the PrintNBytes above but the buffer needs to be extracted out of the
|
||||
// TfLiteEvalTensor*
|
||||
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes, const char* prefix) {
|
||||
const int8_t* tensor_data = tflite::GetTensorData<int8_t>(tensor);
|
||||
PrintNBytes(tensor_data, n_bytes, prefix);
|
||||
}
|
||||
|
||||
TfLiteStatus CopyOpInputsToSubgraphInputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
|
||||
@@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -30,7 +32,20 @@ namespace micro {
|
||||
TfLiteRegistration RegisterOp(
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
|
||||
void (*free)(TfLiteContext* context, void* buffer) = nullptr);
|
||||
|
||||
// Prints out n bytes in a int8_t buffer as hex
|
||||
void PrintNBytes(const int8_t* tensor_data, int n_bytes,
|
||||
const char* prefix = nullptr);
|
||||
|
||||
// Prints out the the n bytes in a TfLiteEvalTensor as hex
|
||||
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
|
||||
const char* prefix = nullptr);
|
||||
|
||||
// Prints out the the n bytes in a TfLiteTensor as hex
|
||||
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes,
|
||||
const char* prefix = nullptr);
|
||||
|
||||
// Returns a mutable tensor for a given input index. is_variable must be checked
|
||||
// during prepare when the full TfLiteTensor is available.
|
||||
|
||||
@@ -125,9 +125,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
L2EvalFloat(*params, *input, &op_params, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"L2_POOL_2D only supports float32 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("L2_POOL_2D only supports float32 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -126,8 +126,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Output type is %s, requires float.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -132,9 +132,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"LOG_SOFTMAX only supports float32, int8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("LOG_SOFTMAX only supports float32, int8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
@@ -530,11 +532,20 @@ void CalculateLstmGateInteger8x8_16(
|
||||
// Apply activation
|
||||
switch (activation) {
|
||||
case kTfLiteActSigmoid:
|
||||
micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
|
||||
break;
|
||||
case kTfLiteActTanh:
|
||||
micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
|
||||
|
||||
reference_integer_ops::Logistic(
|
||||
0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
|
||||
n_batch * n_cell /*NumElements(input->dims)*/,
|
||||
gate /* tflite::micro::GetTensorData<int16_t>(input) */,
|
||||
gate /*tflite::micro::GetTensorData<int16_t>(output) */);
|
||||
|
||||
break;
|
||||
case kTfLiteActTanh: {
|
||||
int32_t dims_data = n_batch * n_cell;
|
||||
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
|
||||
reference_integer_ops::Tanh(0, 0, tanh_inp_shape, gate, tanh_inp_shape,
|
||||
gate);
|
||||
} break;
|
||||
default:
|
||||
// Only Sigmoid or Tanh is used.
|
||||
TFLITE_ASSERT_FALSE;
|
||||
@@ -599,7 +610,7 @@ void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
|
||||
// - scratch1: scratch area of size n_batch*n_cell
|
||||
// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
|
||||
void CalculateLstmOutputInteger8x8_16(
|
||||
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
|
||||
int n_batch, int n_cell, int n_output, int16_t* cell_state,
|
||||
int32_t cell_state_scale, const int16_t* output_gate,
|
||||
int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
|
||||
const int8_t* projection_weights, int32_t proj_scale_a,
|
||||
@@ -607,8 +618,23 @@ void CalculateLstmOutputInteger8x8_16(
|
||||
int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
|
||||
int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
|
||||
// Note: unlike float/hybrid, the activation is always Tanh.
|
||||
micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
|
||||
n_cell, scratch0);
|
||||
|
||||
{
|
||||
int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3;
|
||||
int32_t dims_data = n_batch * n_cell;
|
||||
if (tanh_input_left_shift < 0) /* handling negative shift value */
|
||||
{
|
||||
int32_t i;
|
||||
tanh_input_left_shift = -tanh_input_left_shift;
|
||||
for (i = 0; i < dims_data; i++) {
|
||||
cell_state[i] = cell_state[i] >> tanh_input_left_shift;
|
||||
}
|
||||
tanh_input_left_shift = 0;
|
||||
}
|
||||
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
|
||||
reference_integer_ops::Tanh(0, tanh_input_left_shift, tanh_inp_shape,
|
||||
cell_state, tanh_inp_shape, scratch0);
|
||||
}
|
||||
micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
|
||||
hidden_scale_b, n_batch, n_cell, hidden_zp,
|
||||
scratch1);
|
||||
|
||||
@@ -98,15 +98,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLiteOperation<int64_t, OpType>(context, node, op_context);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s (%d) is not supported by Maximum/Minimum.",
|
||||
TfLiteTypeGetName(op_context.output->type),
|
||||
op_context.output->type);
|
||||
MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.",
|
||||
TfLiteTypeGetName(op_context.output->type),
|
||||
op_context.output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Kernel type not supported by Maximum/Minimum.");
|
||||
MicroPrintf("Kernel type not supported by Maximum/Minimum.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -72,6 +72,7 @@ TfLiteRegistration Register_READ_VARIABLE();
|
||||
TfLiteRegistration Register_RELU();
|
||||
TfLiteRegistration Register_RELU6();
|
||||
TfLiteRegistration Register_RESIZE_BILINEAR();
|
||||
TfLiteRegistration Register_SELECT_V2();
|
||||
TfLiteRegistration Register_SHAPE();
|
||||
TfLiteRegistration Register_SLICE();
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
||||
@@ -79,6 +80,7 @@ TfLiteRegistration Register_SPACE_TO_DEPTH();
|
||||
TfLiteRegistration Register_SQUARED_DIFFERENCE();
|
||||
TfLiteRegistration Register_SQUEEZE();
|
||||
TfLiteRegistration Register_SUB();
|
||||
TfLiteRegistration Register_SUM();
|
||||
TfLiteRegistration Register_SVDF();
|
||||
TfLiteRegistration Register_TRANSPOSE();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV();
|
||||
|
||||
@@ -663,7 +663,7 @@ void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
const int16_t b = input_2[index];
|
||||
int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
|
||||
value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
|
||||
value -= output_zp;
|
||||
value += output_zp;
|
||||
value = std::min(std::max(static_cast<int32_t>(-128), value),
|
||||
static_cast<int32_t>(127));
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -60,6 +60,15 @@ void EvalMulFloatReference(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input2,
|
||||
TfLiteEvalTensor* output);
|
||||
|
||||
// Generic must define registration function.
|
||||
TfLiteRegistration Register_MUL();
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_MUL_INT8();
|
||||
#else
|
||||
// Fallback registration
|
||||
inline TfLiteRegistration Register_MUL_INT8() { return Register_MUL(); }
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_MUL_H_
|
||||
|
||||
@@ -41,8 +41,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->axis);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Type '%s' is not supported by pack.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,8 +213,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} break;
|
||||
default:
|
||||
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type %s not currently supported by Pad.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -45,8 +45,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
AveragePoolingEvalQuantized(context, node, params, data, input, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Input type %s is not currently supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -73,8 +73,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
MaxPoolingEvalQuantized(context, node, params, data, input, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@@ -66,6 +67,19 @@ void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input,
|
||||
TfLiteEvalTensor* output);
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_AVERAGE_POOL_2D_INT8();
|
||||
|
||||
TfLiteRegistration Register_MAX_POOL_2D_INT8();
|
||||
#else
|
||||
inline TfLiteRegistration Register_AVERAGE_POOL_2D_INT8() {
|
||||
return tflite::Register_AVERAGE_POOL_2D();
|
||||
}
|
||||
|
||||
inline TfLiteRegistration Register_MAX_POOL_2D_INT8() {
|
||||
return tflite::Register_MAX_POOL_2D();
|
||||
}
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_POOLING_H_
|
||||
|
||||
@@ -61,9 +61,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Only float32 and uint8_t are supported currently, got %d.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Only float32 and uint8_t are supported currently, got %d.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
const int kMaxNumberOfAxis = 4;
|
||||
const int kMaxNumberOfAxis = 5;
|
||||
const int kMaxNumberOfReducedAxis = 2;
|
||||
|
||||
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
@@ -55,8 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
|
||||
if (params->half_pixel_centers && params->align_corners) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "If half_pixel_centers is True, align_corners must be False.");
|
||||
MicroPrintf("If half_pixel_centers is True, align_corners must be False.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -100,8 +99,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float or int8.",
|
||||
output->type);
|
||||
MicroPrintf("Output type is %d, requires float or int8.", output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@@ -55,7 +54,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output->type = input->type;
|
||||
|
||||
if (!IsConstantTensor(size)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Dynamic tensors are unsupported in tfmicro.");
|
||||
MicroPrintf("Dynamic tensors are unsupported in tfmicro.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/select.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
constexpr int kInputTensorCondition = 0;
|
||||
constexpr int kInputTensorX = 1;
|
||||
constexpr int kInputTensorY = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
struct OpData {
|
||||
bool requires_broadcast;
|
||||
// True if input condition is scalar or input condition has rank one and
|
||||
// matches the first dimension of other inputs.
|
||||
bool has_low_rank_input_condition;
|
||||
};
|
||||
|
||||
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
auto* data = static_cast<OpData*>(
|
||||
context->AllocatePersistentBuffer(context, sizeof(OpData)));
|
||||
data->requires_broadcast = false;
|
||||
data->has_low_rank_input_condition = false;
|
||||
return data;
|
||||
}
|
||||
|
||||
TfLiteStatus CheckBroadcastShape(TfLiteContext* context,
|
||||
const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2,
|
||||
const TfLiteTensor* input3,
|
||||
const TfLiteIntArray* output_shape) {
|
||||
const int dims1 = NumDimensions(input1);
|
||||
const int dims2 = NumDimensions(input2);
|
||||
const int dims3 = NumDimensions(input3);
|
||||
const int out_dims = std::max(std::max(dims1, dims2), dims3);
|
||||
TF_LITE_ENSURE_EQ(context, out_dims, output_shape->size);
|
||||
|
||||
for (int i = 0; i < out_dims; ++i) {
|
||||
const int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
|
||||
const int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
||||
const int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
|
||||
const int min_value = std::min(std::min(d1, d2), d3);
|
||||
int max_value = std::max(std::max(d1, d2), d3);
|
||||
// If one dimention is 0, others must be 0 or 1.
|
||||
if (min_value == 0) max_value = 0;
|
||||
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
|
||||
!(d3 == 1 || d3 == max_value)) {
|
||||
MicroPrintf("Given shapes are not broadcastable.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, output_shape->data[out_dims - i - 1], max_value);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input_condition =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
|
||||
|
||||
TfLiteTensor* input_x =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorX);
|
||||
|
||||
TfLiteTensor* input_y =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorY);
|
||||
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
// Input must be bool.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type);
|
||||
output->type = input_x->type;
|
||||
|
||||
// Respect the original output shape when there are mixed shapes to represent
|
||||
// a scalar data.
|
||||
if (GetTensorShape(input_condition).FlatSize() == 1 &&
|
||||
GetTensorShape(input_x).FlatSize() == 1 &&
|
||||
GetTensorShape(input_y).FlatSize() == 1 &&
|
||||
GetTensorShape(output).FlatSize() == 1) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool same_shape = HaveSameShapes(input_condition, input_x) &&
|
||||
HaveSameShapes(input_x, input_y);
|
||||
if (!same_shape) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CheckBroadcastShape(context, input_condition, input_x, input_y,
|
||||
output->dims));
|
||||
data->requires_broadcast = true;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_condition);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_x);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_y);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input_condition =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
|
||||
|
||||
TfLiteTensor* input_x =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorX);
|
||||
|
||||
TfLiteTensor* input_y =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorY);
|
||||
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
#define TF_LITE_SELECT(type, op) \
|
||||
reference_ops::op(GetTensorShape(input_condition), \
|
||||
GetTensorData<bool>(input_condition), \
|
||||
GetTensorShape(input_x), GetTensorData<type>(input_x), \
|
||||
GetTensorShape(input_y), GetTensorData<type>(input_y), \
|
||||
GetTensorShape(output), GetTensorData<type>(output));
|
||||
|
||||
#define TF_LITE_SWITCH(type, op) \
|
||||
switch (type) { \
|
||||
case kTfLiteFloat32: \
|
||||
TF_LITE_SELECT(float, op); \
|
||||
break; \
|
||||
case kTfLiteInt8: \
|
||||
TF_LITE_SELECT(int8_t, op); \
|
||||
break; \
|
||||
case kTfLiteInt16: \
|
||||
TF_LITE_SELECT(int16_t, op); \
|
||||
break; \
|
||||
default: \
|
||||
MicroPrintf("Does not support type other than %s, but got %s", \
|
||||
"int8|int16|float32", TfLiteTypeGetName(type)); \
|
||||
return kTfLiteError; \
|
||||
}
|
||||
|
||||
if (data->has_low_rank_input_condition) {
|
||||
MicroPrintf("Not yet implemented.");
|
||||
return kTfLiteError;
|
||||
} else if (data->requires_broadcast) {
|
||||
TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow);
|
||||
} else {
|
||||
TF_LITE_SWITCH(input_x->type, Select);
|
||||
}
|
||||
|
||||
#undef TF_LITE_SELECT
|
||||
#undef TF_LITE_SWITCH
|
||||
micro_context->DeallocateTempTfLiteTensor(input_condition);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_x);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_y);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// SelectV2 op selects values of 'x' if the corresponding value of 'condition'
|
||||
// is true or the value of 'y' if false. There are valid condition input sizes:
|
||||
//
|
||||
// 1. Either the same shape (in which case the select is elementwise), or
|
||||
// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
|
||||
TfLiteRegistration Register_SELECT_V2() {
|
||||
return tflite::micro::RegisterOp(tflite::SelectInit, tflite::SelectPrepare,
|
||||
tflite::SelectEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -47,8 +47,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
if (output->type != kTfLiteInt32) {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
MicroPrintf("Output type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
} else {
|
||||
ExtractShape(input, tflite::micro::GetTensorData<int32_t>(output));
|
||||
|
||||
@@ -106,8 +106,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetBeginAndSizeVectors<int64_t>(input->dims->size, begin, size,
|
||||
op_params.begin, op_params.size);
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Begin tensor type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Begin tensor type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -75,8 +75,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,8 +104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -109,9 +109,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return SplitImpl<int32_t>(context, node, input, axis_value);
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type %s currently not supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -90,8 +90,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
|
||||
if (input->type == kTfLiteString) {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -183,9 +183,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int32_t>(output));
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
reference_ops::StridedSlice(op_params,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<bool>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<bool>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -82,7 +82,7 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node);
|
||||
// (reference or optimized) must define this function.
|
||||
TfLiteRegistration Register_SVDF();
|
||||
|
||||
#if defined(HEXAGON)
|
||||
#if defined(HEXAGON) || defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_SVDF_INT8();
|
||||
|
||||
#else
|
||||
|
||||
@@ -185,9 +185,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type), context);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,10 +103,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s is currently not supported by Transpose. "
|
||||
"Only float32 and int8 is supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf(
|
||||
"Type %s is currently not supported by Transpose. "
|
||||
"Only float32 and int8 is supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -327,8 +327,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -24,7 +24,7 @@ namespace testing {
|
||||
// kernel is reconciled with reference kernel
|
||||
#if !defined(XTENSA)
|
||||
|
||||
typedef struct LstmIntegerTestConfig {
|
||||
struct LstmIntegerTestConfig {
|
||||
const int n_batch;
|
||||
const int n_input;
|
||||
const int n_cell;
|
||||
@@ -100,9 +100,9 @@ typedef struct LstmIntegerTestConfig {
|
||||
|
||||
bool asymmetric_quantize_inputs;
|
||||
const float ranges[25][2];
|
||||
} LstmIntegerTestConfig;
|
||||
};
|
||||
|
||||
typedef struct LstmFloatTestConfig {
|
||||
struct LstmFloatTestConfig {
|
||||
const int n_batch;
|
||||
const int n_input;
|
||||
const int n_cell;
|
||||
@@ -153,9 +153,9 @@ typedef struct LstmFloatTestConfig {
|
||||
float* output;
|
||||
const float* expected_output_original;
|
||||
float* expected_output;
|
||||
} LstmFloatTestConfig;
|
||||
};
|
||||
|
||||
typedef struct LstmWeightQuantizationBuffers {
|
||||
struct LstmWeightQuantizationBuffers {
|
||||
int8_t* lstm_i2i_quant;
|
||||
float* lstm_i2i_scale;
|
||||
int* lstm_i2i_zp;
|
||||
@@ -215,7 +215,7 @@ typedef struct LstmWeightQuantizationBuffers {
|
||||
float* lstm_proj_w_scale;
|
||||
int* lstm_proj_w_zp;
|
||||
TfLiteAffineQuantization* lstm_proj_w_qparam;
|
||||
} LstmWeightQuantizationBuffers;
|
||||
};
|
||||
|
||||
extern LstmIntegerTestConfig lstm_integer_no_peephole_config;
|
||||
|
||||
|
||||
@@ -91,8 +91,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type '%s' is not supported by unpack.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,10 +70,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
resetZeros(tflite::micro::GetTensorData<float>(output), flat_size);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"ZerosLike only currently supports int64, int32, "
|
||||
"and float32, got %d.",
|
||||
input->type);
|
||||
MicroPrintf(
|
||||
"ZerosLike only currently supports int64, int32, "
|
||||
"and float32, got %d.",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -323,8 +323,12 @@ TfLiteStatus AllocationInfoBuilder::GetOfflinePlannedOffsets(
|
||||
if (model_->metadata()) {
|
||||
for (size_t i = 0; i < model_->metadata()->size(); ++i) {
|
||||
auto metadata = model_->metadata()->Get(i);
|
||||
if (strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
|
||||
strlen(kOfflineMemAllocMetadata)) == 0) {
|
||||
const size_t metadata_name_size = (size_t)metadata->name()->size();
|
||||
|
||||
if ((strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
|
||||
std::min(metadata_name_size,
|
||||
strlen(kOfflineMemAllocMetadata))) == 0) &&
|
||||
metadata_name_size == strlen(kOfflineMemAllocMetadata)) {
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
|
||||
model_->buffers();
|
||||
auto* buffer = (*buffers)[metadata->buffer()];
|
||||
|
||||
@@ -509,14 +509,15 @@ TfLiteStatus MicroAllocator::FinishModelAllocation(
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Allocate scratch buffer metadata and buffers for variable tensors.
|
||||
// Allocate scratch buffer metadata.
|
||||
TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles(
|
||||
scratch_buffer_handles, scratch_buffer_request_count_));
|
||||
|
||||
// Allocate buffers for variable tensors.
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
|
||||
subgraph_idx++) {
|
||||
const SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx);
|
||||
TFLITE_DCHECK(subgraph != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles(
|
||||
scratch_buffer_handles, scratch_buffer_request_count_));
|
||||
TF_LITE_ENSURE_STATUS(AllocateVariables(
|
||||
subgraph, subgraph_allocations[subgraph_idx].tensors));
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
|
||||
// of a sequential, array of ScratchBufferHandle allocations in the tail
|
||||
// section. These allocations are indexed by the request API defined in the
|
||||
// TfLiteContext struct.
|
||||
typedef struct {
|
||||
struct ScratchBufferRequest {
|
||||
// Number of bytes required by the buffer. The actual allocated size might be
|
||||
// greater than `bytes` due to buffer alignment.
|
||||
size_t bytes;
|
||||
@@ -63,29 +63,29 @@ typedef struct {
|
||||
// have `before` = node_idx and `after` = node_idx.
|
||||
int node_idx;
|
||||
int subgraph_idx;
|
||||
} ScratchBufferRequest;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
typedef struct {
|
||||
struct NodeAndRegistration {
|
||||
TfLiteNode node;
|
||||
const TfLiteRegistration* registration;
|
||||
} NodeAndRegistration;
|
||||
};
|
||||
|
||||
// Holds a pointer to a buffer for a scratch buffer requested by a kernel during
|
||||
// the model prepare stage. This struct is allocated in-place and allows for
|
||||
// quick pointer-indexed lookup for speed during model inference.
|
||||
typedef struct {
|
||||
struct ScratchBufferHandle {
|
||||
// Pointer to location of the scratch buffer:
|
||||
uint8_t* data;
|
||||
} ScratchBufferHandle;
|
||||
};
|
||||
|
||||
// Stores all per-subgraph allocations. This includes the node and registration
|
||||
// array, tensor list and scratch buffer handles for each subgraph.
|
||||
typedef struct {
|
||||
// array, and tensor list for each subgraph.
|
||||
struct SubgraphAllocations {
|
||||
NodeAndRegistration* node_and_registrations;
|
||||
TfLiteEvalTensor* tensors;
|
||||
} SubgraphAllocations;
|
||||
};
|
||||
|
||||
// Allocator responsible for allocating memory for all intermediate tensors
|
||||
// necessary to invoke a model.
|
||||
|
||||
@@ -317,7 +317,17 @@ TfLiteTensor* MicroInterpreter::output(size_t index) {
|
||||
}
|
||||
return output_tensors_[index];
|
||||
}
|
||||
// Repurposing free subgraphs to reset state for some ops for now
|
||||
// will reset api is made. See b/220940833#comment25 for more context.
|
||||
TfLiteStatus MicroInterpreter::Reset() {
|
||||
TfLiteStatus status = graph_.FreeSubgraphs();
|
||||
if (status != kTfLiteOk) {
|
||||
return status;
|
||||
}
|
||||
return graph_.ResetVariableTensors();
|
||||
}
|
||||
|
||||
// TODO: remove this API completely in favor of MicroInterpreter::Reset
|
||||
TfLiteStatus MicroInterpreter::ResetVariableTensors() {
|
||||
return graph_.ResetVariableTensors();
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/portable_type_to_tflitetype.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
// Copied from tensorflow/lite/version.h to avoid a dependency chain into
|
||||
/// Copied from tensorflow/lite/version.h to avoid a dependency chain into
|
||||
// tensorflow/core.
|
||||
#define TFLITE_SCHEMA_VERSION (3)
|
||||
|
||||
@@ -116,6 +116,11 @@ class MicroInterpreter {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Reset the state to be what you would expect when the interpreter is first
|
||||
// created. i.e. after Init and Prepare is called for the very first time.
|
||||
TfLiteStatus Reset();
|
||||
|
||||
// TODO(b/244457206): remove this in favor of Reset()
|
||||
// Reset all variable tensors to the default value.
|
||||
TfLiteStatus ResetVariableTensors();
|
||||
|
||||
|
||||
@@ -24,11 +24,13 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/compatibility.h"
|
||||
#include "tensorflow/lite/micro/kernels/add.h"
|
||||
#include "tensorflow/lite/micro/kernels/conv.h"
|
||||
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
|
||||
#include "tensorflow/lite/micro/kernels/ethosu.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
#include "tensorflow/lite/micro/kernels/pooling.h"
|
||||
#include "tensorflow/lite/micro/kernels/reduce.h"
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
@@ -140,9 +142,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::Register_ASSIGN_VARIABLE(), ParseAssignVariable);
|
||||
}
|
||||
|
||||
TfLiteStatus AddAveragePool2D() {
|
||||
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
|
||||
tflite::Register_AVERAGE_POOL_2D(), ParsePool);
|
||||
TfLiteStatus AddAveragePool2D(
|
||||
const TfLiteRegistration& registration = Register_AVERAGE_POOL_2D()) {
|
||||
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, registration, ParsePool);
|
||||
}
|
||||
|
||||
TfLiteStatus AddBatchToSpaceNd() {
|
||||
@@ -363,9 +365,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMaxPool2D() {
|
||||
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
|
||||
tflite::Register_MAX_POOL_2D(), ParsePool);
|
||||
TfLiteStatus AddMaxPool2D(
|
||||
const TfLiteRegistration& registration = Register_MAX_POOL_2D()) {
|
||||
return AddBuiltin(BuiltinOperator_MAX_POOL_2D, registration, ParsePool);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMirrorPad() {
|
||||
@@ -382,8 +384,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_MINIMUM(), ParseMinimum);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMul() {
|
||||
return AddBuiltin(BuiltinOperator_MUL, tflite::Register_MUL(), ParseMul);
|
||||
TfLiteStatus AddMul(const TfLiteRegistration& registration = Register_MUL()) {
|
||||
return AddBuiltin(BuiltinOperator_MUL, registration, ParseMul);
|
||||
}
|
||||
|
||||
TfLiteStatus AddNeg() {
|
||||
@@ -466,6 +468,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSelectV2() {
|
||||
return AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2(),
|
||||
ParseSelectV2);
|
||||
}
|
||||
|
||||
TfLiteStatus AddShape() {
|
||||
return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape);
|
||||
}
|
||||
@@ -519,6 +526,12 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_SQUARE(), ParseSquare);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSquaredDifference() {
|
||||
return AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE,
|
||||
tflite::Register_SQUARED_DIFFERENCE(),
|
||||
ParseSquaredDifference);
|
||||
}
|
||||
|
||||
TfLiteStatus AddStridedSlice() {
|
||||
return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
|
||||
tflite::ops::micro::Register_STRIDED_SLICE(),
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
@@ -67,4 +68,48 @@ void MicroProfiler::LogCsv() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
void MicroProfiler::LogTicksPerTagCsv() {
|
||||
#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
|
||||
MicroPrintf(
|
||||
"\"Unique Tag\",\"Total ticks across all events with that tag.\"");
|
||||
int total_ticks = 0;
|
||||
for (int i = 0; i < num_events_; ++i) {
|
||||
uint32_t ticks = end_ticks_[i] - start_ticks_[i];
|
||||
TFLITE_DCHECK(tags_[i] != nullptr);
|
||||
int position = FindExistingOrNextPosition(tags_[i]);
|
||||
TFLITE_DCHECK(position >= 0);
|
||||
total_ticks_per_tag[position].tag = tags_[i];
|
||||
total_ticks_per_tag[position].ticks =
|
||||
total_ticks_per_tag[position].ticks + ticks;
|
||||
total_ticks += ticks;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_events_; ++i) {
|
||||
TicksPerTag each_tag_entry = total_ticks_per_tag[i];
|
||||
if (each_tag_entry.tag == nullptr) {
|
||||
break;
|
||||
}
|
||||
MicroPrintf("%s, %d", each_tag_entry.tag, each_tag_entry.ticks);
|
||||
}
|
||||
MicroPrintf("total number of ticks, %d", total_ticks);
|
||||
#endif
|
||||
}
|
||||
|
||||
// This method finds a particular array element in the total_ticks_per_tag array
|
||||
// with the matching tag_name passed in the method. If it can find a
|
||||
// matching array element that has the same tag_name, then it will return the
|
||||
// position of the matching element. But if it unable to find a matching element
|
||||
// with the given tag_name, it will return the next available empty position
|
||||
// from the array.
|
||||
int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) {
|
||||
int pos = 0;
|
||||
for (; pos < num_events_; pos++) {
|
||||
TicksPerTag each_tag_entry = total_ticks_per_tag[pos];
|
||||
if (each_tag_entry.tag == nullptr ||
|
||||
strcmp(each_tag_entry.tag, tag_name) == 0) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return pos < num_events_ ? pos : -1;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
@@ -61,6 +61,11 @@ class MicroProfiler {
|
||||
// Separated Value) form.
|
||||
void LogCsv() const;
|
||||
|
||||
// Prints total ticks for each unique tag in CSV format.
|
||||
// Output will have one row for each unique tag along with the
|
||||
// total ticks summed across all events with that particular tag.
|
||||
void LogTicksPerTagCsv();
|
||||
|
||||
private:
|
||||
// Maximum number of events that this class can keep track of. If we call
|
||||
// AddEvent more than kMaxEvents number of times, then the oldest event's
|
||||
@@ -72,6 +77,17 @@ class MicroProfiler {
|
||||
uint32_t end_ticks_[kMaxEvents];
|
||||
int num_events_ = 0;
|
||||
|
||||
struct TicksPerTag {
|
||||
const char* tag;
|
||||
uint32_t ticks;
|
||||
};
|
||||
// In practice, the number of tags will be much lower than the number of
|
||||
// events. But it is theoretically possible that each event to be unique and
|
||||
// hence we allow total_ticks_per_tag to have kMaxEvents entries.
|
||||
TicksPerTag total_ticks_per_tag[kMaxEvents] = {};
|
||||
|
||||
int FindExistingOrNextPosition(const char* tag_name);
|
||||
|
||||
TF_LITE_REMOVE_VIRTUAL_DELETE;
|
||||
};
|
||||
|
||||
|
||||
@@ -163,10 +163,12 @@ TfLiteStatus RecordingMicroAllocator::AllocateNodeAndRegistrations(
|
||||
|
||||
TfLiteStatus status =
|
||||
MicroAllocator::AllocateNodeAndRegistrations(model, subgraph_allocations);
|
||||
|
||||
RecordAllocationUsage(allocations,
|
||||
recorded_node_and_registration_array_data_);
|
||||
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
|
||||
subgraph_idx++) {
|
||||
RecordAllocationUsage(allocations,
|
||||
recorded_node_and_registration_array_data_);
|
||||
// The allocation count in SingleArenaBufferAllocator will only be 1. To
|
||||
// provide better logging, decrement by 1 and add in the actual number of
|
||||
// operators used in the graph: The allocation for this recording will
|
||||
@@ -176,8 +178,12 @@ TfLiteStatus RecordingMicroAllocator::AllocateNodeAndRegistrations(
|
||||
// potential for fragmentation, manually adjust the accounting by
|
||||
// decrementing by 1 and adding the actual number of nodes used in the
|
||||
// graph:
|
||||
recorded_node_and_registration_array_data_.count +=
|
||||
model->subgraphs()->Get(subgraph_idx)->operators()->size() - 1;
|
||||
if (model->subgraphs()->Get(subgraph_idx)->operators()) {
|
||||
recorded_node_and_registration_array_data_.count +=
|
||||
model->subgraphs()->Get(subgraph_idx)->operators()->size() - 1;
|
||||
} else {
|
||||
recorded_node_and_registration_array_data_.count -= 1;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
@@ -188,9 +194,11 @@ TfLiteStatus RecordingMicroAllocator::AllocateTfLiteEvalTensors(
|
||||
|
||||
TfLiteStatus status =
|
||||
MicroAllocator::AllocateTfLiteEvalTensors(model, subgraph_allocations);
|
||||
|
||||
RecordAllocationUsage(allocations, recorded_tflite_eval_tensor_data_);
|
||||
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
|
||||
subgraph_idx++) {
|
||||
RecordAllocationUsage(allocations, recorded_tflite_eval_tensor_data_);
|
||||
// The allocation for this recording will always be 1. This is because the
|
||||
// parent class mallocs one large allocation for the number of tensors in
|
||||
// the graph (e.g. sizeof(TfLiteEvalTensor) * num_tensors). To prevent extra
|
||||
|
||||
Reference in New Issue
Block a user