This commit is contained in:
jomjol
2022-08-28 19:52:51 +02:00
parent 338184712d
commit c9b7a5f84c
223 changed files with 13226 additions and 2342 deletions

View File

@@ -179,6 +179,12 @@ typedef enum {
kTfLiteBuiltinMultinomial = 149,
kTfLiteBuiltinGelu = 150,
kTfLiteBuiltinDynamicUpdateSlice = 151,
kTfLiteBuiltinRelu0To1 = 152,
kTfLiteBuiltinUnsortedSegmentProd = 153,
kTfLiteBuiltinUnsortedSegmentMax = 154,
kTfLiteBuiltinUnsortedSegmentSum = 155,
kTfLiteBuiltinAtan2 = 156,
kTfLiteBuiltinUnsortedSegmentMin = 157,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@@ -113,7 +113,13 @@ typedef struct TfLiteQuantizationParams {
} TfLiteQuantizationParams;
// --------------------------------------------------------------------------
// Opaque types used by c_api_opaque.h.
// Opaque types used by c_api.h, c_api_opaque.h and common.h.
// TfLiteOpaqueContext is an opaque version of TfLiteContext;
typedef struct TfLiteOpaqueContext TfLiteOpaqueContext;
// TfLiteOpaqueNode is an opaque version of TfLiteNode;
typedef struct TfLiteOpaqueNode TfLiteOpaqueNode;
// TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;

View File

@@ -14,7 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/c_api_types.h"
#ifdef TF_LITE_TENSORFLOW_PROFILER
#include "tensorflow/lite/tensorflow_profiler_logger.h"
#endif
#ifndef TF_LITE_STATIC_MEMORY
#include <stdlib.h>
@@ -99,7 +103,12 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
void TfLiteTensorDataFree(TfLiteTensor* t) {
if (t->allocation_type == kTfLiteDynamic ||
t->allocation_type == kTfLitePersistentRo) {
free(t->data.raw);
if (t->data.raw) {
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorDealloc(t);
#endif
free(t->data.raw);
}
}
t->data.raw = nullptr;
}
@@ -161,7 +170,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
t->dims = nullptr;
if (t->dims_signature) {
TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature);
}
t->dims_signature = nullptr;
@@ -191,16 +200,12 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
}
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
if (!src || !dst)
return kTfLiteOk;
if (src->bytes != dst->bytes)
return kTfLiteError;
if (src == dst)
return kTfLiteOk;
if (!src || !dst) return kTfLiteOk;
if (src->bytes != dst->bytes) return kTfLiteError;
if (src == dst) return kTfLiteOk;
dst->type = src->type;
if (dst->dims)
TfLiteIntArrayFree(dst->dims);
if (dst->dims) TfLiteIntArrayFree(dst->dims);
dst->dims = TfLiteIntArrayCopy(src->dims);
memcpy(dst->data.raw, src->data.raw, src->bytes);
dst->buffer_handle = src->buffer_handle;
@@ -218,8 +223,17 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
// TODO(b/145340303): Tensor data should be aligned.
if (!tensor->data.raw) {
tensor->data.raw = (char*)malloc(num_bytes);
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
#endif
} else if (num_bytes > tensor->bytes) {
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorDealloc(tensor);
#endif
tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
#endif
}
tensor->bytes = num_bytes;
}

View File

