mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-09 21:17:06 +03:00
Rolling 20220924
This commit is contained in:
@@ -51,18 +51,28 @@ set(lib_srcs
|
||||
"${tflite_dir}/kernels/internal/quantization_util.cc"
|
||||
"${tflite_dir}/schema/schema_utils.cc")
|
||||
|
||||
set(priv_req esp-nn)
|
||||
|
||||
# include component requirements which were introduced after IDF version 4.1
|
||||
if("${IDF_VERSION_MAJOR}.${IDF_VERSION_MINOR}" VERSION_GREATER "4.1")
|
||||
list(APPEND priv_req esp_timer driver)
|
||||
endif()
|
||||
|
||||
idf_component_register(
|
||||
SRCS "${lib_srcs}"
|
||||
INCLUDE_DIRS "." "third_party/gemmlowp"
|
||||
"third_party/flatbuffers/include"
|
||||
"third_party/ruy"
|
||||
"third_party/kissfft"
|
||||
REQUIRES "esp-nn")
|
||||
REQUIRES ${pub_req}
|
||||
PRIV_REQUIRES ${priv_req})
|
||||
|
||||
# Reduce the level of paranoia to be able to compile TF sources
|
||||
target_compile_options(${COMPONENT_LIB} PRIVATE
|
||||
-Wno-maybe-uninitialized
|
||||
-Wno-missing-field-initializers
|
||||
-Wno-error=sign-compare
|
||||
-Wno-error=double-promotion
|
||||
-DESP_NN # enables ESP-NN optimizations by Espressif
|
||||
-Wno-type-limits)
|
||||
|
||||
|
||||
@@ -185,6 +185,7 @@ typedef enum {
|
||||
kTfLiteBuiltinUnsortedSegmentSum = 155,
|
||||
kTfLiteBuiltinAtan2 = 156,
|
||||
kTfLiteBuiltinUnsortedSegmentMin = 157,
|
||||
kTfLiteBuiltinSign = 158,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -283,4 +283,25 @@ const char* TfLiteTypeGetName(TfLiteType type) {
|
||||
|
||||
TfLiteDelegate TfLiteDelegateCreate() { return TfLiteDelegate{}; }
|
||||
|
||||
struct TfLiteOpaqueDelegateStruct* TfLiteOpaqueDelegateCreate(
|
||||
const TfLiteOpaqueDelegateBuilder* opaque_delegate_builder) {
|
||||
if (!opaque_delegate_builder) return nullptr;
|
||||
|
||||
TfLiteDelegate* result = new TfLiteDelegate{};
|
||||
result->opaque_delegate_builder = new TfLiteOpaqueDelegateBuilder{};
|
||||
*(result->opaque_delegate_builder) = *opaque_delegate_builder;
|
||||
|
||||
return reinterpret_cast<struct TfLiteOpaqueDelegateStruct*>(result);
|
||||
}
|
||||
|
||||
void TfLiteOpaqueDelegateDelete(
|
||||
const struct TfLiteOpaqueDelegateStruct* opaque_delegate) {
|
||||
if (!opaque_delegate) return;
|
||||
|
||||
const TfLiteDelegate* tflite_delegate =
|
||||
reinterpret_cast<const TfLiteDelegate*>(opaque_delegate);
|
||||
delete tflite_delegate->opaque_delegate_builder;
|
||||
delete tflite_delegate;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -63,6 +63,8 @@ typedef enum TfLiteExternalContextType {
|
||||
struct TfLiteContext;
|
||||
struct TfLiteDelegate;
|
||||
struct TfLiteRegistration;
|
||||
struct TfLiteOpaqueDelegateStruct;
|
||||
struct TfLiteOpaqueDelegateBuilder;
|
||||
|
||||
// An external context is a collection of information unrelated to the TF Lite
|
||||
// framework, but useful to a subset of the ops. TF Lite knows very little
|
||||
@@ -973,7 +975,7 @@ typedef enum TfLiteDelegateFlags {
|
||||
typedef struct TfLiteDelegate {
|
||||
// Data that delegate needs to identify itself. This data is owned by the
|
||||
// delegate. The delegate is owned in the user code, so the delegate is
|
||||
// responsible for doing this when it is destroyed.
|
||||
// responsible for deallocating this when it is destroyed.
|
||||
void* data_;
|
||||
|
||||
// Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
|
||||
@@ -1010,12 +1012,83 @@ typedef struct TfLiteDelegate {
|
||||
|
||||
// Bitmask flags. See the comments in `TfLiteDelegateFlags`.
|
||||
int64_t flags;
|
||||
|
||||
// The opaque delegate builder associated with this object. If set then the
|
||||
// TF Lite runtime will give precedence to this field. E.g. instead of
|
||||
// invoking 'Prepare' via the function pointer inside the 'TfLiteDelegate'
|
||||
// object, the runtime will first check if the corresponding function
|
||||
// pointer inside 'opaque_delegate_builder' is set and if so invoke that.
|
||||
//
|
||||
// If this field is non-null, then the 'Prepare' field (of the
|
||||
// 'TfLiteDelegate') should be null.
|
||||
struct TfLiteOpaqueDelegateBuilder* opaque_delegate_builder;
|
||||
} TfLiteDelegate;
|
||||
|
||||
// Build a 'null' delegate, with all the fields properly set to their default
|
||||
// values.
|
||||
TfLiteDelegate TfLiteDelegateCreate(void);
|
||||
|
||||
// `TfLiteOpaqueDelegateBuilder` is used for constructing
|
||||
// `TfLiteOpaqueDelegateStruct`, see `TfLiteOpaqueDelegateCreate` below. Note:
|
||||
// This struct is not ABI stable.
|
||||
//
|
||||
// For forward source compatibility `TfLiteOpaqueDelegateBuilder` objects should
|
||||
// be brace-initialized, so that all fields (including any that might be added
|
||||
// in the future) get zero-initialized. The purpose of each field is exactly
|
||||
// the same as with `TfLiteDelegate`.
|
||||
//
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
typedef struct TfLiteOpaqueDelegateBuilder {
|
||||
// Data that delegate needs to identify itself. This data is owned by the
|
||||
// delegate. The delegate is owned in the user code, so the delegate is
|
||||
// responsible for deallocating this when it is destroyed.
|
||||
void* data;
|
||||
// Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
|
||||
// delegate a view of the current graph through TfLiteContext*. It typically
|
||||
// will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
|
||||
// to ask the TensorFlow lite runtime to create macro-nodes to represent
|
||||
// delegated subgraphs of the original graph.
|
||||
TfLiteStatus (*Prepare)(TfLiteOpaqueContext* context, // NOLINT
|
||||
struct TfLiteOpaqueDelegateStruct* delegate,
|
||||
void* data);
|
||||
// Copies the data from delegate buffer handle into raw memory of the given
|
||||
// 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
|
||||
// long as it follows the rules for kTfLiteDynamic tensors, in which case this
|
||||
// cannot be null.
|
||||
TfLiteStatus (*CopyFromBufferHandle)( // NOLINT
|
||||
TfLiteOpaqueContext* context, struct TfLiteOpaqueDelegateStruct* delegate,
|
||||
void* data, TfLiteBufferHandle buffer_handle, TfLiteOpaqueTensor* tensor);
|
||||
// Copies the data from raw memory of the given 'tensor' to delegate buffer
|
||||
// handle. This can be null if the delegate doesn't use its own buffer.
|
||||
TfLiteStatus (*CopyToBufferHandle)( // NOLINT
|
||||
TfLiteOpaqueContext* context, struct TfLiteOpaqueDelegateStruct* delegate,
|
||||
void* data, TfLiteBufferHandle buffer_handle, TfLiteOpaqueTensor* tensor);
|
||||
// Frees the Delegate Buffer Handle. Note: This only frees the handle, but
|
||||
// this doesn't release the underlying resource (e.g. textures). The
|
||||
// resources are either owned by application layer or the delegate.
|
||||
// This can be null if the delegate doesn't use its own buffer.
|
||||
void (*FreeBufferHandle)(TfLiteOpaqueContext* context, // NOLINT
|
||||
struct TfLiteOpaqueDelegateStruct* delegate,
|
||||
void* data, TfLiteBufferHandle* handle);
|
||||
// Bitmask flags. See the comments in `TfLiteDelegateFlags`.
|
||||
int64_t flags;
|
||||
} TfLiteOpaqueDelegateBuilder;
|
||||
|
||||
// Creates an opaque delegate and returns its address. The opaque delegate will
|
||||
// behave according to the provided 'opaque_delegate_builder'. The lifetime of
|
||||
// the fields within the 'opaque_delegate_builder' must outlive any interaction
|
||||
// between the runtime and the returned 'TfLiteOpaqueDelegateStruct'. The
|
||||
// returned address should be passed to 'TfLiteOpaqueDelegateDelete' for
|
||||
// deletion. If 'opaque_delegate_builder' is a null pointer, then a null
|
||||
// pointer will be returned.
|
||||
struct TfLiteOpaqueDelegateStruct* TfLiteOpaqueDelegateCreate(
|
||||
const TfLiteOpaqueDelegateBuilder* opaque_delegate_builder);
|
||||
|
||||
// Deletes the provided opaque 'delegate'. This function has no effect if the
|
||||
// 'delegate' is a null pointer.
|
||||
void TfLiteOpaqueDelegateDelete(
|
||||
const struct TfLiteOpaqueDelegateStruct* delegate);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This provides a few C++ helpers that are useful for manipulating C structures
|
||||
// in C++.
|
||||
/// \file
|
||||
/// This provides a few C++ helpers that are useful for manipulating C
|
||||
/// structures in C++.
|
||||
#ifndef TENSORFLOW_LITE_CONTEXT_UTIL_H_
|
||||
#define TENSORFLOW_LITE_CONTEXT_UTIL_H_
|
||||
|
||||
@@ -23,13 +24,14 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite
|
||||
// C api uses. Can't use the google array_view, since we can't depend on even
|
||||
/// Provides a range iterable wrapper for TfLiteIntArray* (C lists) that TfLite
|
||||
/// C api uses.
|
||||
// Can't use the google array_view, since we can't depend on even
|
||||
// absl for embedded device reasons.
|
||||
class TfLiteIntArrayView {
|
||||
public:
|
||||
// Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null
|
||||
// and this view does not take ownership of it.
|
||||
/// Construct a view of a TfLiteIntArray*. Note, `int_array` should be
|
||||
/// non-null and this view does not take ownership of it.
|
||||
explicit TfLiteIntArrayView(const TfLiteIntArray* int_array)
|
||||
: int_array_(int_array) {}
|
||||
|
||||
|
||||
@@ -457,6 +457,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
return ParseRsqrt(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_SELECT_V2: {
|
||||
return ParseSelectV2(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_SHAPE: {
|
||||
return ParseShape(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
@@ -865,7 +869,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_RELU_0_TO_1:
|
||||
case BuiltinOperator_SCATTER_ND:
|
||||
case BuiltinOperator_SELECT:
|
||||
case BuiltinOperator_SELECT_V2:
|
||||
case BuiltinOperator_SLICE:
|
||||
case BuiltinOperator_TILE:
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
@@ -881,6 +884,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_UNSORTED_SEGMENT_PROD:
|
||||
case BuiltinOperator_UNSORTED_SEGMENT_SUM:
|
||||
case BuiltinOperator_ATAN2:
|
||||
case BuiltinOperator_SIGN:
|
||||
case BuiltinOperator_WHERE:
|
||||
return kTfLiteOk;
|
||||
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
|
||||
@@ -1982,6 +1986,14 @@ TfLiteStatus ParseRsqrt(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseSelectV2(const Operator*, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseShape(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
|
||||
@@ -319,6 +319,10 @@ TfLiteStatus ParseRound(const Operator* op, ErrorReporter* error_reporter,
|
||||
TfLiteStatus ParseRsqrt(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseSelectV2(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseShape(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_common.h"
|
||||
|
||||
#define FIXED_POINT 16
|
||||
|
||||
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_HARD_SWISH_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_HARD_SWISH_H_
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@@ -165,4 +165,4 @@ inline void HardSwish(const HardSwishParams& params,
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_HARD_SWISH_H_
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
|
||||
@@ -61,6 +62,20 @@ inline void Mul(const ArithmeticParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Mul(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const std::complex<float>* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const std::complex<float>* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
std::complex<float>* output_data) {
|
||||
const int flat_size =
|
||||
MatchingExtendedShapeFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = input1_data[i] * input2_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
inline void Mul(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape, const uint8_t* input1_data,
|
||||
const RuntimeShape& input2_shape, const uint8_t* input2_data,
|
||||
@@ -162,6 +177,37 @@ void BroadcastMul4DSlow(const ArithmeticParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void BroadcastMul4DSlow(const ArithmeticParams& params,
|
||||
const RuntimeShape& unextended_input1_shape,
|
||||
const std::complex<float>* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape,
|
||||
const std::complex<float>* input2_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
std::complex<float>* output_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
input1_data[SubscriptToIndex(desc1, b, y, x, c)] *
|
||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename D, typename T>
|
||||
void Select(const RuntimeShape& input_condition_shape,
|
||||
const D* input_condition_data, const RuntimeShape& input_x_shape,
|
||||
const T* input_x_data, const RuntimeShape& input_y_shape,
|
||||
const T* input_y_data, const RuntimeShape& output_shape,
|
||||
T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("Select");
|
||||
int64_t flatsize;
|
||||
// Allow select operator executions on mixed scalar tensors and one element
|
||||
// tensors.
|
||||
if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
|
||||
input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1) {
|
||||
flatsize = 1;
|
||||
} else {
|
||||
flatsize = MatchingFlatSize(input_condition_shape, input_x_shape,
|
||||
input_y_shape, output_shape);
|
||||
}
|
||||
for (int64_t i = 0; i < flatsize; ++i) {
|
||||
output_data[i] =
|
||||
input_condition_data[i] ? input_x_data[i] : input_y_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename D, typename T>
|
||||
void RankOneSelect(const RuntimeShape& input_condition_shape,
|
||||
const D* input_condition_data,
|
||||
const RuntimeShape& input_x_shape, const T* input_x_data,
|
||||
const RuntimeShape& input_y_shape, const T* input_y_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("Select/RankOneSelect");
|
||||
const int64_t outer_size = input_condition_shape.FlatSize();
|
||||
int64_t inner_size;
|
||||
if (input_condition_shape.DimensionsCount() == 0) {
|
||||
inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
|
||||
} else {
|
||||
TFLITE_DCHECK_EQ(
|
||||
MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
|
||||
outer_size);
|
||||
inner_size =
|
||||
MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
|
||||
}
|
||||
|
||||
int64_t offset = 0;
|
||||
for (int64_t i = 0; i < outer_size; i++) {
|
||||
const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
|
||||
memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
|
||||
offset += inner_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename D, typename T>
|
||||
void BroadcastSelect5DSlow(const RuntimeShape& input_condition_shape,
|
||||
const D* input_condition_data,
|
||||
const RuntimeShape& input_x_shape,
|
||||
const T* input_x_data,
|
||||
const RuntimeShape& input_y_shape,
|
||||
const T* input_y_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("Select/BroadcastSelectSlow");
|
||||
TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 5);
|
||||
|
||||
NdArrayDesc<5> desc_condition;
|
||||
NdArrayDesc<5> desc_x;
|
||||
NdArrayDesc<5> desc_y;
|
||||
NdArrayDesc<5> desc_output;
|
||||
const RuntimeShape extended_output_shape =
|
||||
RuntimeShape::ExtendedShape(5, output_shape);
|
||||
CopyDimsToDesc(extended_output_shape, &desc_output);
|
||||
NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
|
||||
input_y_shape, &desc_condition, &desc_x,
|
||||
&desc_y);
|
||||
|
||||
// In Tensorflow, the dimensions are canonically named (batch_number, row,
|
||||
// col, channel), with extents (batches, height, width, depth), with the
|
||||
// trailing dimension changing most rapidly (channels has the smallest
|
||||
// stride, typically 1 element).
|
||||
//
|
||||
// In generated C code, we store arrays with the dimensions reversed. The
|
||||
// first dimension has smallest stride.
|
||||
//
|
||||
// We name our variables by their Tensorflow convention, but generate C code
|
||||
// nesting loops such that the innermost loop has the smallest stride for
|
||||
// the best cache behavior.
|
||||
for (int n = 0; n < desc_output.extents[0]; ++n) {
|
||||
int out_idx_n = desc_output.extents[1] * n;
|
||||
int cond_idx_n = desc_condition.strides[0] * n;
|
||||
int in_idx1_n = desc_x.strides[0] * n;
|
||||
int in_idx2_n = desc_y.strides[0] * n;
|
||||
for (int b = 0; b < desc_output.extents[1]; ++b) {
|
||||
int out_idx_b = (out_idx_n + b) * desc_output.extents[2];
|
||||
int cond_idx_b = cond_idx_n + desc_condition.strides[1] * b;
|
||||
int in_idx1_b = in_idx1_n + desc_x.strides[1] * b;
|
||||
int in_idx2_b = in_idx2_n + desc_y.strides[1] * b;
|
||||
for (int y = 0; y < desc_output.extents[2]; ++y) {
|
||||
int out_idx_y = (out_idx_b + y) * desc_output.extents[3];
|
||||
int cond_idx_y = cond_idx_b + desc_condition.strides[2] * y;
|
||||
int in_idx1_y = in_idx1_b + desc_x.strides[2] * y;
|
||||
int in_idx2_y = in_idx2_b + desc_y.strides[2] * y;
|
||||
for (int x = 0; x < desc_output.extents[3]; ++x) {
|
||||
int out_idx = (out_idx_y + x) * desc_output.extents[4];
|
||||
int cond_idx = cond_idx_y + desc_condition.strides[3] * x;
|
||||
int in_idx1 = in_idx1_y + desc_x.strides[3] * x;
|
||||
int in_idx2 = in_idx2_y + desc_y.strides[3] * x;
|
||||
for (int c = 0; c < desc_output.extents[4]; ++c) {
|
||||
output_data[out_idx] = input_condition_data[cond_idx]
|
||||
? input_x_data[in_idx1]
|
||||
: input_y_data[in_idx2];
|
||||
out_idx++;
|
||||
cond_idx += desc_condition.strides[4];
|
||||
in_idx1 += desc_x.strides[4];
|
||||
in_idx2 += desc_y.strides[4];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
|
||||
@@ -92,6 +92,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddResizeNearestNeighbor();
|
||||
AddRound();
|
||||
AddRsqrt();
|
||||
AddSelectV2();
|
||||
AddShape();
|
||||
AddSin();
|
||||
AddSlice();
|
||||
@@ -102,6 +103,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddSplitV();
|
||||
AddSqrt();
|
||||
AddSquare();
|
||||
AddSquaredDifference();
|
||||
AddSqueeze();
|
||||
AddStridedSlice();
|
||||
AddSub();
|
||||
@@ -110,6 +112,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddTanh();
|
||||
AddTranspose();
|
||||
AddTransposeConv();
|
||||
AddUnidirectionalSequenceLSTM();
|
||||
AddUnpack();
|
||||
AddVarHandle();
|
||||
AddWhile();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -59,6 +59,19 @@ TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params,
|
||||
|
||||
TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// Generic must define registration function.
|
||||
TfLiteRegistration Register_ADD();
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_ADD_INT8();
|
||||
|
||||
TfLiteRegistration Register_ADD_INT16();
|
||||
#else
|
||||
// Fallback registration
|
||||
inline TfLiteRegistration Register_ADD_INT8() { return Register_ADD(); }
|
||||
|
||||
inline TfLiteRegistration Register_ADD_INT16() { return Register_ADD(); }
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_ADD_H_
|
||||
|
||||
@@ -121,8 +121,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||
context, kTfLiteActNone, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -198,8 +198,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} else if (output->type == kTfLiteInt8) {
|
||||
EvalAddNQuantized<int8_t>(context, node, output);
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("ADD_N only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -70,21 +70,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only float32, uint8_t and int8_t are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf(
|
||||
"Only float32, uint8_t and int8_t are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(axis->type));
|
||||
MicroPrintf("Only int32_t are supported currently, got %s.",
|
||||
TfLiteTypeGetName(axis->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -90,7 +90,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -118,8 +118,8 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -210,8 +210,8 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -288,8 +288,8 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -366,8 +366,8 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -444,8 +444,8 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -522,8 +522,8 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_data);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
MicroPrintf("Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input1->type), input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input_type == kTfLiteFloat32 || input_type == kTfLiteInt8 ||
|
||||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
|
||||
input_type == kTfLiteInt64);
|
||||
input_type == kTfLiteInt64 || input_type == kTfLiteBool);
|
||||
|
||||
// Output type must match input type
|
||||
TF_LITE_ENSURE_EQ(context, output_type, input_type);
|
||||
@@ -149,8 +149,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
int num_dimensions = NumDimensions(input);
|
||||
|
||||
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
MicroPrintf(
|
||||
"Op Concatenation does not currently support num dimensions > %d "
|
||||
"Tensor has %d dimensions.",
|
||||
RuntimeShape::kMaxSmallSize, num_dimensions);
|
||||
@@ -168,6 +167,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
switch (output_type) { // Already know in/outtypes are same.
|
||||
case kTfLiteBool:
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteInt16:
|
||||
case kTfLiteInt32:
|
||||
@@ -205,9 +205,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -238,11 +237,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt16:
|
||||
EvalUnquantized<int16_t>(context, node);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
EvalUnquantized<bool>(context, node);
|
||||
break;
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
|
||||
TfLiteTypeGetName(output_type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (axis < 0) axis += input_shape.DimensionsCount();
|
||||
|
||||
if (axis < 0 || axis >= input_shape.DimensionsCount()) {
|
||||
TF_LITE_KERNEL_LOG(context, "CUMSUM Invalid axis: %d", axis);
|
||||
MicroPrintf("CUMSUM Invalid axis: %d", axis);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -156,9 +156,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} break;
|
||||
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"CUMSUM only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("CUMSUM only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,9 +124,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -162,8 +162,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
#undef TF_LITE_DIV
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Unsupported combination of input and output types in DIV.");
|
||||
MicroPrintf("Unsupported combination of input and output types in DIV.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -189,10 +188,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
|
||||
input1, input2, output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"DIV only supports FLOAT32, quantized INT8 "
|
||||
"now, got type %s (%d).",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
MicroPrintf(
|
||||
"DIV only supports FLOAT32, quantized INT8 "
|
||||
"now, got type %s (%d).",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -90,8 +90,8 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -112,8 +112,8 @@ TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Input data type %s (%d) is not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -317,8 +317,8 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
type);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
MicroPrintf("Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -355,8 +355,8 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
elementwise::validate_input_func, type);
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
MicroPrintf("Current data type %s is not supported.",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -426,4 +426,4 @@ TfLiteRegistration Register_LOGICAL_NOT() {
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
|
||||
@@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@@ -136,9 +135,8 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "ELU only supports float32 and int8 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("ELU only supports float32 and int8 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
|
||||
@@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
|
||||
@@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include <esp_timer.h>
|
||||
|
||||
#if ESP_NN
|
||||
|
||||
@@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@@ -63,8 +64,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
static_cast<size_t>(flat_size),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) currently not supported by Exp.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) currently not supported by Exp.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -31,8 +31,7 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
|
||||
int32_t* axis_value) {
|
||||
const int axis_dims = (tflite::GetTensorShape(axis)).DimensionsCount();
|
||||
if (axis_dims > 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "Axis has only one element for Expand_Dims.",
|
||||
axis_dims);
|
||||
MicroPrintf("Axis has only one element for Expand_Dims.", axis_dims);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -41,9 +40,8 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
|
||||
*axis_value = axis_ptr[0];
|
||||
return kTfLiteOk;
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Axis type %s (%d) not supported by Expand_Dims.",
|
||||
TfLiteTypeGetName(axis->type), axis->type);
|
||||
MicroPrintf("Axis type %s (%d) not supported by Expand_Dims.",
|
||||
TfLiteTypeGetName(axis->type), axis->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -99,8 +97,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
output->type = input->type;
|
||||
if (IsDynamicTensor(axis)) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"DynamicTensor is not yet supported by Expand_Dims.");
|
||||
MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
|
||||
@@ -135,8 +132,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(input), flat_size);
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
MicroPrintf(
|
||||
"Expand_Dims only currently supports int8 and float32, got %d.",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
|
||||
@@ -53,9 +53,8 @@ TfLiteStatus EnsureEq(TfLiteContext* context, const TfLiteIntArray* array,
|
||||
case kTfLiteInt64:
|
||||
return EnsureEqImpl<int64_t>(context, array, tensor);
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"cannot compare int array to tensor of type %d.",
|
||||
tensor->type);
|
||||
MicroPrintf("cannot compare int array to tensor of type %d.",
|
||||
tensor->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -123,9 +122,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
FillImpl<int8_t>(value, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Fill only currently supports float32 for input 1, got %d.",
|
||||
TfLiteTypeGetName(value->type));
|
||||
MicroPrintf("Fill only currently supports float32 for input 1, got %d.",
|
||||
TfLiteTypeGetName(value->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ TfLiteStatus EvalFloorDiv(TfLiteContext* context,
|
||||
// Validate the denominator.
|
||||
for (int i = 0; i < tflite::ElementCount(*input2->dims); ++i) {
|
||||
if (std::equal_to<T>()(denominator_data[i], 0)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Division by 0");
|
||||
MicroPrintf("Division by 0");
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@@ -113,8 +113,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalFloorDiv<float>(context, input1, input2, output);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_DIV.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
MicroPrintf("Type '%s' is not supported by FLOOR_DIV.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
output);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by FLOOR_MOD.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
MicroPrintf("Type '%s' is not supported by FLOOR_MOD.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,8 +141,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +118,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Positions of type '%s' are not supported by gather.",
|
||||
TfLiteTypeGetName(coords->type));
|
||||
MicroPrintf("Positions of type '%s' are not supported by gather.",
|
||||
TfLiteTypeGetName(coords->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -134,8 +133,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -207,8 +206,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Gather<int8_t, int32_t>(params, input, coords, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type '%s' is not supported by gather.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -47,9 +47,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Params of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
MicroPrintf("Params of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
return kTfLiteError;
|
||||
break;
|
||||
}
|
||||
@@ -57,9 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -67,22 +65,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int indices_rank = NumDimensions(indices);
|
||||
const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
|
||||
if (params_rank < 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
|
||||
MicroPrintf("Params must be at least a vector.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices_rank < 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
|
||||
MicroPrintf("Indices must be at least a vector.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices_nd > params_rank) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Index innermost dimension length must be <= params rank.");
|
||||
MicroPrintf("Index innermost dimension length must be <= params rank.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices_nd > MAX_INDICES_ND) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Index innermost dimension length must not exceed %d.",
|
||||
MAX_INDICES_ND);
|
||||
MicroPrintf("Index innermost dimension length must not exceed %d.",
|
||||
MAX_INDICES_ND);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -171,13 +167,12 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
|
||||
status = GatherNd<int8_t, IndicesT>(params, indices, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Params type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
MicroPrintf("Params type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(params->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (status != kTfLiteOk) {
|
||||
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
|
||||
MicroPrintf("gather_nd index out of bounds");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
@@ -195,9 +190,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalGatherNd<int32_t>(context, params, indices, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,5 +106,17 @@ TfLiteStatus KernelRunner::Invoke() {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus KernelRunner::Free() {
|
||||
tflite::micro::ClearBufferApi(&context_);
|
||||
context_.GetScratchBuffer = MicroContextGetScratchBuffer;
|
||||
|
||||
if (registration_.free == nullptr) {
|
||||
MicroPrintf("TfLiteRegistration missing free function pointer!");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
registration_.free(&context_, node_.user_data);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace micro
|
||||
} // namespace tflite
|
||||
} // namespace tflite
|
||||
@@ -48,6 +48,11 @@ class KernelRunner {
|
||||
// passed into the constructor of this class.
|
||||
TfLiteStatus Invoke();
|
||||
|
||||
// Calls Free on a given TfLiteRegistration pointer(if it's implemented).
|
||||
// After successful Free, kTfLiteOk status will be returned. If Free is not
|
||||
// implemented for a given kernel kTfLiteError will be returned.
|
||||
TfLiteStatus Free();
|
||||
|
||||
// Returns a pointer to the internal MockMicroGraph which KernelRunner uses
|
||||
// to stub out MicroGraph methods and track invocations on each subgraph.
|
||||
MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; }
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -39,9 +40,10 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
|
||||
TfLiteRegistration RegisterOp(
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
|
||||
void (*free)(TfLiteContext* context, void* buffer)) {
|
||||
return {/*init=*/init,
|
||||
/*free=*/nullptr,
|
||||
/*free=*/free,
|
||||
/*prepare=*/prepare,
|
||||
/*invoke=*/invoke,
|
||||
/*profiling_string=*/nullptr,
|
||||
@@ -160,6 +162,46 @@ TfLiteStatus CopyOpInputsToOpOutputs(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Args:
|
||||
// 1. int8_t tensor_data - int8_t buffer of unknown size who's data you'd
|
||||
// like
|
||||
// to print
|
||||
// 2. int n_btyes - a small int representing number of bytes you want to
|
||||
// print
|
||||
// to debug output. It should always be <= tensor_data's size.
|
||||
// 3. prefix - optional message you'd like to print before printing bytes
|
||||
//
|
||||
// Purpose:
|
||||
// Function takes in paramaters above and prints n_bytes bytes from the
|
||||
// tensor_data buffer. This can be use to debug the output of a model and it's
|
||||
// op.
|
||||
|
||||
void PrintNBytes(const int8_t* tensor_data, int n_bytes, const char* prefix) {
|
||||
if (prefix != nullptr) {
|
||||
MicroPrintf("%s", prefix);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_bytes; ++i) {
|
||||
MicroPrintf(" %x", tensor_data[i]);
|
||||
}
|
||||
MicroPrintf("\n");
|
||||
}
|
||||
|
||||
// same as the PrintNBytes above but the buffer needs to be extracted out of the
|
||||
// TfLiteEvalTensor*
|
||||
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
|
||||
const char* prefix) {
|
||||
const int8_t* tensor_data = tflite::micro::GetTensorData<int8_t>(tensor);
|
||||
PrintNBytes(tensor_data, n_bytes, prefix);
|
||||
}
|
||||
|
||||
// same as the PrintNBytes above but the buffer needs to be extracted out of the
|
||||
// TfLiteEvalTensor*
|
||||
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes, const char* prefix) {
|
||||
const int8_t* tensor_data = tflite::GetTensorData<int8_t>(tensor);
|
||||
PrintNBytes(tensor_data, n_bytes, prefix);
|
||||
}
|
||||
|
||||
TfLiteStatus CopyOpInputsToSubgraphInputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
MicroGraph* graph_info,
|
||||
|
||||
@@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/micro/micro_context.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace micro {
|
||||
@@ -30,7 +32,20 @@ namespace micro {
|
||||
TfLiteRegistration RegisterOp(
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
|
||||
void (*free)(TfLiteContext* context, void* buffer) = nullptr);
|
||||
|
||||
// Prints out n bytes in a int8_t buffer as hex
|
||||
void PrintNBytes(const int8_t* tensor_data, int n_bytes,
|
||||
const char* prefix = nullptr);
|
||||
|
||||
// Prints out the the n bytes in a TfLiteEvalTensor as hex
|
||||
void PrintNBytes(const TfLiteEvalTensor* tensor, int n_bytes,
|
||||
const char* prefix = nullptr);
|
||||
|
||||
// Prints out the the n bytes in a TfLiteTensor as hex
|
||||
void PrintNBytes(const TfLiteTensor* tensor, int n_bytes,
|
||||
const char* prefix = nullptr);
|
||||
|
||||
// Returns a mutable tensor for a given input index. is_variable must be checked
|
||||
// during prepare when the full TfLiteTensor is available.
|
||||
|
||||
@@ -125,9 +125,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
L2EvalFloat(*params, *input, &op_params, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"L2_POOL_2D only supports float32 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("L2_POOL_2D only supports float32 currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -126,8 +126,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Output type is %s, requires float.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -132,9 +132,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"LOG_SOFTMAX only supports float32, int8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("LOG_SOFTMAX only supports float32, int8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
@@ -530,11 +532,20 @@ void CalculateLstmGateInteger8x8_16(
|
||||
// Apply activation
|
||||
switch (activation) {
|
||||
case kTfLiteActSigmoid:
|
||||
micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
|
||||
break;
|
||||
case kTfLiteActTanh:
|
||||
micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
|
||||
|
||||
reference_integer_ops::Logistic(
|
||||
0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
|
||||
n_batch * n_cell /*NumElements(input->dims)*/,
|
||||
gate /* tflite::micro::GetTensorData<int16_t>(input) */,
|
||||
gate /*tflite::micro::GetTensorData<int16_t>(output) */);
|
||||
|
||||
break;
|
||||
case kTfLiteActTanh: {
|
||||
int32_t dims_data = n_batch * n_cell;
|
||||
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
|
||||
reference_integer_ops::Tanh(0, 0, tanh_inp_shape, gate, tanh_inp_shape,
|
||||
gate);
|
||||
} break;
|
||||
default:
|
||||
// Only Sigmoid or Tanh is used.
|
||||
TFLITE_ASSERT_FALSE;
|
||||
@@ -599,7 +610,7 @@ void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
|
||||
// - scratch1: scratch area of size n_batch*n_cell
|
||||
// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
|
||||
void CalculateLstmOutputInteger8x8_16(
|
||||
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
|
||||
int n_batch, int n_cell, int n_output, int16_t* cell_state,
|
||||
int32_t cell_state_scale, const int16_t* output_gate,
|
||||
int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
|
||||
const int8_t* projection_weights, int32_t proj_scale_a,
|
||||
@@ -607,8 +618,23 @@ void CalculateLstmOutputInteger8x8_16(
|
||||
int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
|
||||
int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
|
||||
// Note: unlike float/hybrid, the activation is always Tanh.
|
||||
micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
|
||||
n_cell, scratch0);
|
||||
|
||||
{
|
||||
int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3;
|
||||
int32_t dims_data = n_batch * n_cell;
|
||||
if (tanh_input_left_shift < 0) /* handling negative shift value */
|
||||
{
|
||||
int32_t i;
|
||||
tanh_input_left_shift = -tanh_input_left_shift;
|
||||
for (i = 0; i < dims_data; i++) {
|
||||
cell_state[i] = cell_state[i] >> tanh_input_left_shift;
|
||||
}
|
||||
tanh_input_left_shift = 0;
|
||||
}
|
||||
RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
|
||||
reference_integer_ops::Tanh(0, tanh_input_left_shift, tanh_inp_shape,
|
||||
cell_state, tanh_inp_shape, scratch0);
|
||||
}
|
||||
micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
|
||||
hidden_scale_b, n_batch, n_cell, hidden_zp,
|
||||
scratch1);
|
||||
|
||||
@@ -98,15 +98,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLiteOperation<int64_t, OpType>(context, node, op_context);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s (%d) is not supported by Maximum/Minimum.",
|
||||
TfLiteTypeGetName(op_context.output->type),
|
||||
op_context.output->type);
|
||||
MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.",
|
||||
TfLiteTypeGetName(op_context.output->type),
|
||||
op_context.output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Kernel type not supported by Maximum/Minimum.");
|
||||
MicroPrintf("Kernel type not supported by Maximum/Minimum.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -72,6 +72,7 @@ TfLiteRegistration Register_READ_VARIABLE();
|
||||
TfLiteRegistration Register_RELU();
|
||||
TfLiteRegistration Register_RELU6();
|
||||
TfLiteRegistration Register_RESIZE_BILINEAR();
|
||||
TfLiteRegistration Register_SELECT_V2();
|
||||
TfLiteRegistration Register_SHAPE();
|
||||
TfLiteRegistration Register_SLICE();
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
||||
@@ -79,6 +80,7 @@ TfLiteRegistration Register_SPACE_TO_DEPTH();
|
||||
TfLiteRegistration Register_SQUARED_DIFFERENCE();
|
||||
TfLiteRegistration Register_SQUEEZE();
|
||||
TfLiteRegistration Register_SUB();
|
||||
TfLiteRegistration Register_SUM();
|
||||
TfLiteRegistration Register_SVDF();
|
||||
TfLiteRegistration Register_TRANSPOSE();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV();
|
||||
|
||||
@@ -663,7 +663,7 @@ void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
const int16_t b = input_2[index];
|
||||
int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
|
||||
value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
|
||||
value -= output_zp;
|
||||
value += output_zp;
|
||||
value = std::min(std::max(static_cast<int32_t>(-128), value),
|
||||
static_cast<int32_t>(127));
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -60,6 +60,15 @@ void EvalMulFloatReference(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input2,
|
||||
TfLiteEvalTensor* output);
|
||||
|
||||
// Generic must define registration function.
|
||||
TfLiteRegistration Register_MUL();
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_MUL_INT8();
|
||||
#else
|
||||
// Fallback registration
|
||||
inline TfLiteRegistration Register_MUL_INT8() { return Register_MUL(); }
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_MUL_H_
|
||||
|
||||
@@ -41,8 +41,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -95,8 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
data->axis);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Type '%s' is not supported by pack.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,8 +213,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} break;
|
||||
default:
|
||||
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type %s not currently supported by Pad.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -45,8 +45,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
AveragePoolingEvalQuantized(context, node, params, data, input, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Input type %s is not currently supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@@ -73,8 +73,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
MaxPoolingEvalQuantized(context, node, params, data, input, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@@ -66,6 +67,19 @@ void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input,
|
||||
TfLiteEvalTensor* output);
|
||||
|
||||
#if defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_AVERAGE_POOL_2D_INT8();
|
||||
|
||||
TfLiteRegistration Register_MAX_POOL_2D_INT8();
|
||||
#else
|
||||
inline TfLiteRegistration Register_AVERAGE_POOL_2D_INT8() {
|
||||
return tflite::Register_AVERAGE_POOL_2D();
|
||||
}
|
||||
|
||||
inline TfLiteRegistration Register_MAX_POOL_2D_INT8() {
|
||||
return tflite::Register_MAX_POOL_2D();
|
||||
}
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_POOLING_H_
|
||||
|
||||
@@ -61,9 +61,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Only float32 and uint8_t are supported currently, got %d.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Only float32 and uint8_t are supported currently, got %d.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
const int kMaxNumberOfAxis = 4;
|
||||
const int kMaxNumberOfAxis = 5;
|
||||
const int kMaxNumberOfReducedAxis = 2;
|
||||
|
||||
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
@@ -55,8 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
|
||||
if (params->half_pixel_centers && params->align_corners) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "If half_pixel_centers is True, align_corners must be False.");
|
||||
MicroPrintf("If half_pixel_centers is True, align_corners must be False.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@@ -100,8 +99,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float or int8.",
|
||||
output->type);
|
||||
MicroPrintf("Output type is %d, requires float or int8.", output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@@ -55,7 +54,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output->type = input->type;
|
||||
|
||||
if (!IsConstantTensor(size)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Dynamic tensors are unsupported in tfmicro.");
|
||||
MicroPrintf("Dynamic tensors are unsupported in tfmicro.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/select.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
constexpr int kInputTensorCondition = 0;
|
||||
constexpr int kInputTensorX = 1;
|
||||
constexpr int kInputTensorY = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
struct OpData {
|
||||
bool requires_broadcast;
|
||||
// True if input condition is scalar or input condition has rank one and
|
||||
// matches the first dimension of other inputs.
|
||||
bool has_low_rank_input_condition;
|
||||
};
|
||||
|
||||
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
auto* data = static_cast<OpData*>(
|
||||
context->AllocatePersistentBuffer(context, sizeof(OpData)));
|
||||
data->requires_broadcast = false;
|
||||
data->has_low_rank_input_condition = false;
|
||||
return data;
|
||||
}
|
||||
|
||||
TfLiteStatus CheckBroadcastShape(TfLiteContext* context,
|
||||
const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2,
|
||||
const TfLiteTensor* input3,
|
||||
const TfLiteIntArray* output_shape) {
|
||||
const int dims1 = NumDimensions(input1);
|
||||
const int dims2 = NumDimensions(input2);
|
||||
const int dims3 = NumDimensions(input3);
|
||||
const int out_dims = std::max(std::max(dims1, dims2), dims3);
|
||||
TF_LITE_ENSURE_EQ(context, out_dims, output_shape->size);
|
||||
|
||||
for (int i = 0; i < out_dims; ++i) {
|
||||
const int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
|
||||
const int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
||||
const int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
|
||||
const int min_value = std::min(std::min(d1, d2), d3);
|
||||
int max_value = std::max(std::max(d1, d2), d3);
|
||||
// If one dimention is 0, others must be 0 or 1.
|
||||
if (min_value == 0) max_value = 0;
|
||||
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
|
||||
!(d3 == 1 || d3 == max_value)) {
|
||||
MicroPrintf("Given shapes are not broadcastable.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, output_shape->data[out_dims - i - 1], max_value);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
TfLiteTensor* input_condition =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
|
||||
|
||||
TfLiteTensor* input_x =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorX);
|
||||
|
||||
TfLiteTensor* input_y =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorY);
|
||||
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
// Input must be bool.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type);
|
||||
output->type = input_x->type;
|
||||
|
||||
// Respect the original output shape when there are mixed shapes to represent
|
||||
// a scalar data.
|
||||
if (GetTensorShape(input_condition).FlatSize() == 1 &&
|
||||
GetTensorShape(input_x).FlatSize() == 1 &&
|
||||
GetTensorShape(input_y).FlatSize() == 1 &&
|
||||
GetTensorShape(output).FlatSize() == 1) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool same_shape = HaveSameShapes(input_condition, input_x) &&
|
||||
HaveSameShapes(input_x, input_y);
|
||||
if (!same_shape) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CheckBroadcastShape(context, input_condition, input_x, input_y,
|
||||
output->dims));
|
||||
data->requires_broadcast = true;
|
||||
}
|
||||
|
||||
micro_context->DeallocateTempTfLiteTensor(input_condition);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_x);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_y);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
MicroContext* micro_context = GetMicroContext(context);
|
||||
|
||||
TfLiteTensor* input_condition =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
|
||||
|
||||
TfLiteTensor* input_x =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorX);
|
||||
|
||||
TfLiteTensor* input_y =
|
||||
micro_context->AllocateTempInputTensor(node, kInputTensorY);
|
||||
|
||||
TfLiteTensor* output =
|
||||
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
|
||||
|
||||
#define TF_LITE_SELECT(type, op) \
|
||||
reference_ops::op(GetTensorShape(input_condition), \
|
||||
GetTensorData<bool>(input_condition), \
|
||||
GetTensorShape(input_x), GetTensorData<type>(input_x), \
|
||||
GetTensorShape(input_y), GetTensorData<type>(input_y), \
|
||||
GetTensorShape(output), GetTensorData<type>(output));
|
||||
|
||||
#define TF_LITE_SWITCH(type, op) \
|
||||
switch (type) { \
|
||||
case kTfLiteFloat32: \
|
||||
TF_LITE_SELECT(float, op); \
|
||||
break; \
|
||||
case kTfLiteInt8: \
|
||||
TF_LITE_SELECT(int8_t, op); \
|
||||
break; \
|
||||
case kTfLiteInt16: \
|
||||
TF_LITE_SELECT(int16_t, op); \
|
||||
break; \
|
||||
default: \
|
||||
MicroPrintf("Does not support type other than %s, but got %s", \
|
||||
"int8|int16|float32", TfLiteTypeGetName(type)); \
|
||||
return kTfLiteError; \
|
||||
}
|
||||
|
||||
if (data->has_low_rank_input_condition) {
|
||||
MicroPrintf("Not yet implemented.");
|
||||
return kTfLiteError;
|
||||
} else if (data->requires_broadcast) {
|
||||
TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow);
|
||||
} else {
|
||||
TF_LITE_SWITCH(input_x->type, Select);
|
||||
}
|
||||
|
||||
#undef TF_LITE_SELECT
|
||||
#undef TF_LITE_SWITCH
|
||||
micro_context->DeallocateTempTfLiteTensor(input_condition);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_x);
|
||||
micro_context->DeallocateTempTfLiteTensor(input_y);
|
||||
micro_context->DeallocateTempTfLiteTensor(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// SelectV2 op selects values of 'x' if the corresponding value of 'condition'
|
||||
// is true or the value of 'y' if false. There are valid condition input sizes:
|
||||
//
|
||||
// 1. Either the same shape (in which case the select is elementwise), or
|
||||
// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
|
||||
TfLiteRegistration Register_SELECT_V2() {
|
||||
return tflite::micro::RegisterOp(tflite::SelectInit, tflite::SelectPrepare,
|
||||
tflite::SelectEval);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@@ -47,8 +47,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
if (output->type != kTfLiteInt32) {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
MicroPrintf("Output type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
} else {
|
||||
ExtractShape(input, tflite::micro::GetTensorData<int32_t>(output));
|
||||
|
||||
@@ -106,8 +106,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetBeginAndSizeVectors<int64_t>(input->dims->size, begin, size,
|
||||
op_params.begin, op_params.size);
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Begin tensor type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Begin tensor type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -75,8 +75,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,8 +104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -109,9 +109,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("SPACE_TO_DEPTH only supports FLOAT32 and INT8, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -111,8 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return SplitImpl<int32_t>(context, node, input, axis_value);
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type %s currently not supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -90,8 +90,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
|
||||
if (input->type == kTfLiteString) {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -183,9 +183,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int32_t>(output));
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
reference_ops::StridedSlice(op_params,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<bool>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<bool>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -82,7 +82,7 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node);
|
||||
// (reference or optimized) must define this function.
|
||||
TfLiteRegistration Register_SVDF();
|
||||
|
||||
#if defined(HEXAGON)
|
||||
#if defined(HEXAGON) || defined(CMSIS_NN)
|
||||
TfLiteRegistration Register_SVDF_INT8();
|
||||
|
||||
#else
|
||||
|
||||
@@ -185,9 +185,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
MicroPrintf("Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type), context);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,10 +103,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s is currently not supported by Transpose. "
|
||||
"Only float32 and int8 is supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf(
|
||||
"Type %s is currently not supported by Transpose. "
|
||||
"Only float32 and int8 is supported",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
||||
@@ -327,8 +327,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -24,7 +24,7 @@ namespace testing {
|
||||
// kernel is reconciled with reference kernel
|
||||
#if !defined(XTENSA)
|
||||
|
||||
typedef struct LstmIntegerTestConfig {
|
||||
struct LstmIntegerTestConfig {
|
||||
const int n_batch;
|
||||
const int n_input;
|
||||
const int n_cell;
|
||||
@@ -100,9 +100,9 @@ typedef struct LstmIntegerTestConfig {
|
||||
|
||||
bool asymmetric_quantize_inputs;
|
||||
const float ranges[25][2];
|
||||
} LstmIntegerTestConfig;
|
||||
};
|
||||
|
||||
typedef struct LstmFloatTestConfig {
|
||||
struct LstmFloatTestConfig {
|
||||
const int n_batch;
|
||||
const int n_input;
|
||||
const int n_cell;
|
||||
@@ -153,9 +153,9 @@ typedef struct LstmFloatTestConfig {
|
||||
float* output;
|
||||
const float* expected_output_original;
|
||||
float* expected_output;
|
||||
} LstmFloatTestConfig;
|
||||
};
|
||||
|
||||
typedef struct LstmWeightQuantizationBuffers {
|
||||
struct LstmWeightQuantizationBuffers {
|
||||
int8_t* lstm_i2i_quant;
|
||||
float* lstm_i2i_scale;
|
||||
int* lstm_i2i_zp;
|
||||
@@ -215,7 +215,7 @@ typedef struct LstmWeightQuantizationBuffers {
|
||||
float* lstm_proj_w_scale;
|
||||
int* lstm_proj_w_zp;
|
||||
TfLiteAffineQuantization* lstm_proj_w_qparam;
|
||||
} LstmWeightQuantizationBuffers;
|
||||
};
|
||||
|
||||
extern LstmIntegerTestConfig lstm_integer_no_peephole_config;
|
||||
|
||||
|
||||
@@ -91,8 +91,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
MicroPrintf("Type '%s' is not supported by unpack.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,10 +70,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
resetZeros(tflite::micro::GetTensorData<float>(output), flat_size);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"ZerosLike only currently supports int64, int32, "
|
||||
"and float32, got %d.",
|
||||
input->type);
|
||||
MicroPrintf(
|
||||
"ZerosLike only currently supports int64, int32, "
|
||||
"and float32, got %d.",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
||||
@@ -323,8 +323,12 @@ TfLiteStatus AllocationInfoBuilder::GetOfflinePlannedOffsets(
|
||||
if (model_->metadata()) {
|
||||
for (size_t i = 0; i < model_->metadata()->size(); ++i) {
|
||||
auto metadata = model_->metadata()->Get(i);
|
||||
if (strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
|
||||
strlen(kOfflineMemAllocMetadata)) == 0) {
|
||||
const size_t metadata_name_size = (size_t)metadata->name()->size();
|
||||
|
||||
if ((strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
|
||||
std::min(metadata_name_size,
|
||||
strlen(kOfflineMemAllocMetadata))) == 0) &&
|
||||
metadata_name_size == strlen(kOfflineMemAllocMetadata)) {
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
|
||||
model_->buffers();
|
||||
auto* buffer = (*buffers)[metadata->buffer()];
|
||||
|
||||
@@ -509,14 +509,15 @@ TfLiteStatus MicroAllocator::FinishModelAllocation(
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Allocate scratch buffer metadata and buffers for variable tensors.
|
||||
// Allocate scratch buffer metadata.
|
||||
TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles(
|
||||
scratch_buffer_handles, scratch_buffer_request_count_));
|
||||
|
||||
// Allocate buffers for variable tensors.
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
|
||||
subgraph_idx++) {
|
||||
const SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx);
|
||||
TFLITE_DCHECK(subgraph != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles(
|
||||
scratch_buffer_handles, scratch_buffer_request_count_));
|
||||
TF_LITE_ENSURE_STATUS(AllocateVariables(
|
||||
subgraph, subgraph_allocations[subgraph_idx].tensors));
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
|
||||
// of a sequential, array of ScratchBufferHandle allocations in the tail
|
||||
// section. These allocations are indexed by the request API defined in the
|
||||
// TfLiteContext struct.
|
||||
typedef struct {
|
||||
struct ScratchBufferRequest {
|
||||
// Number of bytes required by the buffer. The actual allocated size might be
|
||||
// greater than `bytes` due to buffer alignment.
|
||||
size_t bytes;
|
||||
@@ -63,29 +63,29 @@ typedef struct {
|
||||
// have `before` = node_idx and `after` = node_idx.
|
||||
int node_idx;
|
||||
int subgraph_idx;
|
||||
} ScratchBufferRequest;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
typedef struct {
|
||||
struct NodeAndRegistration {
|
||||
TfLiteNode node;
|
||||
const TfLiteRegistration* registration;
|
||||
} NodeAndRegistration;
|
||||
};
|
||||
|
||||
// Holds a pointer to a buffer for a scratch buffer requested by a kernel during
|
||||
// the model prepare stage. This struct is allocated in-place and allows for
|
||||
// quick pointer-indexed lookup for speed during model inference.
|
||||
typedef struct {
|
||||
struct ScratchBufferHandle {
|
||||
// Pointer to location of the scratch buffer:
|
||||
uint8_t* data;
|
||||
} ScratchBufferHandle;
|
||||
};
|
||||
|
||||
// Stores all per-subgraph allocations. This includes the node and registration
|
||||
// array, tensor list and scratch buffer handles for each subgraph.
|
||||
typedef struct {
|
||||
// array, and tensor list for each subgraph.
|
||||
struct SubgraphAllocations {
|
||||
NodeAndRegistration* node_and_registrations;
|
||||
TfLiteEvalTensor* tensors;
|
||||
} SubgraphAllocations;
|
||||
};
|
||||
|
||||
// Allocator responsible for allocating memory for all intermediate tensors
|
||||
// necessary to invoke a model.
|
||||
|
||||
@@ -317,7 +317,17 @@ TfLiteTensor* MicroInterpreter::output(size_t index) {
|
||||
}
|
||||
return output_tensors_[index];
|
||||
}
|
||||
// Repurposing free subgraphs to reset state for some ops for now
|
||||
// will reset api is made. See b/220940833#comment25 for more context.
|
||||
TfLiteStatus MicroInterpreter::Reset() {
|
||||
TfLiteStatus status = graph_.FreeSubgraphs();
|
||||
if (status != kTfLiteOk) {
|
||||
return status;
|
||||
}
|
||||
return graph_.ResetVariableTensors();
|
||||
}
|
||||
|
||||
// TODO: remove this API completely in favor of MicroInterpreter::Reset
|
||||
TfLiteStatus MicroInterpreter::ResetVariableTensors() {
|
||||
return graph_.ResetVariableTensors();
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/portable_type_to_tflitetype.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
// Copied from tensorflow/lite/version.h to avoid a dependency chain into
|
||||
/// Copied from tensorflow/lite/version.h to avoid a dependency chain into
|
||||
// tensorflow/core.
|
||||
#define TFLITE_SCHEMA_VERSION (3)
|
||||
|
||||
@@ -116,6 +116,11 @@ class MicroInterpreter {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Reset the state to be what you would expect when the interpreter is first
|
||||
// created. i.e. after Init and Prepare is called for the very first time.
|
||||
TfLiteStatus Reset();
|
||||
|
||||
// TODO(b/244457206): remove this in favor of Reset()
|
||||
// Reset all variable tensors to the default value.
|
||||
TfLiteStatus ResetVariableTensors();
|
||||
|
||||
|
||||
@@ -24,11 +24,13 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/compatibility.h"
|
||||
#include "tensorflow/lite/micro/kernels/add.h"
|
||||
#include "tensorflow/lite/micro/kernels/conv.h"
|
||||
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
|
||||
#include "tensorflow/lite/micro/kernels/ethosu.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
#include "tensorflow/lite/micro/kernels/pooling.h"
|
||||
#include "tensorflow/lite/micro/kernels/reduce.h"
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
@@ -140,9 +142,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::Register_ASSIGN_VARIABLE(), ParseAssignVariable);
|
||||
}
|
||||
|
||||
TfLiteStatus AddAveragePool2D() {
|
||||
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
|
||||
tflite::Register_AVERAGE_POOL_2D(), ParsePool);
|
||||
TfLiteStatus AddAveragePool2D(
|
||||
const TfLiteRegistration& registration = Register_AVERAGE_POOL_2D()) {
|
||||
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, registration, ParsePool);
|
||||
}
|
||||
|
||||
TfLiteStatus AddBatchToSpaceNd() {
|
||||
@@ -363,9 +365,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMaxPool2D() {
|
||||
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
|
||||
tflite::Register_MAX_POOL_2D(), ParsePool);
|
||||
TfLiteStatus AddMaxPool2D(
|
||||
const TfLiteRegistration& registration = Register_MAX_POOL_2D()) {
|
||||
return AddBuiltin(BuiltinOperator_MAX_POOL_2D, registration, ParsePool);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMirrorPad() {
|
||||
@@ -382,8 +384,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_MINIMUM(), ParseMinimum);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMul() {
|
||||
return AddBuiltin(BuiltinOperator_MUL, tflite::Register_MUL(), ParseMul);
|
||||
TfLiteStatus AddMul(const TfLiteRegistration& registration = Register_MUL()) {
|
||||
return AddBuiltin(BuiltinOperator_MUL, registration, ParseMul);
|
||||
}
|
||||
|
||||
TfLiteStatus AddNeg() {
|
||||
@@ -466,6 +468,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSelectV2() {
|
||||
return AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2(),
|
||||
ParseSelectV2);
|
||||
}
|
||||
|
||||
TfLiteStatus AddShape() {
|
||||
return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape);
|
||||
}
|
||||
@@ -519,6 +526,12 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_SQUARE(), ParseSquare);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSquaredDifference() {
|
||||
return AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE,
|
||||
tflite::Register_SQUARED_DIFFERENCE(),
|
||||
ParseSquaredDifference);
|
||||
}
|
||||
|
||||
TfLiteStatus AddStridedSlice() {
|
||||
return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
|
||||
tflite::ops::micro::Register_STRIDED_SLICE(),
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
@@ -67,4 +68,48 @@ void MicroProfiler::LogCsv() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
void MicroProfiler::LogTicksPerTagCsv() {
|
||||
#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
|
||||
MicroPrintf(
|
||||
"\"Unique Tag\",\"Total ticks across all events with that tag.\"");
|
||||
int total_ticks = 0;
|
||||
for (int i = 0; i < num_events_; ++i) {
|
||||
uint32_t ticks = end_ticks_[i] - start_ticks_[i];
|
||||
TFLITE_DCHECK(tags_[i] != nullptr);
|
||||
int position = FindExistingOrNextPosition(tags_[i]);
|
||||
TFLITE_DCHECK(position >= 0);
|
||||
total_ticks_per_tag[position].tag = tags_[i];
|
||||
total_ticks_per_tag[position].ticks =
|
||||
total_ticks_per_tag[position].ticks + ticks;
|
||||
total_ticks += ticks;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_events_; ++i) {
|
||||
TicksPerTag each_tag_entry = total_ticks_per_tag[i];
|
||||
if (each_tag_entry.tag == nullptr) {
|
||||
break;
|
||||
}
|
||||
MicroPrintf("%s, %d", each_tag_entry.tag, each_tag_entry.ticks);
|
||||
}
|
||||
MicroPrintf("total number of ticks, %d", total_ticks);
|
||||
#endif
|
||||
}
|
||||
|
||||
// This method finds a particular array element in the total_ticks_per_tag array
|
||||
// with the matching tag_name passed in the method. If it can find a
|
||||
// matching array element that has the same tag_name, then it will return the
|
||||
// position of the matching element. But if it unable to find a matching element
|
||||
// with the given tag_name, it will return the next available empty position
|
||||
// from the array.
|
||||
int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) {
|
||||
int pos = 0;
|
||||
for (; pos < num_events_; pos++) {
|
||||
TicksPerTag each_tag_entry = total_ticks_per_tag[pos];
|
||||
if (each_tag_entry.tag == nullptr ||
|
||||
strcmp(each_tag_entry.tag, tag_name) == 0) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return pos < num_events_ ? pos : -1;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
@@ -61,6 +61,11 @@ class MicroProfiler {
|
||||
// Separated Value) form.
|
||||
void LogCsv() const;
|
||||
|
||||
// Prints total ticks for each unique tag in CSV format.
|
||||
// Output will have one row for each unique tag along with the
|
||||
// total ticks summed across all events with that particular tag.
|
||||
void LogTicksPerTagCsv();
|
||||
|
||||
private:
|
||||
// Maximum number of events that this class can keep track of. If we call
|
||||
// AddEvent more than kMaxEvents number of times, then the oldest event's
|
||||
@@ -72,6 +77,17 @@ class MicroProfiler {
|
||||
uint32_t end_ticks_[kMaxEvents];
|
||||
int num_events_ = 0;
|
||||
|
||||
struct TicksPerTag {
|
||||
const char* tag;
|
||||
uint32_t ticks;
|
||||
};
|
||||
// In practice, the number of tags will be much lower than the number of
|
||||
// events. But it is theoretically possible that each event to be unique and
|
||||
// hence we allow total_ticks_per_tag to have kMaxEvents entries.
|
||||
TicksPerTag total_ticks_per_tag[kMaxEvents] = {};
|
||||
|
||||
int FindExistingOrNextPosition(const char* tag_name);
|
||||
|
||||
TF_LITE_REMOVE_VIRTUAL_DELETE;
|
||||
};
|
||||
|
||||
|
||||
@@ -163,10 +163,12 @@ TfLiteStatus RecordingMicroAllocator::AllocateNodeAndRegistrations(
|
||||
|
||||
TfLiteStatus status =
|
||||
MicroAllocator::AllocateNodeAndRegistrations(model, subgraph_allocations);
|
||||
|
||||
RecordAllocationUsage(allocations,
|
||||
recorded_node_and_registration_array_data_);
|
||||
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
|
||||
subgraph_idx++) {
|
||||
RecordAllocationUsage(allocations,
|
||||
recorded_node_and_registration_array_data_);
|
||||
// The allocation count in SingleArenaBufferAllocator will only be 1. To
|
||||
// provide better logging, decrement by 1 and add in the actual number of
|
||||
// operators used in the graph: The allocation for this recording will
|
||||
@@ -176,8 +178,12 @@ TfLiteStatus RecordingMicroAllocator::AllocateNodeAndRegistrations(
|
||||
// potential for fragmentation, manually adjust the accounting by
|
||||
// decrementing by 1 and adding the actual number of nodes used in the
|
||||
// graph:
|
||||
recorded_node_and_registration_array_data_.count +=
|
||||
model->subgraphs()->Get(subgraph_idx)->operators()->size() - 1;
|
||||
if (model->subgraphs()->Get(subgraph_idx)->operators()) {
|
||||
recorded_node_and_registration_array_data_.count +=
|
||||
model->subgraphs()->Get(subgraph_idx)->operators()->size() - 1;
|
||||
} else {
|
||||
recorded_node_and_registration_array_data_.count -= 1;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
@@ -188,9 +194,11 @@ TfLiteStatus RecordingMicroAllocator::AllocateTfLiteEvalTensors(
|
||||
|
||||
TfLiteStatus status =
|
||||
MicroAllocator::AllocateTfLiteEvalTensors(model, subgraph_allocations);
|
||||
|
||||
RecordAllocationUsage(allocations, recorded_tflite_eval_tensor_data_);
|
||||
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
|
||||
subgraph_idx++) {
|
||||
RecordAllocationUsage(allocations, recorded_tflite_eval_tensor_data_);
|
||||
// The allocation for this recording will always be 1. This is because the
|
||||
// parent class mallocs one large allocation for the number of tensors in
|
||||
// the graph (e.g. sizeof(TfLiteEvalTensor) * num_tensors). To prevent extra
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -150,7 +150,7 @@
|
||||
|
||||
#define FLATBUFFERS_VERSION_MAJOR 2
|
||||
#define FLATBUFFERS_VERSION_MINOR 0
|
||||
#define FLATBUFFERS_VERSION_REVISION 5
|
||||
#define FLATBUFFERS_VERSION_REVISION 6
|
||||
#define FLATBUFFERS_STRING_EXPAND(X) #X
|
||||
#define FLATBUFFERS_STRING(X) FLATBUFFERS_STRING_EXPAND(X)
|
||||
namespace flatbuffers {
|
||||
@@ -270,9 +270,12 @@ namespace flatbuffers {
|
||||
#endif // !FLATBUFFERS_HAS_NEW_STRTOD
|
||||
|
||||
#ifndef FLATBUFFERS_LOCALE_INDEPENDENT
|
||||
// Enable locale independent functions {strtof_l, strtod_l,strtoll_l, strtoull_l}.
|
||||
#if ((defined(_MSC_VER) && _MSC_VER >= 1800) || \
|
||||
(defined(_XOPEN_VERSION) && (_XOPEN_VERSION>=700)) && (!defined(__ANDROID_API__) || (defined(__ANDROID_API__) && (__ANDROID_API__>=21))))
|
||||
// Enable locale independent functions {strtof_l, strtod_l,strtoll_l,
|
||||
// strtoull_l}.
|
||||
#if (defined(_MSC_VER) && _MSC_VER >= 1800) || \
|
||||
(defined(__ANDROID_API__) && __ANDROID_API__>= 21) || \
|
||||
(defined(_XOPEN_VERSION) && (_XOPEN_VERSION >= 700)) && \
|
||||
(!defined(__Fuchsia__) && !defined(__ANDROID_API__))
|
||||
#define FLATBUFFERS_LOCALE_INDEPENDENT 1
|
||||
#else
|
||||
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
|
||||
@@ -338,8 +341,17 @@ typedef uintmax_t largest_scalar_t;
|
||||
// In 32bits, this evaluates to 2GB - 1
|
||||
#define FLATBUFFERS_MAX_BUFFER_SIZE ((1ULL << (sizeof(::flatbuffers::soffset_t) * 8 - 1)) - 1)
|
||||
|
||||
// The minimum size buffer that can be a valid flatbuffer.
|
||||
// Includes the offset to the root table (uoffset_t), the offset to the vtable
|
||||
// of the root table (soffset_t), the size of the vtable (uint16_t), and the
|
||||
// size of the referring table (uint16_t).
|
||||
#define FLATBUFFERS_MIN_BUFFER_SIZE sizeof(uoffset_t) + sizeof(soffset_t) + \
|
||||
sizeof(uint16_t) + sizeof(uint16_t)
|
||||
|
||||
// We support aligning the contents of buffers up to this size.
|
||||
#define FLATBUFFERS_MAX_ALIGNMENT 16
|
||||
#ifndef FLATBUFFERS_MAX_ALIGNMENT
|
||||
#define FLATBUFFERS_MAX_ALIGNMENT 32
|
||||
#endif
|
||||
|
||||
/// @brief The length of a FlatBuffer file header.
|
||||
static const size_t kFileIdentifierLength = 4;
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#define FLATBUFFERS_FLATBUFFER_BUILDER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "flatbuffers/allocator.h"
|
||||
#include "flatbuffers/array.h"
|
||||
@@ -42,14 +43,15 @@ inline voffset_t FieldIndexToOffset(voffset_t field_id) {
|
||||
return static_cast<voffset_t>((field_id + fixed_fields) * sizeof(voffset_t));
|
||||
}
|
||||
|
||||
template<typename T, typename Alloc>
|
||||
template<typename T, typename Alloc = std::allocator<T>>
|
||||
const T *data(const std::vector<T, Alloc> &v) {
|
||||
// Eventually the returned pointer gets passed down to memcpy, so
|
||||
// we need it to be non-null to avoid undefined behavior.
|
||||
static uint8_t t;
|
||||
return v.empty() ? reinterpret_cast<const T *>(&t) : &v.front();
|
||||
}
|
||||
template<typename T, typename Alloc> T *data(std::vector<T, Alloc> &v) {
|
||||
template<typename T, typename Alloc = std::allocator<T>>
|
||||
T *data(std::vector<T, Alloc> &v) {
|
||||
// Eventually the returned pointer gets passed down to memcpy, so
|
||||
// we need it to be non-null to avoid undefined behavior.
|
||||
static uint8_t t;
|
||||
@@ -285,9 +287,7 @@ class FlatBufferBuilder {
|
||||
FieldLoc fl = { off, field };
|
||||
buf_.scratch_push_small(fl);
|
||||
num_field_loc++;
|
||||
if (field > max_voffset_) {
|
||||
max_voffset_ = field;
|
||||
}
|
||||
if (field > max_voffset_) { max_voffset_ = field; }
|
||||
}
|
||||
|
||||
// Like PushElement, but additionally tracks the field this represents.
|
||||
@@ -443,6 +443,7 @@ class FlatBufferBuilder {
|
||||
// Aligns such that when "len" bytes are written, an object can be written
|
||||
// after it with "alignment" without padding.
|
||||
void PreAlign(size_t len, size_t alignment) {
|
||||
if (len == 0) return;
|
||||
TrackMinAlign(alignment);
|
||||
buf_.fill(PaddingBytes(GetSize() + len, alignment));
|
||||
}
|
||||
@@ -601,12 +602,14 @@ class FlatBufferBuilder {
|
||||
// This is useful when storing a nested_flatbuffer in a vector of bytes,
|
||||
// or when storing SIMD floats, etc.
|
||||
void ForceVectorAlignment(size_t len, size_t elemsize, size_t alignment) {
|
||||
if (len == 0) return;
|
||||
FLATBUFFERS_ASSERT(VerifyAlignmentRequirements(alignment));
|
||||
PreAlign(len * elemsize, alignment);
|
||||
}
|
||||
|
||||
// Similar to ForceVectorAlignment but for String fields.
|
||||
void ForceStringAlignment(size_t len, size_t alignment) {
|
||||
if (len == 0) return;
|
||||
FLATBUFFERS_ASSERT(VerifyAlignmentRequirements(alignment));
|
||||
PreAlign((len + 1) * sizeof(char), alignment);
|
||||
}
|
||||
@@ -642,6 +645,27 @@ class FlatBufferBuilder {
|
||||
return Offset<Vector<T>>(EndVector(len));
|
||||
}
|
||||
|
||||
/// @brief Serialize an array like object into a FlatBuffer `vector`.
|
||||
/// @tparam T The data type of the array elements.
|
||||
/// @tparam C The type of the array.
|
||||
/// @param[in] array A reference to an array like object of type `T` to
|
||||
/// serialize into the buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, class C> Offset<Vector<T>> CreateVector(const C &array) {
|
||||
return CreateVector(array.data(), array.size());
|
||||
}
|
||||
|
||||
/// @brief Serialize an initializer list into a FlatBuffer `vector`.
|
||||
/// @tparam T The data type of the initializer list elements.
|
||||
/// @param[in] v The value of the initializer list.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T>
|
||||
Offset<Vector<T>> CreateVector(std::initializer_list<T> v) {
|
||||
return CreateVector(v.begin(), v.size());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Offset<Vector<Offset<T>>> CreateVector(const Offset<T> *v, size_t len) {
|
||||
StartVector(len, sizeof(Offset<T>));
|
||||
@@ -655,7 +679,7 @@ class FlatBufferBuilder {
|
||||
/// buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename Alloc>
|
||||
template<typename T, typename Alloc = std::allocator<T>>
|
||||
Offset<Vector<T>> CreateVector(const std::vector<T, Alloc> &v) {
|
||||
return CreateVector(data(v), v.size());
|
||||
}
|
||||
@@ -706,15 +730,18 @@ class FlatBufferBuilder {
|
||||
return CreateVector(elems);
|
||||
}
|
||||
|
||||
/// @brief Serialize a `std::vector<std::string>` into a FlatBuffer `vector`.
|
||||
/// @brief Serialize a `std::vector<StringType>` into a FlatBuffer `vector`.
|
||||
/// whereas StringType is any type that is accepted by the CreateString()
|
||||
/// overloads.
|
||||
/// This is a convenience function for a common case.
|
||||
/// @param v A const reference to the `std::vector` to serialize into the
|
||||
/// buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename Alloc>
|
||||
template<typename StringType = std::string,
|
||||
typename Alloc = std::allocator<StringType>>
|
||||
Offset<Vector<Offset<String>>> CreateVectorOfStrings(
|
||||
const std::vector<std::string, Alloc> &v) {
|
||||
const std::vector<StringType, Alloc> &v) {
|
||||
return CreateVectorOfStrings(v.cbegin(), v.cend());
|
||||
}
|
||||
|
||||
@@ -841,7 +868,7 @@ class FlatBufferBuilder {
|
||||
/// serialize into the buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename Alloc>
|
||||
template<typename T, typename Alloc = std::allocator<T>>
|
||||
Offset<Vector<const T *>> CreateVectorOfStructs(
|
||||
const std::vector<T, Alloc> &v) {
|
||||
return CreateVectorOfStructs(data(v), v.size());
|
||||
@@ -857,7 +884,7 @@ class FlatBufferBuilder {
|
||||
/// to the FlatBuffer struct.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename S, typename Alloc>
|
||||
template<typename T, typename S, typename Alloc = std::allocator<T>>
|
||||
Offset<Vector<const T *>> CreateVectorOfNativeStructs(
|
||||
const std::vector<S, Alloc> &v, T (*const pack_func)(const S &)) {
|
||||
return CreateVectorOfNativeStructs<T, S>(data(v), v.size(), pack_func);
|
||||
@@ -871,7 +898,7 @@ class FlatBufferBuilder {
|
||||
/// serialize into the buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename S, typename Alloc>
|
||||
template<typename T, typename S, typename Alloc = std::allocator<S>>
|
||||
Offset<Vector<const T *>> CreateVectorOfNativeStructs(
|
||||
const std::vector<S, Alloc> &v) {
|
||||
return CreateVectorOfNativeStructs<T, S>(data(v), v.size());
|
||||
@@ -892,7 +919,7 @@ class FlatBufferBuilder {
|
||||
/// serialize into the buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename Alloc>
|
||||
template<typename T, typename Alloc = std::allocator<T>>
|
||||
Offset<Vector<const T *>> CreateVectorOfSortedStructs(
|
||||
std::vector<T, Alloc> *v) {
|
||||
return CreateVectorOfSortedStructs(data(*v), v->size());
|
||||
@@ -906,7 +933,7 @@ class FlatBufferBuilder {
|
||||
/// serialize into the buffer as a `vector`.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename S, typename Alloc>
|
||||
template<typename T, typename S, typename Alloc = std::allocator<T>>
|
||||
Offset<Vector<const T *>> CreateVectorOfSortedNativeStructs(
|
||||
std::vector<S, Alloc> *v) {
|
||||
return CreateVectorOfSortedNativeStructs<T, S>(data(*v), v->size());
|
||||
@@ -922,7 +949,7 @@ class FlatBufferBuilder {
|
||||
/// where the vector is stored.
|
||||
template<typename T>
|
||||
Offset<Vector<const T *>> CreateVectorOfSortedStructs(T *v, size_t len) {
|
||||
std::sort(v, v + len, StructKeyComparator<T>());
|
||||
std::stable_sort(v, v + len, StructKeyComparator<T>());
|
||||
return CreateVectorOfStructs(v, len);
|
||||
}
|
||||
|
||||
@@ -941,7 +968,7 @@ class FlatBufferBuilder {
|
||||
extern T Pack(const S &);
|
||||
auto structs = StartVectorOfStructs<T>(len);
|
||||
for (size_t i = 0; i < len; i++) { structs[i] = Pack(v[i]); }
|
||||
std::sort(structs, structs + len, StructKeyComparator<T>());
|
||||
std::stable_sort(structs, structs + len, StructKeyComparator<T>());
|
||||
return EndVectorOfStructs<T>(len);
|
||||
}
|
||||
|
||||
@@ -973,7 +1000,7 @@ class FlatBufferBuilder {
|
||||
template<typename T>
|
||||
Offset<Vector<Offset<T>>> CreateVectorOfSortedTables(Offset<T> *v,
|
||||
size_t len) {
|
||||
std::sort(v, v + len, TableKeyComparator<T>(buf_));
|
||||
std::stable_sort(v, v + len, TableKeyComparator<T>(buf_));
|
||||
return CreateVector(v, len);
|
||||
}
|
||||
|
||||
@@ -984,7 +1011,7 @@ class FlatBufferBuilder {
|
||||
/// offsets to store in the buffer in sorted order.
|
||||
/// @return Returns a typed `Offset` into the serialized data indicating
|
||||
/// where the vector is stored.
|
||||
template<typename T, typename Alloc>
|
||||
template<typename T, typename Alloc = std::allocator<T>>
|
||||
Offset<Vector<Offset<T>>> CreateVectorOfSortedTables(
|
||||
std::vector<Offset<T>, Alloc> *v) {
|
||||
return CreateVectorOfSortedTables(data(*v), v->size());
|
||||
@@ -1074,7 +1101,7 @@ class FlatBufferBuilder {
|
||||
void SwapBufAllocator(FlatBufferBuilder &other) {
|
||||
buf_.swap_allocator(other.buf_);
|
||||
}
|
||||
|
||||
|
||||
/// @brief The length of a FlatBuffer file header.
|
||||
static const size_t kFileIdentifierLength =
|
||||
::flatbuffers::kFileIdentifierLength;
|
||||
|
||||
@@ -226,27 +226,13 @@ struct TypeTable {
|
||||
};
|
||||
|
||||
// String which identifies the current version of FlatBuffers.
|
||||
// flatbuffer_version_string is used by Google developers to identify which
|
||||
// applications uploaded to Google Play are using this library. This allows
|
||||
// the development team at Google to determine the popularity of the library.
|
||||
// How it works: Applications that are uploaded to the Google Play Store are
|
||||
// scanned for this version string. We track which applications are using it
|
||||
// to measure popularity. You are free to remove it (of course) but we would
|
||||
// appreciate if you left it in.
|
||||
inline const char *flatbuffers_version_string() {
|
||||
return "FlatBuffers " FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MAJOR) "."
|
||||
FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MINOR) "."
|
||||
FLATBUFFERS_STRING(FLATBUFFERS_VERSION_REVISION);
|
||||
}
|
||||
|
||||
// Weak linkage is culled by VS & doesn't work on cygwin.
|
||||
// clang-format off
|
||||
#if !defined(_WIN32) && !defined(__CYGWIN__)
|
||||
|
||||
extern volatile __attribute__((weak)) const char *flatbuffer_version_string;
|
||||
volatile __attribute__((weak)) const char *flatbuffer_version_string =
|
||||
"FlatBuffers "
|
||||
FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MAJOR) "."
|
||||
FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MINOR) "."
|
||||
FLATBUFFERS_STRING(FLATBUFFERS_VERSION_REVISION);
|
||||
|
||||
#endif // !defined(_WIN32) && !defined(__CYGWIN__)
|
||||
|
||||
#define FLATBUFFERS_DEFINE_BITMASK_OPERATORS(E, T)\
|
||||
inline E operator | (E lhs, E rhs){\
|
||||
return E(T(lhs) | T(rhs));\
|
||||
|
||||
@@ -156,6 +156,7 @@ inline uint64_t ReadUInt64(const uint8_t *data, uint8_t byte_width) {
|
||||
// TODO: GCC apparently replaces memcpy by a rep movsb, but only if count is a
|
||||
// constant, which here it isn't. Test if memcpy is still faster than
|
||||
// the conditionals in ReadSizedScalar. Can also use inline asm.
|
||||
|
||||
// clang-format off
|
||||
#if defined(_MSC_VER) && defined(_M_X64) && !defined(_M_ARM64EC)
|
||||
// This is 64-bit Windows only, __movsb does not work on 32-bit Windows.
|
||||
@@ -371,10 +372,7 @@ void AppendToString(std::string &s, T &&v, bool keys_quoted) {
|
||||
class Reference {
|
||||
public:
|
||||
Reference()
|
||||
: data_(nullptr),
|
||||
parent_width_(0),
|
||||
byte_width_(0),
|
||||
type_(FBT_NULL) {}
|
||||
: data_(nullptr), parent_width_(0), byte_width_(0), type_(FBT_NULL) {}
|
||||
|
||||
Reference(const uint8_t *data, uint8_t parent_width, uint8_t byte_width,
|
||||
Type type)
|
||||
@@ -590,7 +588,23 @@ class Reference {
|
||||
auto keys = m.Keys();
|
||||
auto vals = m.Values();
|
||||
for (size_t i = 0; i < keys.size(); i++) {
|
||||
keys[i].ToString(true, keys_quoted, s);
|
||||
bool kq = keys_quoted;
|
||||
if (!kq) {
|
||||
// FlexBuffers keys may contain arbitrary characters, only allow
|
||||
// unquoted if it looks like an "identifier":
|
||||
const char *p = keys[i].AsKey();
|
||||
if (!flatbuffers::is_alpha(*p) && *p != '_') {
|
||||
kq = true;
|
||||
} else {
|
||||
while (*++p) {
|
||||
if (!flatbuffers::is_alnum(*p) && *p != '_') {
|
||||
kq = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
keys[i].ToString(true, kq, s);
|
||||
s += ": ";
|
||||
vals[i].ToString(true, keys_quoted, s);
|
||||
if (i < keys.size() - 1) s += ", ";
|
||||
@@ -1424,10 +1438,12 @@ class Builder FLATBUFFERS_FINAL_CLASS {
|
||||
|
||||
template<typename T> static Type GetScalarType() {
|
||||
static_assert(flatbuffers::is_scalar<T>::value, "Unrelated types");
|
||||
return flatbuffers::is_floating_point<T>::value ? FBT_FLOAT
|
||||
: flatbuffers::is_same<T, bool>::value
|
||||
? FBT_BOOL
|
||||
: (flatbuffers::is_unsigned<T>::value ? FBT_UINT : FBT_INT);
|
||||
return flatbuffers::is_floating_point<T>::value
|
||||
? FBT_FLOAT
|
||||
: flatbuffers::is_same<T, bool>::value
|
||||
? FBT_BOOL
|
||||
: (flatbuffers::is_unsigned<T>::value ? FBT_UINT
|
||||
: FBT_INT);
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -1660,8 +1676,7 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
// comes at the cost of using additional memory the same size of
|
||||
// the buffer being verified, so it is by default off.
|
||||
std::vector<uint8_t> *reuse_tracker = nullptr,
|
||||
bool _check_alignment = true,
|
||||
size_t max_depth = 64)
|
||||
bool _check_alignment = true, size_t max_depth = 64)
|
||||
: buf_(buf),
|
||||
size_(buf_len),
|
||||
depth_(0),
|
||||
@@ -1704,18 +1719,16 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
auto o = static_cast<size_t>(p - buf_);
|
||||
return VerifyBefore(o, len);
|
||||
}
|
||||
|
||||
|
||||
bool VerifyByteWidth(size_t width) {
|
||||
return Check(width == 1 || width == 2 || width == 4 || width == 8);
|
||||
}
|
||||
|
||||
bool VerifyType(int type) {
|
||||
return Check(type >= 0 && type < FBT_MAX_TYPE);
|
||||
}
|
||||
bool VerifyType(int type) { return Check(type >= 0 && type < FBT_MAX_TYPE); }
|
||||
|
||||
bool VerifyOffset(uint64_t off, const uint8_t *p) {
|
||||
return Check(off <= static_cast<uint64_t>(size_)) &&
|
||||
off <= static_cast<uint64_t>(p - buf_);
|
||||
off <= static_cast<uint64_t>(p - buf_);
|
||||
}
|
||||
|
||||
bool VerifyAlignment(const uint8_t *p, size_t size) const {
|
||||
@@ -1723,16 +1736,16 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
return Check((o & (size - 1)) == 0 || !check_alignment_);
|
||||
}
|
||||
|
||||
// Macro, since we want to escape from parent function & use lazy args.
|
||||
#define FLEX_CHECK_VERIFIED(P, PACKED_TYPE) \
|
||||
if (reuse_tracker_) { \
|
||||
auto packed_type = PACKED_TYPE; \
|
||||
auto existing = (*reuse_tracker_)[P - buf_]; \
|
||||
if (existing == packed_type) return true; \
|
||||
/* Fail verification if already set with different type! */ \
|
||||
if (!Check(existing == 0)) return false; \
|
||||
(*reuse_tracker_)[P - buf_] = packed_type; \
|
||||
}
|
||||
// Macro, since we want to escape from parent function & use lazy args.
|
||||
#define FLEX_CHECK_VERIFIED(P, PACKED_TYPE) \
|
||||
if (reuse_tracker_) { \
|
||||
auto packed_type = PACKED_TYPE; \
|
||||
auto existing = (*reuse_tracker_)[P - buf_]; \
|
||||
if (existing == packed_type) return true; \
|
||||
/* Fail verification if already set with different type! */ \
|
||||
if (!Check(existing == 0)) return false; \
|
||||
(*reuse_tracker_)[P - buf_] = packed_type; \
|
||||
}
|
||||
|
||||
bool VerifyVector(Reference r, const uint8_t *p, Type elem_type) {
|
||||
// Any kind of nesting goes thru this function, so guard against that
|
||||
@@ -1742,19 +1755,19 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
if (!Check(depth_ <= max_depth_ && num_vectors_ <= max_vectors_))
|
||||
return false;
|
||||
auto size_byte_width = r.byte_width_;
|
||||
FLEX_CHECK_VERIFIED(p, PackedType(Builder::WidthB(size_byte_width), r.type_));
|
||||
if (!VerifyBeforePointer(p, size_byte_width))
|
||||
return false;
|
||||
if (!VerifyBeforePointer(p, size_byte_width)) return false;
|
||||
FLEX_CHECK_VERIFIED(p - size_byte_width,
|
||||
PackedType(Builder::WidthB(size_byte_width), r.type_));
|
||||
auto sized = Sized(p, size_byte_width);
|
||||
auto num_elems = sized.size();
|
||||
auto elem_byte_width =
|
||||
r.type_ == FBT_STRING || r.type_ == FBT_BLOB ? uint8_t(1) : r.byte_width_;
|
||||
auto elem_byte_width = r.type_ == FBT_STRING || r.type_ == FBT_BLOB
|
||||
? uint8_t(1)
|
||||
: r.byte_width_;
|
||||
auto max_elems = SIZE_MAX / elem_byte_width;
|
||||
if (!Check(num_elems < max_elems))
|
||||
return false; // Protect against byte_size overflowing.
|
||||
auto byte_size = num_elems * elem_byte_width;
|
||||
if (!VerifyFromPointer(p, byte_size))
|
||||
return false;
|
||||
if (!VerifyFromPointer(p, byte_size)) return false;
|
||||
if (elem_type == FBT_NULL) {
|
||||
// Verify type bytes after the vector.
|
||||
if (!VerifyFromPointer(p + byte_size, num_elems)) return false;
|
||||
@@ -1775,28 +1788,25 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
bool VerifyKeys(const uint8_t *p, uint8_t byte_width) {
|
||||
// The vector part of the map has already been verified.
|
||||
const size_t num_prefixed_fields = 3;
|
||||
if (!VerifyBeforePointer(p, byte_width * num_prefixed_fields))
|
||||
return false;
|
||||
if (!VerifyBeforePointer(p, byte_width * num_prefixed_fields)) return false;
|
||||
p -= byte_width * num_prefixed_fields;
|
||||
auto off = ReadUInt64(p, byte_width);
|
||||
if (!VerifyOffset(off, p))
|
||||
return false;
|
||||
if (!VerifyOffset(off, p)) return false;
|
||||
auto key_byte_with =
|
||||
static_cast<uint8_t>(ReadUInt64(p + byte_width, byte_width));
|
||||
if (!VerifyByteWidth(key_byte_with))
|
||||
return false;
|
||||
static_cast<uint8_t>(ReadUInt64(p + byte_width, byte_width));
|
||||
if (!VerifyByteWidth(key_byte_with)) return false;
|
||||
return VerifyVector(Reference(p, byte_width, key_byte_with, FBT_VECTOR_KEY),
|
||||
p - off, FBT_KEY);
|
||||
}
|
||||
|
||||
bool VerifyKey(const uint8_t* p) {
|
||||
bool VerifyKey(const uint8_t *p) {
|
||||
FLEX_CHECK_VERIFIED(p, PackedType(BIT_WIDTH_8, FBT_KEY));
|
||||
while (p < buf_ + size_)
|
||||
if (*p++) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
#undef FLEX_CHECK_VERIFIED
|
||||
#undef FLEX_CHECK_VERIFIED
|
||||
|
||||
bool VerifyTerminator(const String &s) {
|
||||
return VerifyFromPointer(reinterpret_cast<const uint8_t *>(s.c_str()),
|
||||
@@ -1814,37 +1824,26 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
}
|
||||
// All remaining types are an offset.
|
||||
auto off = ReadUInt64(r.data_, r.parent_width_);
|
||||
if (!VerifyOffset(off, r.data_))
|
||||
return false;
|
||||
if (!VerifyOffset(off, r.data_)) return false;
|
||||
auto p = r.Indirect();
|
||||
if (!VerifyAlignment(p, r.byte_width_))
|
||||
return false;
|
||||
if (!VerifyAlignment(p, r.byte_width_)) return false;
|
||||
switch (r.type_) {
|
||||
case FBT_INDIRECT_INT:
|
||||
case FBT_INDIRECT_UINT:
|
||||
case FBT_INDIRECT_FLOAT:
|
||||
return VerifyFromPointer(p, r.byte_width_);
|
||||
case FBT_KEY:
|
||||
return VerifyKey(p);
|
||||
case FBT_INDIRECT_FLOAT: return VerifyFromPointer(p, r.byte_width_);
|
||||
case FBT_KEY: return VerifyKey(p);
|
||||
case FBT_MAP:
|
||||
return VerifyVector(r, p, FBT_NULL) &&
|
||||
VerifyKeys(p, r.byte_width_);
|
||||
case FBT_VECTOR:
|
||||
return VerifyVector(r, p, FBT_NULL);
|
||||
case FBT_VECTOR_INT:
|
||||
return VerifyVector(r, p, FBT_INT);
|
||||
return VerifyVector(r, p, FBT_NULL) && VerifyKeys(p, r.byte_width_);
|
||||
case FBT_VECTOR: return VerifyVector(r, p, FBT_NULL);
|
||||
case FBT_VECTOR_INT: return VerifyVector(r, p, FBT_INT);
|
||||
case FBT_VECTOR_BOOL:
|
||||
case FBT_VECTOR_UINT:
|
||||
return VerifyVector(r, p, FBT_UINT);
|
||||
case FBT_VECTOR_FLOAT:
|
||||
return VerifyVector(r, p, FBT_FLOAT);
|
||||
case FBT_VECTOR_KEY:
|
||||
return VerifyVector(r, p, FBT_KEY);
|
||||
case FBT_VECTOR_UINT: return VerifyVector(r, p, FBT_UINT);
|
||||
case FBT_VECTOR_FLOAT: return VerifyVector(r, p, FBT_FLOAT);
|
||||
case FBT_VECTOR_KEY: return VerifyVector(r, p, FBT_KEY);
|
||||
case FBT_VECTOR_STRING_DEPRECATED:
|
||||
// Use of FBT_KEY here intentional, see elsewhere.
|
||||
return VerifyVector(r, p, FBT_KEY);
|
||||
case FBT_BLOB:
|
||||
return VerifyVector(r, p, FBT_UINT);
|
||||
case FBT_BLOB: return VerifyVector(r, p, FBT_UINT);
|
||||
case FBT_STRING:
|
||||
return VerifyVector(r, p, FBT_UINT) &&
|
||||
VerifyTerminator(String(p, r.byte_width_));
|
||||
@@ -1859,12 +1858,10 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
case FBT_VECTOR_FLOAT4: {
|
||||
uint8_t len = 0;
|
||||
auto vtype = ToFixedTypedVectorElementType(r.type_, &len);
|
||||
if (!VerifyType(vtype))
|
||||
return false;
|
||||
if (!VerifyType(vtype)) return false;
|
||||
return VerifyFromPointer(p, r.byte_width_ * len);
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1874,8 +1871,7 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
auto end = buf_ + size_;
|
||||
auto byte_width = *--end;
|
||||
auto packed_type = *--end;
|
||||
return VerifyByteWidth(byte_width) &&
|
||||
Check(end - buf_ >= byte_width) &&
|
||||
return VerifyByteWidth(byte_width) && Check(end - buf_ >= byte_width) &&
|
||||
VerifyRef(Reference(end - byte_width, byte_width, packed_type));
|
||||
}
|
||||
|
||||
@@ -1890,27 +1886,14 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
std::vector<uint8_t> *reuse_tracker_;
|
||||
};
|
||||
|
||||
// Utility function that contructs the Verifier for you, see above for parameters.
|
||||
// Utility function that contructs the Verifier for you, see above for
|
||||
// parameters.
|
||||
inline bool VerifyBuffer(const uint8_t *buf, size_t buf_len,
|
||||
std::vector<uint8_t> *reuse_tracker = nullptr) {
|
||||
Verifier verifier(buf, buf_len, reuse_tracker);
|
||||
return verifier.VerifyBuffer();
|
||||
}
|
||||
|
||||
|
||||
#ifdef FLATBUFFERS_H_
|
||||
// This is a verifier utility function that works together with the
|
||||
// FlatBuffers verifier, which should only be present if flatbuffer.h
|
||||
// has been included (which it typically is in generated code).
|
||||
inline bool VerifyNestedFlexBuffer(const flatbuffers::Vector<uint8_t> *nv,
|
||||
flatbuffers::Verifier &verifier) {
|
||||
if (!nv) return true;
|
||||
return verifier.Check(
|
||||
flexbuffers::VerifyBuffer(nv->data(), nv->size(),
|
||||
verifier.GetFlexReuseTracker()));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace flexbuffers
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
@@ -26,16 +26,20 @@
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
|
||||
// Detect C++17 compatible compiler.
|
||||
// __cplusplus >= 201703L - a compiler has support of 'static inline' variables.
|
||||
#if defined(FLATBUFFERS_USE_STD_OPTIONAL) \
|
||||
|| (defined(__cplusplus) && __cplusplus >= 201703L) \
|
||||
|| (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L))
|
||||
#ifndef FLATBUFFERS_USE_STD_OPTIONAL
|
||||
// Detect C++17 compatible compiler.
|
||||
// __cplusplus >= 201703L - a compiler has support of 'static inline' variables.
|
||||
#if (defined(__cplusplus) && __cplusplus >= 201703L) \
|
||||
|| (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
|
||||
#define FLATBUFFERS_USE_STD_OPTIONAL 1
|
||||
#else
|
||||
#define FLATBUFFERS_USE_STD_OPTIONAL 0
|
||||
#endif // (defined(__cplusplus) && __cplusplus >= 201703L) ...
|
||||
#endif // FLATBUFFERS_USE_STD_OPTIONAL
|
||||
|
||||
#if FLATBUFFERS_USE_STD_OPTIONAL
|
||||
#include <optional>
|
||||
#ifndef FLATBUFFERS_USE_STD_OPTIONAL
|
||||
#define FLATBUFFERS_USE_STD_OPTIONAL
|
||||
#endif
|
||||
#endif // defined(FLATBUFFERS_USE_STD_OPTIONAL) ...
|
||||
#endif
|
||||
|
||||
// The __cpp_lib_span is the predefined feature macro.
|
||||
#if defined(FLATBUFFERS_USE_STD_SPAN)
|
||||
@@ -128,7 +132,7 @@ namespace flatbuffers {
|
||||
};
|
||||
#endif // defined(FLATBUFFERS_TEMPLATES_ALIASES)
|
||||
|
||||
#ifdef FLATBUFFERS_USE_STD_OPTIONAL
|
||||
#if FLATBUFFERS_USE_STD_OPTIONAL
|
||||
template<class T>
|
||||
using Optional = std::optional<T>;
|
||||
using nullopt_t = std::nullopt_t;
|
||||
@@ -284,13 +288,13 @@ FLATBUFFERS_CONSTEXPR std::size_t dynamic_extent = static_cast<std::size_t>(-1);
|
||||
namespace internal {
|
||||
// This is SFINAE helper class for checking of a common condition:
|
||||
// > This overload only participates in overload resolution
|
||||
// > Check whether a pointer to an array of U can be converted
|
||||
// > to a pointer to an array of E.
|
||||
// This helper is used for checking of 'U -> const U'.
|
||||
template<class E, std::size_t Extent, class U, std::size_t N>
|
||||
// > Check whether a pointer to an array of From can be converted
|
||||
// > to a pointer to an array of To.
|
||||
// This helper is used for checking of 'From -> const From'.
|
||||
template<class To, std::size_t Extent, class From, std::size_t N>
|
||||
struct is_span_convertable {
|
||||
using type =
|
||||
typename std::conditional<std::is_convertible<U (*)[], E (*)[]>::value
|
||||
typename std::conditional<std::is_convertible<From (*)[], To (*)[]>::value
|
||||
&& (Extent == dynamic_extent || N == Extent),
|
||||
int, void>::type;
|
||||
};
|
||||
@@ -362,13 +366,9 @@ class span FLATBUFFERS_FINAL_CLASS {
|
||||
|
||||
#if !defined(FLATBUFFERS_SPAN_MINIMAL)
|
||||
using Iterator = internal::SpanIterator<T>;
|
||||
using ConstIterator = internal::SpanIterator<const T>;
|
||||
|
||||
Iterator begin() const { return Iterator(data()); }
|
||||
Iterator end() const { return Iterator(data() + size()); }
|
||||
|
||||
ConstIterator cbegin() const { return ConstIterator(data()); }
|
||||
ConstIterator cend() const { return ConstIterator(data() + size()); }
|
||||
#endif
|
||||
|
||||
// Returns a reference to the idx-th element of the sequence.
|
||||
@@ -462,45 +462,45 @@ class span FLATBUFFERS_FINAL_CLASS {
|
||||
private:
|
||||
// This is a naive implementation with 'count_' member even if (Extent != dynamic_extent).
|
||||
pointer const data_;
|
||||
const size_type count_;
|
||||
size_type count_;
|
||||
};
|
||||
#endif // defined(FLATBUFFERS_USE_STD_SPAN)
|
||||
|
||||
#if !defined(FLATBUFFERS_SPAN_MINIMAL)
|
||||
template<class U, std::size_t N>
|
||||
template<class ElementType, std::size_t Extent>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11
|
||||
flatbuffers::span<U, N> make_span(U(&arr)[N]) FLATBUFFERS_NOEXCEPT {
|
||||
return span<U, N>(arr);
|
||||
flatbuffers::span<ElementType, Extent> make_span(ElementType(&arr)[Extent]) FLATBUFFERS_NOEXCEPT {
|
||||
return span<ElementType, Extent>(arr);
|
||||
}
|
||||
|
||||
template<class U, std::size_t N>
|
||||
template<class ElementType, std::size_t Extent>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11
|
||||
flatbuffers::span<const U, N> make_span(const U(&arr)[N]) FLATBUFFERS_NOEXCEPT {
|
||||
return span<const U, N>(arr);
|
||||
flatbuffers::span<const ElementType, Extent> make_span(const ElementType(&arr)[Extent]) FLATBUFFERS_NOEXCEPT {
|
||||
return span<const ElementType, Extent>(arr);
|
||||
}
|
||||
|
||||
template<class U, std::size_t N>
|
||||
template<class ElementType, std::size_t Extent>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11
|
||||
flatbuffers::span<U, N> make_span(std::array<U, N> &arr) FLATBUFFERS_NOEXCEPT {
|
||||
return span<U, N>(arr);
|
||||
flatbuffers::span<ElementType, Extent> make_span(std::array<ElementType, Extent> &arr) FLATBUFFERS_NOEXCEPT {
|
||||
return span<ElementType, Extent>(arr);
|
||||
}
|
||||
|
||||
template<class U, std::size_t N>
|
||||
template<class ElementType, std::size_t Extent>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11
|
||||
flatbuffers::span<const U, N> make_span(const std::array<U, N> &arr) FLATBUFFERS_NOEXCEPT {
|
||||
return span<const U, N>(arr);
|
||||
flatbuffers::span<const ElementType, Extent> make_span(const std::array<ElementType, Extent> &arr) FLATBUFFERS_NOEXCEPT {
|
||||
return span<const ElementType, Extent>(arr);
|
||||
}
|
||||
|
||||
template<class U, std::size_t N>
|
||||
template<class ElementType, std::size_t Extent>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11
|
||||
flatbuffers::span<U, dynamic_extent> make_span(U *first, std::size_t count) FLATBUFFERS_NOEXCEPT {
|
||||
return span<U, dynamic_extent>(first, count);
|
||||
flatbuffers::span<ElementType, dynamic_extent> make_span(ElementType *first, std::size_t count) FLATBUFFERS_NOEXCEPT {
|
||||
return span<ElementType, dynamic_extent>(first, count);
|
||||
}
|
||||
|
||||
template<class U, std::size_t N>
|
||||
template<class ElementType, std::size_t Extent>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11
|
||||
flatbuffers::span<const U, dynamic_extent> make_span(const U *first, std::size_t count) FLATBUFFERS_NOEXCEPT {
|
||||
return span<const U, dynamic_extent>(first, count);
|
||||
flatbuffers::span<const ElementType, dynamic_extent> make_span(const ElementType *first, std::size_t count) FLATBUFFERS_NOEXCEPT {
|
||||
return span<const ElementType, dynamic_extent>(first, count);
|
||||
}
|
||||
#endif // !defined(FLATBUFFERS_SPAN_MINIMAL)
|
||||
|
||||
|
||||
@@ -112,20 +112,22 @@ class Table {
|
||||
|
||||
// Verify a particular field.
|
||||
template<typename T>
|
||||
bool VerifyField(const Verifier &verifier, voffset_t field) const {
|
||||
bool VerifyField(const Verifier &verifier, voffset_t field,
|
||||
size_t align) const {
|
||||
// Calling GetOptionalFieldOffset should be safe now thanks to
|
||||
// VerifyTable().
|
||||
auto field_offset = GetOptionalFieldOffset(field);
|
||||
// Check the actual field.
|
||||
return !field_offset || verifier.Verify<T>(data_, field_offset);
|
||||
return !field_offset || verifier.VerifyField<T>(data_, field_offset, align);
|
||||
}
|
||||
|
||||
// VerifyField for required fields.
|
||||
template<typename T>
|
||||
bool VerifyFieldRequired(const Verifier &verifier, voffset_t field) const {
|
||||
bool VerifyFieldRequired(const Verifier &verifier, voffset_t field,
|
||||
size_t align) const {
|
||||
auto field_offset = GetOptionalFieldOffset(field);
|
||||
return verifier.Check(field_offset != 0) &&
|
||||
verifier.Verify<T>(data_, field_offset);
|
||||
verifier.VerifyField<T>(data_, field_offset, align);
|
||||
}
|
||||
|
||||
// Versions for offsets.
|
||||
@@ -163,4 +165,4 @@ inline flatbuffers::Optional<bool> Table::GetOptional<uint8_t, bool>(
|
||||
|
||||
} // namespace flatbuffers
|
||||
|
||||
#endif // FLATBUFFERS_TABLE_H_
|
||||
#endif // FLATBUFFERS_TABLE_H_
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
#ifndef FLATBUFFERS_UTIL_H_
|
||||
#define FLATBUFFERS_UTIL_H_
|
||||
|
||||
#include <errno.h>
|
||||
#include <ctype.h>
|
||||
#include <errno.h>
|
||||
|
||||
#include "flatbuffers/base.h"
|
||||
#include "flatbuffers/stl_emulation.h"
|
||||
@@ -30,8 +30,8 @@
|
||||
#endif
|
||||
|
||||
#ifndef FLATBUFFERS_PREFER_PRINTF
|
||||
# include <sstream>
|
||||
# include <iomanip>
|
||||
# include <sstream>
|
||||
#else // FLATBUFFERS_PREFER_PRINTF
|
||||
# include <float.h>
|
||||
# include <stdio.h>
|
||||
@@ -454,6 +454,9 @@ std::string StripPath(const std::string &filepath);
|
||||
// Strip the last component of the path + separator.
|
||||
std::string StripFileName(const std::string &filepath);
|
||||
|
||||
std::string StripPrefix(const std::string &filepath,
|
||||
const std::string &prefix_to_remove);
|
||||
|
||||
// Concatenates a path with a filename, regardless of whether the path
|
||||
// ends in a separator or not.
|
||||
std::string ConCatPathFileName(const std::string &path,
|
||||
@@ -691,6 +694,32 @@ bool ReadEnvironmentVariable(const char *var_name,
|
||||
// MSVC specific: Send all assert reports to STDOUT to prevent CI hangs.
|
||||
void SetupDefaultCRTReportMode();
|
||||
|
||||
enum class Case {
|
||||
kUnknown = 0,
|
||||
// TheQuickBrownFox
|
||||
kUpperCamel = 1,
|
||||
// theQuickBrownFox
|
||||
kLowerCamel = 2,
|
||||
// the_quick_brown_fox
|
||||
kSnake = 3,
|
||||
// THE_QUICK_BROWN_FOX
|
||||
kScreamingSnake = 4,
|
||||
// THEQUICKBROWNFOX
|
||||
kAllUpper = 5,
|
||||
// thequickbrownfox
|
||||
kAllLower = 6,
|
||||
// the-quick-brown-fox
|
||||
kDasher = 7,
|
||||
// THEQuiCKBr_ownFox (or whatever you want, we won't change it)
|
||||
kKeep = 8,
|
||||
// the_quick_brown_fox123 (as opposed to the_quick_brown_fox_123)
|
||||
kSnake2 = 9,
|
||||
};
|
||||
|
||||
// Convert the `input` string of case `input_case` to the specified `output_case`.
|
||||
std::string ConvertCase(const std::string &input, Case output_case,
|
||||
Case input_case = Case::kSnake);
|
||||
|
||||
} // namespace flatbuffers
|
||||
|
||||
#endif // FLATBUFFERS_UTIL_H_
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
#include "flatbuffers/base.h"
|
||||
#include "flatbuffers/buffer.h"
|
||||
#include "flatbuffers/stl_emulation.h"
|
||||
|
||||
namespace flatbuffers {
|
||||
|
||||
@@ -326,6 +327,24 @@ FLATBUFFERS_CONSTEXPR_CPP11 flatbuffers::span<const uint8_t> make_bytes_span(
|
||||
return span<const uint8_t>(vec.Data(), vec.size() * sizeof(U));
|
||||
}
|
||||
|
||||
// Convenient helper functions to get a span of any vector, regardless
|
||||
// of whether it is null or not (the field is not set).
|
||||
template<class U>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11 flatbuffers::span<U> make_span(Vector<U> *ptr)
|
||||
FLATBUFFERS_NOEXCEPT {
|
||||
static_assert(Vector<U>::is_span_observable,
|
||||
"wrong type U, only LE-scalar, or byte types are allowed");
|
||||
return ptr ? make_span(*ptr) : span<U>();
|
||||
}
|
||||
|
||||
template<class U>
|
||||
FLATBUFFERS_CONSTEXPR_CPP11 flatbuffers::span<const U> make_span(
|
||||
const Vector<U> *ptr) FLATBUFFERS_NOEXCEPT {
|
||||
static_assert(Vector<U>::is_span_observable,
|
||||
"wrong type U, only LE-scalar, or byte types are allowed");
|
||||
return ptr ? make_span(*ptr) : span<const U>();
|
||||
}
|
||||
|
||||
// Represent a vector much like the template above, but in this case we
|
||||
// don't know what the element types are (used with reflection.h).
|
||||
class VectorOfAny {
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
#define FLATBUFFERS_VERIFIER_H_
|
||||
|
||||
#include "flatbuffers/base.h"
|
||||
#include "flatbuffers/util.h"
|
||||
#include "flatbuffers/vector.h"
|
||||
|
||||
namespace flatbuffers {
|
||||
@@ -26,22 +25,24 @@ namespace flatbuffers {
|
||||
// Helper class to verify the integrity of a FlatBuffer
|
||||
class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
public:
|
||||
Verifier(const uint8_t *buf, size_t buf_len, uoffset_t _max_depth = 64,
|
||||
uoffset_t _max_tables = 1000000, bool _check_alignment = true)
|
||||
Verifier(const uint8_t *const buf, const size_t buf_len,
|
||||
const uoffset_t _max_depth = 64,
|
||||
const uoffset_t _max_tables = 1000000,
|
||||
const bool _check_alignment = true)
|
||||
: buf_(buf),
|
||||
size_(buf_len),
|
||||
depth_(0),
|
||||
max_depth_(_max_depth),
|
||||
num_tables_(0),
|
||||
max_tables_(_max_tables),
|
||||
upper_bound_(0),
|
||||
check_alignment_(_check_alignment),
|
||||
upper_bound_(0),
|
||||
depth_(0),
|
||||
num_tables_(0),
|
||||
flex_reuse_tracker_(nullptr) {
|
||||
FLATBUFFERS_ASSERT(size_ < FLATBUFFERS_MAX_BUFFER_SIZE);
|
||||
}
|
||||
|
||||
// Central location where any verification failures register.
|
||||
bool Check(bool ok) const {
|
||||
bool Check(const bool ok) const {
|
||||
// clang-format off
|
||||
#ifdef FLATBUFFERS_DEBUG_VERIFICATION_FAILURE
|
||||
FLATBUFFERS_ASSERT(ok);
|
||||
@@ -55,7 +56,7 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
}
|
||||
|
||||
// Verify any range within the buffer.
|
||||
bool Verify(size_t elem, size_t elem_len) const {
|
||||
bool Verify(const size_t elem, const size_t elem_len) const {
|
||||
// clang-format off
|
||||
#ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
|
||||
auto upper_bound = elem + elem_len;
|
||||
@@ -66,48 +67,52 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
return Check(elem_len < size_ && elem <= size_ - elem_len);
|
||||
}
|
||||
|
||||
template<typename T> bool VerifyAlignment(size_t elem) const {
|
||||
return Check((elem & (sizeof(T) - 1)) == 0 || !check_alignment_);
|
||||
bool VerifyAlignment(const size_t elem, const size_t align) const {
|
||||
return Check((elem & (align - 1)) == 0 || !check_alignment_);
|
||||
}
|
||||
|
||||
// Verify a range indicated by sizeof(T).
|
||||
template<typename T> bool Verify(size_t elem) const {
|
||||
return VerifyAlignment<T>(elem) && Verify(elem, sizeof(T));
|
||||
template<typename T> bool Verify(const size_t elem) const {
|
||||
return VerifyAlignment(elem, sizeof(T)) && Verify(elem, sizeof(T));
|
||||
}
|
||||
|
||||
bool VerifyFromPointer(const uint8_t *p, size_t len) {
|
||||
auto o = static_cast<size_t>(p - buf_);
|
||||
return Verify(o, len);
|
||||
bool VerifyFromPointer(const uint8_t *const p, const size_t len) {
|
||||
return Verify(static_cast<size_t>(p - buf_), len);
|
||||
}
|
||||
|
||||
// Verify relative to a known-good base pointer.
|
||||
bool Verify(const uint8_t *base, voffset_t elem_off, size_t elem_len) const {
|
||||
return Verify(static_cast<size_t>(base - buf_) + elem_off, elem_len);
|
||||
bool VerifyFieldStruct(const uint8_t *const base, const voffset_t elem_off,
|
||||
const size_t elem_len, const size_t align) const {
|
||||
const auto f = static_cast<size_t>(base - buf_) + elem_off;
|
||||
return VerifyAlignment(f, align) && Verify(f, elem_len);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool Verify(const uint8_t *base, voffset_t elem_off) const {
|
||||
return Verify(static_cast<size_t>(base - buf_) + elem_off, sizeof(T));
|
||||
bool VerifyField(const uint8_t *const base, const voffset_t elem_off,
|
||||
const size_t align) const {
|
||||
const auto f = static_cast<size_t>(base - buf_) + elem_off;
|
||||
return VerifyAlignment(f, align) && Verify(f, sizeof(T));
|
||||
}
|
||||
|
||||
// Verify a pointer (may be NULL) of a table type.
|
||||
template<typename T> bool VerifyTable(const T *table) {
|
||||
template<typename T> bool VerifyTable(const T *const table) {
|
||||
return !table || table->Verify(*this);
|
||||
}
|
||||
|
||||
// Verify a pointer (may be NULL) of any vector type.
|
||||
template<typename T> bool VerifyVector(const Vector<T> *vec) const {
|
||||
template<typename T> bool VerifyVector(const Vector<T> *const vec) const {
|
||||
return !vec || VerifyVectorOrString(reinterpret_cast<const uint8_t *>(vec),
|
||||
sizeof(T));
|
||||
}
|
||||
|
||||
// Verify a pointer (may be NULL) of a vector to struct.
|
||||
template<typename T> bool VerifyVector(const Vector<const T *> *vec) const {
|
||||
template<typename T>
|
||||
bool VerifyVector(const Vector<const T *> *const vec) const {
|
||||
return VerifyVector(reinterpret_cast<const Vector<T> *>(vec));
|
||||
}
|
||||
|
||||
// Verify a pointer (may be NULL) to string.
|
||||
bool VerifyString(const String *str) const {
|
||||
bool VerifyString(const String *const str) const {
|
||||
size_t end;
|
||||
return !str || (VerifyVectorOrString(reinterpret_cast<const uint8_t *>(str),
|
||||
1, &end) &&
|
||||
@@ -116,24 +121,24 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
}
|
||||
|
||||
// Common code between vectors and strings.
|
||||
bool VerifyVectorOrString(const uint8_t *vec, size_t elem_size,
|
||||
size_t *end = nullptr) const {
|
||||
auto veco = static_cast<size_t>(vec - buf_);
|
||||
bool VerifyVectorOrString(const uint8_t *const vec, const size_t elem_size,
|
||||
size_t *const end = nullptr) const {
|
||||
const auto veco = static_cast<size_t>(vec - buf_);
|
||||
// Check we can read the size field.
|
||||
if (!Verify<uoffset_t>(veco)) return false;
|
||||
// Check the whole array. If this is a string, the byte past the array
|
||||
// must be 0.
|
||||
auto size = ReadScalar<uoffset_t>(vec);
|
||||
auto max_elems = FLATBUFFERS_MAX_BUFFER_SIZE / elem_size;
|
||||
const auto size = ReadScalar<uoffset_t>(vec);
|
||||
const auto max_elems = FLATBUFFERS_MAX_BUFFER_SIZE / elem_size;
|
||||
if (!Check(size < max_elems))
|
||||
return false; // Protect against byte_size overflowing.
|
||||
auto byte_size = sizeof(size) + elem_size * size;
|
||||
const auto byte_size = sizeof(size) + elem_size * size;
|
||||
if (end) *end = veco + byte_size;
|
||||
return Verify(veco, byte_size);
|
||||
}
|
||||
|
||||
// Special case for string contents, after the above has been called.
|
||||
bool VerifyVectorOfStrings(const Vector<Offset<String>> *vec) const {
|
||||
bool VerifyVectorOfStrings(const Vector<Offset<String>> *const vec) const {
|
||||
if (vec) {
|
||||
for (uoffset_t i = 0; i < vec->size(); i++) {
|
||||
if (!VerifyString(vec->Get(i))) return false;
|
||||
@@ -143,7 +148,8 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
}
|
||||
|
||||
// Special case for table contents, after the above has been called.
|
||||
template<typename T> bool VerifyVectorOfTables(const Vector<Offset<T>> *vec) {
|
||||
template<typename T>
|
||||
bool VerifyVectorOfTables(const Vector<Offset<T>> *const vec) {
|
||||
if (vec) {
|
||||
for (uoffset_t i = 0; i < vec->size(); i++) {
|
||||
if (!vec->Get(i)->Verify(*this)) return false;
|
||||
@@ -153,29 +159,40 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
}
|
||||
|
||||
__supress_ubsan__("unsigned-integer-overflow") bool VerifyTableStart(
|
||||
const uint8_t *table) {
|
||||
const uint8_t *const table) {
|
||||
// Check the vtable offset.
|
||||
auto tableo = static_cast<size_t>(table - buf_);
|
||||
const auto tableo = static_cast<size_t>(table - buf_);
|
||||
if (!Verify<soffset_t>(tableo)) return false;
|
||||
// This offset may be signed, but doing the subtraction unsigned always
|
||||
// gives the result we want.
|
||||
auto vtableo = tableo - static_cast<size_t>(ReadScalar<soffset_t>(table));
|
||||
const auto vtableo =
|
||||
tableo - static_cast<size_t>(ReadScalar<soffset_t>(table));
|
||||
// Check the vtable size field, then check vtable fits in its entirety.
|
||||
return VerifyComplexity() && Verify<voffset_t>(vtableo) &&
|
||||
VerifyAlignment<voffset_t>(ReadScalar<voffset_t>(buf_ + vtableo)) &&
|
||||
Verify(vtableo, ReadScalar<voffset_t>(buf_ + vtableo));
|
||||
if (!(VerifyComplexity() && Verify<voffset_t>(vtableo) &&
|
||||
VerifyAlignment(ReadScalar<voffset_t>(buf_ + vtableo),
|
||||
sizeof(voffset_t))))
|
||||
return false;
|
||||
const auto vsize = ReadScalar<voffset_t>(buf_ + vtableo);
|
||||
return Check((vsize & 1) == 0) && Verify(vtableo, vsize);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool VerifyBufferFromStart(const char *identifier, size_t start) {
|
||||
bool VerifyBufferFromStart(const char *const identifier, const size_t start) {
|
||||
// Buffers have to be of some size to be valid. The reason it is a runtime
|
||||
// check instead of static_assert, is that nested flatbuffers go through
|
||||
// this call and their size is determined at runtime.
|
||||
if (!Check(size_ >= FLATBUFFERS_MIN_BUFFER_SIZE)) return false;
|
||||
|
||||
// If an identifier is provided, check that we have a buffer
|
||||
if (identifier && !Check((size_ >= 2 * sizeof(flatbuffers::uoffset_t) &&
|
||||
BufferHasIdentifier(buf_ + start, identifier)))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Call T::Verify, which must be in the generated code for this type.
|
||||
auto o = VerifyOffset(start);
|
||||
return o && reinterpret_cast<const T *>(buf_ + start + o)->Verify(*this)
|
||||
const auto o = VerifyOffset(start);
|
||||
return Check(o != 0) &&
|
||||
reinterpret_cast<const T *>(buf_ + start + o)->Verify(*this)
|
||||
// clang-format off
|
||||
#ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
|
||||
&& GetComputedSize()
|
||||
@@ -185,9 +202,14 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool VerifyNestedFlatBuffer(const Vector<uint8_t> *buf,
|
||||
const char *identifier) {
|
||||
bool VerifyNestedFlatBuffer(const Vector<uint8_t> *const buf,
|
||||
const char *const identifier) {
|
||||
// An empty buffer is OK as it indicates not present.
|
||||
if (!buf) return true;
|
||||
|
||||
// If there is a nested buffer, it must be greater than the min size.
|
||||
if(!Check(buf->size() >= FLATBUFFERS_MIN_BUFFER_SIZE)) return false;
|
||||
|
||||
Verifier nested_verifier(buf->data(), buf->size());
|
||||
return nested_verifier.VerifyBuffer<T>(identifier);
|
||||
}
|
||||
@@ -195,19 +217,20 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
// Verify this whole buffer, starting with root type T.
|
||||
template<typename T> bool VerifyBuffer() { return VerifyBuffer<T>(nullptr); }
|
||||
|
||||
template<typename T> bool VerifyBuffer(const char *identifier) {
|
||||
template<typename T> bool VerifyBuffer(const char *const identifier) {
|
||||
return VerifyBufferFromStart<T>(identifier, 0);
|
||||
}
|
||||
|
||||
template<typename T> bool VerifySizePrefixedBuffer(const char *identifier) {
|
||||
template<typename T>
|
||||
bool VerifySizePrefixedBuffer(const char *const identifier) {
|
||||
return Verify<uoffset_t>(0U) &&
|
||||
ReadScalar<uoffset_t>(buf_) == size_ - sizeof(uoffset_t) &&
|
||||
Check(ReadScalar<uoffset_t>(buf_) == size_ - sizeof(uoffset_t)) &&
|
||||
VerifyBufferFromStart<T>(identifier, sizeof(uoffset_t));
|
||||
}
|
||||
|
||||
uoffset_t VerifyOffset(size_t start) const {
|
||||
uoffset_t VerifyOffset(const size_t start) const {
|
||||
if (!Verify<uoffset_t>(start)) return 0;
|
||||
auto o = ReadScalar<uoffset_t>(buf_ + start);
|
||||
const auto o = ReadScalar<uoffset_t>(buf_ + start);
|
||||
// May not point to itself.
|
||||
if (!Check(o != 0)) return 0;
|
||||
// Can't wrap around / buffers are max 2GB.
|
||||
@@ -218,7 +241,8 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
return o;
|
||||
}
|
||||
|
||||
uoffset_t VerifyOffset(const uint8_t *base, voffset_t start) const {
|
||||
uoffset_t VerifyOffset(const uint8_t *const base,
|
||||
const voffset_t start) const {
|
||||
return VerifyOffset(static_cast<size_t>(base - buf_) + start);
|
||||
}
|
||||
|
||||
@@ -255,23 +279,23 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *GetFlexReuseTracker() {
|
||||
return flex_reuse_tracker_;
|
||||
}
|
||||
std::vector<uint8_t> *GetFlexReuseTracker() { return flex_reuse_tracker_; }
|
||||
|
||||
void SetFlexReuseTracker(std::vector<uint8_t> *rt) {
|
||||
void SetFlexReuseTracker(std::vector<uint8_t> *const rt) {
|
||||
flex_reuse_tracker_ = rt;
|
||||
}
|
||||
|
||||
private:
|
||||
const uint8_t *buf_;
|
||||
size_t size_;
|
||||
uoffset_t depth_;
|
||||
uoffset_t max_depth_;
|
||||
uoffset_t num_tables_;
|
||||
uoffset_t max_tables_;
|
||||
const size_t size_;
|
||||
const uoffset_t max_depth_;
|
||||
const uoffset_t max_tables_;
|
||||
const bool check_alignment_;
|
||||
|
||||
mutable size_t upper_bound_;
|
||||
bool check_alignment_;
|
||||
|
||||
uoffset_t depth_;
|
||||
uoffset_t num_tables_;
|
||||
std::vector<uint8_t> *flex_reuse_tracker_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user