mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-10 05:26:52 +03:00
Rolling 20220526
This commit is contained in:
@@ -117,15 +117,21 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kActivationsInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kActivationsInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kActivationsOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kActivationsOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
CalculateReluOpData<int8_t>(input, output, data);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -133,7 +139,9 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kActivationsInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kActivationsInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
@@ -142,6 +150,8 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->zero_int8 = input->params.zero_point;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -80,11 +80,15 @@ TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kAddInputTensor1);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kAddInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kAddInputTensor2);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kAddInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kAddOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kAddOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
OpDataAdd* data = static_cast<OpDataAdd*>(node->user_data);
|
||||
@@ -93,6 +97,9 @@ TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CalculateOpDataAdd(context, params, input1, input2, output, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -50,18 +50,19 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, num_inputs >= 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input_tensor_first;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensor0, &input_tensor_first));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input_tensor_first =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor0);
|
||||
TF_LITE_ENSURE(context, input_tensor_first != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
// Check that all tensors have the same shape and type.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_tensor_first->type);
|
||||
for (int i = kInputTensor0 + 1; i < num_inputs; ++i) {
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, HaveSameShapes(input_tensor_first, input));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_tensor_first->type, input->type);
|
||||
|
||||
@@ -72,6 +73,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input_tensor_first->params.scale == input->params.scale);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
@@ -123,6 +126,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_tensor_first);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -52,21 +52,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_resource_id_tensor->type == kTfLiteInt32));
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor->dims), 1);
|
||||
|
||||
const TfLiteTensor* input_value = GetInput(context, node, kInputValue);
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
TfLiteTensor* input_value =
|
||||
micro_context->AllocateTempInputTensor(node, kInputValue);
|
||||
TFLITE_DCHECK(input_value != nullptr);
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
MicroResourceVariables* resources = graph_info->GetResourceVariables();
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
MicroResourceVariables* resources = graph_info.GetResourceVariables();
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
resources->Allocate(input_resource_id_tensor->data.i32[0],
|
||||
context, input_value));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_value);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -79,14 +77,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetEvalInput(context, node, kInputValue);
|
||||
TFLITE_DCHECK(input_value != nullptr);
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
MicroResourceVariables* resources = graph_info->GetResourceVariables();
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
MicroResourceVariables* resources = graph_info.GetResourceVariables();
|
||||
if (resources == nullptr) {
|
||||
MicroPrintf(
|
||||
"ASSIGN_VARIABLE requires resource variables. Please create "
|
||||
|
||||
@@ -41,8 +41,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||
@@ -51,6 +55,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/broadcast_args.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
constexpr int kShape1Tensor = 0;
|
||||
constexpr int kShape2Tensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus BroadcastArgsPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* shape1 =
|
||||
micro_context->AllocateTempInputTensor(node, kShape1Tensor);
|
||||
TfLiteTensor* shape2 =
|
||||
micro_context->AllocateTempInputTensor(node, kShape2Tensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
shape1->type == kTfLiteInt32 || shape1->type == kTfLiteInt64);
|
||||
TF_LITE_ENSURE_EQ(context, shape1->type, shape2->type);
|
||||
TF_LITE_ENSURE_EQ(context, shape1->type, output->type);
|
||||
|
||||
// Ensures the shapes are 1D tensor.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(shape1), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(shape2), 1);
|
||||
|
||||
// Ensure the shape of the output tensor is compatible
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(shape1);
|
||||
micro_context->DeallocateTempTfLiteTensor(shape2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* shape1 =
|
||||
micro::GetEvalInput(context, node, kShape1Tensor);
|
||||
const TfLiteEvalTensor* shape2 =
|
||||
micro::GetEvalInput(context, node, kShape2Tensor);
|
||||
TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
if (output->type == kTfLiteInt32) {
|
||||
reference_ops::BroadcastArgs(
|
||||
micro::GetTensorShape(shape1), micro::GetTensorData<int32_t>(shape1),
|
||||
micro::GetTensorShape(shape2), micro::GetTensorData<int32_t>(shape2),
|
||||
micro::GetTensorShape(output), micro::GetTensorData<int32_t>(output));
|
||||
} else {
|
||||
reference_ops::BroadcastArgs(
|
||||
micro::GetTensorShape(shape1), micro::GetTensorData<int64_t>(shape1),
|
||||
micro::GetTensorShape(shape2), micro::GetTensorData<int64_t>(shape2),
|
||||
micro::GetTensorShape(output), micro::GetTensorData<int64_t>(output));
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_BROADCAST_ARGS() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/BroadcastArgsPrepare,
|
||||
/*invoke=*/BroadcastArgsEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,129 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kShapeTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
// Support a maximum of 5 dimensions in TFLM.
|
||||
constexpr int kMaxDims = 5;
|
||||
|
||||
TfLiteStatus ValidateOutputTensor(TfLiteContext* context, TfLiteTensor* input,
|
||||
TfLiteTensor* shape, TfLiteTensor* output) {
|
||||
// Ensures the shape is 1D tensor.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
|
||||
|
||||
// Ensure output dims is not less than input dims.
|
||||
int input_num_dims = NumDimensions(input);
|
||||
int output_num_dims = NumDimensions(output);
|
||||
int shape_num_dims = SizeOfDimension(shape, 0);
|
||||
TF_LITE_ENSURE_MSG(context, output_num_dims == shape_num_dims,
|
||||
"Output must match with the expected shape dimension.");
|
||||
TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
|
||||
"Output shape must be broadcastable from input shape.");
|
||||
TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
|
||||
"BroadcastTo only supports 1-5D tensor.");
|
||||
|
||||
// Check if output shape is broadcastable from input shape.
|
||||
auto get_shape_data = [shape](int i) -> int32_t {
|
||||
if (shape->type == kTfLiteInt32) {
|
||||
return GetTensorData<int32_t>(shape)[i];
|
||||
} else {
|
||||
return GetTensorData<int64_t>(shape)[i];
|
||||
}
|
||||
};
|
||||
|
||||
int extending_dims = output_num_dims - input_num_dims;
|
||||
for (int idx = 0; idx < input_num_dims; ++idx) {
|
||||
TF_LITE_ENSURE_MSG(
|
||||
context,
|
||||
(SizeOfDimension(input, idx) == 1 ||
|
||||
SizeOfDimension(input, idx) == get_shape_data(extending_dims + idx)),
|
||||
"Output shape must be broadcastable from input shape.");
|
||||
}
|
||||
|
||||
// Validating the shape of the output tensor.
|
||||
tflite::RuntimeShape output_shape = tflite::GetTensorShape(output);
|
||||
for (int idx = 0; idx < output_num_dims; ++idx) {
|
||||
TF_LITE_ENSURE(context, output_shape.Dims(idx) == get_shape_data(idx));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus BroadcastToPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TfLiteTensor* shape =
|
||||
micro_context->AllocateTempInputTensor(node, kShapeTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_MSG(context, (NumDimensions(input) <= kMaxDims),
|
||||
"BroadcastTo only supports 1-5D tensor.");
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
shape->type == kTfLiteInt32 || shape->type == kTfLiteInt64);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
|
||||
// Does not support String type due to its variable size. This limitation is
|
||||
// the same as TFLite.
|
||||
TF_LITE_ENSURE(context, input->type != kTfLiteString);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(ValidateOutputTensor(context, input, shape, output));
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(shape);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input =
|
||||
micro::GetEvalInput(context, node, kInputTensor);
|
||||
TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
// BroadcastTo op support upto 5 dims, different from 8 dims in TFLite.
|
||||
reference_ops::BroadcastTo<kMaxDims>(
|
||||
micro::GetTensorShape(input), input->data.raw,
|
||||
micro::GetTensorShape(output), output->data.raw, input->type);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_BROADCAST_TO() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/BroadcastToPrepare,
|
||||
/*invoke=*/BroadcastToEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
#include "tensorflow/lite/micro/micro_graph.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@@ -50,16 +51,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 0);
|
||||
TF_LITE_ENSURE(context, NumOutputs(node) == 0);
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
op_data->init_subgraph_index < graph_info->NumSubgraphs());
|
||||
op_data->init_subgraph_index < graph_info.NumSubgraphs());
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@@ -72,16 +68,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
graph_info->InvokeSubgraph(op_data->init_subgraph_index));
|
||||
graph_info.InvokeSubgraph(op_data->init_subgraph_index));
|
||||
|
||||
op_data->has_run = true;
|
||||
|
||||
|
||||
@@ -28,11 +28,19 @@ constexpr int kOutputTensor = 0;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -83,6 +91,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt32:
|
||||
return copyToTensor(context, tflite::micro::GetTensorData<int32_t>(input),
|
||||
output, num_elements);
|
||||
case kTfLiteUInt32:
|
||||
return copyToTensor(context,
|
||||
tflite::micro::GetTensorData<uint32_t>(input), output,
|
||||
num_elements);
|
||||
case kTfLiteFloat32:
|
||||
return copyToTensor(context, tflite::micro::GetTensorData<float>(input),
|
||||
output, num_elements);
|
||||
|
||||
@@ -29,9 +29,13 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
@@ -42,6 +46,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < output->dims->size; ++i) {
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,9 +39,13 @@ const int kCircularBufferCyclesMaxIndex = 0; // 'cycles_max'
|
||||
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(-9);
|
||||
|
||||
TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, kCircularBufferInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kCircularBufferOutputTensor);
|
||||
|
||||
MicroContext * micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context-> AllocateTempInputTensor(node, kCircularBufferInputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context-> AllocateTempOutputTensor(node, kCircularBufferOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpDataCircularBuffer* op_data =
|
||||
@@ -85,6 +89,9 @@ TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
op_data->cycles_until_run = op_data->cycles_max;
|
||||
node->user_data = op_data;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -540,9 +540,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
|
||||
if (input1->type == kTfLiteInt8) {
|
||||
@@ -570,6 +574,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->params.input2_shift = input2_shift;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -115,13 +115,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteConcatenationParams* params =
|
||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input_tensor = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input_tensor != nullptr);
|
||||
TfLiteType input_type = input_tensor->type;
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output_tensor =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
||||
TfLiteType output_type = output_tensor->type;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_tensor);
|
||||
micro_context->DeallocateTempTfLiteTensor(output_tensor);
|
||||
|
||||
// Check activation and input type
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
|
||||
TF_LITE_ENSURE(context,
|
||||
@@ -138,7 +144,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Shapes with dimensions >4 are not yet supported with static allocation.
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
int num_dimensions = NumDimensions(input);
|
||||
|
||||
@@ -150,13 +156,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
num_dimensions);
|
||||
return kTfLiteError;
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
}
|
||||
|
||||
// Calculate OpData.
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
switch (output_type) { // Already know in/outtypes are same.
|
||||
@@ -183,10 +191,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate persistent scale and zeropoint buffers.
|
||||
// Store input scale and zero point values in OpParams:
|
||||
for (int i = 0; i < node->inputs->size; ++i) {
|
||||
const TfLiteTensor* t = GetInput(context, node, i);
|
||||
TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i);
|
||||
TF_LITE_ENSURE(context, t != nullptr);
|
||||
input_scales[i] = t->params.scale;
|
||||
input_zero_points[i] = t->params.zero_point;
|
||||
micro_context->DeallocateTempTfLiteTensor(t);
|
||||
}
|
||||
|
||||
data->params.input_scale = input_scales;
|
||||
@@ -202,6 +211,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -79,7 +79,8 @@ TfLiteRegistration Register_CONV_2D();
|
||||
|
||||
#if defined(XTENSA)
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int8 inputs and outputs.
|
||||
// int8 activations and int8 weights and always calls the reference
|
||||
// implementation.
|
||||
TfLiteRegistration Register_CONV_2D_INT8REF();
|
||||
#else
|
||||
inline TfLiteRegistration Register_CONV_2D_INT8REF() {
|
||||
@@ -87,6 +88,25 @@ inline TfLiteRegistration Register_CONV_2D_INT8REF() {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int8 activations and int8 weights and uses the latency optimized
|
||||
// implementations.
|
||||
TfLiteRegistration Register_CONV_2D_INT8();
|
||||
|
||||
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||
// int16 activations and int8 weights and uses the latency optimized
|
||||
// implementations.
|
||||
TfLiteRegistration Register_CONV_2D_INT16();
|
||||
|
||||
#else
|
||||
inline TfLiteRegistration Register_CONV_2D_INT8() { return Register_CONV_2D(); }
|
||||
|
||||
inline TfLiteRegistration Register_CONV_2D_INT16() {
|
||||
return Register_CONV_2D();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||
|
||||
@@ -93,13 +93,18 @@ TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
|
||||
params.dilation_width_factor, height, width, filter_height, filter_width,
|
||||
padding, &out_height, &out_width);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kConvBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
// Note that quantized inference requires that all tensors have their
|
||||
@@ -119,6 +124,11 @@ TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -129,12 +139,16 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
|
||||
const auto& params =
|
||||
*(static_cast<const TfLiteConvParams*>(node->builtin_data));
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
|
||||
const int input_width = input->dims->data[2];
|
||||
@@ -174,6 +188,10 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context, node, params, input_width, input_height, filter_width,
|
||||
filter_height, output_width, output_height, input->type, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
@@ -47,8 +47,12 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TfLiteTensor* axis =
|
||||
micro_context->AllocateTempInputTensor(node, kAxisTensor);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
|
||||
@@ -58,7 +62,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE(context, HaveSameShapes(input, output));
|
||||
@@ -91,6 +96,10 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
&data->output_activation_max));
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,11 +40,14 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
|
||||
@@ -83,6 +86,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
output->dims->data[kWidthRank] = output_width;
|
||||
output->dims->data[kDepthRank] = output_channels;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -94,13 +94,18 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
|
||||
params.dilation_width_factor, height, width, filter_height, filter_width,
|
||||
padding, &out_height, &out_width);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kConvBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
// Note that quantized inference requires that all tensors have their
|
||||
@@ -120,6 +125,11 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -130,14 +140,16 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
|
||||
const auto& params =
|
||||
*(static_cast<const TfLiteDepthwiseConvParams*>(node->builtin_data));
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kDepthwiseConvOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, kDepthwiseConvInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kDepthwiseConvInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter =
|
||||
GetInput(context, node, kDepthwiseConvWeightsTensor);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kDepthwiseConvWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
|
||||
const int input_width = input->dims->data[2];
|
||||
@@ -180,6 +192,10 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context, node, params, input_width, input_height, filter_width,
|
||||
filter_height, output_width, output_height, input->type, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,10 +33,12 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// TODO(b/140515557): Add cached dequant to improve hybrid model performance.
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
@@ -54,6 +56,10 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->quantization_params.zero_point = input->params.zero_point;
|
||||
data->quantization_params.scale = static_cast<double>(input->params.scale);
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
@@ -152,14 +154,17 @@ void Free(TfLiteContext* context, void* buffer) {}
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// Inputs: box_encodings, scores, anchors
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_class_predictions =
|
||||
GetInput(context, node, kInputTensorClassPredictions);
|
||||
const TfLiteTensor* input_anchors =
|
||||
GetInput(context, node, kInputTensorAnchors);
|
||||
TfLiteTensor* input_box_encodings =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorBoxEncodings);
|
||||
TfLiteTensor* input_class_predictions =
|
||||
micro_context->AllocateTempInputTensor(node,
|
||||
kInputTensorClassPredictions);
|
||||
TfLiteTensor* input_anchors =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorAnchors);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
|
||||
@@ -217,6 +222,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// num_detections
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_box_encodings);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_class_predictions);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_anchors);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -313,9 +322,10 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
||||
void DecreasingPartialArgSort(const float* values, int num_values,
|
||||
int num_to_sort, int* indices) {
|
||||
std::iota(indices, indices + num_values, 0);
|
||||
std::partial_sort(
|
||||
indices, indices + num_to_sort, indices + num_values,
|
||||
[&values](const int i, const int j) { return values[i] > values[j]; });
|
||||
std::partial_sort(indices, indices + num_to_sort, indices + num_values,
|
||||
[&values](const int i, const int j) {
|
||||
return std::tie(values[i], j) > std::tie(values[j], i);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Compare>
|
||||
|
||||
@@ -38,11 +38,13 @@ bool IsLogicalSupportedType(const TfLiteType type) {
|
||||
typedef bool (*IsSupportedType)(TfLiteType);
|
||||
template <IsSupportedType>
|
||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
@@ -50,6 +52,9 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -80,13 +80,16 @@ void EvalUsingLookupTable(const OpData* data, const TfLiteEvalTensor* input,
|
||||
}
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
// Use LUT to handle quantized elu path.
|
||||
@@ -97,7 +100,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
};
|
||||
PopulateLookupTable<int8_t>(input, output, transform, data);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
# Info
|
||||
|
||||
These are the Espressif chipset specific replacement kernels.
|
||||
The kernels call optimized routines or reference routines depending upon optimization option selected.
|
||||
|
||||
By default optimizations are selected if available.
|
||||
To change this behaviour, please make the appropriate `ESP-NN` menu selection after running:
|
||||
|
||||
```
|
||||
idf.py menuconfig
|
||||
```
|
||||
@@ -0,0 +1,209 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/add.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/add.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
long long add_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
|
||||
void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
const OpDataAdd* data, const TfLiteEvalTensor* input1,
|
||||
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
SetActivationParams(data->output_activation_min_f32,
|
||||
data->output_activation_max_f32, &op_params);
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastAdd4DSlow(
|
||||
op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<float>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<float>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
} else {
|
||||
reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<float>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<float>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteAddParams* params, const OpDataAdd* data,
|
||||
const TfLiteEvalTensor* input1,
|
||||
const TfLiteEvalTensor* input2,
|
||||
TfLiteEvalTensor* output) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
op_params.left_shift = data->left_shift;
|
||||
op_params.input1_offset = data->input1_offset;
|
||||
op_params.input1_multiplier = data->input1_multiplier;
|
||||
op_params.input1_shift = data->input1_shift;
|
||||
op_params.input2_offset = data->input2_offset;
|
||||
op_params.input2_multiplier = data->input2_multiplier;
|
||||
op_params.input2_shift = data->input2_shift;
|
||||
op_params.output_offset = data->output_offset;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_shift = data->output_shift;
|
||||
SetActivationParams(data->output_activation_min, data->output_activation_max,
|
||||
&op_params);
|
||||
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
|
||||
tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorShape(input2), &op_params);
|
||||
|
||||
switch (output->type) {
|
||||
case kTfLiteInt8: {
|
||||
if (need_broadcast) {
|
||||
reference_integer_ops::BroadcastAdd4DSlow(
|
||||
op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<int8_t>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<int8_t>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
#if ESP_NN
|
||||
const int8_t *input1_data = tflite::micro::GetTensorData<int8_t>(input1);
|
||||
const int8_t *input2_data = tflite::micro::GetTensorData<int8_t>(input2);
|
||||
int8_t *out_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
|
||||
esp_nn_add_elementwise_s8(input1_data,
|
||||
input2_data,
|
||||
data->input1_offset,
|
||||
data->input2_offset,
|
||||
data->input1_multiplier,
|
||||
data->input2_multiplier,
|
||||
data->input1_shift,
|
||||
data->input2_shift,
|
||||
data->left_shift,
|
||||
out_data,
|
||||
data->output_offset,
|
||||
data->output_multiplier,
|
||||
data->output_shift,
|
||||
data->output_activation_min,
|
||||
data->output_activation_max,
|
||||
MatchingElementsSize(tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorShape(output))
|
||||
);
|
||||
#else
|
||||
reference_integer_ops::Add(
|
||||
op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<int8_t>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<int8_t>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
if (need_broadcast) {
|
||||
reference_ops::BroadcastAdd4DSlow(
|
||||
op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<int16_t>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<int16_t>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<int16_t>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<int16_t>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output),
|
||||
false);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* AddInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpDataAdd));
|
||||
}
|
||||
|
||||
TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpDataAdd* data = static_cast<const OpDataAdd*>(node->user_data);
|
||||
|
||||
const TfLiteEvalTensor* input1 =
|
||||
tflite::micro::GetEvalInput(context, node, kAddInputTensor1);
|
||||
const TfLiteEvalTensor* input2 =
|
||||
tflite::micro::GetEvalInput(context, node, kAddInputTensor2);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kAddOutputTensor);
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
EvalAdd(context, node, params, data, input1, input2, output);
|
||||
} else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, data,
|
||||
input1, input2, output));
|
||||
} else {
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(output->type),
|
||||
output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
add_total_time += esp_timer_get_time() - start_time;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_ADD() {
|
||||
return {/*init=*/AddInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/AddPrepare,
|
||||
/*invoke=*/AddEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,319 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/kernels/conv.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
|
||||
long long conv_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct NodeData {
|
||||
OpDataConv op_data;
|
||||
#if ESP_NN
|
||||
int buffer_idx;
|
||||
#endif
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(NodeData));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
NodeData* data = static_cast<NodeData*>(node->user_data);
|
||||
const auto& params =
|
||||
*(static_cast<const TfLiteConvParams*>(node->builtin_data));
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
const int input_width = input->dims->data[2];
|
||||
const int input_height = input->dims->data[1];
|
||||
const int filter_width = filter->dims->data[2];
|
||||
const int filter_height = filter->dims->data[1];
|
||||
const int output_width = output->dims->data[2];
|
||||
const int output_height = output->dims->data[1];
|
||||
|
||||
// Dynamically allocate per-channel quantization parameters.
|
||||
const int num_channels = filter->dims->data[kConvQuantizedDimension];
|
||||
data->op_data.per_channel_output_multiplier =
|
||||
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||
context, num_channels * sizeof(int32_t)));
|
||||
data->op_data.per_channel_output_shift =
|
||||
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||
context, num_channels * sizeof(int32_t)));
|
||||
|
||||
// All per-channel quantized tensors need valid zero point and scale arrays.
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
||||
kTfLiteAffineQuantization);
|
||||
|
||||
const auto* affine_quantization =
|
||||
static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
|
||||
TFLITE_DCHECK(affine_quantization != nullptr);
|
||||
TFLITE_DCHECK(affine_quantization->scale != nullptr);
|
||||
TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
affine_quantization->scale->size == 1 ||
|
||||
affine_quantization->scale->size ==
|
||||
filter->dims->data[kConvQuantizedDimension]);
|
||||
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
||||
affine_quantization->zero_point->size);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
|
||||
context, node, params, input_width, input_height, filter_width,
|
||||
filter_height, output_width, output_height, input->type, &data->op_data));
|
||||
|
||||
#if ESP_NN
|
||||
if (input->type == kTfLiteInt8) {
|
||||
int scratch_buf_size = esp_nn_get_conv_scratch_size(
|
||||
input_width, input_height, input->dims->data[3],
|
||||
output->dims->data[3], filter_width, filter_height);
|
||||
if (scratch_buf_size > 0) {
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, scratch_buf_size, &data->buffer_idx));
|
||||
} else {
|
||||
data->buffer_idx = -1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
#if ESP_NN
|
||||
// Fixed-point per-channel-quantization convolution Int8 function wrapper.
|
||||
inline void EvalQuantizedPerChannel(
|
||||
TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params,
|
||||
const NodeData& data, const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
const int dilation_height_factor = params.dilation_height_factor;
|
||||
|
||||
if (dilation_width_factor == 1 && dilation_height_factor == 1) {
|
||||
// Get parameters.
|
||||
RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
|
||||
RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
|
||||
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
|
||||
RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);
|
||||
|
||||
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
|
||||
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
|
||||
const int32_t input_offset = -data.op_data.input_zero_point;
|
||||
const int32_t output_offset = data.op_data.output_zero_point;
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int pad_width = data.op_data.padding.width;
|
||||
const int pad_height = data.op_data.padding.height;
|
||||
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
|
||||
// Set min and max value of the output.
|
||||
const int32_t activation_min = data.op_data.output_activation_min;
|
||||
const int32_t activation_max = data.op_data.output_activation_max;
|
||||
|
||||
// Consistency check.
|
||||
TFLITE_DCHECK_LE(activation_min, activation_max);
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
|
||||
if (tflite::micro::GetTensorData<int8_t>(bias)) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
}
|
||||
|
||||
void *scratch_buf = NULL;
|
||||
if (data.buffer_idx > -1) {
|
||||
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
|
||||
}
|
||||
esp_nn_set_conv_scratch_buf(scratch_buf);
|
||||
|
||||
const int input_size = input_width * input_height * input_depth;
|
||||
const int output_size = output_width * output_height * output_depth;
|
||||
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
||||
esp_nn_conv_s8(input_data + i_batch * input_size,
|
||||
input_width, input_height, input_depth, input_offset,
|
||||
pad_width, pad_height, stride_width, stride_height,
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
filter_width, filter_height,
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
output_data + i_batch * output_size,
|
||||
output_width, output_height, output_depth, output_offset,
|
||||
data.op_data.per_channel_output_shift,
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
activation_min, activation_max);
|
||||
}
|
||||
} else {
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
ConvParamsQuantized(params, data.op_data),
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
data.op_data.per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
const auto& params =
|
||||
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const auto& data = *(static_cast<const NodeData*>(node->user_data));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32: {
|
||||
tflite::reference_ops::Conv(
|
||||
ConvParamsFloat(params, data.op_data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output),
|
||||
tflite::micro::GetTensorShape(nullptr), nullptr);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
#if ESP_NN
|
||||
EvalQuantizedPerChannel(context, node, params, data, input, filter,
|
||||
bias, output);
|
||||
#else
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
ConvParamsQuantized(params, data.op_data),
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
data.op_data.per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
//EvalQuantized
|
||||
reference_ops::Conv(ConvParamsQuantized(params, data.op_data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<uint8_t>(output),
|
||||
tflite::micro::GetTensorShape(nullptr), nullptr,
|
||||
nullptr);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
conv_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_CONV_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,319 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
long long dc_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct NodeData {
|
||||
OpDataConv op_data;
|
||||
#if ESP_NN
|
||||
int buffer_idx;
|
||||
#endif
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(NodeData));
|
||||
}
|
||||
|
||||
#if ESP_NN
|
||||
inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteDepthwiseConvParams& params,
|
||||
const NodeData& data,
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
const int dilation_height_factor = params.dilation_height_factor;
|
||||
|
||||
if (dilation_width_factor == 1 && dilation_height_factor == 1) {
|
||||
// Get parameters.
|
||||
RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
|
||||
RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
|
||||
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
|
||||
RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);
|
||||
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
|
||||
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
|
||||
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
|
||||
const int depth_multiplier = params.depth_multiplier;
|
||||
const int32_t input_offset = -data.op_data.input_zero_point;
|
||||
const int32_t output_offset = data.op_data.output_zero_point;
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int pad_width = data.op_data.padding.width;
|
||||
const int pad_height = data.op_data.padding.height;
|
||||
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
|
||||
// Set min and max value of the output.
|
||||
const int32_t activation_min = data.op_data.output_activation_min;
|
||||
const int32_t activation_max = data.op_data.output_activation_max;
|
||||
|
||||
// Consistency check.
|
||||
TFLITE_DCHECK_LE(activation_min, activation_max);
|
||||
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
|
||||
|
||||
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
|
||||
if (tflite::micro::GetTensorData<int8_t>(bias)) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
}
|
||||
|
||||
const int input_size = input_width * input_height * input_depth;
|
||||
const int output_size = output_width * output_height * output_depth;
|
||||
void *scratch_buf = NULL;
|
||||
if (data.buffer_idx > -1) {
|
||||
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
|
||||
}
|
||||
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
|
||||
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
||||
esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
|
||||
input_height, input_depth, input_offset,
|
||||
pad_width, pad_height,
|
||||
stride_width, stride_height, depth_multiplier,
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
filter_width, filter_height,
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
output_data + i_batch * output_size,
|
||||
output_width, output_height, output_offset,
|
||||
data.op_data.per_channel_output_shift,
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
activation_min, activation_max);
|
||||
}
|
||||
} else {
|
||||
reference_integer_ops::DepthwiseConvPerChannel(
|
||||
DepthwiseConvParamsQuantized(params, data.op_data),
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
data.op_data.per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
NodeData* data = static_cast<NodeData*>(node->user_data);
|
||||
const TfLiteDepthwiseConvParams& params =
|
||||
*(static_cast<const TfLiteDepthwiseConvParams*>(node->builtin_data));
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
const int input_width = input->dims->data[2];
|
||||
const int input_height = input->dims->data[1];
|
||||
const int filter_width = filter->dims->data[2];
|
||||
const int filter_height = filter->dims->data[1];
|
||||
const int output_width = output->dims->data[2];
|
||||
const int output_height = output->dims->data[1];
|
||||
|
||||
// Dynamically allocate per-channel quantization parameters.
|
||||
const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
||||
data->op_data.per_channel_output_multiplier =
|
||||
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||
context, num_channels * sizeof(int32_t)));
|
||||
data->op_data.per_channel_output_shift =
|
||||
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||
context, num_channels * sizeof(int32_t)));
|
||||
|
||||
// All per-channel quantized tensors need valid zero point and scale arrays.
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
||||
kTfLiteAffineQuantization);
|
||||
|
||||
const auto* affine_quantization =
|
||||
static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
|
||||
TFLITE_DCHECK(affine_quantization != nullptr);
|
||||
TFLITE_DCHECK(affine_quantization->scale != nullptr);
|
||||
TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(
|
||||
context, affine_quantization->scale->size == 1 ||
|
||||
affine_quantization->scale->size ==
|
||||
filter->dims->data[kDepthwiseConvQuantizedDimension]);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
||||
affine_quantization->zero_point->size);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
|
||||
context, node, params, input_width, input_height, filter_width,
|
||||
filter_height, output_width, output_height, input->type, &data->op_data));
|
||||
|
||||
#if ESP_NN
|
||||
if (input->type == kTfLiteInt8) {
|
||||
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
|
||||
input_width, input_height, input->dims->data[3],
|
||||
params.depth_multiplier, filter_width, filter_height);
|
||||
if (scratch_buf_size > 0) {
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, scratch_buf_size, &data->buffer_idx));
|
||||
} else {
|
||||
data->buffer_idx = -1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
auto& params =
|
||||
*(reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data));
|
||||
const NodeData& data = *(static_cast<const NodeData*>(node->user_data));
|
||||
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kDepthwiseConvOutputTensor);
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kDepthwiseConvInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32:
|
||||
tflite::reference_ops::DepthwiseConv(
|
||||
DepthwiseConvParamsFloat(params, data.op_data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
#if ESP_NN
|
||||
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
|
||||
output);
|
||||
#else
|
||||
reference_integer_ops::DepthwiseConvPerChannel(
|
||||
DepthwiseConvParamsQuantized(params, data.op_data),
|
||||
data.op_data.per_channel_output_multiplier,
|
||||
data.op_data.per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
//EvalQuantized(context, node, params, &data, input, filter, bias, output);
|
||||
reference_ops::DepthwiseConv(
|
||||
DepthwiseConvParamsQuantized(params, data.op_data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<uint8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
dc_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,198 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
#include <esp_timer.h>
|
||||
|
||||
long long fc_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context,
|
||||
sizeof(OpDataFullyConnected));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
auto* data = static_cast<OpDataFullyConnected*>(node->user_data);
|
||||
const auto params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kFullyConnectedInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* filter = micro_context->AllocateTempInputTensor(
|
||||
node, kFullyConnectedWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kFullyConnectedBiasTensor);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
|
||||
node, kFullyConnectedOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
|
||||
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
|
||||
context, params->activation, input->type,
|
||||
input, filter, bias, output, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
if (bias != nullptr) {
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
const auto* params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const auto& data =
|
||||
*(static_cast<const OpDataFullyConnected*>(node->user_data));
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
tflite::reference_ops::FullyConnected(
|
||||
FullyConnectedParamsFloat(params->activation),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
}
|
||||
|
||||
case kTfLiteInt8: {
|
||||
const int32_t* bias_data =
|
||||
nullptr != bias ? tflite::micro::GetTensorData<int32_t>(bias)
|
||||
: nullptr;
|
||||
#if ESP_NN
|
||||
const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter);
|
||||
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
|
||||
const int filter_dim_count = filter_shape.DimensionsCount();
|
||||
const int batches = output_shape.Dims(0);
|
||||
const int output_depth = output_shape.Dims(1);
|
||||
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
|
||||
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
|
||||
|
||||
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
|
||||
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
const int8_t *filter_data = tflite::micro::GetTensorData<int8_t>(filter);
|
||||
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
esp_nn_fully_connected_s8(input_data, -data.input_zero_point,
|
||||
accum_depth,
|
||||
filter_data, -data.filter_zero_point,
|
||||
bias_data, output_data, output_depth,
|
||||
data.output_zero_point,
|
||||
data.output_shift, data.output_multiplier,
|
||||
data.output_activation_min,
|
||||
data.output_activation_max);
|
||||
input_data += accum_depth;
|
||||
output_data += output_depth;
|
||||
}
|
||||
#else
|
||||
tflite::reference_integer_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias), bias_data,
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
|
||||
case kTfLiteUInt8: {
|
||||
tflite::reference_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<uint8_t>(output));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
fc_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,131 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/kernels/mul.h"
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/mul.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
#include <esp_timer.h>
|
||||
|
||||
long long mul_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
#if ESP_NN
|
||||
void MulEvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpDataMul* data, const TfLiteEvalTensor* input1,
|
||||
const TfLiteEvalTensor* input2,
|
||||
TfLiteEvalTensor* output) {
|
||||
tflite::ArithmeticParams op_params = {};
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
op_params.float_activation_max = data->output_activation_max_f32;
|
||||
op_params.input1_offset = -data->input1_zero_point;
|
||||
op_params.input2_offset = -data->input2_zero_point;
|
||||
op_params.output_offset = data->output_zero_point;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_shift = data->output_shift;
|
||||
|
||||
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
|
||||
tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorShape(input2), &op_params);
|
||||
|
||||
if (need_broadcast) {
|
||||
reference_integer_ops::BroadcastMul4DSlow(
|
||||
op_params, tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<int8_t>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<int8_t>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
const int8_t *input1_data = tflite::micro::GetTensorData<int8_t>(input1);
|
||||
const int8_t *input2_data = tflite::micro::GetTensorData<int8_t>(input2);
|
||||
int8_t *out_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
|
||||
esp_nn_mul_elementwise_s8(input1_data, input2_data, op_params.input1_offset,
|
||||
op_params.input2_offset, out_data, op_params.output_offset,
|
||||
op_params.output_multiplier, op_params.output_shift,
|
||||
op_params.quantized_activation_min, op_params.quantized_activation_max,
|
||||
MatchingElementsSize(tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorShape(output)));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpDataMul* data = static_cast<const OpDataMul*>(node->user_data);
|
||||
|
||||
const TfLiteEvalTensor* input1 =
|
||||
tflite::micro::GetEvalInput(context, node, kMulInput1Tensor);
|
||||
const TfLiteEvalTensor* input2 =
|
||||
tflite::micro::GetEvalInput(context, node, kMulInput2Tensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kMulOutputTensor);
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
switch (input1->type) {
|
||||
case kTfLiteInt8:
|
||||
#if ESP_NN
|
||||
MulEvalQuantized(context, node, data, input1, input2, output);
|
||||
#else
|
||||
EvalMulQuantizedReference(context, node, data, input1, input2, output);
|
||||
#endif
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
EvalMulQuantizedReference(context, node, data, input1, input2, output);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
EvalMulFloatReference(context, node, params, data, input1, input2,
|
||||
output);
|
||||
break;
|
||||
default:
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
mul_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MUL() {
|
||||
return {/*init=*/MulInit,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/MulPrepare,
|
||||
/*invoke=*/MulEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -0,0 +1,245 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/pooling.h"
|
||||
|
||||
#if ESP_NN
|
||||
#include <esp_nn.h>
|
||||
#endif
|
||||
|
||||
#include <esp_timer.h>
|
||||
|
||||
long long pooling_total_time = 0;
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
#if ESP_NN
|
||||
void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
|
||||
const TfLitePoolParams* params, const OpDataPooling* data,
|
||||
const TfLiteEvalTensor* input,
|
||||
TfLiteEvalTensor* output) {
|
||||
|
||||
const int stride_height = params->stride_height;
|
||||
const int stride_width = params->stride_width;
|
||||
const int filter_height = params->filter_height;
|
||||
const int filter_width = params->filter_width;
|
||||
const int activation_min = data->activation_min;
|
||||
const int activation_max = data->activation_max;
|
||||
const int pad_height = data->padding.height;
|
||||
const int pad_width = data->padding.width;
|
||||
|
||||
const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
|
||||
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
|
||||
TFLITE_DCHECK_LE(activation_min, activation_max);
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
|
||||
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
|
||||
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
|
||||
const int input_size = input_width * input_height * depth;
|
||||
const int output_size = output_width * output_height * depth;
|
||||
|
||||
if (depth % 4 == 0) { // S3 version only supports channels multiple of 4
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
esp_nn_avg_pool_s8(input_data, input_width, input_height,
|
||||
output_data, output_width, output_height,
|
||||
stride_width, stride_height,
|
||||
filter_width, filter_height,
|
||||
pad_width, pad_height,
|
||||
activation_min, activation_max, depth);
|
||||
input_data += input_size;
|
||||
output_data += output_size;
|
||||
}
|
||||
} else {
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
esp_nn_avg_pool_s8_ansi(input_data, input_width, input_height,
|
||||
output_data, output_width, output_height,
|
||||
stride_width, stride_height,
|
||||
filter_width, filter_height,
|
||||
pad_width, pad_height,
|
||||
activation_min, activation_max, depth);
|
||||
input_data += input_size;
|
||||
output_data += output_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLitePoolParams* params, const OpDataPooling* data,
|
||||
const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
|
||||
|
||||
const int stride_height = params->stride_height;
|
||||
const int stride_width = params->stride_width;
|
||||
const int filter_height = params->filter_height;
|
||||
const int filter_width = params->filter_width;
|
||||
const int activation_min = data->activation_min;
|
||||
const int activation_max = data->activation_max;
|
||||
const int pad_height = data->padding.height;
|
||||
const int pad_width = data->padding.width;
|
||||
|
||||
const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
|
||||
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
|
||||
TFLITE_DCHECK_LE(activation_min, activation_max);
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
|
||||
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
|
||||
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||
|
||||
const int input_size = input_width * input_height * depth;
|
||||
const int output_size = output_width * output_height * depth;
|
||||
if (depth % 4 == 0) { // S3 version only supports channels multiple of 4
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
esp_nn_max_pool_s8(input_data, input_width, input_height,
|
||||
output_data, output_width, output_height,
|
||||
stride_width, stride_height,
|
||||
filter_width, filter_height,
|
||||
pad_width, pad_height,
|
||||
activation_min, activation_max, depth);
|
||||
input_data += input_size;
|
||||
output_data += output_size;
|
||||
}
|
||||
} else {
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
esp_nn_max_pool_s8_ansi(input_data, input_width, input_height,
|
||||
output_data, output_width, output_height,
|
||||
stride_width, stride_height,
|
||||
filter_width, filter_height,
|
||||
pad_width, pad_height,
|
||||
activation_min, activation_max, depth);
|
||||
input_data += input_size;
|
||||
output_data += output_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpDataPooling* data =
|
||||
static_cast<const OpDataPooling*>(node->user_data);
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
micro::GetEvalInput(context, node, kPoolingInputTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
micro::GetEvalOutput(context, node, kPoolingOutputTensor);
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
// Inputs and outputs share the same type, guaranteed by the converter.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
AveragePoolingEvalFloat(context, node, params, data, input, output);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
#if ESP_NN
|
||||
AverageEvalQuantized(context, node, params, data, input, output);
|
||||
#else
|
||||
AveragePoolingEvalQuantized(context, node, params, data, input, output);
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
pooling_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpDataPooling* data =
|
||||
static_cast<const OpDataPooling*>(node->user_data);
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
micro::GetEvalInput(context, node, kPoolingInputTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
micro::GetEvalOutput(context, node, kPoolingOutputTensor);
|
||||
|
||||
long long start_time = esp_timer_get_time();
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
MaxPoolingEvalFloat(context, node, params, data, input, output);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
#if ESP_NN
|
||||
MaxEvalQuantized(context, node, params, data, input, output);
|
||||
#else
|
||||
MaxPoolingEvalQuantized(context, node, params, data, input, output);
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
pooling_total_time += esp_timer_get_time() - start_time;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpDataPooling));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_AVERAGE_POOL_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PoolingPrepare,
|
||||
/*invoke=*/AverageEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_MAX_POOL_2D() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/PoolingPrepare,
|
||||
/*invoke=*/MaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -27,11 +27,15 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
|
||||
@@ -40,6 +44,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < output->dims->size; ++i) {
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -84,22 +84,31 @@ TfLiteStatus VerifyTensorDim(TfLiteContext* context, const TfLiteTensor* input,
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* axis;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* axis =
|
||||
micro_context->AllocateTempInputTensor(node, kAxisTensor);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
output->type = input->type;
|
||||
if (IsDynamicTensor(axis)) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"DynamicTensor is not yet supported by Expand_Dims.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
return VerifyTensorDim(context, input, axis, output);
|
||||
TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -65,14 +65,18 @@ constexpr int kValueTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// Ensure inputs and outputs exist.
|
||||
const TfLiteTensor* dims;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* dims =
|
||||
micro_context->AllocateTempInputTensor(node, kDimsTensor);
|
||||
TF_LITE_ENSURE(context, dims != nullptr);
|
||||
TfLiteTensor* value =
|
||||
micro_context->AllocateTempInputTensor(node, kValueTensor);
|
||||
TF_LITE_ENSURE(context, value != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
// The value tensor must be a scalar.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
|
||||
@@ -90,6 +94,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, EnsureEq(context, output->dims, dims));
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(dims);
|
||||
micro_context->DeallocateTempTfLiteTensor(value);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -31,22 +31,28 @@ constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, output->type);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,22 +36,28 @@ constexpr int kOutputTensor = 0;
|
||||
// OLD-TODO(b/117912880): Support quantization.
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, output->type);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
@@ -42,23 +44,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, kFullyConnectedInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kFullyConnectedInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter =
|
||||
GetInput(context, node, kFullyConnectedWeightsTensor);
|
||||
TfLiteTensor* filter = micro_context->AllocateTempInputTensor(
|
||||
node, kFullyConnectedWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kFullyConnectedBiasTensor);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
|
||||
node, kFullyConnectedOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
|
||||
return CalculateOpDataFullyConnected(context, params->activation, input->type,
|
||||
input, filter, bias, output, data);
|
||||
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
|
||||
context, params->activation, input->type,
|
||||
input, filter, bias, output, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
if (bias != nullptr) {
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
@@ -97,19 +97,23 @@ TfLiteStatus Gather(const TfLiteGatherParams* params,
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* coords;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputPositions, &coords));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* coords =
|
||||
micro_context->AllocateTempInputTensor(node, kInputPositions);
|
||||
TF_LITE_ENSURE(context, coords != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
switch (coords->type) {
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
@@ -176,6 +180,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = axis + 1; i < input->dims->size; ++i) {
|
||||
output_shape->data[output_index++] = input->dims->data[i];
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(coords);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -28,16 +28,19 @@ constexpr int kOutputTensor = 0;
|
||||
constexpr int MAX_INDICES_ND = 5;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* params;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* params = micro_context->AllocateTempInputTensor(node, kParams);
|
||||
TF_LITE_ENSURE(context, params != nullptr);
|
||||
TfLiteTensor* indices =
|
||||
micro_context->AllocateTempInputTensor(node, kIndices);
|
||||
TF_LITE_ENSURE(context, indices != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
switch (params->type) {
|
||||
case kTfLiteFloat32:
|
||||
@@ -98,6 +101,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_shape->data[output_index++] = params->dims->data[i];
|
||||
}
|
||||
output_shape->size = output_index;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(params);
|
||||
micro_context->DeallocateTempTfLiteTensor(indices);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,13 +32,17 @@ const int kHardSwishInputTensor = 0;
|
||||
const int kHardSwishOutputTensor = 0;
|
||||
|
||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kHardSwishInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kHardSwishInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kHardSwishOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kHardSwishOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
@@ -73,6 +77,9 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
¶ms->reluish_multiplier_fixedpoint_int16);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
#include "tensorflow/lite/micro/micro_graph.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@@ -50,36 +51,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, node->inputs->size > 0);
|
||||
|
||||
// The first input is the condition.
|
||||
const TfLiteTensor* cond;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
TfLiteTensor* cond = micro_context->AllocateTempInputTensor(node, 0);
|
||||
|
||||
TF_LITE_ENSURE(context, cond != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(cond), 1);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(cond);
|
||||
|
||||
// The first input of the node is the condition. The rest of inputs are
|
||||
// passed to the branch subgraphs. Therefore, the number of subgraph inputs
|
||||
// will be the number of node inputs - 1.
|
||||
size_t num_inputs = node->inputs->size - 1;
|
||||
size_t num_outputs = node->outputs->size;
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
op_data->then_subgraph_index < graph_info->NumSubgraphs());
|
||||
op_data->then_subgraph_index < graph_info.NumSubgraphs());
|
||||
TF_LITE_ENSURE(context,
|
||||
op_data->else_subgraph_index < graph_info->NumSubgraphs());
|
||||
op_data->else_subgraph_index < graph_info.NumSubgraphs());
|
||||
|
||||
TF_LITE_ENSURE_EQ(
|
||||
context, num_inputs,
|
||||
graph_info->NumSubgraphInputs(op_data->then_subgraph_index));
|
||||
TF_LITE_ENSURE_EQ(context, num_inputs,
|
||||
graph_info.NumSubgraphInputs(op_data->then_subgraph_index));
|
||||
TF_LITE_ENSURE_EQ(
|
||||
context, num_outputs,
|
||||
graph_info->NumSubgraphOutputs(op_data->then_subgraph_index));
|
||||
graph_info.NumSubgraphOutputs(op_data->then_subgraph_index));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@@ -87,66 +85,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* cond;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
TfLiteTensor* cond = micro_context->AllocateTempInputTensor(node, 0);
|
||||
|
||||
TF_LITE_ENSURE(context, cond != nullptr);
|
||||
bool cond_value = cond->data.b[0];
|
||||
micro_context->DeallocateTempTfLiteTensor(cond);
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
|
||||
// Currently we copy the input / output between the subgraphs. This isn't
|
||||
// optimized yet.
|
||||
MicroGraph* graph_info = µ_context->graph();
|
||||
// Currently we copy the input / output between the subgraphs.
|
||||
int active_branch_subgraph_index =
|
||||
cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index;
|
||||
|
||||
for (size_t i = 0;
|
||||
i < graph_info->NumSubgraphInputs(active_branch_subgraph_index); ++i) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, i + 1);
|
||||
|
||||
TfLiteEvalTensor* subgraph_input =
|
||||
graph_info->GetSubgraphInput(active_branch_subgraph_index, i);
|
||||
|
||||
// These checks must occur in Eval since TfLiteEvalTensors are not available
|
||||
// during Prepare.
|
||||
size_t input_bytes;
|
||||
size_t subgraph_input_bytes;
|
||||
TF_LITE_ENSURE_OK(context, TfLiteEvalTensorByteLength(input, &input_bytes));
|
||||
TF_LITE_ENSURE_OK(context, TfLiteEvalTensorByteLength(
|
||||
subgraph_input, &subgraph_input_bytes));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type);
|
||||
TF_LITE_ENSURE_EQ(context, input_bytes, subgraph_input_bytes);
|
||||
memcpy(subgraph_input->data.raw, input->data.raw, input_bytes);
|
||||
}
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::micro::CopyOpInputsToSubgraphInputs(
|
||||
context, node, graph_info, active_branch_subgraph_index,
|
||||
/*first_tensor_idx=*/1));
|
||||
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
graph_info->InvokeSubgraph(active_branch_subgraph_index));
|
||||
|
||||
for (size_t i = 0;
|
||||
i < graph_info->NumSubgraphOutputs(active_branch_subgraph_index); ++i) {
|
||||
const TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, i);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, tflite::micro::CopySubgraphOutputsToOpOutputs(
|
||||
context, node, graph_info, active_branch_subgraph_index));
|
||||
|
||||
TfLiteEvalTensor* subgraph_output =
|
||||
graph_info->GetSubgraphOutput(active_branch_subgraph_index, i);
|
||||
|
||||
// These checks must occur in Eval since TfLiteEvalTensors are not available
|
||||
// during Prepare.
|
||||
size_t output_bytes;
|
||||
size_t subgraph_output_bytes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
TfLiteEvalTensorByteLength(output, &output_bytes));
|
||||
TF_LITE_ENSURE_OK(context, TfLiteEvalTensorByteLength(
|
||||
subgraph_output, &subgraph_output_bytes));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, subgraph_output->type);
|
||||
TF_LITE_ENSURE_EQ(context, output_bytes, subgraph_output_bytes);
|
||||
memcpy(output->data.raw, subgraph_output->data.raw, output_bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ namespace tflite {
|
||||
namespace micro {
|
||||
|
||||
// TODO(b/161841696): Consider moving away from global arena buffers:
|
||||
constexpr int KernelRunner::kNumScratchBuffers_;
|
||||
constexpr int KernelRunner::kKernelRunnerBufferSize_;
|
||||
uint8_t KernelRunner::kKernelRunnerBuffer_[];
|
||||
|
||||
@@ -32,22 +31,23 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||
TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
||||
void* builtin_data)
|
||||
: allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
|
||||
: registration_(registration),
|
||||
allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
|
||||
kKernelRunnerBuffer_,
|
||||
kKernelRunnerBufferSize_)),
|
||||
registration_(registration),
|
||||
tensors_(tensors),
|
||||
mock_micro_graph_(allocator_) {
|
||||
mock_micro_graph_(allocator_),
|
||||
fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
|
||||
// Prepare TfLiteContext:
|
||||
context_.impl_ = static_cast<void*>(this);
|
||||
context_.ReportError = ReportOpError;
|
||||
context_.impl_ = static_cast<void*>(&fake_micro_context_);
|
||||
context_.ReportError = MicroContextReportOpError;
|
||||
context_.recommended_num_threads = 1;
|
||||
context_.GetTensor = GetTensor;
|
||||
context_.GetEvalTensor = GetEvalTensor;
|
||||
context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
|
||||
context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
|
||||
context_.GetScratchBuffer = GetScratchBuffer;
|
||||
context_.GetExecutionPlan = GetGraph;
|
||||
context_.GetTensor = MicroContextGetTensor;
|
||||
context_.GetEvalTensor = MicroContextGetEvalTensor;
|
||||
context_.AllocatePersistentBuffer = MicroContextAllocatePersistentBuffer;
|
||||
context_.RequestScratchBufferInArena =
|
||||
MicroContextRequestScratchBufferInArena;
|
||||
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
|
||||
|
||||
context_.recommended_num_threads = 0;
|
||||
|
||||
// Prepare TfLiteNode:
|
||||
@@ -56,14 +56,24 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||
node_.builtin_data = builtin_data;
|
||||
}
|
||||
|
||||
bool KernelRunner::ValidateTempBufferDeallocated() {
|
||||
return fake_micro_context_.IsAllTempTfLiteTensorDeallocated();
|
||||
}
|
||||
|
||||
TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
|
||||
size_t length) {
|
||||
if (registration_.init) {
|
||||
node_.user_data = registration_.init(&context_, init_data, length);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(&context_, ValidateTempBufferDeallocated());
|
||||
|
||||
if (registration_.prepare) {
|
||||
TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(&context_, ValidateTempBufferDeallocated());
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -72,101 +82,11 @@ TfLiteStatus KernelRunner::Invoke() {
|
||||
MicroPrintf("TfLiteRegistration missing invoke function pointer!");
|
||||
return kTfLiteError;
|
||||
}
|
||||
return registration_.invoke(&context_, &node_);
|
||||
}
|
||||
|
||||
TfLiteTensor* KernelRunner::GetTensor(const struct TfLiteContext* context,
|
||||
int tensor_index) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
TF_LITE_ENSURE_STATUS(registration_.invoke(&context_, &node_));
|
||||
|
||||
return &runner->tensors_[tensor_index];
|
||||
}
|
||||
TF_LITE_ENSURE(&context_, ValidateTempBufferDeallocated());
|
||||
|
||||
TfLiteEvalTensor* KernelRunner::GetEvalTensor(
|
||||
const struct TfLiteContext* context, int tensor_index) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
|
||||
TfLiteEvalTensor* eval_tensor =
|
||||
reinterpret_cast<TfLiteEvalTensor*>(runner->allocator_->AllocateTemp(
|
||||
sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
|
||||
TFLITE_DCHECK(eval_tensor != nullptr);
|
||||
|
||||
// In unit tests, the TfLiteTensor pointer contains the source of truth for
|
||||
// buffers and values:
|
||||
eval_tensor->data = runner->tensors_[tensor_index].data;
|
||||
eval_tensor->dims = runner->tensors_[tensor_index].dims;
|
||||
eval_tensor->type = runner->tensors_[tensor_index].type;
|
||||
return eval_tensor;
|
||||
}
|
||||
|
||||
void* KernelRunner::AllocatePersistentBuffer(TfLiteContext* context,
|
||||
size_t bytes) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
|
||||
return runner->allocator_->AllocateFromTail(bytes,
|
||||
MicroArenaBufferAlignment());
|
||||
}
|
||||
|
||||
TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
|
||||
size_t bytes,
|
||||
int* buffer_index) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
TFLITE_DCHECK(buffer_index != nullptr);
|
||||
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
|
||||
if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
|
||||
MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
|
||||
kNumScratchBuffers_);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// For tests, we allocate scratch buffers from the tail and keep them around
|
||||
// for the lifetime of model. This means that the arena size in the tests will
|
||||
// be more than what we would have if the scratch buffers could share memory.
|
||||
runner->scratch_buffers_[runner->scratch_buffer_count_] =
|
||||
runner->allocator_->AllocateFromTail(bytes, MicroArenaBufferAlignment());
|
||||
TFLITE_DCHECK(runner->scratch_buffers_[runner->scratch_buffer_count_] !=
|
||||
nullptr);
|
||||
|
||||
*buffer_index = runner->scratch_buffer_count_++;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
|
||||
TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
|
||||
if (buffer_index >= runner->scratch_buffer_count_) {
|
||||
return nullptr;
|
||||
}
|
||||
return runner->scratch_buffers_[buffer_index];
|
||||
}
|
||||
|
||||
void KernelRunner::ReportOpError(struct TfLiteContext* context,
|
||||
const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
GetMicroErrorReporter()->Report(format, args);
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
TfLiteStatus KernelRunner::GetGraph(struct TfLiteContext* context,
|
||||
TfLiteIntArray** args) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
*args = reinterpret_cast<TfLiteIntArray*>(runner->GetMockGraph());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/micro/fake_micro_context.h"
|
||||
#include "tensorflow/lite/micro/mock_micro_graph.h"
|
||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
||||
|
||||
@@ -50,40 +51,22 @@ class KernelRunner {
|
||||
// to stub out MicroGraph methods and track invocations on each subgraph.
|
||||
MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; }
|
||||
|
||||
protected:
|
||||
static TfLiteTensor* GetTensor(const struct TfLiteContext* context,
|
||||
int tensor_index);
|
||||
static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
|
||||
int tensor_index);
|
||||
static void* AllocatePersistentBuffer(TfLiteContext* context, size_t bytes);
|
||||
static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* context,
|
||||
size_t bytes,
|
||||
int* buffer_index);
|
||||
static void* GetScratchBuffer(TfLiteContext* context, int buffer_index);
|
||||
static void ReportOpError(struct TfLiteContext* context, const char* format,
|
||||
...);
|
||||
// This method matches GetExecutionPlan from TfLiteContext since TFLM reuses
|
||||
// this method to get the MicroGraph from an operator context.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
static TfLiteStatus GetGraph(struct TfLiteContext* context,
|
||||
TfLiteIntArray** args);
|
||||
// Returns true if all temp buffer in tests are deallocated.
|
||||
// TODO(b/209453859): move this function to private after deallocation checks
|
||||
// are enabled for all kernel tests.
|
||||
bool ValidateTempBufferDeallocated();
|
||||
|
||||
private:
|
||||
static constexpr int kNumScratchBuffers_ = 12;
|
||||
|
||||
static constexpr int kKernelRunnerBufferSize_ = 10000;
|
||||
static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_];
|
||||
|
||||
SimpleMemoryAllocator* allocator_ = nullptr;
|
||||
const TfLiteRegistration& registration_;
|
||||
TfLiteTensor* tensors_ = nullptr;
|
||||
MockMicroGraph mock_micro_graph_;
|
||||
|
||||
TfLiteContext context_ = {};
|
||||
TfLiteNode node_ = {};
|
||||
const TfLiteRegistration& registration_;
|
||||
|
||||
int scratch_buffer_count_ = 0;
|
||||
uint8_t* scratch_buffers_[kNumScratchBuffers_];
|
||||
SimpleMemoryAllocator* allocator_;
|
||||
MockMicroGraph mock_micro_graph_;
|
||||
FakeMicroContext fake_micro_context_;
|
||||
};
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -119,13 +120,83 @@ TfLiteStatus CreateWritableTensorDimsWithCopy(TfLiteContext* context,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Returns a blob of payload data. The payload is subjected to interpretation by
|
||||
// the OP. This is the recommended API for an OP to get an external context. OP
|
||||
// should use this instead of directly calling GetExternalContext function in
|
||||
// context.
|
||||
void* GetExternalContext(TfLiteContext* context) {
|
||||
return reinterpret_cast<void*>(
|
||||
context->GetExternalContext(context, kTfLiteMaxExternalContexts));
|
||||
// Verify that both tensors have the same type and size, then return the size
|
||||
// of both tensors in bytes if they are the same, or -1 if they are different.
|
||||
size_t ValidateAndGetTensorSizes(const TfLiteEvalTensor* tensor1,
|
||||
const TfLiteEvalTensor* tensor2) {
|
||||
TFLITE_DCHECK(tensor1->type == tensor2->type);
|
||||
size_t tensor1_size = 0;
|
||||
size_t tensor2_size = 0;
|
||||
TfLiteEvalTensorByteLength(tensor1, &tensor1_size);
|
||||
TfLiteEvalTensorByteLength(tensor2, &tensor2_size);
|
||||
return (tensor1_size == tensor2_size) ? tensor1_size : -1;
|
||||
}
|
||||
|
||||
TfLiteStatus CopyOpInputsToOpOutputs(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, node->inputs->size == node->outputs->size);
|
||||
for (int i = 0; i < node->inputs->size; i++) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, i);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, i);
|
||||
int bytes = ValidateAndGetTensorSizes(input, output);
|
||||
TF_LITE_ENSURE(context, bytes >= 0);
|
||||
memcpy(output->data.raw, input->data.raw, bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CopyOpInputsToSubgraphInputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
int subgraph_idx,
|
||||
int first_tensor_idx) {
|
||||
TF_LITE_ENSURE(context,
|
||||
static_cast<size_t>(node->inputs->size - first_tensor_idx) ==
|
||||
graph_info->NumSubgraphInputs(subgraph_idx));
|
||||
for (int i = 0; i < node->inputs->size - first_tensor_idx; i++) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, i + first_tensor_idx);
|
||||
TfLiteEvalTensor* subgraph_input =
|
||||
graph_info->GetSubgraphInput(subgraph_idx, i);
|
||||
int bytes = ValidateAndGetTensorSizes(input, subgraph_input);
|
||||
TF_LITE_ENSURE(context, bytes >= 0);
|
||||
memcpy(subgraph_input->data.raw, input->data.raw, bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CopyOpOutputsToSubgraphInputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
int subgraph_idx) {
|
||||
TF_LITE_ENSURE(context, static_cast<size_t>(node->outputs->size) ==
|
||||
graph_info->NumSubgraphInputs(subgraph_idx));
|
||||
for (int i = 0; i < node->outputs->size; i++) {
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, i);
|
||||
TfLiteEvalTensor* subgraph_input =
|
||||
graph_info->GetSubgraphInput(subgraph_idx, i);
|
||||
int bytes = ValidateAndGetTensorSizes(output, subgraph_input);
|
||||
TF_LITE_ENSURE(context, bytes >= 0);
|
||||
memcpy(subgraph_input->data.raw, output->data.raw, bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CopySubgraphOutputsToOpOutputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
int subgraph_idx) {
|
||||
TF_LITE_ENSURE(context, static_cast<size_t>(node->outputs->size) ==
|
||||
graph_info->NumSubgraphOutputs(subgraph_idx));
|
||||
for (int i = 0; i < node->outputs->size; i++) {
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, i);
|
||||
TfLiteEvalTensor* subgraph_output =
|
||||
graph_info->GetSubgraphOutput(subgraph_idx, i);
|
||||
int bytes = ValidateAndGetTensorSizes(output, subgraph_output);
|
||||
TF_LITE_ENSURE(context, bytes >= 0);
|
||||
memcpy(output->data.raw, subgraph_output->data.raw, bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
||||
@@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -69,23 +70,33 @@ TfLiteStatus CreateWritableTensorDimsWithCopy(TfLiteContext* context,
|
||||
TfLiteTensor* tensor,
|
||||
TfLiteEvalTensor* eval_tensor);
|
||||
|
||||
// Returns a blob of payload data. The payload is subjected to interpretation by
|
||||
// the OP. This is the recommended API for an OP to get an external context. OP
|
||||
// should use this instead of directly calling GetExternalContext function in
|
||||
// context. Example usage:
|
||||
//
|
||||
// An application can set an external context through interpreter as below
|
||||
// interpreter->SetMicroExternalContext(pointer_to_your_payload);
|
||||
//
|
||||
// Inside an OP that needs this payload, it get the payload pointer by:
|
||||
// Prepare(TfliteContext * context) {
|
||||
// ...
|
||||
// payload_ptr =
|
||||
// reinterpret_cast<your_data_type>(GetMicroExternalContext(context))
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
void* GetMicroExternalContext(TfLiteContext* context);
|
||||
// Copy all op input tensors to op output tensors. Requires all op input tensor
|
||||
// shapes and types to be identical to op output tensor shapes and types.
|
||||
TfLiteStatus CopyOpInputsToOpOutputs(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// Copy all op input tensors to subgraph input tensors. Requires all op input
|
||||
// tensor shapes and types to be identical to subgraph input tensor shapes and
|
||||
// types.
|
||||
TfLiteStatus CopyOpInputsToSubgraphInputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
int subgraph_idx,
|
||||
int first_tensor_idx);
|
||||
|
||||
// Copy all op output tensors to subgraph input tensors. Requires all op output
|
||||
// tensor shapes and types to be identical to subgraph input tensor shapes and
|
||||
// types.
|
||||
TfLiteStatus CopyOpOutputsToSubgraphInputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
int subgraph_idx);
|
||||
|
||||
// Copy all subgraph output tensors to op outputs. Requires all subgraph output
|
||||
// tensor shapes and types to be identical to op output tensor shapes and types.
|
||||
TfLiteStatus CopySubgraphOutputsToOpOutputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
int subgraph_idx);
|
||||
|
||||
} // namespace micro
|
||||
} // namespace tflite
|
||||
|
||||
@@ -36,15 +36,18 @@ constexpr int kTensorShapeRank = 4;
|
||||
enum { kBatchRank = 0, kHeightRank, kWidthRank, kChannelRank };
|
||||
|
||||
TfLiteStatus L2Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
auto* params = static_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kTensorShapeRank);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kTensorShapeRank);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
@@ -82,6 +85,9 @@ TfLiteStatus L2Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output->dims->data[kWidthRank] = out_width;
|
||||
output->dims->data[kChannelRank] = channels_out;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -49,11 +49,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
@@ -69,6 +72,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Our implementations don't currently support activations.
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -30,13 +30,16 @@ const int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus CalculateOpDataLeakyRelu(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
|
||||
@@ -62,6 +65,9 @@ TfLiteStatus CalculateOpDataLeakyRelu(TfLiteContext* context,
|
||||
data->output_shift_identity = static_cast<int32_t>(output_shift_identity);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -43,13 +43,16 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
TF_LITE_ENSURE(context, HaveSameShapes(input, output));
|
||||
@@ -89,6 +92,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->depth = static_cast<size_t>(input_shape.Dims(trailing_dim));
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,9 +32,13 @@ const int kLogisticOutputTensor = 0;
|
||||
TfLiteStatus CalculateArithmeticOpDataLogistic(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
OpDataLogistic* data) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kLogisticInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kLogisticInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kLogisticOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kLogisticOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
@@ -55,6 +59,53 @@ TfLiteStatus CalculateArithmeticOpDataLogistic(TfLiteContext* context,
|
||||
data->input_range_radius =
|
||||
CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteInt16) {
|
||||
static constexpr int kInputIntegerBits = 3;
|
||||
static constexpr int kOutputFractionalBits = 15;
|
||||
|
||||
// See comments in TanhPrepare about requiring zero_point==0
|
||||
// and a power-of-two ("POT") scale.
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
|
||||
int input_scale_log2_rounded;
|
||||
bool param_scale_pot =
|
||||
CheckedLog2(input->params.scale, &input_scale_log2_rounded);
|
||||
|
||||
data->input_left_shift =
|
||||
(15 - kInputIntegerBits) + input_scale_log2_rounded;
|
||||
param_scale_pot &= (data->input_left_shift == 0);
|
||||
|
||||
if (param_scale_pot) {
|
||||
data->input_multiplier = 0;
|
||||
} else {
|
||||
// Calculate multiplier to change input scale to 1/(3*4096)
|
||||
// as required by the table lookup.
|
||||
// In this scaling +/-2^17 represents +/-10.7
|
||||
double multiplier =
|
||||
static_cast<double>(input->params.scale) * 4096.0 * 3.0;
|
||||
|
||||
data->input_left_shift = 0;
|
||||
|
||||
while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) {
|
||||
data->input_left_shift++;
|
||||
multiplier = multiplier * 2.0;
|
||||
}
|
||||
|
||||
data->input_multiplier = static_cast<int32_t>(multiplier);
|
||||
}
|
||||
|
||||
int output_scale_log2_rounded;
|
||||
TF_LITE_ENSURE(
|
||||
context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
|
||||
TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
|
||||
-kOutputFractionalBits);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ TfLiteRegistration Register_ADD_N();
|
||||
TfLiteRegistration Register_ASSIGN_VARIABLE();
|
||||
TfLiteRegistration Register_AVERAGE_POOL_2D();
|
||||
TfLiteRegistration Register_BATCH_TO_SPACE_ND();
|
||||
TfLiteRegistration Register_BROADCAST_ARGS();
|
||||
TfLiteRegistration Register_BROADCAST_TO();
|
||||
TfLiteRegistration Register_CALL_ONCE();
|
||||
TfLiteRegistration Register_CAST();
|
||||
// TODO(b/160234179): Change custom OPs to also return by value.
|
||||
@@ -62,6 +64,7 @@ TfLiteRegistration Register_LOGICAL_AND();
|
||||
TfLiteRegistration Register_LOGICAL_OR();
|
||||
TfLiteRegistration Register_LOGISTIC();
|
||||
TfLiteRegistration Register_MAX_POOL_2D();
|
||||
TfLiteRegistration Register_MIRROR_PAD();
|
||||
TfLiteRegistration Register_PRELU();
|
||||
TfLiteRegistration Register_MUL();
|
||||
TfLiteRegistration Register_QUANTIZE();
|
||||
@@ -79,6 +82,7 @@ TfLiteRegistration Register_SVDF();
|
||||
TfLiteRegistration Register_TRANSPOSE();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV();
|
||||
TfLiteRegistration Register_VAR_HANDLE();
|
||||
TfLiteRegistration Register_WHILE();
|
||||
TfLiteRegistration Register_ZEROS_LIKE();
|
||||
|
||||
namespace ops {
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct OpDataMirrorPad {
|
||||
int input_dims;
|
||||
int output_size;
|
||||
int offset;
|
||||
int output_dims_num_elements_buffer_index;
|
||||
int input_dims_num_elements_buffer_index;
|
||||
};
|
||||
|
||||
// Helper method that fills the left and right pads.
|
||||
template <typename T>
|
||||
inline void GetPadding(const T* data, int offset, int64_t* left_pad,
|
||||
int64_t* right_pad) {
|
||||
*left_pad = static_cast<int64_t>(*(data + offset * 2));
|
||||
*right_pad = static_cast<int64_t>(*(data + offset * 2 + 1));
|
||||
}
|
||||
|
||||
// Given dimension index and the left/right padding.
|
||||
// Returns the corresponding dimension in the input array.
|
||||
inline int GetInputDimension(int padded_dimension, int left_pad, int right_pad,
|
||||
int input_dim_size, int offset) {
|
||||
if (padded_dimension < left_pad) {
|
||||
const int original_ind = left_pad + offset - 1;
|
||||
return original_ind - (std::min(padded_dimension, original_ind - offset));
|
||||
}
|
||||
padded_dimension -= left_pad;
|
||||
if (padded_dimension >= input_dim_size) {
|
||||
padded_dimension -= input_dim_size;
|
||||
const int original_ind = input_dim_size - (1 + offset);
|
||||
return original_ind - std::min(padded_dimension, original_ind);
|
||||
}
|
||||
return padded_dimension;
|
||||
}
|
||||
|
||||
// Given and index in output array, returns the index of the value
|
||||
// in input array.
|
||||
int GetFlatIndex(int index, int num_dims,
|
||||
const TfLiteEvalTensor* padding_matrix,
|
||||
const TfLiteIntArray* input_dims,
|
||||
int* output_dims_num_elements, int* input_dims_num_elements,
|
||||
const int offset) {
|
||||
int flat_index = 0;
|
||||
int64_t left_pad = 0, right_pad = 0, dimension_index, index_in_input;
|
||||
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
switch (padding_matrix->type) {
|
||||
case kTfLiteInt32:
|
||||
GetPadding(padding_matrix->data.i32, i, &left_pad, &right_pad);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
GetPadding(padding_matrix->data.i64, i, &left_pad, &right_pad);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
dimension_index = index / output_dims_num_elements[i];
|
||||
|
||||
index_in_input = GetInputDimension(dimension_index, left_pad, right_pad,
|
||||
input_dims->data[i], offset);
|
||||
|
||||
flat_index += index_in_input * (input_dims_num_elements)[i];
|
||||
index %= output_dims_num_elements[i];
|
||||
}
|
||||
|
||||
return flat_index;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MirrorPad(const TfLiteEvalTensor* padding_matrix,
|
||||
const TfLiteIntArray* input_dims, int* output_dims_num_elements,
|
||||
int* input_dims_num_elements, const T* input_data,
|
||||
T* output_data, const int offset, const int num_dims,
|
||||
const int output_size) {
|
||||
for (int i = 0; i < output_size; ++i) {
|
||||
output_data[i] = input_data[GetFlatIndex(
|
||||
i, num_dims, padding_matrix, input_dims, output_dims_num_elements,
|
||||
input_dims_num_elements, offset)];
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
const OpDataMirrorPad* data =
|
||||
static_cast<const OpDataMirrorPad*>(node->user_data);
|
||||
|
||||
const TfLiteEvalTensor* input_tensor =
|
||||
tflite::micro::GetEvalInput(context, node, 0);
|
||||
const TfLiteEvalTensor* padding_matrix =
|
||||
tflite::micro::GetEvalInput(context, node, 1);
|
||||
|
||||
TfLiteEvalTensor* output_tensor =
|
||||
tflite::micro::GetEvalOutput(context, node, 0);
|
||||
const int input_dims = data->input_dims;
|
||||
const int output_size = data->output_size;
|
||||
|
||||
int* input_dims_num_elements = (int*)context->GetScratchBuffer(
|
||||
context, data->input_dims_num_elements_buffer_index);
|
||||
int* output_dims_num_elements = (int*)context->GetScratchBuffer(
|
||||
context, data->output_dims_num_elements_buffer_index);
|
||||
|
||||
for (int i = 0; i < input_dims; i++) {
|
||||
output_dims_num_elements[i] = 1;
|
||||
input_dims_num_elements[i] = 1;
|
||||
}
|
||||
|
||||
for (int i = input_dims - 2; i >= 0; i--) {
|
||||
output_dims_num_elements[i] =
|
||||
output_dims_num_elements[i + 1] * output_tensor->dims->data[i + 1];
|
||||
|
||||
input_dims_num_elements[i] =
|
||||
input_dims_num_elements[i + 1] * input_tensor->dims->data[i + 1];
|
||||
}
|
||||
|
||||
switch (output_tensor->type) {
|
||||
case kTfLiteFloat32: {
|
||||
MirrorPad(padding_matrix, input_tensor->dims, output_dims_num_elements,
|
||||
input_dims_num_elements,
|
||||
tflite::micro::GetTensorData<float>(input_tensor),
|
||||
tflite::micro::GetTensorData<float>(output_tensor),
|
||||
data->offset, input_dims, output_size);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
MirrorPad(padding_matrix, input_tensor->dims, output_dims_num_elements,
|
||||
input_dims_num_elements,
|
||||
tflite::micro::GetTensorData<int8_t>(input_tensor),
|
||||
tflite::micro::GetTensorData<int8_t>(output_tensor),
|
||||
data->offset, input_dims, output_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
status = kTfLiteError;
|
||||
break;
|
||||
}
|
||||
|
||||
#undef TF_LITE_MIRROR_PAD
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpDataMirrorPad));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpDataMirrorPad* data = static_cast<OpDataMirrorPad*>(node->user_data);
|
||||
|
||||
TfLiteTensor* input_tensor = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TfLiteTensor* padding_matrix =
|
||||
micro_context->AllocateTempInputTensor(node, 1);
|
||||
TfLiteTensor* output_tensor =
|
||||
micro_context->AllocateTempOutputTensor(node, 0);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(padding_matrix), 2);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(padding_matrix, 0),
|
||||
NumDimensions(input_tensor));
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data);
|
||||
if (params == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
data->offset =
|
||||
params->mode != TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect ? 0
|
||||
: 1;
|
||||
data->input_dims = NumDimensions(input_tensor);
|
||||
data->output_size = NumElements(output_tensor);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, data->input_dims * sizeof(int),
|
||||
&data->output_dims_num_elements_buffer_index));
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, data->input_dims * sizeof(int),
|
||||
&data->input_dims_num_elements_buffer_index));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_tensor);
|
||||
micro_context->DeallocateTempTfLiteTensor(padding_matrix);
|
||||
micro_context->DeallocateTempTfLiteTensor(output_tensor);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_MIRROR_PAD() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -37,11 +37,16 @@ void* MulInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
|
||||
TfLiteStatus CalculateOpDataMul(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteMulParams* params, OpDataMul* data) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kMulInput1Tensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kMulInput1Tensor);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kMulInput2Tensor);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kMulInput2Tensor);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kMulOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kMulOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
@@ -72,6 +77,9 @@ TfLiteStatus CalculateOpDataMul(TfLiteContext* context, TfLiteNode* node,
|
||||
&data->output_activation_max_f32);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -43,19 +43,26 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, /*index=*/0);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, /*index=*/0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* paddings = GetInput(context, node, /*index=*/1);
|
||||
TfLiteTensor* paddings =
|
||||
micro_context->AllocateTempInputTensor(node, /*index=*/1);
|
||||
TF_LITE_ENSURE(context, paddings != nullptr);
|
||||
const TfLiteTensor* constant_values =
|
||||
NumInputs(node) == 3 ? GetInput(context, node, /*index=*/2) : nullptr;
|
||||
TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
|
||||
TfLiteTensor* constant_values =
|
||||
NumInputs(node) == 3
|
||||
? micro_context->AllocateTempInputTensor(node, /*index=*/2)
|
||||
: nullptr;
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, /*index=*/0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
@@ -122,6 +129,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(paddings);
|
||||
if (constant_values != nullptr) {
|
||||
micro_context->DeallocateTempTfLiteTensor(constant_values);
|
||||
}
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -54,9 +54,13 @@ TfLiteStatus PoolingPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpDataPooling* data = static_cast<OpDataPooling*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kPoolingInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kPoolingInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kPoolingOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kPoolingOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
@@ -71,6 +75,9 @@ TfLiteStatus PoolingPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
&data->activation_max);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -84,14 +84,22 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
PreluParams* params = static_cast<PreluParams*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
TfLiteTensor* alpha = micro_context->AllocateTempInputTensor(node, 1);
|
||||
TF_LITE_ENSURE(context, alpha != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
return CalculatePreluParams(input, alpha, output, params);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
CalculatePreluParams(input, alpha, output, params));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(alpha);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -36,9 +36,11 @@ TfLiteStatus PrepareQuantizeReference(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
|
||||
@@ -77,6 +79,9 @@ TfLiteStatus PrepareQuantizeReference(TfLiteContext* context,
|
||||
data->quantization_params.scale = static_cast<double>(output->params.scale);
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,13 +39,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(NumInputs(node) == 1);
|
||||
TFLITE_DCHECK(NumOutputs(node) == 1);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input_resource_id_tensor =
|
||||
micro_context->AllocateTempInputTensor(node, kInputVariableId);
|
||||
|
||||
TFLITE_DCHECK(input_resource_id_tensor != nullptr);
|
||||
TFLITE_DCHECK(input_resource_id_tensor->type == kTfLiteResource);
|
||||
TFLITE_DCHECK(NumElements(input_resource_id_tensor) == 1);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_resource_id_tensor);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -58,14 +62,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputValue);
|
||||
TFLITE_DCHECK(output_value != nullptr);
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
MicroResourceVariables* resources = graph_info->GetResourceVariables();
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
MicroResourceVariables* resources = graph_info.GetResourceVariables();
|
||||
if (resources == nullptr) {
|
||||
MicroPrintf(
|
||||
"READ_VARIABLE requires resource variables. Please create "
|
||||
|
||||
@@ -50,10 +50,12 @@ void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// Inputs Tensor (dtype depends on quantization):
|
||||
// [0] = Input
|
||||
// [1] = Axis
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
|
||||
// Outputs Tensor (dtype depends on quantization):
|
||||
// [0] = Output
|
||||
@@ -63,28 +65,31 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
// Validate axis type
|
||||
const TfLiteTensor* axis = GetInput(context, node, 1);
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
const double real_multiplier = static_cast<double>(input->params.scale) /
|
||||
static_cast<double>(output->params.scale);
|
||||
QuantizeMultiplier(real_multiplier, &data->multiplier, &data->shift);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* axis = GetInput(context, node, 1);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
|
||||
|
||||
op_data->input_scale = input->params.scale;
|
||||
op_data->output_scale = output->params.scale;
|
||||
@@ -96,13 +101,17 @@ TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
|
||||
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
|
||||
&op_data->resolved_axis_idx);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
|
||||
const double real_multiplier = static_cast<double>(input->params.scale) /
|
||||
static_cast<double>(output->params.scale);
|
||||
@@ -121,6 +130,8 @@ TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
|
||||
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
@@ -31,9 +33,13 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
// Tensorflow's Reshape allows one of the shape components to have the
|
||||
// special -1 value, meaning it will be calculated automatically based on the
|
||||
@@ -68,6 +74,9 @@ TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -93,9 +102,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Do nothing for in-place reshape.
|
||||
if (input->data.raw != output->data.raw) {
|
||||
// Otherwise perform reshape with copy.
|
||||
for (size_t i = 0; i < input_bytes; ++i) {
|
||||
output->data.raw[i] = input->data.raw[i];
|
||||
}
|
||||
memcpy(output->data.raw, input->data.raw, input_bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -30,12 +30,17 @@ constexpr int kSizeTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TfLiteTensor* size =
|
||||
micro_context->AllocateTempInputTensor(node, kSizeTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
|
||||
@@ -55,6 +60,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(size);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,12 +33,17 @@ constexpr int kSizeTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TfLiteTensor* size =
|
||||
micro_context->AllocateTempInputTensor(node, kSizeTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
// Our current implementations rely on the input being 4D,
|
||||
// and the size being 1D tensor with exactly 2 elements.
|
||||
@@ -53,6 +58,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_KERNEL_LOG(context, "Dynamic tensors are unsupported in tfmicro.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(size);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -29,9 +29,13 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
@@ -42,6 +46,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < output->dims->size; ++i) {
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,16 +45,22 @@ void GetBeginAndSizeVectors(int dimensions, const TfLiteEvalTensor* begin,
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TFLITE_DCHECK(input != nullptr);
|
||||
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
|
||||
TfLiteTensor* begin =
|
||||
micro_context->AllocateTempInputTensor(node, kBeginTensor);
|
||||
TFLITE_DCHECK(begin != nullptr);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* size =
|
||||
micro_context->AllocateTempInputTensor(node, kSizeTensor);
|
||||
TFLITE_DCHECK(size != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TFLITE_DCHECK(output != nullptr);
|
||||
|
||||
// Ensure validity of input tensor and its dimension.
|
||||
@@ -66,6 +72,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(NumDimensions(size) == 1);
|
||||
TFLITE_DCHECK(NumElements(begin) == NumElements(size));
|
||||
TFLITE_DCHECK(NumDimensions(input) <= kMaxDim);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(begin);
|
||||
micro_context->DeallocateTempTfLiteTensor(size);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@@ -90,12 +91,14 @@ void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
@@ -136,7 +139,12 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
return CalculateSoftmaxParams(context, input, output, params, op_data);
|
||||
auto ret_val =
|
||||
CalculateSoftmaxParams(context, input, output, params, op_data);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return ret_val;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -44,11 +44,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||
@@ -57,6 +61,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,11 +39,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
|
||||
@@ -75,6 +78,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output->dims->data[kDepthRank] =
|
||||
input->dims->data[kDepthRank] * block_size * block_size;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,8 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* axis = GetInput(context, node, 0);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 0);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
|
||||
// Dynamic output tensors are needed if axis tensor is not constant.
|
||||
@@ -77,6 +78,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// constant axis tensor for now.
|
||||
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
|
||||
"Non constant axis tensor not supported");
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -74,13 +74,14 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
// Dynamic output tensors are needed if axis tensor is not constant.
|
||||
// But Micro doesn't support dynamic memory allocation, so we only support
|
||||
// constant axis tensor for now.
|
||||
const TfLiteTensor* axis = GetInput(context, node, 2);
|
||||
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 2);
|
||||
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
|
||||
"Non constant axis tensor not supported");
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(axis);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -27,12 +27,19 @@ namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct SqueezeContext {
|
||||
SqueezeContext(TfLiteContext* context, TfLiteNode* node)
|
||||
: params(reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data)),
|
||||
input(GetInput(context, node, 0)),
|
||||
output(GetOutput(context, node, 0)) {}
|
||||
SqueezeContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
params = reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data);
|
||||
micro_context = GetMicroContext(context);
|
||||
input = micro_context->AllocateTempInputTensor(node, 0);
|
||||
output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||
}
|
||||
~SqueezeContext() {
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
}
|
||||
MicroContext* micro_context;
|
||||
TfLiteSqueezeParams* params;
|
||||
const TfLiteTensor* const input;
|
||||
TfLiteTensor* input;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
@@ -80,18 +87,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
SqueezeContext op_context(context, node);
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
|
||||
if (op_context.input->type == kTfLiteString) {
|
||||
if (input->type == kTfLiteString) {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(op_context.input->type),
|
||||
op_context.input->type);
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes);
|
||||
memcpy(op_context.output->data.raw, op_context.input->data.raw,
|
||||
op_context.input->bytes);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
size_t input_byte_size;
|
||||
size_t output_byte_size;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
TfLiteEvalTensorByteLength(input, &input_byte_size));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
TfLiteEvalTensorByteLength(output, &output_byte_size));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input_byte_size, output_byte_size);
|
||||
memcpy(output->data.raw, input->data.raw, input_byte_size);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -38,18 +38,27 @@ constexpr int kOutputTensor = 0;
|
||||
struct StridedSliceContext {
|
||||
StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
|
||||
input = GetInput(context, node, kInputTensor);
|
||||
begin = GetInput(context, node, kBeginTensor);
|
||||
end = GetInput(context, node, kEndTensor);
|
||||
strides = GetInput(context, node, kStridesTensor);
|
||||
output = GetOutput(context, node, kOutputTensor);
|
||||
micro_context = GetMicroContext(context);
|
||||
input = micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
begin = micro_context->AllocateTempInputTensor(node, kBeginTensor);
|
||||
end = micro_context->AllocateTempInputTensor(node, kEndTensor);
|
||||
strides = micro_context->AllocateTempInputTensor(node, kStridesTensor);
|
||||
output = micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
dims = NumDimensions(input);
|
||||
}
|
||||
~StridedSliceContext() {
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(begin);
|
||||
micro_context->DeallocateTempTfLiteTensor(end);
|
||||
micro_context->DeallocateTempTfLiteTensor(strides);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
}
|
||||
const TfLiteStridedSliceParams* params;
|
||||
const TfLiteTensor* input;
|
||||
const TfLiteTensor* begin;
|
||||
const TfLiteTensor* end;
|
||||
const TfLiteTensor* strides;
|
||||
MicroContext* micro_context;
|
||||
TfLiteTensor* input;
|
||||
TfLiteTensor* begin;
|
||||
TfLiteTensor* end;
|
||||
TfLiteTensor* strides;
|
||||
TfLiteTensor* output;
|
||||
int dims;
|
||||
};
|
||||
|
||||
@@ -83,15 +83,24 @@ TfLiteStatus SubPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpDataSub* data = static_cast<OpDataSub*>(node->user_data);
|
||||
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kSubInputTensor1);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input1 =
|
||||
micro_context->AllocateTempInputTensor(node, kSubInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kSubInputTensor2);
|
||||
TfLiteTensor* input2 =
|
||||
micro_context->AllocateTempInputTensor(node, kSubInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kSubOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kSubOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CalculateOpDataSub(context, params, input1, input2, output, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input1);
|
||||
micro_context->DeallocateTempTfLiteTensor(input2);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -364,6 +364,8 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
// Validate Tensor Inputs (dtype depends on quantization):
|
||||
// [0] = Input, {2, batch_size, input_size}
|
||||
// [1] = Weights Feature, {2, num_filters, input_size}
|
||||
@@ -371,18 +373,19 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
|
||||
// [3] = Bias (optional), {1, num_units}
|
||||
// [4] = Activation State (variable),
|
||||
// {2, batch_size, memory_size * num_filters}
|
||||
const TfLiteTensor* input = GetInput(context, node, kSvdfInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kSvdfInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kSvdfWeightsFeatureTensor);
|
||||
TfLiteTensor* weights_feature =
|
||||
micro_context->AllocateTempInputTensor(node, kSvdfWeightsFeatureTensor);
|
||||
TF_LITE_ENSURE(context, weights_feature != nullptr);
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kSvdfWeightsTimeTensor);
|
||||
TfLiteTensor* weights_time =
|
||||
micro_context->AllocateTempInputTensor(node, kSvdfWeightsTimeTensor);
|
||||
TF_LITE_ENSURE(context, weights_time != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kSvdfBiasTensor);
|
||||
const TfLiteTensor* activation_state =
|
||||
GetInput(context, node, kSvdfInputActivationStateTensor);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kSvdfBiasTensor);
|
||||
TfLiteTensor* activation_state = micro_context->AllocateTempInputTensor(
|
||||
node, kSvdfInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
|
||||
// Define input constants based on input tensor definition above:
|
||||
@@ -402,7 +405,8 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Validate Tensor Output:
|
||||
// [0] = float/int8_t, {2, batch_size, num_units}
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kSvdfOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kSvdfOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
|
||||
@@ -498,6 +502,12 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, scratch_status);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(weights_feature);
|
||||
micro_context->DeallocateTempTfLiteTensor(weights_time);
|
||||
micro_context->DeallocateTempTfLiteTensor(activation_state);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -48,11 +48,14 @@ void* TanhInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
|
||||
TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
OpData* data) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
@@ -69,6 +72,62 @@ TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
data->input_range_radius =
|
||||
CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteInt16) {
|
||||
static constexpr int kInputIntegerBits = 3;
|
||||
static constexpr int kOutputFractionalBits = 15;
|
||||
|
||||
// These operators are implemented in fixed-point arithmetic,
|
||||
// which intrinsically wants symmetric ranges (zero_point==0)
|
||||
// and power-of-two scales (power-of-two is abbreviated below as POT).
|
||||
// While more general support would be possible by means of rescaling,
|
||||
// that would add some overhead and some loss of accuracy and wouldn't
|
||||
// be used at the moment as current quantized LSTM applications are
|
||||
// happy with symmetric, power-of-two-scales quantization. So we just
|
||||
// implement that narrow case only for now.
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
|
||||
int input_scale_log2_rounded;
|
||||
bool param_scale_pot =
|
||||
CheckedLog2(input->params.scale, &input_scale_log2_rounded);
|
||||
|
||||
data->input_left_shift =
|
||||
(15 - kInputIntegerBits) + input_scale_log2_rounded;
|
||||
param_scale_pot &=
|
||||
(data->input_left_shift == 0 || data->input_left_shift == 1);
|
||||
|
||||
if (param_scale_pot) {
|
||||
data->input_multiplier = 0;
|
||||
} else {
|
||||
// Calculate multiplier to change input scale to 1/(3*4096)
|
||||
// as required by the table lookup.
|
||||
// The number 3.0 in the multiplier comes from here,
|
||||
// because the interval is [-10.7, 10.7] instead of [-8, 8].
|
||||
// So, in this scaling +/-2^17 represents +/-10.7.
|
||||
|
||||
double multiplier =
|
||||
static_cast<double>(input->params.scale) * 4096.0 * 3.0;
|
||||
data->input_left_shift = 0;
|
||||
|
||||
while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) {
|
||||
data->input_left_shift++;
|
||||
multiplier = multiplier * 2.0;
|
||||
}
|
||||
|
||||
data->input_multiplier = static_cast<int32_t>(multiplier);
|
||||
}
|
||||
|
||||
int output_scale_log2_rounded;
|
||||
TF_LITE_ENSURE(
|
||||
context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
|
||||
TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
|
||||
-kOutputFractionalBits);
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@@ -77,10 +136,15 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
return CalculateArithmeticOpData(context, node, data);
|
||||
TF_LITE_ENSURE_OK(context, CalculateArithmeticOpData(context, node, data));
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -18,18 +18,30 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kPermTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
struct TransposeContext {
|
||||
TransposeContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
input = GetInput(context, node, 0);
|
||||
perm = GetInput(context, node, 1);
|
||||
output = GetOutput(context, node, 0);
|
||||
micro_context = GetMicroContext(context);
|
||||
input = micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
perm = micro_context->AllocateTempInputTensor(node, kPermTensor);
|
||||
output = micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
}
|
||||
const TfLiteTensor* input;
|
||||
const TfLiteTensor* perm;
|
||||
~TransposeContext() {
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(perm);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
}
|
||||
MicroContext* micro_context;
|
||||
TfLiteTensor* input;
|
||||
TfLiteTensor* perm;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
@@ -60,10 +72,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TransposeContext op_context(context, node);
|
||||
|
||||
const int32_t* perm_data = GetTensorData<int32_t>(op_context.perm);
|
||||
const int size = op_context.perm->dims->data[0];
|
||||
const TfLiteEvalTensor* perm_tensor =
|
||||
tflite::micro::GetEvalInput(context, node, kPermTensor);
|
||||
const int32_t* perm_data = perm_tensor->data.i32;
|
||||
const int size = perm_tensor->dims->data[0];
|
||||
TransposeParams params;
|
||||
params.perm_count = size;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
@@ -73,24 +85,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Transpose kernel only does rearranging values not numeric evaluations
|
||||
// on each cell. It's safe to implement per size of scalar type and this
|
||||
// trick keeps the total code size in a reasonable range.
|
||||
switch (op_context.input->type) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
reference_ops::Transpose(params, GetTensorShape(op_context.input),
|
||||
GetTensorData<float>(op_context.input),
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<float>(op_context.output));
|
||||
reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
reference_ops::Transpose(params, GetTensorShape(op_context.input),
|
||||
GetTensorData<int8_t>(op_context.input),
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<int8_t>(op_context.output));
|
||||
reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s is currently not supported by Transpose. "
|
||||
"Only float32 and int8 is supported",
|
||||
TfLiteTypeGetName(op_context.input->type));
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -94,13 +94,18 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
// Note that quantized inference requires that all tensors have their
|
||||
// parameters set. This is usually done during quantized training.
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kFilterTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* bias =
|
||||
micro_context->AllocateTempInputTensor(node, kBiasTensor);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
int output_channels = filter->dims->data[kConvQuantizedDimension];
|
||||
|
||||
@@ -124,6 +129,13 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
&(data->bias_converted_buffer_index)) == kTfLiteOk);
|
||||
}
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
if (bias != nullptr) {
|
||||
micro_context->DeallocateTempTfLiteTensor(bias);
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@@ -141,11 +153,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto params =
|
||||
static_cast<const TfLiteTransposeConvParams*>(node->builtin_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TfLiteTensor* filter =
|
||||
micro_context->AllocateTempInputTensor(node, kFilterTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
|
||||
// Get height and width of the output.
|
||||
@@ -212,6 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Stride
|
||||
data->params.stride_width = params->stride_width;
|
||||
data->params.stride_height = params->stride_height;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(filter);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@@ -46,14 +46,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteVarHandleParams*>(node->builtin_data);
|
||||
|
||||
// Casting to TfliteIntArray is required since we are re-using
|
||||
// GetExecutionPlan from TfLiteContext. On TFLM this method returns a
|
||||
// MicroGraph.
|
||||
// TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
|
||||
MicroGraph* graph_info;
|
||||
context->GetExecutionPlan(context,
|
||||
reinterpret_cast<TfLiteIntArray**>(&graph_info));
|
||||
MicroResourceVariables* resources = graph_info->GetResourceVariables();
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
MicroResourceVariables* resources = graph_info.GetResourceVariables();
|
||||
if (resources == nullptr) {
|
||||
MicroPrintf(
|
||||
"VAR_HANDLE requires resource variables. Please create "
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
#include "tensorflow/lite/micro/micro_graph.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
int cond_subgraph_index;
|
||||
int body_subgraph_index;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteWhileParams*>(node->builtin_data);
|
||||
|
||||
op_data->cond_subgraph_index = params->cond_subgraph_index;
|
||||
op_data->body_subgraph_index = params->body_subgraph_index;
|
||||
|
||||
// The first input is the condition.
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
|
||||
size_t num_inputs = node->inputs->size;
|
||||
size_t num_outputs = node->outputs->size;
|
||||
|
||||
MicroGraph& graph_info = micro_context->graph();
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
op_data->cond_subgraph_index < graph_info.NumSubgraphs());
|
||||
TF_LITE_ENSURE(context,
|
||||
op_data->body_subgraph_index < graph_info.NumSubgraphs());
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, num_inputs,
|
||||
graph_info.NumSubgraphInputs(op_data->cond_subgraph_index));
|
||||
TF_LITE_ENSURE_EQ(context, num_inputs,
|
||||
graph_info.NumSubgraphInputs(op_data->body_subgraph_index));
|
||||
TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs);
|
||||
TF_LITE_ENSURE_EQ(
|
||||
context, num_outputs,
|
||||
graph_info.NumSubgraphOutputs(op_data->body_subgraph_index));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
|
||||
MicroGraph* graph_info = µ_context->graph();
|
||||
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::micro::CopyOpInputsToSubgraphInputs(
|
||||
context, node, graph_info, op_data->cond_subgraph_index,
|
||||
/*first_tensor_idx=*/0));
|
||||
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
graph_info->InvokeSubgraph(op_data->cond_subgraph_index));
|
||||
|
||||
TfLiteEvalTensor* cond_subgraph_output = graph_info->GetSubgraphOutput(
|
||||
op_data->cond_subgraph_index, /*tensor_idx=*/0);
|
||||
bool cond_value = cond_subgraph_output->data.b[0];
|
||||
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::micro::CopyOpInputsToSubgraphInputs(
|
||||
context, node, graph_info, op_data->body_subgraph_index,
|
||||
/*first_tensor_idx=*/0));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::micro::CopyOpInputsToOpOutputs(context, node));
|
||||
|
||||
while (cond_value == true) {
|
||||
// Copy output of this iteration back to the body input.
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, tflite::micro::CopyOpOutputsToSubgraphInputs(
|
||||
context, node, graph_info, op_data->body_subgraph_index));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
graph_info->InvokeSubgraph(op_data->body_subgraph_index));
|
||||
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, tflite::micro::CopySubgraphOutputsToOpOutputs(
|
||||
context, node, graph_info, op_data->body_subgraph_index));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, tflite::micro::CopyOpOutputsToSubgraphInputs(
|
||||
context, node, graph_info, op_data->cond_subgraph_index));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
graph_info->InvokeSubgraph(op_data->cond_subgraph_index));
|
||||
|
||||
cond_subgraph_output = graph_info->GetSubgraphOutput(
|
||||
op_data->cond_subgraph_index, /*tensor_idx=*/0);
|
||||
cond_value = cond_subgraph_output->data.b[0];
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_WHILE() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -25,15 +25,20 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteTensor* input =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
output->type = input->type;
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user