@@ -173,9 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
} \
} while (false)
#else // TF_LITE_STRIP_ERROR_STRINGS
#define UNUSED(...) (void)sizeof(#__VA_ARGS__)
#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
#define ARGS_UNUSED(...) (void)sizeof(#__VA_ARGS__)
#define TF_LITE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
#endif // TF_LITE_STRIP_ERROR_STRINGS
// Check whether value is true, and if not return kTfLiteError from
@@ -842,6 +842,12 @@ typedef struct TfLiteContext {
size_t* bytes);
} TfLiteContext;
// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
// for C API which doesn't use internal types (such as `TfLiteContext`) but only
// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each
// field is the exactly the same as with `TfLiteRegistration`.
typedef struct TfLiteRegistrationExternal TfLiteRegistrationExternal;
typedef struct TfLiteRegistration {
// Initializes the op from serialized data.
// Called only *once* for the lifetime of the op, so any one-time allocations
@@ -903,8 +909,31 @@ typedef struct TfLiteRegistration {
// Note: It is the responsibility of the registration binder to set this
// properly.
int version;
// The external version of `TfLiteRegistration`. Since we can't use internal
// types (such as `TfLiteContext`) for C API to maintain ABI stability.
// C API user will provide `TfLiteRegistrationExternal` to implement custom
// ops. We keep it inside of `TfLiteRegistration` and use it to route
// callbacks properly.
TfLiteRegistrationExternal* registration_external;
} TfLiteRegistration;
// Old version of `TfLiteRegistration` to maintain binary backward
// compatibility.
// WARNING: This structure is deprecated / not an official part of the API.
// It should be only used for binary backward compatibility.
typedef struct TfLiteRegistration_V1 {
void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
void (*free)(TfLiteContext* context, void* buffer);
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
const char* (*profiling_string)(const TfLiteContext* context,
const TfLiteNode* node);
int32_t builtin_code;
const char* custom_name;
int version;
} TfLiteRegistration_V1;
// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
// values should be 1, 2, 4, 8, ...etc.
typedef enum TfLiteDelegateFlags {

View File

@@ -493,6 +493,11 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseSquare(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_SQUARED_DIFFERENCE: {
return ParseSquaredDifference(op, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_SQUEEZE: {
return ParseSqueeze(op, error_reporter, allocator, builtin_data);
}
@@ -840,14 +845,25 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
// ok for now, since there is no call implementation either.
case BuiltinOperator_CALL:
case BuiltinOperator_COMPLEX_ABS:
case BuiltinOperator_CONCAT_EMBEDDINGS:
case BuiltinOperator_COS:
case BuiltinOperator_CUSTOM:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_DYNAMIC_UPDATE_SLICE:
case BuiltinOperator_EMBEDDING_LOOKUP:
case BuiltinOperator_EQUAL:
case BuiltinOperator_HASHTABLE_FIND:
case BuiltinOperator_HASHTABLE_IMPORT:
case BuiltinOperator_HASHTABLE_SIZE:
case BuiltinOperator_IMAG:
case BuiltinOperator_MATRIX_DIAG:
case BuiltinOperator_MATRIX_SET_DIAG:
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
case BuiltinOperator_RELU_N1_TO_1:
case BuiltinOperator_RELU_0_TO_1:
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_SELECT:
case BuiltinOperator_SELECT_V2:
case BuiltinOperator_SLICE:
@@ -855,23 +871,17 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_TRANSPOSE:
case BuiltinOperator_RANGE:
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_REVERSE_V2:
case BuiltinOperator_WHERE:
case BuiltinOperator_RANK:
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_RFFT2D:
case BuiltinOperator_IMAG:
case BuiltinOperator_REAL:
case BuiltinOperator_COMPLEX_ABS:
case BuiltinOperator_HASHTABLE_FIND:
case BuiltinOperator_HASHTABLE_IMPORT:
case BuiltinOperator_HASHTABLE_SIZE:
case BuiltinOperator_DYNAMIC_UPDATE_SLICE:
case BuiltinOperator_RFFT2D:
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_REVERSE_V2:
case BuiltinOperator_UNSORTED_SEGMENT_MAX:
case BuiltinOperator_UNSORTED_SEGMENT_MIN:
case BuiltinOperator_UNSORTED_SEGMENT_PROD:
case BuiltinOperator_UNSORTED_SEGMENT_SUM:
case BuiltinOperator_ATAN2:
case BuiltinOperator_WHERE:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;
@@ -2189,6 +2199,14 @@ TfLiteStatus ParseSquare(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseSquaredDifference(const Operator*, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseStridedSlice(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,

View File

@@ -356,6 +356,11 @@ TfLiteStatus ParseSqrt(const Operator* op, ErrorReporter* error_reporter,
TfLiteStatus ParseSquare(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSquaredDifference(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseStridedSlice(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,

View File

@@ -23,6 +23,16 @@ limitations under the License.
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate.
// This is used for cases (e.g. when using "TF Lite with Google Play Services")
// where the TF Lite runtime might be built using a newer (or older)
// version of the TF Lite sources than the app, and hence might have a
// different definition of the TfLiteDelegate type. TF Lite APIs use
// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to
// refer to a delegate defined with that potentially different version
// of the TfLiteDelegate type.
struct TfLiteOpaqueDelegateStruct;
namespace tflite {
/// Abstract interface that returns TfLiteRegistrations given op codes or custom
@@ -37,8 +47,10 @@ class OpResolver {
virtual const TfLiteRegistration* FindOp(const char* op,
int version) const = 0;
// Represents a sequence of delegates.
using TfLiteDelegatePtrVector =
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
// Returns optional delegates for resolving and handling ops in the flatbuffer
// model. This may be used in addition to the standard TfLiteRegistration
// lookup for graph resolution.
@@ -47,16 +59,55 @@ class OpResolver {
return {};
}
// Represent a function that creates a TfLite delegate instance.
// Represents a function that creates a TfLite delegate instance.
using TfLiteDelegateCreator =
std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
int /*num_threads*/)>;
// Represents a sequence of delegate creator functions.
using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
// Returns a vector of delegate creators to create optional delegates for
// resolving and handling ops in the flatbuffer model. This may be used in
// addition to the standard TfLiteRegistration lookup for graph resolution.
//
// Note that this method is not used (will not be called) if you are using
// TF Lite in Google Play Services; the GetOpaqueDelegateCreators method
// (see below) is used for that case.
virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
// TODO(b/202712825): it would be nice if we could avoid the need for separate
// "opaque" types & methods for use only with TF Lite in Google Play Services.
// Represents an opaque delegate instance.
// WARNING: Experimental interface, subject to change.
using TfLiteOpaqueDelegatePtr =
std::unique_ptr<TfLiteOpaqueDelegateStruct,
void (*)(TfLiteOpaqueDelegateStruct*)>;
// Represents a function that creates an opaque delegate instance.
// WARNING: Experimental interface, subject to change.
using TfLiteOpaqueDelegateCreator =
std::function<TfLiteOpaqueDelegatePtr(int /*num_threads*/)>;
// Represents a sequence of opaque delegate creator functions.
// WARNING: Experimental interface, subject to change.
using TfLiteOpaqueDelegateCreators = std::vector<TfLiteOpaqueDelegateCreator>;
// Returns a vector of opaque delegate creators to create optional opaque
// delegates for resolving and handling ops in the flatbuffer model. This may
// be used in addition to the standard TfLiteRegistration lookup for graph
// resolution.
//
// Note that this method will be called only if you are using TF Lite in
// Google Play Services; if you are using regular TF Lite, GetDelegateCreators
// (see above) is used instead.
//
// WARNING: Experimental interface, subject to change.
virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const {
return {};
}
virtual ~OpResolver() {}
private:

View File

@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/microfrontend/lib/fft.h"
#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
#include <string.h>
#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
void FftCompute(struct FftState* state, const int16_t* input,
int input_scale_shift) {
@@ -37,9 +37,9 @@ void FftCompute(struct FftState* state, const int16_t* input,
// Apply the FFT.
kissfft_fixed16::kiss_fftr(
reinterpret_cast<kissfft_fixed16::kiss_fftr_cfg>(state->scratch),
state->input,
reinterpret_cast<kissfft_fixed16::kiss_fft_cpx*>(state->output));
reinterpret_cast<kissfft_fixed16::kiss_fftr_cfg>(state->scratch),
state->input,
reinterpret_cast<kissfft_fixed16::kiss_fft_cpx*>(state->output));
}
void FftInit(struct FftState* state) {

View File

@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h"
#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
#include <stdio.h>
#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
int FftPopulateState(struct FftState* state, size_t input_size) {
state->input_size = input_size;
state->fft_size = 1;

View File

@@ -31,4 +31,3 @@ namespace kissfft_fixed16 {
#undef KISS_FFT_H
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_KISS_FFT_INT16_H_

View File

@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
#include <algorithm>
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK

View File

@@ -86,6 +86,16 @@ using int32 = std::int32_t;
using uint32 = std::uint32_t;
#endif // !defined(TF_LITE_STATIC_MEMORY)
// Allow for cross-compiler usage of function signatures - currently used for
// specifying named RUY profiler regions in templated methods.
#if defined(_MSC_VER)
#define TFLITE_PRETTY_FUNCTION __FUNCSIG__
#elif defined(__GNUC__)
#define TFLITE_PRETTY_FUNCTION __PRETTY_FUNCTION__
#else
#define TFLITE_PRETTY_FUNCTION __func__
#endif
// TFLITE_DEPRECATED()
//
// Duplicated from absl/base/macros.h to avoid pulling in that library.

View File

@@ -324,7 +324,7 @@ void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
// - n_input: the size for input and output.
// - output: the 16 bit output
// The input is in Qm.15-m format and the output is in Q0.15 format.
void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
// Apply Tanh to a quantized vector. Tbe internal calculation is in float.

View File

@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
#include <algorithm>
#include <type_traits>
#include "fixedpoint/fixedpoint.h"

View File

@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"

View File

@@ -0,0 +1,247 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline void DivCheckArithmeticParams(const ArithmeticParams& params) {
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
// Input offset is negative input zero point. Activation tensors are
// asymmetric quantized so they span the full int8 range.
constexpr int32_t max_value =
static_cast<int32_t>(std::numeric_limits<T>::max());
TFLITE_DCHECK_GE(params.input1_offset, -max_value);
TFLITE_DCHECK_LE(params.input1_offset, max_value);
TFLITE_DCHECK_GE(params.input2_offset, -max_value);
TFLITE_DCHECK_LE(params.input2_offset, max_value);
TFLITE_DCHECK_GE(params.output_offset, -max_value);
TFLITE_DCHECK_LE(params.output_offset, max_value);
}
// Element-wise div that can often be used for inner loop of broadcast Div as
// well as the non-broadcast Div.
template <typename T>
inline void DivElementwise(int size, const ArithmeticParams& params,
const T* input1_data, const T* input2_data,
T* output_data) {
DivCheckArithmeticParams<T>(params);
for (int i = 0; i < size; ++i) {
int32_t input1_val = params.input1_offset + input1_data[i];
int32_t input2_val = params.input2_offset + input2_data[i];
TFLITE_DCHECK_NE(input2_val, 0);
if (input2_val < 0) {
// Invert signs to avoid a negative input2_val as input2_inv needs to be
// positive to be used as multiplier of MultiplyByQuantizedMultiplier.
input1_val = -input1_val;
input2_val = -input2_val;
}
int recip_shift;
const int32_t input2_inv = GetReciprocal(input2_val, 31, &recip_shift);
const int headroom = CountLeadingSignBits(input1_val);
const int32_t unscaled_quotient =
MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv,
headroom);
const int total_shift = params.output_shift - recip_shift - headroom;
const int32_t unclamped_result =
params.output_offset +
MultiplyByQuantizedMultiplierSmallerThanOneExp(
unscaled_quotient, params.output_multiplier, total_shift);
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, unclamped_result));
output_data[i] = static_cast<T>(clamped_output);
}
}
inline void Div(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const uint8_t* input1_data,
const RuntimeShape& input2_shape, const uint8_t* input2_data,
const RuntimeShape& output_shape, uint8_t* output_data) {
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
DivElementwise(flat_size, params, input1_data, input2_data, output_data);
}
inline void Div(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int8_t* input1_data,
const RuntimeShape& input2_shape, const int8_t* input2_data,
const RuntimeShape& output_shape, int8_t* output_data) {
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
DivElementwise(flat_size, params, input1_data, input2_data, output_data);
}
template <typename T, int N = 5>
inline void BroadcastDivSlowQuantized(
const ArithmeticParams& params, const RuntimeShape& unextended_input1_shape,
const T* input1_data, const RuntimeShape& unextended_input2_shape,
const T* input2_data, const RuntimeShape& unextended_output_shape,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
NdArrayDesc<N> desc1;
NdArrayDesc<N> desc2;
NdArrayDesc<N> output_desc;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
&output_desc);
DivCheckArithmeticParams<T>(params);
auto div_func = [&](int indexes[N]) {
int32_t input1_val =
params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
int32_t input2_val =
params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
TFLITE_DCHECK_NE(input2_val, 0);
if (input2_val < 0) {
// Invert signs to avoid a negative input2_val as input2_inv needs to be
// positive to be used as multiplier of MultiplyByQuantizedMultiplier.
input1_val = -input1_val;
input2_val = -input2_val;
}
int recip_shift;
const int32_t input2_inv = GetReciprocal(input2_val, 31, &recip_shift);
const int headroom = CountLeadingSignBits(input1_val);
const int32_t unscaled_quotient =
MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv,
headroom);
const int total_shift = params.output_shift - recip_shift - headroom;
const int32_t unclamped_result =
params.output_offset +
MultiplyByQuantizedMultiplierSmallerThanOneExp(
unscaled_quotient, params.output_multiplier, total_shift);
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, unclamped_result));
output_data[SubscriptToIndex(output_desc, indexes)] =
static_cast<T>(clamped_output);
};
NDOpsHelper<N>(output_desc, div_func);
}
template <int N = 5>
inline void BroadcastDivSlow(const ArithmeticParams& params,
const RuntimeShape& unextended_input1_shape,
const uint8_t* input1_data,
const RuntimeShape& unextended_input2_shape,
const uint8_t* input2_data,
const RuntimeShape& unextended_output_shape,
uint8_t* output_data) {
BroadcastDivSlowQuantized<uint8_t, N>(
params, unextended_input1_shape, input1_data, unextended_input2_shape,
input2_data, unextended_output_shape, output_data);
}
template <int N = 5>
inline void BroadcastDivSlow(const ArithmeticParams& params,
const RuntimeShape& unextended_input1_shape,
const int8_t* input1_data,
const RuntimeShape& unextended_input2_shape,
const int8_t* input2_data,
const RuntimeShape& unextended_output_shape,
int8_t* output_data) {
BroadcastDivSlowQuantized<int8_t, N>(
params, unextended_input1_shape, input1_data, unextended_input2_shape,
input2_data, unextended_output_shape, output_data);
}
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
template <typename T, int N = 5>
void BroadcastDivSlow(const ArithmeticParams& params,
const RuntimeShape& unextended_input1_shape,
const T* input1_data,
const RuntimeShape& unextended_input2_shape,
const T* input2_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
T output_activation_min;
T output_activation_max;
GetActivationParams(params, &output_activation_min, &output_activation_max);
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
NdArrayDesc<N> desc1;
NdArrayDesc<N> desc2;
NdArrayDesc<N> output_desc;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
&output_desc);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
// trailing dimension changing most rapidly (channels has the smallest
// stride, typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
auto div_func = [&](int indexes[N]) {
output_data[SubscriptToIndex(output_desc, indexes)] =
ActivationFunctionWithMinMax(
input1_data[SubscriptToIndex(desc1, indexes)] /
input2_data[SubscriptToIndex(desc2, indexes)],
output_activation_min, output_activation_max);
};
NDOpsHelper<N>(output_desc, div_func);
}
template <typename T>
inline void Div(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
T output_activation_min;
T output_activation_max;
GetActivationParams(params, &output_activation_min, &output_activation_max);
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
output_activation_max);
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
@@ -23,9 +25,9 @@ namespace tflite {
namespace reference_ops {
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
int32_t result = static_cast<int32_t>(value) * (1 << amount);
result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
int64_t result = static_cast<int64_t>(value) * (1 << amount);
result = std::min<int64_t>(result, std::numeric_limits<int16_t>::max());
result = std::max<int64_t>(result, std::numeric_limits<int16_t>::min());
return result;
}

View File

@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_
#include <algorithm>
#include <limits>
#include "tensorflow/lite/kernels/internal/common.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,11 +15,101 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_integer_ops {
// For per-channel functions, since it is defined in quantization spec that
// weights are symmetric
// (https://www.tensorflow.org/lite/performance/quantization_spec#symmetric_vs_asymmetric),
// zero_point (params.weights_offset) is always 0.
// However, for per-tensor functions, params.weights_offset is still applied for
// backward compatibility.
inline void FullyConnectedPerChannel(
const FullyConnectedParams& params, const int32_t* output_multiplier,
const int* output_shift, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
const int32_t* bias_data, const RuntimeShape& output_shape,
int8_t* output_data) {
const int32_t input_offset = params.input_offset;
const int32_t output_offset = params.output_offset;
const int32_t output_activation_min = params.quantized_activation_min;
const int32_t output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int filter_dim_count = filter_shape.DimensionsCount();
const int batches = output_shape.Dims(0);
const int output_depth = output_shape.Dims(1);
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
for (int b = 0; b < batches; ++b) {
for (int out_c = 0; out_c < output_depth; ++out_c) {
int32_t acc = 0;
for (int d = 0; d < accum_depth; ++d) {
int32_t input_val = input_data[b * accum_depth + d];
int32_t filter_val = filter_data[out_c * accum_depth + d];
acc += filter_val * (input_val + input_offset);
}
if (bias_data) {
acc += bias_data[out_c];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier[out_c],
output_shift[out_c]);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
output_data[out_c + output_depth * b] = static_cast<int8_t>(acc);
}
}
}
template <typename AccumScalar>
inline void FullyConnectedPerChannel(
const FullyConnectedParams& params, const int32_t* output_multiplier,
const int* output_shift, const RuntimeShape& input_shape,
const int16_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
const AccumScalar* bias_data, const RuntimeShape& output_shape,
int16_t* output_data) {
const int32_t output_activation_min = params.quantized_activation_min;
const int32_t output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int filter_dim_count = filter_shape.DimensionsCount();
const int output_dim_count = output_shape.DimensionsCount();
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
const int output_depth = output_shape.Dims(output_dim_count - 1);
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
for (int b = 0; b < batches; ++b) {
for (int out_c = 0; out_c < output_depth; ++out_c) {
AccumScalar acc = 0;
for (int d = 0; d < accum_depth; ++d) {
int32_t input_val = input_data[b * accum_depth + d];
int32_t filter_val = filter_data[out_c * accum_depth + d];
acc += filter_val * input_val;
}
if (bias_data) {
acc += bias_data[out_c];
}
int32_t acc_scaled = MultiplyByQuantizedMultiplier(
acc, output_multiplier[out_c], output_shift[out_c]);
acc_scaled = std::max(acc_scaled, output_activation_min);
acc_scaled = std::min(acc_scaled, output_activation_max);
output_data[out_c + output_depth * b] = static_cast<int16_t>(acc_scaled);
}
}
}
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& filter_shape,

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,7 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_
#include <algorithm>
#include <limits>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
#include <algorithm>
#include "fixedpoint/fixedpoint.h"
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"

View File

@@ -15,7 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_POOLING_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_POOLING_H_
#include <algorithm>
#include <limits>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_
#include <algorithm>
#include <limits>
#include "fixedpoint/fixedpoint.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_POOLING_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_POOLING_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"

View File

@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
#include <algorithm>
#include <cmath>
#include "tensorflow/lite/kernels/internal/cppmath.h"

View File

@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
#include <algorithm>
#include <limits>
#include "fixedpoint/fixedpoint.h"

View File

@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
#include <algorithm>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"

View File

@@ -27,6 +27,11 @@ class RuntimeShape {
public:
RuntimeShape& operator=(RuntimeShape const&) = delete;
// RuntimeShape in TFLM supports up to 5 dimensions.
// The name kMaxSmallSize comes from the same file of the upstream
// tensorflow lite repo and need to be kept the same for max reuse.
static constexpr int kMaxSmallSize = 5;
RuntimeShape() : size_(0) {}
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
@@ -104,11 +109,9 @@ class RuntimeShape {
sizeof(int32_t) * shape.DimensionsCount());
}
// A maximum of 4 dimensions are supported on TFLM.
static constexpr int kMaxSize = 5;
int32_t size_;
union {
int32_t dims_[kMaxSize];
int32_t dims_[kMaxSmallSize];
};
};

View File

@@ -974,11 +974,11 @@ struct StridedSliceParams {
int8_t strides_count;
int32_t strides[5];
int16_t begin_mask;
int16_t ellipsis_mask;
int16_t end_mask;
int16_t new_axis_mask;
int16_t shrink_axis_mask;
uint16_t begin_mask;
uint16_t ellipsis_mask;
uint16_t end_mask;
uint16_t new_axis_mask;
uint16_t shrink_axis_mask;
};
struct TanhParams {

View File

@@ -177,6 +177,14 @@ inline int64_t NumElements(const TfLiteTensor* t) {
return NumElements(t->dims);
}
inline int64_t NumElements(const int* dims, int num_dims) {
int64_t count = 1;
for (int i = 0; i < num_dims; ++i) {
count *= dims[i];
}
return count;
}
// Determines whether tensor is constant.
// TODO(b/138199592): Introduce new query which checks for constant OR
// persistent-read-only, which would be useful for most tensor kernels that
@@ -308,7 +316,7 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const TfLiteTensor* input3,
TfLiteIntArray** output_shape);
// Return the size of given type in bytes. Return 0 in in case of string.
// Return the size of given type in bytes. Return 0 in case of string.
int TfLiteTypeGetSize(TfLiteType type);
// Whether the current platform is mobile (Android or iOS).

View File

@@ -43,6 +43,7 @@ AllOpsResolver::AllOpsResolver() {
AddDepthwiseConv2D();
AddDequantize();
AddDetectionPostprocess();
AddDiv();
AddElu();
AddEqual();
AddEthosU();
@@ -104,6 +105,7 @@ AllOpsResolver::AllOpsResolver() {
AddSqueeze();
AddStridedSlice();
AddSub();
AddSum();
AddSvdf();
AddTanh();
AddTranspose();

View File

@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
#include <cstddef>
#include <cstdint>
@@ -97,4 +97,4 @@ class INonPersistentBufferAllocator {
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_

View File

@@ -0,0 +1,170 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
NonPersistentArenaBufferAllocator::NonPersistentArenaBufferAllocator(
uint8_t* buffer, size_t buffer_size)
: buffer_head_(buffer),
buffer_tail_(buffer + buffer_size),
head_temp_(buffer),
next_temp_(buffer) {}
NonPersistentArenaBufferAllocator::~NonPersistentArenaBufferAllocator() {}
// Allocates a temporary buffer. This buffer is not resizable.
uint8_t* NonPersistentArenaBufferAllocator::AllocateTemp(size_t size,
size_t alignment) {
uint8_t* const aligned_result = AlignPointerUp(next_temp_, alignment);
const size_t available_memory = buffer_tail_ - aligned_result;
if (available_memory < size) {
MicroPrintf(
"Failed to allocate temp memory. Requested: %u, "
"available %u, missing: %u",
size, available_memory, size - available_memory);
return nullptr;
}
next_temp_ = aligned_result + size;
temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(aligned_result);
temp_buffer_count_++;
return aligned_result;
}
// Signals that a temporary buffer is no longer needed.
void NonPersistentArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(temp_buf);
temp_buffer_count_--;
}
// Returns true if all temporary buffers are already deallocated.
bool NonPersistentArenaBufferAllocator::IsAllTempDeallocated() {
if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
MicroPrintf(
"Number of allocated temp buffers: %d. Checksum passing status: %d",
temp_buffer_count_, !temp_buffer_ptr_check_sum_);
return false;
}
return true;
}
// Signals that all temporary allocations can be reclaimed. TFLM calls this
// API when it knows that all temporary buffers that it requested has been
// deallocated. The goal of API is to facilitate implementations of
// INonPersistentBufferAllocator can reuse buffer with some reasonable
// complexity.
TfLiteStatus NonPersistentArenaBufferAllocator::ResetTempAllocations() {
if (!IsAllTempDeallocated()) {
MicroPrintf(
"All temp buffers must be freed before calling ResetTempAllocations()");
return kTfLiteError;
}
next_temp_ = head_temp_;
return kTfLiteOk;
}
// Returns a buffer that is resizable viable ResizeBuffer().
uint8_t* NonPersistentArenaBufferAllocator::AllocateResizableBuffer(
size_t size, size_t alignment) {
// Only supports one resizable buffer, which starts at the buffer head.
uint8_t* expected_resizable_buf = AlignPointerUp(buffer_head_, alignment);
if (resizable_buffer_allocated_) {
MicroPrintf(
"Cannot allocate a new resizable buffer when one is already allocated");
return nullptr;
}
if (ResizeBuffer(expected_resizable_buf, size, alignment) == kTfLiteOk) {
resizable_buffer_allocated_ = true;
return expected_resizable_buf;
}
return nullptr;
}
// Resizes a buffer that is previously returned by the AllocateResizableBuffer.
// Note that ResizeBuffer(old_resizable_buf, 0, 1) effectively deallocates
// a previous allocated resizable buffer.
TfLiteStatus NonPersistentArenaBufferAllocator::ResizeBuffer(
uint8_t* resizable_buf, size_t size, size_t alignment) {
// Only supports one resizable buffer, which starts at the buffer head.
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
if (resizable_buf != expect_resizable_buf) {
MicroPrintf("Internal error: buffer is not resizable");
return kTfLiteError;
}
if (head_temp_ != next_temp_) {
MicroPrintf("ResetTempAllocations() is not called before ResizeBuffer().");
return kTfLiteError;
}
const size_t available_memory = buffer_tail_ - expect_resizable_buf;
if (available_memory < size) {
MicroPrintf(
"Failed to resize buffer. Requested: %u, available %u, missing: %u",
size, available_memory, size - available_memory);
return kTfLiteError;
}
head_temp_ = expect_resizable_buf + size;
next_temp_ = head_temp_;
return kTfLiteOk;
}
// Frees up the memory occupied by the resizable buffer.
TfLiteStatus NonPersistentArenaBufferAllocator::DeallocateResizableBuffer(
uint8_t* resizable_buf) {
TfLiteStatus status = ResizeBuffer(resizable_buf, 0, 1);
if (status == kTfLiteOk) {
resizable_buffer_allocated_ = false;
}
return status;
}
// Returns a pointer pointing to the start of the overlay memory, which is
// used for activation tensors and scratch buffers by kernels at Invoke stage.
uint8_t* NonPersistentArenaBufferAllocator::GetOverlayMemoryAddress() const {
return buffer_head_;
}
// Reserves the size of the overlay memory. This overlay is reserved for the
// kernels at Invoke stage. This is referred to as the overlay because before
// Invoket state, the same memory can be used for temp buffers. The layout of
// the memory is planned by the memory planner separately at Invoke stage.
TfLiteStatus
NonPersistentArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
size_t size, size_t alignment) {
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
return ResizeBuffer(expect_resizable_buf, size, alignment);
}
// Returns the size of non-persistent buffer in use.
size_t NonPersistentArenaBufferAllocator::GetNonPersistentUsedBytes() const {
return (next_temp_ - buffer_head_);
}
// Returns the number of bytes available with a given alignment. This number
// takes in account any temporary allocations.
size_t NonPersistentArenaBufferAllocator::GetAvailableMemory(
size_t alignment) const {
uint8_t* const aligned_temp = AlignPointerUp(next_temp_, alignment);
uint8_t* const aligned_tail = AlignPointerDown(buffer_tail_, alignment);
return aligned_tail - aligned_temp;
}
} // namespace tflite

View File

@@ -0,0 +1,105 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
#include <cstddef>
#include <cstdint>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
#include "tensorflow/lite/micro/compatibility.h"
namespace tflite {
// Implement INonPersistentBufferAllocator on an arena that is dedicated for
// non-persistent buffers.
class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator {
public:
NonPersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
virtual ~NonPersistentArenaBufferAllocator();
// Allocates a temporary buffer. This buffer is not resizable.
uint8_t* AllocateTemp(size_t size, size_t alignment) override;
// Signals that a temporary buffer is no longer needed.
void DeallocateTemp(uint8_t* buf) override;
// Returns true if all temporary buffers are already deallocated.
bool IsAllTempDeallocated() override;
// Signals that all temporary allocations can be reclaimed. TFLM calls this
// API when it knows that all temporary buffers that it requested has been
// deallocated.
TfLiteStatus ResetTempAllocations() override;
// Returns a buffer that is resizable viable ResizeBuffer().
uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) override;
// Resizes a buffer that is previously returned by the
// AllocateResizableBuffer.
TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size,
size_t alignment) override;
// Frees up the memory occupied by the resizable buffer.
TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) override;
// Returns a pointer pointing to the start of the overlay memory, which is
// used for activation tensors and scratch buffers by kernels at Invoke stage.
uint8_t* GetOverlayMemoryAddress() const override;
// Reserves the size of the overlay memory. This overlay is reserved for the
// kernels at Invoke stage. This is referred to as the overlay because before
// Invoket state, the same memory can be used for temp buffers. The layout of
// the memory is planned by the memory planner separately at Invoke stage.
TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size,
size_t alignment) override;
// Returns the size of non-persistent buffer in use.
size_t GetNonPersistentUsedBytes() const override;
// Returns the number of bytes available with a given alignment. This number
// takes in account any temporary allocations.
size_t GetAvailableMemory(size_t alignment) const override;
TF_LITE_REMOVE_VIRTUAL_DELETE
private:
// The memory arena that this allocator manages.
uint8_t* const buffer_head_;
uint8_t* const buffer_tail_;
// The whole region is split into two parts:
// buffer_head_ to head_temp_ - 1 belongs to the only resizable buffer.
// head_temp_ to buffer_tail_ can be used for (non-resizable) temp buffers.
uint8_t* head_temp_;
// next_temp_ points to the next available temp buffer allocation address and
// its range is between head_temp_ and buffer_tail_
uint8_t* next_temp_;
// XOR Check sum for outstanding temp buffers.
// If all temp buffers are deallocated OR no temp buffers are allocated,
// temp_buffer_ptr_check_sum_ == nullptr.
intptr_t temp_buffer_ptr_check_sum_ = 0;
// Count of outstanding temp buffers.
int temp_buffer_count_ = 0;
bool resizable_buffer_allocated_ = false;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_

View File

@@ -0,0 +1,52 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
PersistentArenaBufferAllocator::PersistentArenaBufferAllocator(
uint8_t* buffer, size_t buffer_size)
: buffer_head_(buffer),
buffer_tail_(buffer + buffer_size),
tail_temp_(buffer_tail_) {}
PersistentArenaBufferAllocator::~PersistentArenaBufferAllocator() {}
uint8_t* PersistentArenaBufferAllocator::AllocatePersistentBuffer(
size_t size, size_t alignment) {
uint8_t* const aligned_result =
AlignPointerDown(tail_temp_ - size, alignment);
if (aligned_result < buffer_head_) {
#ifndef TF_LITE_STRIP_ERROR_STRINGS
const size_t missing_memory = buffer_head_ - aligned_result;
MicroPrintf(
"Failed to allocate tail memory. Requested: %u, "
"available %u, missing: %u",
size, size - missing_memory, missing_memory);
#endif
return nullptr;
}
tail_temp_ = aligned_result;
return aligned_result;
}
size_t PersistentArenaBufferAllocator::GetPersistentUsedBytes() const {
return buffer_tail_ - tail_temp_;
}
} // namespace tflite

View File

@@ -0,0 +1,59 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
#include <cstddef>
#include <cstdint>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
#include "tensorflow/lite/micro/compatibility.h"
namespace tflite {
// PersistentArenaBufferAllocator is an implementatation of
// IPersistentBufferAllocator interface on an arena that is dedicated for
// persistent buffers.
class PersistentArenaBufferAllocator : public IPersistentBufferAllocator {
public:
PersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
virtual ~PersistentArenaBufferAllocator();
// Allocates persistent memory. The persistent buffer is never freed.
// Returns nullptr if errors occured.
uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) override;
// Returns the size of all persistent allocations in bytes.
size_t GetPersistentUsedBytes() const override;
TF_LITE_REMOVE_VIRTUAL_DELETE
private:
// The memory arena that this allocator manages.
uint8_t* const buffer_head_;
uint8_t* const buffer_tail_;
// The whole region is split into two parts:
// tail_temp_ to buffer_tail_ contains allocated buffers;
// buffer_head_ to tail_temp_ - 1 belongs to still available spaces.
// So in essence, the allocated region grows from the bottom and emulates
// SingleArenaBufferAllocator's persistent part.
uint8_t* tail_temp_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_

View File

@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/recording_simple_memory_allocator.h"
#include "tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h"
#include <new>
@@ -21,47 +21,49 @@ limitations under the License.
namespace tflite {
RecordingSimpleMemoryAllocator::RecordingSimpleMemoryAllocator(
RecordingSingleArenaBufferAllocator::RecordingSingleArenaBufferAllocator(
ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size)
: SimpleMemoryAllocator(error_reporter, buffer_head, buffer_size),
: SingleArenaBufferAllocator(error_reporter, buffer_head, buffer_size),
requested_head_bytes_(0),
requested_tail_bytes_(0),
used_bytes_(0),
alloc_count_(0) {}
RecordingSimpleMemoryAllocator::~RecordingSimpleMemoryAllocator() {}
RecordingSingleArenaBufferAllocator::~RecordingSingleArenaBufferAllocator() {}
RecordingSimpleMemoryAllocator* RecordingSimpleMemoryAllocator::Create(
ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size) {
RecordingSingleArenaBufferAllocator*
RecordingSingleArenaBufferAllocator::Create(ErrorReporter* error_reporter,
uint8_t* buffer_head,
size_t buffer_size) {
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(buffer_head != nullptr);
RecordingSimpleMemoryAllocator tmp =
RecordingSimpleMemoryAllocator(error_reporter, buffer_head, buffer_size);
RecordingSingleArenaBufferAllocator tmp = RecordingSingleArenaBufferAllocator(
error_reporter, buffer_head, buffer_size);
uint8_t* allocator_buffer =
tmp.AllocatePersistentBuffer(sizeof(RecordingSimpleMemoryAllocator),
alignof(RecordingSimpleMemoryAllocator));
uint8_t* allocator_buffer = tmp.AllocatePersistentBuffer(
sizeof(RecordingSingleArenaBufferAllocator),
alignof(RecordingSingleArenaBufferAllocator));
// Use the default copy constructor to populate internal states.
return new (allocator_buffer) RecordingSimpleMemoryAllocator(tmp);
return new (allocator_buffer) RecordingSingleArenaBufferAllocator(tmp);
}
size_t RecordingSimpleMemoryAllocator::GetRequestedBytes() const {
size_t RecordingSingleArenaBufferAllocator::GetRequestedBytes() const {
return requested_head_bytes_ + requested_tail_bytes_;
}
size_t RecordingSimpleMemoryAllocator::GetUsedBytes() const {
size_t RecordingSingleArenaBufferAllocator::GetUsedBytes() const {
return used_bytes_;
}
size_t RecordingSimpleMemoryAllocator::GetAllocatedCount() const {
size_t RecordingSingleArenaBufferAllocator::GetAllocatedCount() const {
return alloc_count_;
}
TfLiteStatus RecordingSimpleMemoryAllocator::ResizeBuffer(
TfLiteStatus RecordingSingleArenaBufferAllocator::ResizeBuffer(
uint8_t* resizable_buf, size_t size, size_t alignment) {
const uint8_t* previous_head = head();
TfLiteStatus status =
SimpleMemoryAllocator::ResizeBuffer(resizable_buf, size, alignment);
SingleArenaBufferAllocator::ResizeBuffer(resizable_buf, size, alignment);
if (status == kTfLiteOk) {
used_bytes_ += head() - previous_head;
requested_head_bytes_ = size;
@@ -69,11 +71,11 @@ TfLiteStatus RecordingSimpleMemoryAllocator::ResizeBuffer(
return status;
}
uint8_t* RecordingSimpleMemoryAllocator::AllocatePersistentBuffer(
uint8_t* RecordingSingleArenaBufferAllocator::AllocatePersistentBuffer(
size_t size, size_t alignment) {
const uint8_t* previous_tail = tail();
uint8_t* result =
SimpleMemoryAllocator::AllocatePersistentBuffer(size, alignment);
SingleArenaBufferAllocator::AllocatePersistentBuffer(size, alignment);
if (result != nullptr) {
used_bytes_ += previous_tail - tail();
requested_tail_bytes_ += size;

View File

@@ -13,28 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
namespace tflite {
// Utility class used to log allocations of a SimpleMemoryAllocator. Should only
// be used in debug/evaluation settings or unit tests to evaluate allocation
// usage.
class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
// Utility class used to log allocations of a SingleArenaBufferAllocator. Should
// only be used in debug/evaluation settings or unit tests to evaluate
// allocation usage.
class RecordingSingleArenaBufferAllocator : public SingleArenaBufferAllocator {
public:
RecordingSimpleMemoryAllocator(ErrorReporter* error_reporter,
uint8_t* buffer_head, size_t buffer_size);
RecordingSingleArenaBufferAllocator(ErrorReporter* error_reporter,
uint8_t* buffer_head, size_t buffer_size);
// TODO(b/157615197): Cleanup constructors/destructor and use factory
// functions.
~RecordingSimpleMemoryAllocator() override;
~RecordingSingleArenaBufferAllocator() override;
static RecordingSimpleMemoryAllocator* Create(ErrorReporter* error_reporter,
uint8_t* buffer_head,
size_t buffer_size);
static RecordingSingleArenaBufferAllocator* Create(
ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size);
// Returns the number of bytes requested from the head or tail.
size_t GetRequestedBytes() const;
@@ -62,4 +61,4 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SINGLE_ARENA_BUFFER_ALLOCATOR_H_

View File

@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/simple_memory_allocator.h"
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include <cstddef>
#include <cstdint>
@@ -29,42 +29,45 @@ limitations under the License.
namespace tflite {
SimpleMemoryAllocator::SimpleMemoryAllocator(ErrorReporter* error_reporter,
uint8_t* buffer_head,
uint8_t* buffer_tail)
: error_reporter_(error_reporter),
SingleArenaBufferAllocator::SingleArenaBufferAllocator(
ErrorReporter* error_reporter, uint8_t* buffer_head, uint8_t* buffer_tail)
:
#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
error_reporter_(error_reporter),
#endif
buffer_head_(buffer_head),
buffer_tail_(buffer_tail),
head_(buffer_head),
tail_(buffer_tail),
temp_(buffer_head_) {}
temp_(buffer_head_) {
}
SimpleMemoryAllocator::SimpleMemoryAllocator(ErrorReporter* error_reporter,
uint8_t* buffer,
size_t buffer_size)
: SimpleMemoryAllocator(error_reporter, buffer, buffer + buffer_size) {}
SingleArenaBufferAllocator::SingleArenaBufferAllocator(
ErrorReporter* error_reporter, uint8_t* buffer, size_t buffer_size)
: SingleArenaBufferAllocator(error_reporter, buffer, buffer + buffer_size) {
}
/* static */
SimpleMemoryAllocator* SimpleMemoryAllocator::Create(
SingleArenaBufferAllocator* SingleArenaBufferAllocator::Create(
ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size) {
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(buffer_head != nullptr);
SimpleMemoryAllocator tmp =
SimpleMemoryAllocator(error_reporter, buffer_head, buffer_size);
SingleArenaBufferAllocator tmp =
SingleArenaBufferAllocator(error_reporter, buffer_head, buffer_size);
// Allocate enough bytes from the buffer to create a SimpleMemoryAllocator.
// The new instance will use the current adjusted tail buffer from the tmp
// allocator instance.
// Allocate enough bytes from the buffer to create a
// SingleArenaBufferAllocator. The new instance will use the current adjusted
// tail buffer from the tmp allocator instance.
uint8_t* allocator_buffer = tmp.AllocatePersistentBuffer(
sizeof(SimpleMemoryAllocator), alignof(SimpleMemoryAllocator));
sizeof(SingleArenaBufferAllocator), alignof(SingleArenaBufferAllocator));
// Use the default copy constructor to populate internal states.
return new (allocator_buffer) SimpleMemoryAllocator(tmp);
return new (allocator_buffer) SingleArenaBufferAllocator(tmp);
}
SimpleMemoryAllocator::~SimpleMemoryAllocator() {}
SingleArenaBufferAllocator::~SingleArenaBufferAllocator() {}
uint8_t* SimpleMemoryAllocator::AllocateResizableBuffer(size_t size,
size_t alignment) {
uint8_t* SingleArenaBufferAllocator::AllocateResizableBuffer(size_t size,
size_t alignment) {
// Only supports one resizable buffer, which starts at the buffer head.
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
if (ResizeBuffer(expect_resizable_buf, size, alignment) == kTfLiteOk) {
@@ -73,20 +76,20 @@ uint8_t* SimpleMemoryAllocator::AllocateResizableBuffer(size_t size,
return nullptr;
}
TfLiteStatus SimpleMemoryAllocator::DeallocateResizableBuffer(
TfLiteStatus SingleArenaBufferAllocator::DeallocateResizableBuffer(
uint8_t* resizable_buf) {
return ResizeBuffer(resizable_buf, 0, 1);
}
TfLiteStatus SimpleMemoryAllocator::ReserveNonPersistentOverlayMemory(
TfLiteStatus SingleArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
size_t size, size_t alignment) {
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
return ResizeBuffer(expect_resizable_buf, size, alignment);
}
TfLiteStatus SimpleMemoryAllocator::ResizeBuffer(uint8_t* resizable_buf,
size_t size,
size_t alignment) {
TfLiteStatus SingleArenaBufferAllocator::ResizeBuffer(uint8_t* resizable_buf,
size_t size,
size_t alignment) {
// Only supports one resizable buffer, which starts at the buffer head.
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
if (head_ != temp_ || resizable_buf != expect_resizable_buf) {
@@ -112,8 +115,8 @@ TfLiteStatus SimpleMemoryAllocator::ResizeBuffer(uint8_t* resizable_buf,
return kTfLiteOk;
}
uint8_t* SimpleMemoryAllocator::AllocatePersistentBuffer(size_t size,
size_t alignment) {
uint8_t* SingleArenaBufferAllocator::AllocatePersistentBuffer(
size_t size, size_t alignment) {
uint8_t* const aligned_result = AlignPointerDown(tail_ - size, alignment);
if (aligned_result < head_) {
#ifndef TF_LITE_STRIP_ERROR_STRINGS
@@ -129,7 +132,8 @@ uint8_t* SimpleMemoryAllocator::AllocatePersistentBuffer(size_t size,
return aligned_result;
}
uint8_t* SimpleMemoryAllocator::AllocateTemp(size_t size, size_t alignment) {
uint8_t* SingleArenaBufferAllocator::AllocateTemp(size_t size,
size_t alignment) {
uint8_t* const aligned_result = AlignPointerUp(temp_, alignment);
const size_t available_memory = tail_ - aligned_result;
if (available_memory < size) {
@@ -145,12 +149,12 @@ uint8_t* SimpleMemoryAllocator::AllocateTemp(size_t size, size_t alignment) {
return aligned_result;
}
void SimpleMemoryAllocator::DeallocateTemp(uint8_t* temp_buf) {
void SingleArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
temp_buffer_ptr_check_sum_ ^= (reinterpret_cast<intptr_t>(temp_buf));
temp_buffer_count_--;
}
bool SimpleMemoryAllocator::IsAllTempDeallocated() {
bool SingleArenaBufferAllocator::IsAllTempDeallocated() {
if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
MicroPrintf(
"Number of allocated temp buffers: %d. Checksum passing status: %d",
@@ -160,7 +164,7 @@ bool SimpleMemoryAllocator::IsAllTempDeallocated() {
return true;
}
TfLiteStatus SimpleMemoryAllocator::ResetTempAllocations() {
TfLiteStatus SingleArenaBufferAllocator::ResetTempAllocations() {
// TODO(b/209453859): enable error check based on IsAllTempDeallocated after
// all AllocateTemp have been paird with DeallocateTemp
if (!IsAllTempDeallocated()) {
@@ -172,34 +176,34 @@ TfLiteStatus SimpleMemoryAllocator::ResetTempAllocations() {
return kTfLiteOk;
}
uint8_t* SimpleMemoryAllocator::GetOverlayMemoryAddress() const {
uint8_t* SingleArenaBufferAllocator::GetOverlayMemoryAddress() const {
return buffer_head_;
}
size_t SimpleMemoryAllocator::GetNonPersistentUsedBytes() const {
size_t SingleArenaBufferAllocator::GetNonPersistentUsedBytes() const {
return std::max(head_ - buffer_head_, temp_ - buffer_head_);
}
size_t SimpleMemoryAllocator::GetPersistentUsedBytes() const {
size_t SingleArenaBufferAllocator::GetPersistentUsedBytes() const {
return buffer_tail_ - tail_;
}
size_t SimpleMemoryAllocator::GetAvailableMemory(size_t alignment) const {
size_t SingleArenaBufferAllocator::GetAvailableMemory(size_t alignment) const {
uint8_t* const aligned_temp = AlignPointerUp(temp_, alignment);
uint8_t* const aligned_tail = AlignPointerDown(tail_, alignment);
return aligned_tail - aligned_temp;
}
size_t SimpleMemoryAllocator::GetUsedBytes() const {
size_t SingleArenaBufferAllocator::GetUsedBytes() const {
return GetPersistentUsedBytes() + GetNonPersistentUsedBytes();
}
size_t SimpleMemoryAllocator::GetBufferSize() const {
size_t SingleArenaBufferAllocator::GetBufferSize() const {
return buffer_tail_ - buffer_head_;
}
uint8_t* SimpleMemoryAllocator::head() const { return head_; }
uint8_t* SingleArenaBufferAllocator::head() const { return head_; }
uint8_t* SimpleMemoryAllocator::tail() const { return tail_; }
uint8_t* SingleArenaBufferAllocator::tail() const { return tail_; }
} // namespace tflite

View File

@@ -13,37 +13,37 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
#include <cstddef>
#include <cstdint>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/ibuffer_allocator.h"
namespace tflite {
// TODO(petewarden): This allocator never frees up or reuses any memory, even
// though we have enough information about lifetimes of the tensors to do so.
// This makes it pretty wasteful, so we should use a more intelligent method.
class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
public IPersistentBufferAllocator {
class SingleArenaBufferAllocator : public INonPersistentBufferAllocator,
public IPersistentBufferAllocator {
public:
// TODO(b/157615197): Cleanup constructors/destructor and use factory
// functions.
SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer_head,
uint8_t* buffer_tail);
SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer,
size_t buffer_size);
virtual ~SimpleMemoryAllocator();
SingleArenaBufferAllocator(ErrorReporter* error_reporter,
uint8_t* buffer_head, uint8_t* buffer_tail);
SingleArenaBufferAllocator(ErrorReporter* error_reporter, uint8_t* buffer,
size_t buffer_size);
virtual ~SingleArenaBufferAllocator();
// Creates a new SimpleMemoryAllocator from a given buffer head and size.
static SimpleMemoryAllocator* Create(ErrorReporter* error_reporter,
uint8_t* buffer_head,
size_t buffer_size);
// Creates a new SingleArenaBufferAllocator from a given buffer head and size.
static SingleArenaBufferAllocator* Create(ErrorReporter* error_reporter,
uint8_t* buffer_head,
size_t buffer_size);
// Resizes a buffer that is previously returned by the
// AllocateResizableBuffer. In current implementation, it Adjusts the head
@@ -126,7 +126,9 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
private:
size_t GetBufferSize() const;
#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
ErrorReporter* error_reporter_;
#endif
uint8_t* buffer_head_;
uint8_t* buffer_tail_;
uint8_t* head_;
@@ -147,4 +149,4 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SINGLE_ARENA_BUFFER_ALLOCATOR_H_

View File

@@ -16,10 +16,10 @@ limitations under the License.
#include "tensorflow/lite/micro/fake_micro_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
namespace tflite {
namespace {
@@ -30,7 +30,7 @@ static uint8_t dummy_tensor_arena[KDummyTensorArenaSize];
} // namespace
FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
SimpleMemoryAllocator* allocator,
SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph)
: MicroContext(
MicroAllocator::Create(dummy_tensor_arena, KDummyTensorArenaSize,
@@ -67,10 +67,10 @@ TfLiteEvalTensor* FakeMicroContext::GetEvalTensor(int tensor_index) {
}
void* FakeMicroContext::AllocatePersistentBuffer(size_t bytes) {
// FakeMicroContext use SimpleMemoryAllocator, which does not automatically
// apply the buffer alignment like MicroAllocator.
// The buffer alignment is potentially wasteful but allows the
// fake_micro_context to work correctly with optimized kernels.
// FakeMicroContext use SingleArenaBufferAllocator, which does not
// automatically apply the buffer alignment like MicroAllocator. The buffer
// alignment is potentially wasteful but allows the fake_micro_context to work
// correctly with optimized kernels.
return allocator_->AllocatePersistentBuffer(bytes,
MicroArenaBufferAlignment());
}

View File

@@ -23,7 +23,7 @@ namespace tflite {
// A fake of MicroContext for kernel util tests.
class FakeMicroContext : public MicroContext {
public:
FakeMicroContext(TfLiteTensor* tensors, SimpleMemoryAllocator* allocator,
FakeMicroContext(TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph);
void* AllocatePersistentBuffer(size_t bytes) override;
@@ -46,7 +46,7 @@ class FakeMicroContext : public MicroContext {
TfLiteTensor* tensors_;
int allocated_tensor_count_ = 0;
SimpleMemoryAllocator* allocator_;
SingleArenaBufferAllocator* allocator_;
TF_LITE_REMOVE_VIRTUAL_DELETE
};

View File

@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default: {
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
MicroPrintf("Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default: {
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
MicroPrintf("Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_RELU() {
return {/*init=*/ReluInit,
/*free=*/nullptr,
/*prepare=*/ReluPrepare,
/*invoke=*/ReluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
}
TfLiteRegistration Register_RELU6() {
return {/*init=*/Relu6Init,
/*free=*/nullptr,
/*prepare=*/Relu6Prepare,
/*invoke=*/Relu6Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
}
} // namespace tflite

View File

@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_ADD() {
return {/*init=*/AddInit,
/*free=*/nullptr,
/*prepare=*/AddPrepare,
/*invoke=*/AddEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
}
} // namespace tflite

View File

@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_ADD_N() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace arg_min_max
TfLiteRegistration Register_ARG_MAX() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/arg_min_max::ArgMaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
}
TfLiteRegistration Register_ARG_MIN() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/arg_min_max::ArgMinEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
}
} // namespace micro

View File

@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_ASSIGN_VARIABLE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_BROADCAST_ARGS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/BroadcastArgsPrepare,
/*invoke=*/BroadcastArgsEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
BroadcastArgsEval);
}
} // namespace tflite
} // namespace tflite

View File

@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_BROADCAST_TO() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/BroadcastToPrepare,
/*invoke=*/BroadcastToEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
BroadcastToEval);
}
} // namespace tflite
} // namespace tflite

View File

@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_CALL_ONCE() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CAST() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace ceil
TfLiteRegistration Register_CEIL() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/ceil::Prepare,
/*invoke=*/ceil::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
}
} // namespace micro

View File

@@ -108,14 +108,8 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
static TfLiteRegistration r = {/*init=*/CircularBufferInit,
/*free=*/nullptr,
/*prepare=*/CircularBufferPrepare,
/*invoke=*/CircularBufferEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
static TfLiteRegistration r = tflite::micro::RegisterOp(
CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
return &r;
}

View File

@@ -39,13 +39,12 @@ const int kCircularBufferCyclesMaxIndex = 0; // 'cycles_max'
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(-9);
TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
MicroContext * micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context-> AllocateTempInputTensor(node, kCircularBufferInputTensor);
TfLiteTensor* output =
micro_context-> AllocateTempOutputTensor(node, kCircularBufferOutputTensor);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kCircularBufferInputTensor);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
node, kCircularBufferOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
OpDataCircularBuffer* op_data =

View File

@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} // namespace comparisons
TfLiteRegistration Register_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::EqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::EqualEval);
}
TfLiteRegistration Register_NOT_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::NotEqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::NotEqualEval);
}
TfLiteRegistration Register_GREATER() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::GreaterEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::GreaterEval);
}
TfLiteRegistration Register_GREATER_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::GreaterEqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::GreaterEqualEval);
}
TfLiteRegistration Register_LESS() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::LessEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::LessEval);
}
TfLiteRegistration Register_LESS_EQUAL() {
return {/*init=*/comparisons::Init,
/*free=*/nullptr,
/*prepare=*/comparisons::Prepare,
/*invoke=*/comparisons::LessEqualEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
comparisons::LessEqualEval);
}
} // namespace micro

View File

@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input != nullptr);
int num_dimensions = NumDimensions(input);
if (num_dimensions > 4) {
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
TF_LITE_KERNEL_LOG(
context,
"Op Concatenation does not currently support num dimensions >4 "
"Op Concatenation does not currently support num dimensions > %d "
"Tensor has %d dimensions.",
num_dimensions);
RuntimeShape::kMaxSmallSize, num_dimensions);
return kTfLiteError;
}
micro_context->DeallocateTempTfLiteTensor(input);
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace concatenation
TfLiteRegistration Register_CONCATENATION() {
return {/*init=*/concatenation::Init,
/*free=*/nullptr,
/*prepare=*/concatenation::Prepare,
/*invoke=*/concatenation::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
concatenation::Eval);
}
} // namespace micro

View File

@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
break;
}
case kTfLiteInt16: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<std::int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
switch (bias->type) {
case kTfLiteInt32: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
case kTfLiteInt64: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default:
MicroPrintf("Bias type %s (%d) not supported.",
TfLiteTypeGetName(bias->type), bias->type);
return kTfLiteError;
}
break;
}
case kTfLiteInt8: {
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/ConvPrepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
}
} // namespace tflite

View File

@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
TfLiteRegistration registration, int16_t* output_data);
TfLiteStatus TestConvQuantizedPerChannel(
int* input_dims_data, const float* input_data, int16_t* input_quantized,
float input_scale, int input_zero_point, int* filter_dims_data,
const float* filter_data, int8_t* filter_data_quantized,
int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
float* bias_scales, int* bias_zero_points, int* output_dims_data,
const float* expected_output_data, int16_t* expected_output_data_quantized,
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
TfLiteRegistration registration, int16_t* output_data);
} // namespace testing
} // namespace tflite

View File

@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CUMSUM() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_DEPTH_TO_SPACE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/DepthwiseConvPrepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
// This is the most generic TfLiteRegistration. The actual supported types may
// still be target dependent. The only requirement is that every implementation
// (reference or optimized) must define this function.
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
#if defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int8 activations and int8 weights and uses the latency optimized
// implementations.
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int16 activations and int8 weights and uses the latency optimized
// implementations.
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
#else
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
return Register_DEPTHWISE_CONV_2D();
}
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
return Register_DEPTHWISE_CONV_2D();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_

View File

@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteUInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_DEQUANTIZE() {
return {/*init=*/DequantizeInit,
/*free=*/nullptr,
/*prepare=*/DequantizePrepare,
/*invoke=*/DequantizeEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
DequantizeEval);
}
} // namespace tflite

View File

@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteUInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
if (output->type == kTfLiteInt32) {

View File

@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return op_data;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = static_cast<OpData*>(node->user_data);
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
static TfLiteRegistration r = {/*init=*/Init,
/*free=*/Free,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}

View File

@@ -0,0 +1,208 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/div.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
struct OpDataDiv {
// Parameters used in the quantized paths where the output is 8bit
int32_t input1_zero_point;
int32_t input2_zero_point;
int32_t output_zero_point;
int32_t output_activation_min;
int32_t output_activation_max;
// Parameters used in all quantized paths
int32_t output_multiplier;
int output_shift;
};
TfLiteStatus CalculateOpDataDiv(TfLiteContext* context, TfLiteTensor* input1,
TfLiteTensor* input2, TfLiteTensor* output,
TfLiteDivParams* params, OpDataDiv* data) {
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, output->type);
if (output->type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
const double real_multiplier = static_cast<double>(
input1->params.scale / (input2->params.scale * output->params.scale));
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
&data->output_shift);
data->input1_zero_point = input1->params.zero_point;
data->input2_zero_point = input2->params.zero_point;
data->output_zero_point = output->params.zero_point;
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataDiv));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input1 =
micro_context->AllocateTempInputTensor(node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* input2 =
micro_context->AllocateTempInputTensor(node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
OpDataDiv* data = static_cast<OpDataDiv*>(node->user_data);
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
TF_LITE_ENSURE_STATUS(
CalculateOpDataDiv(context, input1, input2, output, params, data));
micro_context->DeallocateTempTfLiteTensor(input1);
micro_context->DeallocateTempTfLiteTensor(input2);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
const OpDataDiv* data, const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params = {};
#define TF_LITE_DIV(type, opname, data_type) \
data_type output_activation_min, output_activation_max; \
CalculateActivationRange(params->activation, &output_activation_min, \
&output_activation_max); \
SetActivationParams(output_activation_min, output_activation_max, \
&op_params); \
type::opname(op_params, tflite::micro::GetTensorShape(input1), \
tflite::micro::GetTensorData<data_type>(input1), \
tflite::micro::GetTensorShape(input2), \
tflite::micro::GetTensorData<data_type>(input2), \
tflite::micro::GetTensorShape(output), \
tflite::micro::GetTensorData<data_type>(output))
bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (requires_broadcast) {
TF_LITE_DIV(reference_ops, BroadcastDivSlow, float);
} else {
TF_LITE_DIV(reference_ops, Div, float);
}
#undef TF_LITE_DIV
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteDivParams* params, const OpDataDiv* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params = {};
#define TF_LITE_DIV(type, opname, dtype) \
type::opname(op_params, tflite::micro::GetTensorShape(input1), \
tflite::micro::GetTensorData<dtype>(input1), \
tflite::micro::GetTensorShape(input2), \
tflite::micro::GetTensorData<dtype>(input2), \
tflite::micro::GetTensorShape(output), \
tflite::micro::GetTensorData<dtype>(output))
if (input1->type == kTfLiteInt8 && input2->type == kTfLiteInt8 &&
output->type == kTfLiteInt8) {
SetActivationParams(data->output_activation_min,
data->output_activation_max, &op_params);
op_params.input1_offset = -data->input1_zero_point;
op_params.input2_offset = -data->input2_zero_point;
op_params.output_offset = data->output_zero_point;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (requires_broadcast) {
TF_LITE_DIV(reference_ops, BroadcastDivSlow, int8_t);
} else {
TF_LITE_DIV(reference_ops, Div, int8_t);
}
#undef TF_LITE_DIV
} else {
TF_LITE_KERNEL_LOG(
context, "Unsupported combination of input and output types in DIV.");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = static_cast<TfLiteDivParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
auto* data = static_cast<OpDataDiv*>(node->user_data);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
EvalDiv(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteInt8) {
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
input1, input2, output));
} else {
TF_LITE_KERNEL_LOG(context,
"DIV only supports FLOAT32, quantized INT8 "
"now, got type %s (%d).",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_DIV() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@ limitations under the License.
#include <cmath>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -27,6 +29,22 @@ namespace micro {
namespace elementwise {
namespace {
constexpr int kAbsNameId = 0;
constexpr int kRsrqtNameId = 1;
const int kElementwiseInputTensor = 0;
const int kElementwiseOutputTensor = 0;
struct OpDataAbsRsqrt {
int32_t multiplier;
int shift;
int input_offset;
int output_offset;
bool needs_rescale;
TfLiteQuantizationType input_quantization_type;
TfLiteType input_type;
};
bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32;
}
@@ -35,16 +53,40 @@ bool IsLogicalSupportedType(const TfLiteType type) {
return type == kTfLiteBool;
}
bool IsAbsSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
}
bool IsRsqrtSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8;
}
inline void SetAbsOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int* shift) {
QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
multiplier, shift);
}
inline void SetRsqrtOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int* shift) {
const double scale =
1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
QuantizeMultiplier(scale, multiplier, shift);
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
@@ -58,9 +100,79 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType, const int op_nameid>
TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
op_data->input_type = input->type;
// For int16 type input, we support both quantized and non-quantized
// evaluation.
if (op_nameid == kAbsNameId) {
op_data->input_quantization_type = input->quantization.type;
}
if (input->type == kTfLiteInt8 ||
(input->type == kTfLiteInt16 &&
input->quantization.type != kTfLiteNoQuantization)) {
TF_LITE_ENSURE_EQ(context, input->quantization.type,
kTfLiteAffineQuantization);
TF_LITE_ENSURE_EQ(context, output->quantization.type,
kTfLiteAffineQuantization);
const auto* input_params =
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
TF_LITE_ENSURE(context, input_params != nullptr);
TF_LITE_ENSURE(context, input_params->scale != nullptr);
TF_LITE_ENSURE(context, input_params->scale->size > 0);
TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
TF_LITE_ENSURE(context, output_params != nullptr);
TF_LITE_ENSURE(context, output_params->scale != nullptr);
TF_LITE_ENSURE(context, output_params->scale->size > 0);
TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
op_data->input_offset = input_params->zero_point->data[0];
op_data->output_offset = output_params->zero_point->data[0];
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
}
const float input_scale = input_params->scale->data[0];
const float output_scale = output_params->scale->data[0];
op_data->needs_rescale = input_scale != output_scale;
if (op_nameid == kAbsNameId && op_data->needs_rescale) {
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
} else if (op_nameid == kRsrqtNameId) {
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
}
}
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteType expected_type) {
inline TfLiteStatus EvalImplQuantized(
TfLiteContext* context, TfLiteNode* node,
T func(TfLiteContext*, TfLiteNode*, T),
TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
@@ -68,6 +180,34 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context,
validate_input_func(context, node, in_data[i]));
}
out_data[i] = func(context, node, in_data[i]);
}
return kTfLiteOk;
}
template <typename T>
inline T AbsHelper(T i) {
return std::abs(i);
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteStatus validate_input_func(T),
TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const size_t num_elements = ElementCount(*input->dims);
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
}
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
return EvalImpl<float>(context, node, float_func,
/*validate_input_func=*/nullptr, kTfLiteFloat32);
}
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
bool bool_func(bool)) {
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
return EvalImpl<bool>(context, node, bool_func,
/*validate_input_func=*/nullptr, kTfLiteBool);
}
void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
}
template <typename T>
inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
const int32_t value = std::abs(i - op_data->input_offset);
if (!op_data->needs_rescale) {
return static_cast<T>(
std::min(std::max(static_cast<long int>(value + op_data->output_offset),
static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
const int32_t output = tflite::MultiplyByQuantizedMultiplier(
value, op_data->multiplier, op_data->shift) +
op_data->output_offset;
return static_cast<T>(std::min(
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
template <typename T>
inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
const int32_t value = (i - op_data->input_offset);
const int32_t kShift = 20; // Shift to keep value integer.
if (value == 0) {
// Assume that any value close to 0 represents the max output value.
return static_cast<T>(kMax);
}
int32_t inv_sqrt_multiplier;
int inv_sqrt_shift;
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
&inv_sqrt_shift);
const int32_t data = tflite::MultiplyByQuantizedMultiplier(
static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
const int32_t output =
tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
op_data->shift - kShift) +
op_data->output_offset;
return static_cast<T>(std::min(
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
template <typename T>
TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
"Rsqrt is only defined for positive values");
return static_cast<TfLiteStatus>(kTfLiteOk);
}
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::abs);
OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
TfLiteType type = op_data->input_type;
TfLiteQuantizationType input_quantization_type =
op_data->input_quantization_type;
TfLiteStatus eval_result;
switch (type) {
case kTfLiteFloat32:
eval_result = EvalNumeric(context, node, std::abs);
break;
case kTfLiteInt8:
eval_result =
EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
/*validate_input_func=*/nullptr, type);
break;
case kTfLiteInt16:
eval_result =
input_quantization_type == kTfLiteNoQuantization
? EvalImpl<int16_t>(context, node, AbsHelper,
/*validate_input_func=*/nullptr, type)
: EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
/*validate_input_func=*/nullptr,
type);
break;
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
break;
}
return eval_result;
}
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
TfLiteType type = op_data->input_type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(
context, node, [](float f) { return 1.f / std::sqrt(f); },
/*validate_input_func=*/nullptr, type);
case kTfLiteInt8:
return EvalImplQuantized<int8_t>(context, node,
elementwise::RsqrtEvalQuantized,
elementwise::validate_input_func, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
@@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace elementwise
TfLiteRegistration Register_ABS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::AbsEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
elementwise::ElementWiseAbsRsqrtInit,
elementwise::PrepareAbsRsqrt<elementwise::IsAbsSupportedType,
elementwise::kAbsNameId>,
elementwise::AbsEval);
}
TfLiteRegistration Register_SIN() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::SinEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SinEval);
}
TfLiteRegistration Register_COS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::CosEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::CosEval);
}
TfLiteRegistration Register_LOG() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::LogEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval);
}
TfLiteRegistration Register_SQRT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::SqrtEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SqrtEval);
}
TfLiteRegistration Register_RSQRT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::RsqrtEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
elementwise::ElementWiseAbsRsqrtInit,
elementwise::PrepareAbsRsqrt<elementwise::IsRsqrtSupportedType,
elementwise::kRsrqtNameId>,
elementwise::RsqrtEval);
}
TfLiteRegistration Register_SQUARE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
/*invoke=*/elementwise::SquareEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SquareEval);
}
TfLiteRegistration Register_LOGICAL_NOT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
/*invoke=*/elementwise::LogicalNotEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(
nullptr, elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
elementwise::LogicalNotEval);
}
} // namespace micro
} // namespace ops
} // namespace tflite
} // namespace tflite

View File

@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_ELU() {
return {/*init=*/EluInit,
/*free=*/nullptr,
/*prepare=*/EluPrepare,
/*invoke=*/EluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
}
} // namespace tflite

View File

@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_ADD() {
return {/*init=*/AddInit,
/*free=*/nullptr,
/*prepare=*/AddPrepare,
/*invoke=*/AddEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
}
} // namespace tflite

View File

@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if ESP_NN
if (input->type == kTfLiteInt8) {
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input->dims->data[3], 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output->dims->data[3], 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
conv_params_t conv_params = {
.in_offset = 0, .out_offset = 0,
.stride = {params.stride_width, params.stride_height},
.padding = {data->op_data.padding.width, data->op_data.padding.height},
.dilation = {0, 0}, .activation = {-128, 127}
};
int scratch_buf_size = esp_nn_get_conv_scratch_size(
input_width, input_height, input->dims->data[3],
output->dims->data[3], filter_width, filter_height);
&input_dims, &filter_dims, &output_dims, &conv_params);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
const int input_size = input_width * input_height * input_depth;
const int output_size = output_width * output_height * output_depth;
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input_depth, 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output_depth, 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
conv_params_t conv_params = {
.in_offset = input_offset, .out_offset = output_offset,
.stride = {stride_width, stride_height},
.padding = {pad_width, pad_height},
.dilation = {0, 0},
.activation = {activation_min, activation_max}
};
quant_data_t quant_data = {
.shift = data.op_data.per_channel_output_shift,
.mult = data.op_data.per_channel_output_multiplier
};
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
esp_nn_conv_s8(input_data + i_batch * input_size,
input_width, input_height, input_depth, input_offset,
pad_width, pad_height, stride_width, stride_height,
tflite::micro::GetTensorData<int8_t>(filter),
filter_width, filter_height,
esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorData<int32_t>(bias),
output_data + i_batch * output_size,
output_width, output_height, output_depth, output_offset,
data.op_data.per_channel_output_shift,
data.op_data.per_channel_output_multiplier,
activation_min, activation_max);
&output_dims, output_data + i_batch * output_size,
&conv_params, &quant_data);
}
} else {
reference_integer_ops::ConvPerChannel(
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
conv_total_time += esp_timer_get_time() - start_time;
long long time_this_instance = esp_timer_get_time() - start_time;
conv_total_time += time_this_instance;
//printf("time this instance: %llu\n", time_this_instance / 1000);
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
if (data.buffer_idx > -1) {
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
}
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input_depth, 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output_depth, 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
dw_conv_params_t conv_params = {
.in_offset = input_offset, .out_offset = output_offset,
.ch_mult = depth_multiplier,
.stride = {stride_width, stride_height},
.padding = {pad_width, pad_height}, .dilation = {0, 0},
.activation = {activation_min, activation_max}
};
quant_data_t quant_data = {
.shift = data.op_data.per_channel_output_shift,
.mult = data.op_data.per_channel_output_multiplier
};
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
input_height, input_depth, input_offset,
pad_width, pad_height,
stride_width, stride_height, depth_multiplier,
tflite::micro::GetTensorData<int8_t>(filter),
filter_width, filter_height,
esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorData<int32_t>(bias),
output_data + i_batch * output_size,
output_width, output_height, output_offset,
data.op_data.per_channel_output_shift,
data.op_data.per_channel_output_multiplier,
activation_min, activation_max);
&output_dims, output_data + i_batch * output_size,
&conv_params, &quant_data);
}
} else {
reference_integer_ops::DepthwiseConvPerChannel(
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if ESP_NN
if (input->type == kTfLiteInt8) {
data_dims_t input_dims = {
.width = input_width, .height = input_height,
.channels = input->dims->data[3], 1
};
data_dims_t output_dims = {
.width = output_width, .height = output_height,
.channels = output->dims->data[3], 1
};
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
dw_conv_params_t conv_params = {
.in_offset = 0, .out_offset = 0,
.ch_mult = params.depth_multiplier,
.stride = {params.stride_width, params.stride_height},
.padding = {data->op_data.padding.width, data->op_data.padding.height},
.dilation = {0, 0}, .activation = {-128, 127}
};
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
input_width, input_height, input->dims->data[3],
params.depth_multiplier, filter_width, filter_height);
&input_dims, &filter_dims, &output_dims, &conv_params);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
dc_total_time += esp_timer_get_time() - start_time;
long long time_this_instance = esp_timer_get_time() - start_time;
dc_total_time += time_this_instance;
// printf("time this instance: %llu\n", time_this_instance / 1000);
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_MUL() {
return {/*init=*/MulInit,
/*free=*/nullptr,
/*prepare=*/MulPrepare,
/*invoke=*/MulEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
}
} // namespace tflite

View File

@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} // namespace
TfLiteRegistration Register_AVERAGE_POOL_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PoolingPrepare,
/*invoke=*/AverageEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
}
TfLiteRegistration Register_MAX_POOL_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/PoolingPrepare,
/*invoke=*/MaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
}
} // namespace tflite

View File

@@ -0,0 +1,208 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/kernels/softmax.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "freertos/FreeRTOS.h"
#include <esp_timer.h>
#if ESP_NN
#include <esp_nn.h>
#endif
long long softmax_total_time = 0;
namespace tflite {
namespace {
// Softmax parameter data that persists in user_data
const int kInt16LUTArraySize = 513;
struct NodeData {
SoftmaxParams op_data;
#if ESP_NN
int buffer_idx;
#endif
};
static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(NodeData));
}
void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
TfLiteEvalTensor* output, const NodeData* data) {
if (input->type == kTfLiteInt8) {
if (output->type == kTfLiteInt16) {
tflite::reference_ops::Softmax(
data->op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else {
#if ESP_NN
const int32_t input_beta_multiplier = data->op_data.input_multiplier;
const int32_t input_beta_left_shift = data->op_data.input_left_shift;
const int diff_min = data->op_data.diff_min;
const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int8_t *in_ptr = tflite::micro::GetTensorData<int8_t>(input);
int8_t *out_ptr = tflite::micro::GetTensorData<int8_t>(output);
void *scratch_buf = NULL;
if (data->buffer_idx > -1) {
scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
}
esp_nn_set_softmax_scratch_buf(scratch_buf);
esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
input_beta_left_shift, diff_min, out_ptr);
#else
tflite::reference_ops::Softmax(
data->op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
}
} else {
tflite::reference_ops::SoftmaxInt16(
data->op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
}
}
static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TFLITE_DCHECK(node->user_data != nullptr);
NodeData data = *static_cast<NodeData*>(node->user_data);
long long start_time = esp_timer_get_time();
switch (input->type) {
case kTfLiteFloat32: {
tflite::reference_ops::Softmax(
data.op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
break;
case kTfLiteInt8:
case kTfLiteInt16: {
SoftmaxQuantized(context, input, output, &data);
}
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
softmax_total_time += esp_timer_get_time() - start_time;
return kTfLiteOk;
}
static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, node->user_data != nullptr);
NodeData* data = static_cast<NodeData*>(node->user_data);
// Only allocate LUTs for KTfLiteInt16 data type
if (input->type == kTfLiteInt16) {
void* raw_exp_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
data->op_data.exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
context, sizeof(int16_t) * kInt16LUTArraySize);
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
data->op_data.one_over_one_plus_x_lut =
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
}
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
}
// Populate LUT if required
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// exp LUT only used on negative values
// we consider exp(-10.0) is insignificant to accumulation
gen_lut<float, int16_t, int16_t>(
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
data->op_data.exp_lut);
gen_lut<float, int16_t, int16_t>(
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
1.0f, data->op_data.one_over_one_plus_x_lut);
data->op_data.zero_point = output->params.zero_point;
data->op_data.scale = output->params.scale;
}
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
auto ret_val =
CalculateSoftmaxParams(context, input, output, params, &data->op_data);
#if ESP_NN
if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
const int32_t input_width = input->dims->data[1];
const int32_t input_height = input->dims->data[2];
int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
input_height);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
}
}
#endif
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return ret_val;
}
} // namespace
TfLiteRegistration Register_SOFTMAX() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_EXP() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_EXPAND_DIMS() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FILL() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace floor
TfLiteRegistration Register_FLOOR() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/floor::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
}
} // namespace micro

View File

@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FLOOR_DIV() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FLOOR_MOD() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
context, params->activation, input->type,
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
const int64_t* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
: nullptr;
tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias), bias_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default: {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
@@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

View File

@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
}
#endif
#if defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int16.
TfLiteRegistration Register_FULLY_CONNECTED_INT16();
#else
// Note that while this block gets used for both reference and optimized kernels
// that do not have any specialized implementations, the only goal here is to
// define fallback implementation that allow reference kernels to still be used
// from applications that call a more specific kernel variant.
inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
return Register_FULLY_CONNECTED();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_

View File

@@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_GATHER() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -131,7 +131,8 @@ TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
slice_size *= params->dims->data[i];
}
int remain_flat_size = ElementCount(*params->dims);
int params_flat_size = ElementCount(*params->dims);
int remain_flat_size = params_flat_size;
// Number of elements per dimension
int dims_to_count[MAX_INDICES_ND];
@@ -147,6 +148,9 @@ TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
IndicesT index = index_data[offset];
from_pos += index * dims_to_count[j];
}
if (from_pos < 0 || from_pos + slice_size > params_flat_size) {
return kTfLiteError;
}
std::memcpy(output_data + i * slice_size, param_data + from_pos,
sizeof(ParamsT) * slice_size);
}
@@ -158,12 +162,13 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
const TfLiteEvalTensor* params,
const TfLiteEvalTensor* indices,
TfLiteEvalTensor* output) {
TfLiteStatus status = kTfLiteError;
switch (params->type) {
case kTfLiteFloat32:
return GatherNd<float, IndicesT>(params, indices, output);
status = GatherNd<float, IndicesT>(params, indices, output);
break;
case kTfLiteInt8:
return GatherNd<int8_t, IndicesT>(params, indices, output);
status = GatherNd<int8_t, IndicesT>(params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
@@ -171,6 +176,10 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
TfLiteTypeGetName(params->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
}
return status;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -195,14 +204,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_GATHER_ND() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite

View File

@@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_HARD_SWISH() {
return {/*init=*/HardSwishInit,
/*free=*/nullptr,
/*prepare=*/tflite::HardSwishPrepare,
/*invoke=*/HardSwishEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare,
HardSwishEval);
}
} // namespace tflite

View File

@@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_IF() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite

Some files were not shown because too many files have changed in this diff Show More