diff --git a/README.md b/README.md index 82b14ee0..5fdd5403 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,10 @@ In other cases you can contact the developer via email: - +#include "esp_nn_defs.h" /************************** Basic math functions ****************************/ /** @@ -81,28 +80,15 @@ void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data, * optimization notes: Though input_offset is int32 type, * offset values are contained in 8 bits [-128, 127] */ -void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, - const uint16_t ch_mult, +void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max); + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data); /** * @brief 2d-convolution channelwise @@ -112,43 +98,26 @@ void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data, * inputs type: int8_t, output: int8_t * input offsets: although int32_t, they are contained in 8 bits [-128, 127] */ -void esp_nn_conv_s8_ansi(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, +void esp_nn_conv_s8_ansi(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const uint16_t out_channels, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max); + const conv_params_t *conv_params, + const quant_data_t *quant_data); -int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_ch, - const uint16_t out_ch, - const uint16_t filter_wd, - const uint16_t filter_ht); +int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const conv_params_t *conv_params); void esp_nn_set_conv_scratch_buf_ansi(const void *buf); -int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const uint16_t ch_mult, - const uint16_t filter_wd, - const uint16_t filter_ht); +int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const dw_conv_params_t *conv_params); void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf); /************************** Activation functions *****************************/ @@ -252,9 +221,6 @@ int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t h */ void esp_nn_set_softmax_scratch_buf_ansi(void *buffer); -/* ANSI C function to be hooked up when optimised version needed */ -void esp_nn_set_softmax_scratch_buf_opt(void *buffer); - /** * @brief reference softmax function * @@ -268,6 +234,66 @@ void esp_nn_softmax_s8_ansi(const int8_t *input_data, const int32_t diff_min, int8_t *output_data); + +//////////////////////////// Generic optimisations ///////////////////////////// + +/************************** Convolution functions *****************************/ + +/** + * @brief 2d-convolution channelwise optimized version + * + * @note operation: result += (input + offset) * filter + * + * inputs type: int8_t, output: int8_t + * input offsets: although int32_t, they are contained in 8 bits [-128, 127] + */ +void esp_nn_conv_s8_opt(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, + const int8_t *filter_data, + const int32_t *bias, + const data_dims_t *output_dims, + int8_t *out_data, + const conv_params_t *conv_params, + const quant_data_t *quant_data); + +/** + * @brief depthwise convolution per channel optimized version + * + * @note inputs type: int8_t, output: int8_t + * Version used in tflite is per channel. + * This version follows the same footsprints. + * Meaning, it has per out_channel shift and multiplier for + * requantization + * + * optimization notes: Though input_offset is int32 type, + * offset values are contained in 8 bits [-128, 127] + */ +void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, + const int8_t *filter_data, + const int32_t *bias, + const data_dims_t *output_dims, + int8_t *out_data, + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data); + +int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const conv_params_t *conv_params); +void esp_nn_set_conv_scratch_buf_opt(const void *buf); + +int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const dw_conv_params_t *conv_params); +void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf); + +/* ANSI C function to be hooked up when optimised version needed */ +void esp_nn_set_softmax_scratch_buf_opt(void *buffer); + /** * @brief optimised version of softmax function * diff --git a/code/components/esp-nn/include/esp_nn_defs.h b/code/components/esp-nn/include/esp_nn_defs.h new file mode 100644 index 00000000..756d8e6f --- /dev/null +++ b/code/components/esp-nn/include/esp_nn_defs.h @@ -0,0 +1,83 @@ +// Copyright 2022 Espressif Systems (Shanghai) PTE LTD +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +/** + * @brief structure to club data dims + * this structure can be used for input, output and filter + */ +typedef struct data_dims { + int32_t width; + int32_t height; + int32_t channels; + + int32_t extra; // can be used as batch or any other param +} data_dims_t; + +/** + * @brief 2d data structure (width, height) + * + */ +typedef struct data_2d { + int32_t width; + int32_t height; +} data_2d_t; + +/** + * @brief min/max activation + */ +typedef struct act_params { + int32_t min; + int32_t max; +} act_params_t; + +/** + * @brief per channel quant data + * + * @note number of shift and mult elements are equal to output channels + */ +typedef struct quant_data { + int32_t *shift; + int32_t *mult; +} quant_data_t; + +/** + * @brief params specific to convolution 2d + * + */ +typedef struct conv_params { + int32_t in_offset; + int32_t out_offset; + data_2d_t stride; + data_2d_t padding; + data_2d_t dilation; + act_params_t activation; +} conv_params_t; + +/** + * @brief params specific to depthwise convolution 2d + * + */ +typedef struct dw_conv_params { + int32_t in_offset; + int32_t out_offset; + int32_t ch_mult; // channel multiplier. (in_ch * ch_mult = out_ch) + data_2d_t stride; + data_2d_t padding; + data_2d_t dilation; + act_params_t activation; +} dw_conv_params_t; diff --git a/code/components/esp-nn/include/esp_nn_esp32s3.h b/code/components/esp-nn/include/esp_nn_esp32s3.h index 58b544e4..0f52c943 100644 --- a/code/components/esp-nn/include/esp_nn_esp32s3.h +++ b/code/components/esp-nn/include/esp_nn_esp32s3.h @@ -19,7 +19,7 @@ #pragma once -#include +#include "esp_nn_defs.h" #include "esp_nn_ansi_headers.h" /************************** Basic math functions *****************************/ @@ -85,28 +85,15 @@ void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data, * optimization notes: Though input_offset is int32 type, * offset values are contained in 8 bits [-128, 127] */ -void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, - const uint16_t ch_mult, +void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, - int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max); + const data_dims_t *output_dims, + int8_t *output_data, + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data); /** * @brief 2d - convolution channelwise @@ -116,43 +103,26 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data, * inputs type: int8_t, output: int8_t * input offsets: although int32_t, they are contained in 8 bits [-128, 127] */ -void esp_nn_conv_s8_esp32s3(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, +void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, - int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const uint16_t out_channels, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max); + const data_dims_t *output_dims, + int8_t *output_data, + const conv_params_t *conv_params, + const quant_data_t *quant_data); -int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_ch, - const uint16_t out_ch, - const uint16_t filter_wd, - const uint16_t filter_ht); +int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const conv_params_t *conv_params); void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf); -int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const uint16_t ch_mult, - const uint16_t filter_wd, - const uint16_t filter_ht); +int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const dw_conv_params_t *conv_params); void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf); /************************** Pooling functions *****************************/ diff --git a/code/components/esp-nn/include/esp_nn_esp32.h b/code/components/esp-nn/include/esp_nn_generic_opt.h similarity index 77% rename from code/components/esp-nn/include/esp_nn_esp32.h rename to code/components/esp-nn/include/esp_nn_generic_opt.h index 03fd8216..136cba5d 100644 --- a/code/components/esp-nn/include/esp_nn_esp32.h +++ b/code/components/esp-nn/include/esp_nn_generic_opt.h @@ -13,28 +13,27 @@ // limitations under the License. /** - * @file Header definitions to include for esp_nn optimized functions for - * the ESP32 platform. - * We are hooking up just the C versions for now. - * The file hence is exactly same as `esp_nn_ansi_c.h` + * @file Header definitions to include for esp_nn generic optimisations + * For functions which not having optimisations, _ansi versions are picked. */ #pragma once +#include "esp_nn_defs.h" #include "esp_nn_ansi_headers.h" #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi #define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi -#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi +#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_opt -#define esp_nn_conv_s8 esp_nn_conv_s8_ansi +#define esp_nn_conv_s8 esp_nn_conv_s8_opt -#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi -#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi +#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_opt +#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_opt -#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi -#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi +#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_opt +#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_opt #define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi diff --git a/code/components/esp-nn/src/common/common_functions.h b/code/components/esp-nn/src/common/common_functions.h index 9a5f0dcc..0a74eca4 100644 --- a/code/components/esp-nn/src/common/common_functions.h +++ b/code/components/esp-nn/src/common/common_functions.h @@ -41,15 +41,39 @@ __NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in) { +#if CONFIG_IDF_TARGET_ARCH_XTENSA __asm__ volatile("nsau %0, %0" : "+r" (in)); return in; -} - -__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64) -{ - int32_t sign = (int32_t) (val64 >> 63); - int32_t to_add = sign & ((1ul << 31) - 1); - return (int32_t) ((int64_t) (val64 + to_add) >> 31); +#elif defined(__GNUC__) + return __builtin_clz(in); +#else + int32_t count = 32; + uint32_t x = in, y = in >> 16; + if (y != 0) { + count -= 16; + x = y; + } + y = x >> 8; + if (y != 0) { + count -= 8; + x = y; + } + y = x >> 4; + if (y != 0) { + count -= 4; + x = y; + } + y = x >> 2; + if (y != 0) { + count -= 2; + x = y; + } + y = x >> 1; + if (y != 0) { + return count - 2; + } + return count - x; +#endif } /** @@ -57,8 +81,19 @@ __NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64) */ __NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in) { +#if CONFIG_IDF_TARGET_ARCH_XTENSA __asm__ volatile("clamps %0, %0, 7" : "+a"(in)); return in; +#else + return max(INT8_MIN, min(in, INT8_MAX)); +#endif +} + +__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64) +{ + int32_t sign = (int32_t) (val64 >> 63); + int32_t to_add = sign & ((1ul << 31) - 1); + return (int32_t) ((int64_t) (val64 + to_add) >> 31); } __NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1) @@ -144,7 +179,7 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst, const uint16_t pad_ht) { /* memset with pad_val */ - memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels * 2); + memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels); dst += (pad_wd + input_wd + pad_wd) * channels; for (int i = 0; i < input_ht; i++) { @@ -156,7 +191,6 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst, } } -#if 0 static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst, const uint16_t input_wd, const uint16_t input_ht, @@ -169,13 +203,16 @@ static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst, for (int j = 0; j < input_wd * channels; j++) { *dst++ = *src++; } - memset(dst, pad_val, pad_wd * channels); - dst += pad_wd * channels; + if (pad_wd) { + memset(dst, pad_val, pad_wd * channels); + dst += pad_wd * channels; + } } /* pad end `pad_ht` lines at end */ - memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels); + if (pad_ht) { + memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels); + } } -#endif /** * @brief convert 8 bit input data to 16 bit diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c b/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c index d04f78e1..677c0ad8 100644 --- a/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c +++ b/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include -int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_ch, - const uint16_t out_ch, - const uint16_t filter_wd, - const uint16_t filter_ht) +int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const conv_params_t *conv_params) { return 0; } @@ -108,29 +106,35 @@ void esp_nn_conv_u8_ansi(const uint8_t *input_data, * Assumption 2: Pointers are valid * Assumption 3: dialation width = 1 */ -void esp_nn_conv_s8_ansi(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, +void esp_nn_conv_s8_ansi(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const uint16_t out_channels, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max) + const conv_params_t *conv_params, + const quant_data_t *quant_data) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t in_channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const uint16_t out_channels = output_dims->channels; + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx; for (out_y = 0; out_y < out_ht; out_y++) { diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c b/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c index ea8fdfa5..e13129b2 100644 --- a/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c +++ b/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c @@ -12,30 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include #include static int16_t *scratch_buffer = NULL; -extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_channels, - const int32_t input_offset, - const int16_t *filter_data, - const int32_t *bias, - int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const uint16_t out_channels, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max, - void *buffer /* scratch buffer */); +extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data, + const uint16_t input_wd, + const uint16_t input_ht, + const uint16_t in_channels, + const int32_t input_offset, + const int8_t *filter_aligned, + const int32_t *bias, + int8_t *out_data, + const uint16_t out_wd, + const uint16_t out_ht, + const uint16_t out_channels, + const int32_t out_offset, + const int32_t *out_shift, + const int32_t *out_mult, + const int32_t activation_min, + const int32_t activation_max, + void *buffer /* scratch buffer */); extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data, const uint16_t input_wd, @@ -81,34 +81,40 @@ extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int1 extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size); -static void esp_nn_conv_s8_unrolled(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, +static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const uint16_t out_channels, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max) + const conv_params_t *conv_params, + const quant_data_t *quant_data) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t in_ch = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const uint16_t out_ch = output_dims->channels; + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx; for (out_y = 0; out_y < out_ht; out_y++) { for (out_x = 0; out_x < out_wd; out_x++) { - for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) { + for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) { int32_t conv_out = 0; const int32_t base_y = stride_ht * out_y - pad_ht; @@ -124,10 +130,10 @@ static void esp_nn_conv_s8_unrolled(const int8_t *input_data, for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) { const int32_t in_row = base_y + filter_y_idx; const int32_t in_col = base_x + filter_x_idx; - int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels; - int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd + - (filter_y_idx * filter_wd + filter_x_idx) * in_channels; - for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) { + int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch; + int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd + + (filter_y_idx * filter_wd + filter_x_idx) * in_ch; + for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) { conv_out += (input_data[input_base_offset + in_ch_idx] + input_offset) * filter_data[filter_base_offset + in_ch_idx]; @@ -332,18 +338,35 @@ static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data, } } -int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t in_ch, - const uint16_t out_ch, - const uint16_t filter_wd, - const uint16_t filter_ht) +int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const conv_params_t *conv_params) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t in_ch = input_dims->channels; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_ch = output_dims->channels; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + int filter_size = filter_wd * filter_ht * in_ch * out_ch; int input_size = input_wd * input_ht * in_ch; - int transpose_buf_size = 8 * in_ch; /* to store intermediate data */ + + int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */ + if (input_wd * input_ht < 8) { + transpose_buf_size = 0; // not using this for leftover + } int align_buf_size = 32; /* extra buffer for alignment */ - return 2 * (filter_size + input_size + transpose_buf_size) + align_buf_size; + if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 && + pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) { + return filter_size + transpose_buf_size + align_buf_size; + } + return 2 * (filter_size + input_size) + transpose_buf_size + align_buf_size; } void esp_nn_set_conv_scratch_buf_esp32s3(void *buf) @@ -351,29 +374,35 @@ void esp_nn_set_conv_scratch_buf_esp32s3(void *buf) scratch_buffer = (int16_t *) buf; } -void esp_nn_conv_s8_esp32s3(const int8_t *input, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, +void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims, + const int8_t *input, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const uint16_t out_channels, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max) + const conv_params_t *conv_params, + const quant_data_t *quant_data) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const uint16_t out_channels = output_dims->channels; + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + int filter_size = filter_wd * filter_ht * channels * out_channels; int input_size = input_wd * input_ht * channels; int align_len = 16 - (filter_size & 15); @@ -387,15 +416,16 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input, if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 && pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) { - int scratch_offset = (int) (filter_data16 + filter_size); + int8_t *filter_aligned = (int8_t *) scratch_buffer; + int scratch_offset = (int) (filter_aligned + filter_size); void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15)); - esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size); - esp_nn_conv_s16_mult8_1x1_esp32s3( - input, input_wd, input_ht, channels, input_offset, filter_data16, + memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address + esp_nn_conv_s8_mult8_1x1_esp32s3( + input, input_wd, input_ht, channels, input_offset, filter_aligned, bias, out_data, out_wd, out_ht, out_channels, out_offset, out_shift, out_mult, activation_min, activation_max, scratch_buf); } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 && - (input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */ + (input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */ pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) { int scratch_offset = (int) (input_data16 + input_size); void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15)); @@ -427,10 +457,7 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input, } } else { /* Basic unrolled version */ - esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset, - pad_wd, pad_ht, stride_wd, stride_ht, - filter_data, filter_wd, filter_ht, bias, - out_data, out_wd, out_ht, out_channels, out_offset, out_shift, - out_mult, activation_min, activation_max); + esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data, + bias, output_dims, out_data, conv_params, quant_data); } } diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_opt.c b/code/components/esp-nn/src/convolution/esp_nn_conv_opt.c new file mode 100644 index 00000000..be96430e --- /dev/null +++ b/code/components/esp-nn/src/convolution/esp_nn_conv_opt.c @@ -0,0 +1,179 @@ +// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const conv_params_t *conv_params) +{ + return 0; +} + +void esp_nn_set_conv_scratch_buf_opt(const void *buf) +{ + +} + +__attribute__ ((noinline)) +static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims, + const int8_t *input_data, + const int8_t *filter_data, + const int32_t *bias, + const data_dims_t *output_dims, + int8_t *out_data, + const conv_params_t *conv_params, + const quant_data_t *quant_data) +{ + const uint16_t input_wd = input_dims->width; + const uint16_t in_channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const uint16_t out_channels = output_dims->channels; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + + for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) { + for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) { + const int32_t *out_mult = quant_data->mult; + const int32_t *out_shift = quant_data->shift; + const int8_t *filter_ptr = filter_data; + const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels; + int32_t out_ch_idx = 0; + for (; out_ch_idx < out_channels; out_ch_idx++) { + int32_t conv_out = 0; + + const int8_t *input_ptr = input_base_ptr; + + int32_t in_ch_idx = 0; + for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) { + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + } + for (; in_ch_idx < in_channels; in_ch_idx ++) { + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + } + if (bias) { + conv_out += bias[out_ch_idx]; + } + conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++); + conv_out += out_offset; + conv_out = max(conv_out, activation_min); + conv_out = min(conv_out, activation_max); + *out_data++ = (int8_t) conv_out; + } + } + } +} + +/** + * Assumption 1: i/p channels == o/p channels + * Assumption 2: Pointers are valid + * Assumption 3: dialation width = 1 + */ +void esp_nn_conv_s8_opt(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, + const int8_t *filter_data, + const int32_t *bias, + const data_dims_t *output_dims, + int8_t *out_data, + const conv_params_t *conv_params, + const quant_data_t *quant_data) +{ + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + + if (filter_wd == 1 && filter_ht == 1) { + esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias, + output_dims, out_data, conv_params, quant_data); + return; + } + + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t in_channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const uint16_t out_channels = output_dims->channels; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + + int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx; + + for (out_y = 0; out_y < out_ht; out_y++) { + for (out_x = 0; out_x < out_wd; out_x++) { + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) { + int32_t conv_out = 0; + + const int32_t base_y = stride_ht * out_y - pad_ht; + const int32_t base_x = stride_wd * out_x - pad_wd; + + const int32_t filter_y_start = max(0, -base_y); + const int32_t filter_x_start = max(0, -base_x); + + const int32_t filter_y_end = min(filter_ht, input_ht - base_y); + const int32_t filter_x_end = min(filter_wd, input_wd - base_x); + + for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) { + for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) { + const int32_t in_row = base_y + filter_y_idx; + const int32_t in_col = base_x + filter_x_idx; + + const int8_t *input_ptr = input_data + + (in_row * input_wd + in_col) * in_channels; + const int8_t *filter_ptr = filter_data + + out_ch_idx * in_channels * filter_ht * filter_wd + + (filter_y_idx * filter_wd + filter_x_idx) * in_channels; + int32_t in_ch_idx = 0; + for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) { + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + } + for (; in_ch_idx < in_channels; in_ch_idx ++) { + conv_out += (*input_ptr++ + input_offset) * *filter_ptr++; + } + } + } + if (bias) { + conv_out += bias[out_ch_idx]; + } + conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++); + conv_out += out_offset; + conv_out = max(conv_out, activation_min); + conv_out = min(conv_out, activation_max); + *out_data++ = (int8_t) conv_out; + } + } + } +} diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c index 9cac6cef..1cd02e0f 100644 --- a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c +++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c @@ -12,16 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - +#include #include -int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const uint16_t ch_mult, - const uint16_t filter_wd, - const uint16_t filter_ht) +int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const dw_conv_params_t *conv_params) { return 0; } @@ -31,29 +28,35 @@ void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf) } -void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, - const uint16_t ch_mult, +void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max) + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + const uint16_t ch_mult = conv_params->ch_mult; + int out_idx = 0; for (int out_y = 0; out_y < out_ht; out_y++) { //height loop const int16_t base_y = (out_y * stride_ht) - pad_ht; diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c new file mode 100644 index 00000000..4afea3f3 --- /dev/null +++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c @@ -0,0 +1,291 @@ +// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const dw_conv_params_t *conv_params) +{ + return 0; +} + +void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf) +{ + +} + +/* common channel multiplier == 1 case */ +__attribute__ ((noinline)) +static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, + const int8_t *filter_data, + const int32_t *bias, + const data_dims_t *output_dims, + int8_t *out_data, + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data) +{ + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + + int out_idx = 0; + for (int out_y = 0; out_y < out_ht; out_y++) { //height loop + const int16_t base_y = (out_y * stride_ht) - pad_ht; + for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop + const int16_t base_x = (out_x * stride_wd) - pad_wd; + + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + + /* Select filter so as the point doesn't lie outside block */ + int filter_y_start = max(0, -base_y); + int filter_x_start = max(0, -base_x); + int filter_y_end = min(filter_ht, input_ht - base_y); + int filter_x_end = min(filter_wd, input_wd - base_x); + + int ch_idx = 0; + for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop + int32_t result0 = 0; + int32_t result1 = 0; + int32_t result2 = 0; + int32_t result3 = 0; + + for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) { + const int32_t idx_y = base_y + filter_y_idx; + for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) { + const int32_t idx_x = base_x + filter_x_idx; + int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx; + int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx; + int32_t input_val0 = input_data[input_index + 0] + input_offset; + int32_t input_val1 = input_data[input_index + 1] + input_offset; + int32_t input_val2 = input_data[input_index + 2] + input_offset; + int32_t input_val3 = input_data[input_index + 3] + input_offset; + int32_t filter_val0 = filter_data[filter_index + 0]; + int32_t filter_val1 = filter_data[filter_index + 1]; + int32_t filter_val2 = filter_data[filter_index + 2]; + int32_t filter_val3 = filter_data[filter_index + 3]; + result0 += input_val0 * filter_val0; + result1 += input_val1 * filter_val1; + result2 += input_val2 * filter_val2; + result3 += input_val3 * filter_val3; + } + } + if (bias) { + result0 += bias[ch_idx + 0]; + result1 += bias[ch_idx + 1]; + result2 += bias[ch_idx + 2]; + result3 += bias[ch_idx + 3]; + } + result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++); + result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++); + result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++); + result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++); + + result0 += out_offset; + result1 += out_offset; + result2 += out_offset; + result3 += out_offset; + + result0 = max(result0, activation_min); + result1 = max(result1, activation_min); + result2 = max(result2, activation_min); + result3 = max(result3, activation_min); + + result0 = min(result0, activation_max); + result1 = min(result1, activation_max); + result2 = min(result2, activation_max); + result3 = min(result3, activation_max); + + out_data[out_idx++] = result0; + out_data[out_idx++] = result1; + out_data[out_idx++] = result2; + out_data[out_idx++] = result3; + } + for (; ch_idx < channels; ch_idx++) {//channel_loop + int32_t result = 0; + + for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) { + const int32_t idx_y = base_y + filter_y_idx; + for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) { + const int32_t idx_x = base_x + filter_x_idx; + int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx; + int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx; + int32_t input_val = input_data[input_index] + input_offset; + int32_t filter_val = filter_data[filter_index]; + result += input_val * filter_val; + } + } + if (bias) { + result += bias[ch_idx]; + } + result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++); + result += out_offset; + result = max(result, activation_min); + result = min(result, activation_max); + + out_data[out_idx++] = result; + } + } + } +} + +void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, + const int8_t *filter_data, + const int32_t *bias, + const data_dims_t *output_dims, + int8_t *out_data, + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data) +{ + const uint16_t ch_mult = conv_params->ch_mult; + if (ch_mult == 1) { + esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data, + bias, output_dims, out_data, conv_params, quant_data); + return; + } + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + + int out_idx = 0; + for (int out_y = 0; out_y < out_ht; out_y++) { //height loop + const int16_t base_y = (out_y * stride_ht) - pad_ht; + for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop + const int16_t base_x = (out_x * stride_wd) - pad_wd; + + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + + /* Select filter so as the point doesn't lie outside block */ + int filter_y_start = max(0, -base_y); + int filter_x_start = max(0, -base_x); + int filter_y_end = min(filter_ht, input_ht - base_y); + int filter_x_end = min(filter_wd, input_wd - base_x); + + for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop + int ch_mult_idx = 0; + for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) { + int32_t result0 = 0; + int32_t result1 = 0; + int32_t result2 = 0; + int32_t result3 = 0; + const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx; + + for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) { + const int32_t idx_y = base_y + filter_y_idx; + for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) { + const int32_t idx_x = base_x + filter_x_idx; + int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx; + int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx; + int32_t input_val = input_data[input_index] + input_offset; + int32_t filter_val0 = filter_data[filter_index + 0]; + int32_t filter_val1 = filter_data[filter_index + 1]; + int32_t filter_val2 = filter_data[filter_index + 2]; + int32_t filter_val3 = filter_data[filter_index + 3]; + result0 += input_val * filter_val0; + result1 += input_val * filter_val1; + result2 += input_val * filter_val2; + result3 += input_val * filter_val3; + } + } + if (bias) { + result0 += bias[out_ch_idx + 0]; + result1 += bias[out_ch_idx + 1]; + result2 += bias[out_ch_idx + 2]; + result3 += bias[out_ch_idx + 3]; + } + result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++); + result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++); + result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++); + result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++); + + result0 += out_offset; + result1 += out_offset; + result2 += out_offset; + result3 += out_offset; + + result0 = max(result0, activation_min); + result1 = max(result1, activation_min); + result2 = max(result2, activation_min); + result3 = max(result3, activation_min); + result0 = min(result0, activation_max); + result1 = min(result1, activation_max); + result2 = min(result2, activation_max); + result3 = min(result3, activation_max); + + out_data[out_idx++] = result0; + out_data[out_idx++] = result1; + out_data[out_idx++] = result2; + out_data[out_idx++] = result3; + } + for (; ch_mult_idx < ch_mult; ch_mult_idx++) { + int32_t result = 0; + const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx; + + for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) { + const int32_t idx_y = base_y + filter_y_idx; + for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) { + const int32_t idx_x = base_x + filter_x_idx; + int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx; + int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx; + int32_t input_val = input_data[input_index] + input_offset; + int32_t filter_val = filter_data[filter_index]; + result += input_val * filter_val; + } + } + if (bias) { + result += bias[out_ch_idx]; + } + result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++); + result += out_offset; + result = max(result, activation_min); + result = min(result, activation_max); + + out_data[out_idx++] = result; + } + } + } + } +} diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c index c588c48f..9167a43f 100644 --- a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c +++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include #include @@ -353,17 +353,59 @@ void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data, } } -int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const uint16_t ch_mult, - const uint16_t filter_wd, - const uint16_t filter_ht) +int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims, + const data_dims_t *filter_dims, + const data_dims_t *output_dims, + const dw_conv_params_t *conv_params) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t channels = input_dims->channels; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t ch_mult = conv_params->ch_mult; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + int filter_size = filter_wd * filter_ht * channels * ch_mult; - int padding_used = ((filter_wd == 3) && (filter_ht == 3)) * 2; - int input_size = (input_wd + padding_used) * (input_ht + padding_used) * channels; - return 2 * (filter_size + input_size) + 16; //16 for alignment + int pad_width = 0, pad_height = 0; + + if ((ch_mult == 1) && (channels % 8 == 0) && (filter_wd == 3) && (filter_ht == 3)) { + if (channels % 16 == 0) { + if (pad_wd || pad_ht) { + pad_width = pad_wd * 2; + pad_height = pad_ht * 2; + } else { + // check if we need to pad additionally + pad_width = (out_wd * stride_wd + filter_wd - 1) - input_wd; + pad_height = (out_ht * stride_ht + filter_ht - 1) - input_ht; + // printf("in(%d %d %d), out(%d %d), filter (%d %d) stride (%d %d), pad (%d %d)", + // input_wd, input_ht, channels, out_wd, out_ht, filter_wd, filter_ht, + // stride_wd, stride_ht, pad_wd, pad_ht); + } + if (pad_width || pad_height) { + int input_size = (input_wd + pad_width) * (input_ht + pad_height) * channels; + // printf("ask1 %d\n", filter_size + input_size + 16); + return filter_size + input_size + 16; // 16 for alignment + } else { + // printf("ask2 %d\n", filter_size + 16); + return filter_size + 16; // 16 for alignment + } + } else { + int input_size = input_wd * input_ht * channels; + // printf("ask3 %d\n", 2 * (filter_size + input_size) + 16); + return 2 * (filter_size + input_size) + 16; // 16 for alignment + } + } else if (ch_mult % 4 == 0) { + int input_size = input_wd * input_ht * channels; + // printf("ask4 %d\n", 2 * (filter_size + input_size) + 16); + return 2 * (filter_size + input_size) + 16; // 16 for alignment + } + return 32; // just few bytes } void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf) @@ -376,29 +418,38 @@ void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf) * Assumption 2: Pointers are valid * Assumption 3: dialation width = 1 */ -void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data, - const uint16_t input_wd, - const uint16_t input_ht, - const uint16_t channels, - const int32_t input_offset, - const uint16_t pad_wd, - const uint16_t pad_ht, - const uint16_t stride_wd, - const uint16_t stride_ht, - const uint16_t ch_mult, + + + +void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims, + const int8_t *input_data, + const data_dims_t *filter_dims, const int8_t *filter_data, - const uint16_t filter_wd, - const uint16_t filter_ht, const int32_t *bias, + const data_dims_t *output_dims, int8_t *out_data, - const uint16_t out_wd, - const uint16_t out_ht, - const int32_t out_offset, - const int32_t *out_shift, - const int32_t *out_mult, - const int32_t activation_min, - const int32_t activation_max) + const dw_conv_params_t *conv_params, + const quant_data_t *quant_data) { + const uint16_t input_wd = input_dims->width; + const uint16_t input_ht = input_dims->height; + const uint16_t channels = input_dims->channels; + const int32_t input_offset = conv_params->in_offset; + const int32_t out_offset = conv_params->out_offset; + const uint16_t pad_wd = conv_params->padding.width; + const uint16_t pad_ht = conv_params->padding.height; + const uint16_t stride_wd = conv_params->stride.width; + const uint16_t stride_ht = conv_params->stride.height; + const uint16_t filter_wd = filter_dims->width; + const uint16_t filter_ht = filter_dims->height; + const uint16_t out_wd = output_dims->width; + const uint16_t out_ht = output_dims->height; + const int32_t *out_shift = quant_data->shift; + const int32_t *out_mult = quant_data->mult; + const int32_t activation_min = conv_params->activation.min; + const int32_t activation_max = conv_params->activation.max; + const uint16_t ch_mult = conv_params->ch_mult; + int filter_size = filter_wd * filter_ht * channels * ch_mult; int align_len = 16 - (filter_size & 15); int input_size = input_wd * input_ht * channels; @@ -423,18 +474,27 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data, stride_wd, stride_ht, filter_aligned, bias, out_data, out_wd, out_ht, out_offset, out_shift, out_mult, activation_min, activation_max); - } else if ((pad_wd == 0) && (pad_ht == 0) && - // because this does not handle padding offset cases yet, run just for stride (1, 1). - // end padding of input with `-input_offset` should solve this - (stride_wd == 1) && (stride_ht == 1)) { + } else if ((channels % 16 == 0) && (pad_wd == 0) && (pad_ht == 0)) { /* process in 8 bits */ int8_t *filter_aligned = (int8_t *) scratch_buffer; + int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len; + + // check if we need to pad additionally + int pad_right = (out_wd * stride_wd + filter_wd - 1) - input_wd; + int pad_bottom = (out_ht * stride_ht + filter_ht - 1) - input_ht; + if (pad_right || pad_bottom) { // pad right and bottom + esp_nn_aligned_s8_pad_end_with_value(input_data, input_padded, input_wd, input_ht, + channels, -input_offset, pad_right, pad_bottom); + } else { + input_padded = (int8_t *) input_data; + } memcpy(filter_aligned, filter_data, filter_size); - esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_data, input_wd, input_ht, channels, input_offset, - stride_wd, stride_ht, filter_aligned, - bias, out_data, out_wd, out_ht, out_offset, out_shift, + esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + pad_right, + input_ht + pad_bottom, channels, input_offset, + stride_wd, stride_ht, filter_aligned, bias, + out_data, out_wd, out_ht, out_offset, out_shift, out_mult, activation_min, activation_max); - } else { /* (channels % 8) == 0 && pad_wd == 1 && pad_ht == 1 */ + } else { /* (channels % 8) == 0 */ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size); esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset); esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels, diff --git a/code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3 b/code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3 new file mode 100644 index 00000000..1adc4b01 --- /dev/null +++ b/code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3 @@ -0,0 +1,8 @@ +# Default configurations for ESP32-S3 + +CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y +CONFIG_ESP32S3_SPIRAM_SUPPORT=y + +CONFIG_ESP32S3_DATA_CACHE_64KB=y +CONFIG_ESP32S3_DATA_CACHE_8WAYS=y +CONFIG_ESP32S3_DATA_CACHE_LINE_64B=y diff --git a/code/components/esp-nn/tests/src/basic_math_test.c b/code/components/esp-nn/tests/src/basic_math_test.c index 5b96b990..715d7c78 100644 --- a/code/components/esp-nn/tests/src/basic_math_test.c +++ b/code/components/esp-nn/tests/src/basic_math_test.c @@ -23,7 +23,9 @@ #include "test_utils.h" #if CONFIG_IDF_CMAKE +#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC)) #define IDF_HEAP_CAPS 1 +#endif #if IDF_HEAP_CAPS #include "esp_heap_caps.h" @@ -138,6 +140,11 @@ void esp_nn_add_elementwise_s8_test() out_c_orig = out_data_c; out_opt_orig = out_data_opt; #endif + if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL || + out_opt_orig == NULL) { + printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__); + goto elementwise_add_test_cleanup; + } for (int i = 0; i < size; ++i) { input1[i] = rand() % 256 - 128; @@ -194,10 +201,10 @@ elementwise_add_test_cleanup: if (input2_orig) { free(input2_orig); } - if (out_data_c) { + if (out_c_orig) { free(out_c_orig); } - if (out_data_opt) { + if (out_opt_orig) { free(out_opt_orig); } } @@ -282,6 +289,11 @@ void esp_nn_mul_elementwise_s8_test() out_c_orig = out_data_c; out_opt_orig = out_data_opt; #endif + if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL || + out_opt_orig == NULL) { + printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__); + goto elementwise_mult_test_cleanup; + } for (int i = 0; i < size; ++i) { input1[i] = rand() % 256 - 128; @@ -333,10 +345,10 @@ elementwise_mult_test_cleanup: if (input2_orig) { free(input2_orig); } - if (out_data_c) { + if (out_c_orig) { free(out_c_orig); } - if (out_data_opt) { + if (out_opt_orig) { free(out_opt_orig); } } diff --git a/code/components/esp-nn/tests/src/convolution_test.c b/code/components/esp-nn/tests/src/convolution_test.c index f3802257..c86bdbab 100644 --- a/code/components/esp-nn/tests/src/convolution_test.c +++ b/code/components/esp-nn/tests/src/convolution_test.c @@ -22,8 +22,9 @@ #include "test_utils.h" #if CONFIG_IDF_CMAKE +#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC)) #define IDF_HEAP_CAPS 1 - +#endif #if IDF_HEAP_CAPS #include "esp_heap_caps.h" #endif @@ -44,8 +45,8 @@ void esp_nn_depthwise_conv_s8_test() uint16_t filter_ht, filter_wd, ch_mult; uint16_t pad_wd, pad_ht, stride_wd, stride_ht; - // run for 10 iterations - for (int itr = 0; itr < 10; itr++) { + // run for 15 iterations + for (int itr = 0; itr < 15; itr++) { /* prepare data */ switch (itr) { case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0) @@ -144,22 +145,52 @@ void esp_nn_depthwise_conv_s8_test() stride_wd = 2; stride_ht = 2; break; - default: - input_wd = 4; - input_ht = 4; + case 8: // same as case 7, with large parameters + input_wd = 58; + input_ht = 58; filter_ht = 3; filter_wd = 3; - ch_mult = 4; - channels = 4; - pad_wd = 1; - pad_ht = 1; - stride_wd = 1; - stride_ht = 1; + ch_mult = 1; + channels = 128; + pad_wd = 0; + pad_ht = 0; + stride_wd = 2; + stride_ht = 2; + break; + case 9: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0) stride (2,2) + input_wd = 6; + input_ht = 6; + filter_ht = 3; + filter_wd = 3; + ch_mult = 1; + channels = 16; + pad_wd = 0; + pad_ht = 0; + stride_wd = 2; + stride_ht = 2; + break; + default: + input_wd = 6; + input_ht = 6; + filter_ht = 3; + filter_wd = 3; + ch_mult = 1; + channels = 16; + stride_wd = rand() % 2 + 1; + stride_ht = stride_wd; + pad_wd = stride_wd == 1 ? 0 : rand() % 2; + pad_ht = pad_wd; + printf("stride(%d), pad (%d)\t", stride_wd, pad_wd); break; } uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd; uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht; + if (itr == 9) { + // expect the function to handle this gracefully + out_wd += 1; + out_ht += 1; + } int in_size = input_wd * input_ht * channels; int out_size = out_wd * out_ht * channels * ch_mult; int filter_size = filter_wd * filter_ht * channels * ch_mult + 4; @@ -210,9 +241,16 @@ void esp_nn_depthwise_conv_s8_test() out_mult[i] = 0x7eb0e200 + rand() % 50; } - int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(input_wd, input_ht, - channels, ch_mult, - filter_wd, filter_ht); + data_dims_t input_dims = {.width = input_wd, .height = input_ht, .channels = channels, 1}; + data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = channels * ch_mult, 1}; + data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0}; + dw_conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset, .ch_mult = ch_mult, + .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht}, + .dilation = {0, 0}, .activation = {activation_min, activation_max}}; + quant_data_t quant_data = {.shift = out_shift, .mult = out_mult}; + + int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(&input_dims, &filter_dims, + &output_dims, &conv_params); if (scratch_buf_size > 0) { #if IDF_HEAP_CAPS scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT); @@ -234,11 +272,8 @@ void esp_nn_depthwise_conv_s8_test() } /* C function */ - esp_nn_depthwise_conv_s8_ansi(input, input_wd, input_ht, channels, input_offset, - pad_wd, pad_ht, stride_wd, stride_ht, ch_mult, - filter_data + 4, filter_wd, filter_ht, - bias + 1, out_data_c, out_wd, out_ht, out_offset, out_shift, - out_mult, activation_min, activation_max); + esp_nn_depthwise_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 4, + bias + 1, &output_dims, out_data_c, &conv_params, &quant_data); if (itr == 0) { profile_c_end(); @@ -246,11 +281,8 @@ void esp_nn_depthwise_conv_s8_test() } /* Optimized function */ - esp_nn_depthwise_conv_s8(input, input_wd, input_ht, channels, input_offset, - pad_wd, pad_ht, stride_wd, stride_ht, ch_mult, - filter_data + 4, filter_wd, filter_ht, - bias + 1, out_data_opt, out_wd, out_ht, out_offset, out_shift, - out_mult, activation_min, activation_max); + esp_nn_depthwise_conv_s8(&input_dims, input, &filter_dims, filter_data + 4, + bias + 1, &output_dims, out_data_opt, &conv_params, &quant_data); if (itr == 0) { /* disable profiler */ @@ -479,8 +511,16 @@ void esp_nn_conv_s8_test() out_mult[i] = 0x7f67f4f8 + rand() % 50; } - int scratch_buf_size = esp_nn_get_conv_scratch_size(in_wd, in_ht, in_channels, - out_channels, filter_wd, filter_ht); + data_dims_t input_dims = {.width = in_wd, .height = in_ht, .channels = in_channels, 1}; + data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = out_channels, 1}; + data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0}; + conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset, + .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht}, + .dilation = {0, 0}, .activation = {activation_min, activation_max}}; + quant_data_t quant_data = {.shift = out_shift, .mult = out_mult}; + + int scratch_buf_size = esp_nn_get_conv_scratch_size(&input_dims, &filter_dims, + &output_dims, &conv_params); if (scratch_buf_size > 0) { #if IDF_HEAP_CAPS void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT); @@ -502,11 +542,8 @@ void esp_nn_conv_s8_test() } /* C function */ - esp_nn_conv_s8_ansi(input, in_wd, in_ht, in_channels, input_offset, - pad_wd, pad_ht, stride_wd, stride_ht, - filter_data + 2, filter_wd, filter_ht, bias, - out_data_c, out_wd, out_ht, out_channels, out_offset, out_shift, - out_mult, activation_min, activation_max); + esp_nn_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 2, + bias, &output_dims, out_data_c, &conv_params, &quant_data); if (itr == 0) { profile_c_end(); @@ -514,11 +551,8 @@ void esp_nn_conv_s8_test() } /* Optimized function */ - esp_nn_conv_s8(input, in_wd, in_ht, in_channels, input_offset, - pad_wd, pad_ht, stride_wd, stride_ht, - filter_data + 2, filter_wd, filter_ht, bias, - out_data_opt, out_wd, out_ht, out_channels, out_offset, out_shift, - out_mult, activation_min, activation_max); + esp_nn_conv_s8(&input_dims, input, &filter_dims, filter_data + 2, + bias, &output_dims, out_data_opt, &conv_params, &quant_data); if (itr == 0) { /* disable profiler */ diff --git a/code/components/esp-nn_20220716.zip b/code/components/esp-nn_20220716.zip new file mode 100644 index 00000000..53c7bef2 Binary files /dev/null and b/code/components/esp-nn_20220716.zip differ diff --git a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp index 72bcd63b..a3d47753 100644 --- a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp +++ b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp @@ -756,7 +756,7 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time) _fit = _val + _valminus; } - if (result > 10) + if (result >= 10) result = result - 10; if (result < 0) result = result + 10; diff --git a/code/components/tflite-lib/CMakeLists.txt b/code/components/tflite-lib/CMakeLists.txt index eed31a57..ab666ce0 100644 --- a/code/components/tflite-lib/CMakeLists.txt +++ b/code/components/tflite-lib/CMakeLists.txt @@ -25,7 +25,8 @@ list(REMOVE_ITEM srcs_kernels "${tfmicro_kernels_dir}/depthwise_conv.cc" "${tfmicro_kernels_dir}/fully_connected.cc" "${tfmicro_kernels_dir}/mul.cc" - "${tfmicro_kernels_dir}/pooling.cc") + "${tfmicro_kernels_dir}/pooling.cc" + "${tfmicro_kernels_dir}/softmax.cc") FILE(GLOB esp_nn_kernels "${tfmicro_kernels_dir}/esp_nn/*.cc") @@ -38,6 +39,8 @@ set(lib_srcs "${tflite_dir}/kernels/kernel_util.cc" "${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc" "${tflite_dir}/micro/memory_planner/linear_memory_planner.cc" + "${tflite_dir}/micro/arena_allocator/recording_simple_memory_allocator.cc" + "${tflite_dir}/micro/arena_allocator/simple_memory_allocator.cc" "${tflite_dir}/c/common.cc" "${tflite_dir}/core/api/error_reporter.cc" "${tflite_dir}/core/api/flatbuffer_conversions.cc" diff --git a/code/components/tflite-lib/tensorflow/lite/builtin_ops.h b/code/components/tflite-lib/tensorflow/lite/builtin_ops.h index 19ce3e2c..67014928 100644 --- a/code/components/tflite-lib/tensorflow/lite/builtin_ops.h +++ b/code/components/tflite-lib/tensorflow/lite/builtin_ops.h @@ -179,6 +179,8 @@ typedef enum { kTfLiteBuiltinMultinomial = 149, kTfLiteBuiltinGelu = 150, kTfLiteBuiltinDynamicUpdateSlice = 151, + kTfLiteBuiltinRelu0To1 = 152, + kTfLiteBuiltinUnsortedSegmentProd = 153, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h b/code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h index 7f160972..b8fdb7d1 100644 --- a/code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h +++ b/code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h @@ -518,6 +518,9 @@ typedef struct { bool approximate; } TfLiteGeluParams; +typedef struct { + int num_segments; +} TfLiteUnsortedSegmentProdParams; #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h b/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h index d2524969..d947213b 100644 --- a/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h +++ b/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h @@ -113,7 +113,13 @@ typedef struct TfLiteQuantizationParams { } TfLiteQuantizationParams; // -------------------------------------------------------------------------- -// Opaque types used by c_api_opaque.h. +// Opaque types used by c_api.h, c_api_opaque.h and common.h. + +// TfLiteOpaqueContext is an opaque version of TfLiteContext; +typedef struct TfLiteOpaqueContext TfLiteOpaqueContext; + +// TfLiteOpaqueNode is an opaque version of TfLiteNode; +typedef struct TfLiteOpaqueNode TfLiteOpaqueNode; // TfLiteOpaqueTensor is an opaque version of TfLiteTensor; typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor; diff --git a/code/components/tflite-lib/tensorflow/lite/c/common.cc b/code/components/tflite-lib/tensorflow/lite/c/common.cc index 956e9d69..8548424d 100644 --- a/code/components/tflite-lib/tensorflow/lite/c/common.cc +++ b/code/components/tflite-lib/tensorflow/lite/c/common.cc @@ -14,13 +14,33 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/c/c_api_types.h" +#ifdef TF_LITE_TENSORFLOW_PROFILER +#include + +#include "tensorflow/lite/core/macros.h" +#include "tensorflow/lite/tensorflow_profiler_logger.h" +#endif #ifndef TF_LITE_STATIC_MEMORY #include #include #endif // TF_LITE_STATIC_MEMORY +#ifdef TF_LITE_TENSORFLOW_PROFILER +namespace tflite { +// Use weak symbols here (even though they are guarded by macros) to avoid +// build breakage when building a benchmark requires TFLite runs. The main +// benchmark library should have tensor_profiler_logger dependency. +TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor, + size_t num_bytes); + +TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor); +} // namespace tflite + +#endif // TF_LITE_TENSORFLOW_PROFILER + extern "C" { size_t TfLiteIntArrayGetSizeInBytes(int size) { @@ -99,7 +119,12 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); } void TfLiteTensorDataFree(TfLiteTensor* t) { if (t->allocation_type == kTfLiteDynamic || t->allocation_type == kTfLitePersistentRo) { - free(t->data.raw); + if (t->data.raw) { +#ifdef TF_LITE_TENSORFLOW_PROFILER + tflite::OnTfLiteTensorDealloc(t); +#endif + free(t->data.raw); + } } t->data.raw = nullptr; } @@ -161,7 +186,7 @@ void TfLiteTensorFree(TfLiteTensor* t) { t->dims = nullptr; if (t->dims_signature) { - TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature); + TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature); } t->dims_signature = nullptr; @@ -191,16 +216,12 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, } TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) { - if (!src || !dst) - return kTfLiteOk; - if (src->bytes != dst->bytes) - return kTfLiteError; - if (src == dst) - return kTfLiteOk; + if (!src || !dst) return kTfLiteOk; + if (src->bytes != dst->bytes) return kTfLiteError; + if (src == dst) return kTfLiteOk; dst->type = src->type; - if (dst->dims) - TfLiteIntArrayFree(dst->dims); + if (dst->dims) TfLiteIntArrayFree(dst->dims); dst->dims = TfLiteIntArrayCopy(src->dims); memcpy(dst->data.raw, src->data.raw, src->bytes); dst->buffer_handle = src->buffer_handle; @@ -218,8 +239,17 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { // TODO(b/145340303): Tensor data should be aligned. if (!tensor->data.raw) { tensor->data.raw = (char*)malloc(num_bytes); +#ifdef TF_LITE_TENSORFLOW_PROFILER + tflite::OnTfLiteTensorAlloc(tensor, num_bytes); +#endif } else if (num_bytes > tensor->bytes) { +#ifdef TF_LITE_TENSORFLOW_PROFILER + tflite::OnTfLiteTensorDealloc(tensor); +#endif tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes); +#ifdef TF_LITE_TENSORFLOW_PROFILER + tflite::OnTfLiteTensorAlloc(tensor, num_bytes); +#endif } tensor->bytes = num_bytes; } diff --git a/code/components/tflite-lib/tensorflow/lite/c/common.h b/code/components/tflite-lib/tensorflow/lite/c/common.h index 6a109e1e..8b8ffbe8 100644 --- a/code/components/tflite-lib/tensorflow/lite/c/common.h +++ b/code/components/tflite-lib/tensorflow/lite/c/common.h @@ -173,9 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a); } \ } while (false) #else // TF_LITE_STRIP_ERROR_STRINGS -#define UNUSED(...) (void)sizeof(#__VA_ARGS__) -#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__) -#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__) +#define ARGS_UNUSED(...) (void)sizeof(#__VA_ARGS__) +#define TF_LITE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__) +#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__) #endif // TF_LITE_STRIP_ERROR_STRINGS // Check whether value is true, and if not return kTfLiteError from @@ -842,6 +842,32 @@ typedef struct TfLiteContext { size_t* bytes); } TfLiteContext; +// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration` +// for C API which doesn't use internal types (such as `TfLiteContext`) but only +// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each +// field is the exactly the same as with `TfLiteRegistration`. +typedef struct TfLiteRegistrationExternal { + // Custom op name. + const char* custom_name; + + // The version of the op. The verion should be higher than 0. + const int version; + + // Initializes the op from serialized data. + void* (*init)(TfLiteOpaqueContext* context, const char* buffer, + size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteOpaqueContext* context, void* buffer); + + // Called when the inputs that this node depends on have been resized. + TfLiteStatus (*prepare)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node); + + // Called when the node is executed. (should read node->inputs and output to + // node->outputs). + TfLiteStatus (*invoke)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node); +} TfLiteRegistrationExternal; + typedef struct TfLiteRegistration { // Initializes the op from serialized data. // Called only *once* for the lifetime of the op, so any one-time allocations @@ -903,8 +929,31 @@ typedef struct TfLiteRegistration { // Note: It is the responsibility of the registration binder to set this // properly. int version; + + // The external version of `TfLiteRegistration`. Since we can't use internal + // types (such as `TfLiteContext`) for C API to maintain ABI stability. + // C API user will provide `TfLiteRegistrationExternal` to implement custom + // ops. We keep it inside of `TfLiteRegistration` and use it to route + // callbacks properly. + TfLiteRegistrationExternal* registration_external; } TfLiteRegistration; +// Old version of `TfLiteRegistration` to maintain binary backward +// compatibility. +// WARNING: This structure is deprecated / not an official part of the API. +// It should be only used for binary backward compatibility. +typedef struct TfLiteRegistration_V1 { + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + void (*free)(TfLiteContext* context, void* buffer); + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + int32_t builtin_code; + const char* custom_name; + int version; +} TfLiteRegistration_V1; + // The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the // values should be 1, 2, 4, 8, ...etc. typedef enum TfLiteDelegateFlags { diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc index e92d754f..5175d903 100644 --- a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -836,6 +836,16 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } + case BuiltinOperator_UNSORTED_SEGMENT_PROD: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* unsorted_segment_prod_params = + op->builtin_options_as_UnsortedSegmentProdOptions()) { + params->num_segments = unsorted_segment_prod_params->num_segments(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } // Below are the ops with no builtin_data structure. // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are // ok for now, since there is no call implementation either. @@ -848,6 +858,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_MATRIX_DIAG: case BuiltinOperator_MATRIX_SET_DIAG: case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RELU_0_TO_1: case BuiltinOperator_SELECT: case BuiltinOperator_SELECT_V2: case BuiltinOperator_SLICE: diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h b/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h index 49ac778e..cec1f2dd 100644 --- a/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h +++ b/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h @@ -23,6 +23,16 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/schema/schema_generated.h" +// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate. +// This is used for cases (e.g. when using "TF Lite with Google Play Services") +// where the TF Lite runtime might be built using a newer (or older) +// version of the TF Lite sources than the app, and hence might have a +// different definition of the TfLiteDelegate type. TF Lite APIs use +// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to +// refer to a delegate defined with that potentially different version +// of the TfLiteDelegate type. +struct TfLiteOpaqueDelegateStruct; + namespace tflite { /// Abstract interface that returns TfLiteRegistrations given op codes or custom @@ -37,8 +47,10 @@ class OpResolver { virtual const TfLiteRegistration* FindOp(const char* op, int version) const = 0; + // Represents a sequence of delegates. using TfLiteDelegatePtrVector = std::vector>; + // Returns optional delegates for resolving and handling ops in the flatbuffer // model. This may be used in addition to the standard TfLiteRegistration // lookup for graph resolution. @@ -47,16 +59,55 @@ class OpResolver { return {}; } - // Represent a function that creates a TfLite delegate instance. + // Represents a function that creates a TfLite delegate instance. using TfLiteDelegateCreator = std::function( int /*num_threads*/)>; + + // Represents a sequence of delegate creator functions. using TfLiteDelegateCreators = std::vector; + // Returns a vector of delegate creators to create optional delegates for // resolving and handling ops in the flatbuffer model. This may be used in // addition to the standard TfLiteRegistration lookup for graph resolution. + // + // Note that this method is not used (will not be called) if you are using + // TF Lite in Google Play Services; the GetOpaqueDelegateCreators method + // (see below) is used for that case. virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; } + // TODO(b/202712825): it would be nice if we could avoid the need for separate + // "opaque" types & methods for use only with TF Lite in Google Play Services. + + // Represents an opaque delegate instance. + // WARNING: Experimental interface, subject to change. + using TfLiteOpaqueDelegatePtr = + std::unique_ptr; + + // Represents a function that creates an opaque delegate instance. + // WARNING: Experimental interface, subject to change. + using TfLiteOpaqueDelegateCreator = + std::function; + + // Represents a sequence of opaque delegate creator functions. + // WARNING: Experimental interface, subject to change. + using TfLiteOpaqueDelegateCreators = std::vector; + + // Returns a vector of opaque delegate creators to create optional opaque + // delegates for resolving and handling ops in the flatbuffer model. This may + // be used in addition to the standard TfLiteRegistration lookup for graph + // resolution. + // + // Note that this method will be called only if you are using TF Lite in + // Google Play Services; if you are using regular TF Lite, GetDelegateCreators + // (see above) is used instead. + // + // WARNING: Experimental interface, subject to change. + virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const { + return {}; + } + virtual ~OpResolver() {} private: diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h index cda1b5cf..b1204cc5 100644 --- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h +++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h @@ -23,9 +23,9 @@ namespace tflite { namespace reference_ops { inline int16_t SaturatingLeftShift(int16_t value, int amount) { - int32_t result = static_cast(value) * (1 << amount); - result = std::min(result, std::numeric_limits::max()); - result = std::max(result, std::numeric_limits::min()); + int64_t result = static_cast(value) * (1 << amount); + result = std::min(result, std::numeric_limits::max()); + result = std::max(result, std::numeric_limits::min()); return result; } diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h index 13693643..c2678b57 100644 --- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h +++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h @@ -27,6 +27,11 @@ class RuntimeShape { public: RuntimeShape& operator=(RuntimeShape const&) = delete; + // RuntimeShape in TFLM supports up to 5 dimensions. + // The name kMaxSmallSize comes from the same file of the upstream + // tensorflow lite repo and need to be kept the same for max reuse. + static constexpr int kMaxSmallSize = 5; + RuntimeShape() : size_(0) {} explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {} @@ -104,11 +109,9 @@ class RuntimeShape { sizeof(int32_t) * shape.DimensionsCount()); } - // A maximum of 4 dimensions are supported on TFLM. - static constexpr int kMaxSize = 5; int32_t size_; union { - int32_t dims_[kMaxSize]; + int32_t dims_[kMaxSmallSize]; }; }; diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h index 77644bc0..c44ba48e 100644 --- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h +++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h @@ -974,11 +974,11 @@ struct StridedSliceParams { int8_t strides_count; int32_t strides[5]; - int16_t begin_mask; - int16_t ellipsis_mask; - int16_t end_mask; - int16_t new_axis_mask; - int16_t shrink_axis_mask; + uint16_t begin_mask; + uint16_t ellipsis_mask; + uint16_t end_mask; + uint16_t new_axis_mask; + uint16_t shrink_axis_mask; }; struct TanhParams { diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h index 22689436..ed3a566f 100644 --- a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h +++ b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h @@ -308,7 +308,7 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, const TfLiteTensor* input3, TfLiteIntArray** output_shape); -// Return the size of given type in bytes. Return 0 in in case of string. +// Return the size of given type in bytes. Return 0 in case of string. int TfLiteTypeGetSize(TfLiteType type); // Whether the current platform is mobile (Android or iOS). diff --git a/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h similarity index 95% rename from code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h index 3767cb9f..b92d6b2d 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_ -#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_ +#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_ +#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_ #include #include @@ -97,4 +97,4 @@ class INonPersistentBufferAllocator { } // namespace tflite -#endif // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_ +#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc new file mode 100644 index 00000000..0f75d286 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc @@ -0,0 +1,165 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h" + +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" + +namespace tflite { + +NonPersistentArenaBufferAllocator::NonPersistentArenaBufferAllocator( + uint8_t* buffer, size_t buffer_size) + : buffer_head_(buffer), + buffer_tail_(buffer + buffer_size), + head_temp_(buffer), + next_temp_(buffer) {} + +NonPersistentArenaBufferAllocator::~NonPersistentArenaBufferAllocator() {} + +// Allocates a temporary buffer. This buffer is not resizable. +uint8_t* NonPersistentArenaBufferAllocator::AllocateTemp(size_t size, + size_t alignment) { + uint8_t* const aligned_result = AlignPointerUp(next_temp_, alignment); + const size_t available_memory = buffer_tail_ - aligned_result; + if (available_memory < size) { + MicroPrintf( + "Failed to allocate temp memory. Requested: %u, " + "available %u, missing: %u", + size, available_memory, size - available_memory); + return nullptr; + } + next_temp_ = aligned_result + size; + temp_buffer_ptr_check_sum_ ^= reinterpret_cast(aligned_result); + temp_buffer_count_++; + return aligned_result; +} + +// Signals that a temporary buffer is no longer needed. +void NonPersistentArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) { + temp_buffer_ptr_check_sum_ ^= reinterpret_cast(temp_buf); + temp_buffer_count_--; +} + +// Returns true if all temporary buffers are already deallocated. +bool NonPersistentArenaBufferAllocator::IsAllTempDeallocated() { + if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) { + MicroPrintf( + "Number of allocated temp buffers: %d. Checksum passing status: %d", + temp_buffer_count_, !temp_buffer_ptr_check_sum_); + return false; + } + return true; +} + +// Signals that all temporary allocations can be reclaimed. TFLM calls this +// API when it knows that all temporary buffers that it requested has been +// deallocated. The goal of API is to facilitate implementations of +// INonPersistentBufferAllocator can reuse buffer with some reasonable +// complexity. +TfLiteStatus NonPersistentArenaBufferAllocator::ResetTempAllocations() { + if (!IsAllTempDeallocated()) { + MicroPrintf( + "All temp buffers must be freed before calling ResetTempAllocations()"); + return kTfLiteError; + } + next_temp_ = head_temp_; + return kTfLiteOk; +} + +// Returns a buffer that is resizable viable ResizeBuffer(). +uint8_t* NonPersistentArenaBufferAllocator::AllocateResizableBuffer( + size_t size, size_t alignment) { + // Only supports one resizable buffer, which starts at the buffer head. + uint8_t* expected_resizable_buf = AlignPointerUp(buffer_head_, alignment); + + if (head_temp_ != expected_resizable_buf) { + MicroPrintf( + "Cannot allocate a new resizable buffer when one is already allocated"); + return nullptr; + } + + if (ResizeBuffer(expected_resizable_buf, size, alignment) == kTfLiteOk) { + return expected_resizable_buf; + } + return nullptr; +} + +// Resizes a buffer that is previously returned by the AllocateResizableBuffer. +// Note that ResizeBuffer(old_resizable_buf, 0, 1) effectively deallocates +// a previous allocated resizable buffer. +TfLiteStatus NonPersistentArenaBufferAllocator::ResizeBuffer( + uint8_t* resizable_buf, size_t size, size_t alignment) { + // Only supports one resizable buffer, which starts at the buffer head. + uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment); + if (resizable_buf != expect_resizable_buf) { + MicroPrintf("Internal error: buffer is not resizable"); + return kTfLiteError; + } + if (head_temp_ != next_temp_) { + MicroPrintf("ResetTempAllocations() is not called before ResizeBuffer()."); + return kTfLiteError; + } + + const size_t available_memory = buffer_tail_ - expect_resizable_buf; + if (available_memory < size) { + MicroPrintf( + "Failed to resize buffer. Requested: %u, available %u, missing: %u", + size, available_memory, size - available_memory); + return kTfLiteError; + } + head_temp_ = expect_resizable_buf + size; + next_temp_ = head_temp_; + + return kTfLiteOk; +} + +// Frees up the memory occupied by the resizable buffer. +TfLiteStatus NonPersistentArenaBufferAllocator::DeallocateResizableBuffer( + uint8_t* resizable_buf) { + return ResizeBuffer(resizable_buf, 0, 1); +} + +// Returns a pointer pointing to the start of the overlay memory, which is +// used for activation tensors and scratch buffers by kernels at Invoke stage. +uint8_t* NonPersistentArenaBufferAllocator::GetOverlayMemoryAddress() const { + return buffer_head_; +} + +// Reserves the size of the overlay memory. This overlay is reserved for the +// kernels at Invoke stage. This is referred to as the overlay because before +// Invoket state, the same memory can be used for temp buffers. The layout of +// the memory is planned by the memory planner separately at Invoke stage. +TfLiteStatus +NonPersistentArenaBufferAllocator::ReserveNonPersistentOverlayMemory( + size_t size, size_t alignment) { + uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment); + return ResizeBuffer(expect_resizable_buf, size, alignment); +} + +// Returns the size of non-persistent buffer in use. +size_t NonPersistentArenaBufferAllocator::GetNonPersistentUsedBytes() const { + return (next_temp_ - buffer_head_); +} + +// Returns the number of bytes available with a given alignment. This number +// takes in account any temporary allocations. +size_t NonPersistentArenaBufferAllocator::GetAvailableMemory( + size_t alignment) const { + uint8_t* const aligned_temp = AlignPointerUp(next_temp_, alignment); + uint8_t* const aligned_tail = AlignPointerDown(buffer_tail_, alignment); + return aligned_tail - aligned_temp; +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h new file mode 100644 index 00000000..aad41d3f --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h @@ -0,0 +1,104 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_ +#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_ + +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h" +#include "tensorflow/lite/micro/compatibility.h" + +namespace tflite { + +// Implement INonPersistentBufferAllocator on an arena that is dedicated for +// non-persistent buffers. +class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator { + public: + NonPersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size); + virtual ~NonPersistentArenaBufferAllocator(); + + // Allocates a temporary buffer. This buffer is not resizable. + uint8_t* AllocateTemp(size_t size, size_t alignment) override; + + // Signals that a temporary buffer is no longer needed. + void DeallocateTemp(uint8_t* buf) override; + + // Returns true if all temporary buffers are already deallocated. + bool IsAllTempDeallocated() override; + + // Signals that all temporary allocations can be reclaimed. TFLM calls this + // API when it knows that all temporary buffers that it requested has been + // deallocated. + TfLiteStatus ResetTempAllocations() override; + + // Returns a buffer that is resizable viable ResizeBuffer(). + uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) override; + + // Resizes a buffer that is previously returned by the + // AllocateResizableBuffer. + TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size, + size_t alignment) override; + + // Frees up the memory occupied by the resizable buffer. + TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) override; + + // Returns a pointer pointing to the start of the overlay memory, which is + // used for activation tensors and scratch buffers by kernels at Invoke stage. + uint8_t* GetOverlayMemoryAddress() const override; + + // Reserves the size of the overlay memory. This overlay is reserved for the + // kernels at Invoke stage. This is referred to as the overlay because before + // Invoket state, the same memory can be used for temp buffers. The layout of + // the memory is planned by the memory planner separately at Invoke stage. + TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size, + size_t alignment) override; + + // Returns the size of non-persistent buffer in use. + size_t GetNonPersistentUsedBytes() const override; + + // Returns the number of bytes available with a given alignment. This number + // takes in account any temporary allocations. + size_t GetAvailableMemory(size_t alignment) const override; + + TF_LITE_REMOVE_VIRTUAL_DELETE + + private: + // The memory arena that this allocator manages. + uint8_t* const buffer_head_; + uint8_t* const buffer_tail_; + + // The whole region is split into two parts: + // buffer_head_ to head_temp_ - 1 belongs to the only resizable buffer. + // head_temp_ to buffer_tail_ can be used for (non-resizable) temp buffers. + uint8_t* head_temp_; + + // next_temp_ points to the next available temp buffer allocation address and + // its range is between head_temp_ and buffer_tail_ + uint8_t* next_temp_; + + // XOR Check sum for outstanding temp buffers. + // If all temp buffers are deallocated OR no temp buffers are allocated, + // temp_buffer_ptr_check_sum_ == nullptr. + intptr_t temp_buffer_ptr_check_sum_ = 0; + // Count of outstanding temp buffers. + int temp_buffer_count_ = 0; +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc new file mode 100644 index 00000000..0ccc8fb1 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h" + +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" + +namespace tflite { + +PersistentArenaBufferAllocator::PersistentArenaBufferAllocator( + uint8_t* buffer, size_t buffer_size) + : buffer_head_(buffer), + buffer_tail_(buffer + buffer_size), + tail_temp_(buffer_tail_) {} + +PersistentArenaBufferAllocator::~PersistentArenaBufferAllocator() {} + +uint8_t* PersistentArenaBufferAllocator::AllocatePersistentBuffer( + size_t size, size_t alignment) { + uint8_t* const aligned_result = + AlignPointerDown(tail_temp_ - size, alignment); + if (aligned_result < buffer_head_) { +#ifndef TF_LITE_STRIP_ERROR_STRINGS + const size_t missing_memory = buffer_head_ - aligned_result; + MicroPrintf( + "Failed to allocate tail memory. Requested: %u, " + "available %u, missing: %u", + size, size - missing_memory, missing_memory); +#endif + return nullptr; + } + tail_temp_ = aligned_result; + return aligned_result; +} + +size_t PersistentArenaBufferAllocator::GetPersistentUsedBytes() const { + return buffer_tail_ - tail_temp_; +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h new file mode 100644 index 00000000..10145d72 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h @@ -0,0 +1,59 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_ +#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_ + +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h" +#include "tensorflow/lite/micro/compatibility.h" + +namespace tflite { + +// PersistentArenaBufferAllocator is an implementatation of +// IPersistentBufferAllocator interface on an arena that is dedicated for +// persistent buffers. +class PersistentArenaBufferAllocator : public IPersistentBufferAllocator { + public: + PersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size); + virtual ~PersistentArenaBufferAllocator(); + + // Allocates persistent memory. The persistent buffer is never freed. + // Returns nullptr if errors occured. + uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) override; + + // Returns the size of all persistent allocations in bytes. + size_t GetPersistentUsedBytes() const override; + + TF_LITE_REMOVE_VIRTUAL_DELETE + private: + // The memory arena that this allocator manages. + uint8_t* const buffer_head_; + uint8_t* const buffer_tail_; + + // The whole region is split into two parts: + // tail_temp_ to buffer_tail_ contains allocated buffers; + // buffer_head_ to tail_temp_ - 1 belongs to still available spaces. + // So in essence, the allocated region grows from the bottom and emulates + // SimpleMemoryAllocator's persistent part. + uint8_t* tail_temp_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc similarity index 97% rename from code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc index 6d3e72bd..0efb6512 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/micro/recording_simple_memory_allocator.h" +#include "tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h" #include diff --git a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h similarity index 87% rename from code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h index a251e940..1abe43dd 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_ -#define TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_ +#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_ +#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_ +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include "tensorflow/lite/micro/compatibility.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" namespace tflite { @@ -62,4 +62,4 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator { } // namespace tflite -#endif // TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_ +#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc similarity index 99% rename from code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc index e5d87afb..3e3ea4bd 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/micro/simple_memory_allocator.h" +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include #include diff --git a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h similarity index 95% rename from code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h index d88c4a3d..92d0e425 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_ -#define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_ +#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_ +#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_ #include #include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h" #include "tensorflow/lite/micro/compatibility.h" -#include "tensorflow/lite/micro/ibuffer_allocator.h" namespace tflite { @@ -147,4 +147,4 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator, } // namespace tflite -#endif // TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_ +#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc index 5a5ba9ab..36dd062a 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/lite/micro/fake_micro_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include "tensorflow/lite/micro/micro_allocator.h" #include "tensorflow/lite/micro/micro_arena_constants.h" #include "tensorflow/lite/micro/micro_error_reporter.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" namespace tflite { namespace { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc index c556ac64..e0b79631 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_utils.h" namespace tflite { @@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } default: { - TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s", - TfLiteTypeGetName(input->type)); + MicroPrintf("Only float32 is supported currently, got %s", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } @@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } default: { - TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s", - TfLiteTypeGetName(input->type)); + MicroPrintf("Only float32 is supported currently, got %s", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } @@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_RELU() { - return {/*init=*/ReluInit, - /*free=*/nullptr, - /*prepare=*/ReluPrepare, - /*invoke=*/ReluEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval); } TfLiteRegistration Register_RELU6() { - return {/*init=*/Relu6Init, - /*free=*/nullptr, - /*prepare=*/Relu6Prepare, - /*invoke=*/Relu6Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc index 75523d14..f75db4e5 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc @@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_ADD() { - return {/*init=*/AddInit, - /*free=*/nullptr, - /*prepare=*/AddPrepare, - /*invoke=*/AddEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc index 5d0ab724..ce064687 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc @@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_ADD_N() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc index 8217a4a0..a8aa5a48 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc @@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) { } // namespace arg_min_max TfLiteRegistration Register_ARG_MAX() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/arg_min_max::ArgMaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval); } TfLiteRegistration Register_ARG_MIN() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/arg_min_max::ArgMinEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc index e28ebebb..a770d0aa 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc @@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_ASSIGN_VARIABLE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc index 07b680df..be82d942 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc @@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_BATCH_TO_SPACE_ND() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc index fa333249..be2672ec 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc @@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_BROADCAST_ARGS() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/BroadcastArgsPrepare, - /*invoke=*/BroadcastArgsEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare, + BroadcastArgsEval); } -} // namespace tflite \ No newline at end of file +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc index 5302faf1..63a14db2 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc @@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_BROADCAST_TO() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/BroadcastToPrepare, - /*invoke=*/BroadcastToEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare, + BroadcastToEval); } -} // namespace tflite \ No newline at end of file +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc index 4db39f7d..200242b2 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc @@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_CALL_ONCE() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc index dc651a24..a1f4516b 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc @@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_CAST() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc index d0a48f91..a390a735 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc @@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace ceil TfLiteRegistration Register_CEIL() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ceil::Prepare, - /*invoke=*/ceil::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc index bda3e66a..a66a61c5 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc @@ -108,14 +108,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration* Register_CIRCULAR_BUFFER() { - static TfLiteRegistration r = {/*init=*/CircularBufferInit, - /*free=*/nullptr, - /*prepare=*/CircularBufferPrepare, - /*invoke=*/CircularBufferEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + static TfLiteRegistration r = tflite::micro::RegisterOp(CircularBufferInit, CircularBufferPrepare, CircularBufferEval); return &r; } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc index 925c3fb5..cff15e4d 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc @@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // namespace comparisons TfLiteRegistration Register_EQUAL() { - return {/*init=*/comparisons::Init, - /*free=*/nullptr, - /*prepare=*/comparisons::Prepare, - /*invoke=*/comparisons::EqualEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare, + comparisons::EqualEval); } TfLiteRegistration Register_NOT_EQUAL() { - return {/*init=*/comparisons::Init, - /*free=*/nullptr, - /*prepare=*/comparisons::Prepare, - /*invoke=*/comparisons::NotEqualEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare, + comparisons::NotEqualEval); } TfLiteRegistration Register_GREATER() { - return {/*init=*/comparisons::Init, - /*free=*/nullptr, - /*prepare=*/comparisons::Prepare, - /*invoke=*/comparisons::GreaterEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare, + comparisons::GreaterEval); } TfLiteRegistration Register_GREATER_EQUAL() { - return {/*init=*/comparisons::Init, - /*free=*/nullptr, - /*prepare=*/comparisons::Prepare, - /*invoke=*/comparisons::GreaterEqualEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare, + comparisons::GreaterEqualEval); } TfLiteRegistration Register_LESS() { - return {/*init=*/comparisons::Init, - /*free=*/nullptr, - /*prepare=*/comparisons::Prepare, - /*invoke=*/comparisons::LessEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare, + comparisons::LessEval); } TfLiteRegistration Register_LESS_EQUAL() { - return {/*init=*/comparisons::Init, - /*free=*/nullptr, - /*prepare=*/comparisons::Prepare, - /*invoke=*/comparisons::LessEqualEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare, + comparisons::LessEqualEval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc index d727a0d5..34622c22 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc @@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input != nullptr); int num_dimensions = NumDimensions(input); - if (num_dimensions > 4) { + if (num_dimensions > RuntimeShape::kMaxSmallSize) { TF_LITE_KERNEL_LOG( context, - "Op Concatenation does not currently support num dimensions >4 " + "Op Concatenation does not currently support num dimensions > %d " "Tensor has %d dimensions.", - num_dimensions); + RuntimeShape::kMaxSmallSize, num_dimensions); return kTfLiteError; } micro_context->DeallocateTempTfLiteTensor(input); @@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace concatenation TfLiteRegistration Register_CONCATENATION() { - return {/*init=*/concatenation::Init, - /*free=*/nullptr, - /*prepare=*/concatenation::Prepare, - /*invoke=*/concatenation::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare, + concatenation::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc index 0fed1223..87ea92e6 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" namespace tflite { namespace { @@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); break; } case kTfLiteInt16: { - reference_integer_ops::ConvPerChannel( - ConvParamsQuantized(params, data), data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + switch (bias->type) { + case kTfLiteInt32: { + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt64: { + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: + MicroPrintf("Bias type %s (%d) not supported.", + TfLiteTypeGetName(bias->type), bias->type); + return kTfLiteError; + } break; } case kTfLiteInt8: { @@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; } default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); + MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type), + input->type); return kTfLiteError; } return kTfLiteOk; @@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_CONV_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/ConvPrepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, ConvPrepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h index 38b69525..47ba8ac4 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h @@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel( float output_scale, int output_zero_point, TfLiteConvParams* conv_params, TfLiteRegistration registration, int16_t* output_data); +TfLiteStatus TestConvQuantizedPerChannel( + int* input_dims_data, const float* input_data, int16_t* input_quantized, + float input_scale, int input_zero_point, int* filter_dims_data, + const float* filter_data, int8_t* filter_data_quantized, + int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized, + float* bias_scales, int* bias_zero_points, int* output_dims_data, + const float* expected_output_data, int16_t* expected_output_data_quantized, + float output_scale, int output_zero_point, TfLiteConvParams* conv_params, + TfLiteRegistration registration, int16_t* output_data); + } // namespace testing } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc index 61f7af23..eedc61fd 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc @@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_CUMSUM() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc index cce93c9c..ec000540 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc @@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_DEPTH_TO_SPACE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc index 8a58433a..d2468ff9 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc @@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/DepthwiseConvPrepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h index 7a7eb0ba..562438d7 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv( TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node); +// This is the most generic TfLiteRegistration. The actual supported types may +// still be target dependent. The only requirement is that every implementation +// (reference or optimized) must define this function. +TfLiteRegistration Register_DEPTHWISE_CONV_2D(); + +#if defined(CMSIS_NN) +// Returns a TfLiteRegistration struct for kernel variant that only supports +// int8 activations and int8 weights and uses the latency optimized +// implementations. +TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8(); + +// Returns a TfLiteRegistration struct for kernel variant that only supports +// int16 activations and int8 weights and uses the latency optimized +// implementations. +TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16(); + +#else +inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() { + return Register_DEPTHWISE_CONV_2D(); +} + +inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() { + return Register_DEPTHWISE_CONV_2D(); +} +#endif + } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc index 4438ea33..1cf7f133 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc @@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; + case kTfLiteUInt8: + reference_ops::Dequantize(data->quantization_params, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; default: MicroPrintf("Input %s, output %s not supported.", TfLiteTypeGetName(input->type), @@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_DEQUANTIZE() { - return {/*init=*/DequantizeInit, - /*free=*/nullptr, - /*prepare=*/DequantizePrepare, - /*invoke=*/DequantizeEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare, + DequantizeEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc index 4be5ad89..438f9cda 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc @@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); TF_LITE_ENSURE(context, output != nullptr); - TF_LITE_ENSURE(context, - input->type == kTfLiteInt8 || input->type == kTfLiteInt16); + TF_LITE_ENSURE(context, input->type == kTfLiteInt8 || + input->type == kTfLiteInt16 || + input->type == kTfLiteUInt8); TF_LITE_ENSURE(context, output->type == kTfLiteFloat32); if (output->type == kTfLiteInt32) { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc index efe57e2f..326d87b5 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return op_data; } -void Free(TfLiteContext* context, void* buffer) {} - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* op_data = static_cast(node->user_data); @@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration* Register_DETECTION_POSTPROCESS() { - static TfLiteRegistration r = {/*init=*/Init, - /*free=*/Free, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval); return &r; } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc index 366dd610..b1cb1dcb 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" @@ -27,6 +29,22 @@ namespace micro { namespace elementwise { namespace { +constexpr int kAbsNameId = 0; +constexpr int kRsrqtNameId = 1; + +const int kElementwiseInputTensor = 0; +const int kElementwiseOutputTensor = 0; + +struct OpDataAbsRsqrt { + int32_t multiplier; + int shift; + int input_offset; + int output_offset; + bool needs_rescale; + TfLiteQuantizationType input_quantization_type; + TfLiteType input_type; +}; + bool IsNumericSupportedType(const TfLiteType type) { return type == kTfLiteFloat32; } @@ -35,16 +53,40 @@ bool IsLogicalSupportedType(const TfLiteType type) { return type == kTfLiteBool; } +bool IsAbsSupportedType(const TfLiteType type) { + return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16; +} + +bool IsRsqrtSupportedType(const TfLiteType type) { + return type == kTfLiteFloat32 || type == kTfLiteInt8; +} + +inline void SetAbsOutputMultiplier(const float input_scale, + const float output_scale, + int32_t* multiplier, int* shift) { + QuantizeMultiplier(static_cast(input_scale / output_scale), + multiplier, shift); +} + +inline void SetRsqrtOutputMultiplier(const float input_scale, + const float output_scale, + int32_t* multiplier, int* shift) { + const double scale = + 1. / static_cast((std::sqrt(input_scale) * output_scale)); + QuantizeMultiplier(scale, multiplier, shift); +} + typedef bool (*IsSupportedType)(TfLiteType); template TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { MicroContext* micro_context = GetMicroContext(context); - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor); TF_LITE_ENSURE(context, input != nullptr); - TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor); TF_LITE_ENSURE(context, output != nullptr); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); if (!IsSupportedType(input->type)) { @@ -58,9 +100,79 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +typedef bool (*IsSupportedType)(TfLiteType); +template +TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + TF_LITE_ENSURE(context, input != nullptr); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); + if (!IsSupportedType(input->type)) { + TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + + auto* op_data = static_cast(node->user_data); + op_data->input_type = input->type; + + // For int16 type input, we support both quantized and non-quantized + // evaluation. + if (op_nameid == kAbsNameId) { + op_data->input_quantization_type = input->quantization.type; + } + + if (input->type == kTfLiteInt8 || + (input->type == kTfLiteInt16 && + input->quantization.type != kTfLiteNoQuantization)) { + TF_LITE_ENSURE_EQ(context, input->quantization.type, + kTfLiteAffineQuantization); + TF_LITE_ENSURE_EQ(context, output->quantization.type, + kTfLiteAffineQuantization); + const auto* input_params = + reinterpret_cast(input->quantization.params); + const auto* output_params = reinterpret_cast( + output->quantization.params); + TF_LITE_ENSURE(context, input_params != nullptr); + TF_LITE_ENSURE(context, input_params->scale != nullptr); + TF_LITE_ENSURE(context, input_params->scale->size > 0); + TF_LITE_ENSURE(context, input_params->zero_point->size > 0); + TF_LITE_ENSURE(context, output_params != nullptr); + TF_LITE_ENSURE(context, output_params->scale != nullptr); + TF_LITE_ENSURE(context, output_params->scale->size > 0); + TF_LITE_ENSURE(context, output_params->zero_point->size > 0); + op_data->input_offset = input_params->zero_point->data[0]; + op_data->output_offset = output_params->zero_point->data[0]; + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0); + TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0); + } + const float input_scale = input_params->scale->data[0]; + const float output_scale = output_params->scale->data[0]; + op_data->needs_rescale = input_scale != output_scale; + if (op_nameid == kAbsNameId && op_data->needs_rescale) { + SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier, + &op_data->shift); + } else if (op_nameid == kRsrqtNameId) { + SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier, + &op_data->shift); + } + } + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + return kTfLiteOk; +} + template -inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, - T func(T), TfLiteType expected_type) { +inline TfLiteStatus EvalImplQuantized( + TfLiteContext* context, TfLiteNode* node, + T func(TfLiteContext*, TfLiteNode*, T), + TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T), + TfLiteType expected_type) { const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type); @@ -68,6 +180,34 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, const T* in_data = tflite::micro::GetTensorData(input); T* out_data = tflite::micro::GetTensorData(output); for (size_t i = 0; i < num_elements; ++i) { + if (validate_input_func) { + TF_LITE_ENSURE_OK(context, + validate_input_func(context, node, in_data[i])); + } + out_data[i] = func(context, node, in_data[i]); + } + return kTfLiteOk; +} + +template +inline T AbsHelper(T i) { + return std::abs(i); +} + +template +inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, + T func(T), TfLiteStatus validate_input_func(T), + TfLiteType expected_type) { + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type); + const size_t num_elements = ElementCount(*input->dims); + const T* in_data = tflite::micro::GetTensorData(input); + T* out_data = tflite::micro::GetTensorData(output); + for (size_t i = 0; i < num_elements; ++i) { + if (validate_input_func) { + TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i])); + } out_data[i] = func(in_data[i]); } return kTfLiteOk; @@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, float float_func(float)) { - return EvalImpl(context, node, float_func, kTfLiteFloat32); + return EvalImpl(context, node, float_func, + /*validate_input_func=*/nullptr, kTfLiteFloat32); } inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, + bool bool_func(bool)) { - return EvalImpl(context, node, bool_func, kTfLiteBool); + return EvalImpl(context, node, bool_func, + /*validate_input_func=*/nullptr, kTfLiteBool); +} + +void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt)); +} + +template +inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) { + const auto* op_data = static_cast(node->user_data); + const int kMin = std::numeric_limits::min(); + const int kMax = std::numeric_limits::max(); + + const int32_t value = std::abs(i - op_data->input_offset); + if (!op_data->needs_rescale) { + return static_cast( + std::min(std::max(static_cast(value + op_data->output_offset), + static_cast(kMin)), + static_cast(kMax))); + } + + const int32_t output = tflite::MultiplyByQuantizedMultiplier( + value, op_data->multiplier, op_data->shift) + + op_data->output_offset; + return static_cast(std::min( + std::max(static_cast(output), static_cast(kMin)), + static_cast(kMax))); +} + +template +inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) { + const auto* op_data = static_cast(node->user_data); + const int kMin = std::numeric_limits::min(); + const int kMax = std::numeric_limits::max(); + + const int32_t value = (i - op_data->input_offset); + const int32_t kShift = 20; // Shift to keep value integer. + if (value == 0) { + // Assume that any value close to 0 represents the max output value. + return static_cast(kMax); + } + int32_t inv_sqrt_multiplier; + int inv_sqrt_shift; + GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier, + &inv_sqrt_shift); + const int32_t data = tflite::MultiplyByQuantizedMultiplier( + static_cast(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift); + const int32_t output = + tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier, + op_data->shift - kShift) + + op_data->output_offset; + return static_cast(std::min( + std::max(static_cast(output), static_cast(kMin)), + static_cast(kMax))); +} + +template +TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node, + T i) { + const auto* op_data = static_cast(node->user_data); + + TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset, + "Rsqrt is only defined for positive values"); + return static_cast(kTfLiteOk); } TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) { - return EvalNumeric(context, node, std::abs); + OpDataAbsRsqrt* op_data = reinterpret_cast(node->user_data); + TfLiteType type = op_data->input_type; + TfLiteQuantizationType input_quantization_type = + op_data->input_quantization_type; + TfLiteStatus eval_result; + + switch (type) { + case kTfLiteFloat32: + eval_result = EvalNumeric(context, node, std::abs); + break; + case kTfLiteInt8: + eval_result = + EvalImplQuantized(context, node, AbsEvalQuantized, + /*validate_input_func=*/nullptr, type); + break; + case kTfLiteInt16: + eval_result = + input_quantization_type == kTfLiteNoQuantization + ? EvalImpl(context, node, AbsHelper, + /*validate_input_func=*/nullptr, type) + : EvalImplQuantized(context, node, AbsEvalQuantized, + /*validate_input_func=*/nullptr, + type); + break; + default: + TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.", + TfLiteTypeGetName(type)); + return kTfLiteError; + break; + } + return eval_result; } TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { @@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { - return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); + const auto* op_data = static_cast(node->user_data); + TfLiteType type = op_data->input_type; + switch (type) { + case kTfLiteFloat32: + return EvalImpl( + context, node, [](float f) { return 1.f / std::sqrt(f); }, + /*validate_input_func=*/nullptr, type); + case kTfLiteInt8: + return EvalImplQuantized(context, node, + elementwise::RsqrtEvalQuantized, + elementwise::validate_input_func, type); + + default: + TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.", + TfLiteTypeGetName(type)); + return kTfLiteError; + } } TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) { @@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { } // namespace elementwise TfLiteRegistration Register_ABS() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::AbsEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + elementwise::ElementWiseAbsRsqrtInit, + elementwise::PrepareAbsRsqrt, + elementwise::AbsEval); } TfLiteRegistration Register_SIN() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::SinEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, elementwise::GenericPrepare, + elementwise::SinEval); } TfLiteRegistration Register_COS() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::CosEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, elementwise::GenericPrepare, + elementwise::CosEval); } TfLiteRegistration Register_LOG() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::LogEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, elementwise::GenericPrepare, + elementwise::LogEval); } TfLiteRegistration Register_SQRT() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::SqrtEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, elementwise::GenericPrepare, + elementwise::SqrtEval); } TfLiteRegistration Register_RSQRT() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::RsqrtEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + elementwise::ElementWiseAbsRsqrtInit, + elementwise::PrepareAbsRsqrt, + elementwise::RsqrtEval); } TfLiteRegistration Register_SQUARE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::SquareEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, elementwise::GenericPrepare, + elementwise::SquareEval); } TfLiteRegistration Register_LOGICAL_NOT() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::LogicalNotEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, elementwise::GenericPrepare, + elementwise::LogicalNotEval); } } // namespace micro } // namespace ops -} // namespace tflite +} // namespace tflite \ No newline at end of file diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc index b2cd19cc..0b64e89d 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc @@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_ELU() { - return {/*init=*/EluInit, - /*free=*/nullptr, - /*prepare=*/EluPrepare, - /*invoke=*/EluEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc index 47a17d9f..2f1ac58d 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc @@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_ADD() { - return {/*init=*/AddInit, - /*free=*/nullptr, - /*prepare=*/AddPrepare, - /*invoke=*/AddEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc index 09260482..919dd006 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc @@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { #if ESP_NN if (input->type == kTfLiteInt8) { + data_dims_t input_dims = { + .width = input_width, .height = input_height, + .channels = input->dims->data[3], 1 + }; + data_dims_t output_dims = { + .width = output_width, .height = output_height, + .channels = output->dims->data[3], 1 + }; + data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0}; + conv_params_t conv_params = { + .in_offset = 0, .out_offset = 0, + .stride = {params.stride_width, params.stride_height}, + .padding = {data->op_data.padding.width, data->op_data.padding.height}, + .dilation = {0, 0}, .activation = {-128, 127} + }; + int scratch_buf_size = esp_nn_get_conv_scratch_size( - input_width, input_height, input->dims->data[3], - output->dims->data[3], filter_width, filter_height); + &input_dims, &filter_dims, &output_dims, &conv_params); if (scratch_buf_size > 0) { TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( context, scratch_buf_size, &data->buffer_idx)); @@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel( const int input_size = input_width * input_height * input_depth; const int output_size = output_width * output_height * output_depth; + data_dims_t input_dims = { + .width = input_width, .height = input_height, + .channels = input_depth, 1 + }; + data_dims_t output_dims = { + .width = output_width, .height = output_height, + .channels = output_depth, 1 + }; + data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0}; + conv_params_t conv_params = { + .in_offset = input_offset, .out_offset = output_offset, + .stride = {stride_width, stride_height}, + .padding = {pad_width, pad_height}, + .dilation = {0, 0}, + .activation = {activation_min, activation_max} + }; + quant_data_t quant_data = { + .shift = data.op_data.per_channel_output_shift, + .mult = data.op_data.per_channel_output_multiplier + }; + for (int i_batch = 0; i_batch < batch_size; i_batch++) { - esp_nn_conv_s8(input_data + i_batch * input_size, - input_width, input_height, input_depth, input_offset, - pad_width, pad_height, stride_width, stride_height, - tflite::micro::GetTensorData(filter), - filter_width, filter_height, + esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size, + &filter_dims, tflite::micro::GetTensorData(filter), tflite::micro::GetTensorData(bias), - output_data + i_batch * output_size, - output_width, output_height, output_depth, output_offset, - data.op_data.per_channel_output_shift, - data.op_data.per_channel_output_multiplier, - activation_min, activation_max); + &output_dims, output_data + i_batch * output_size, + &conv_params, &quant_data); } } else { reference_integer_ops::ConvPerChannel( @@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTypeGetName(input->type), input->type); return kTfLiteError; } - conv_total_time += esp_timer_get_time() - start_time; + long long time_this_instance = esp_timer_get_time() - start_time; + conv_total_time += time_this_instance; + //printf("time this instance: %llu\n", time_this_instance / 1000); return kTfLiteOk; } } // namespace TfLiteRegistration Register_CONV_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc index 5f2d9d50..a2460248 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc @@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, if (data.buffer_idx > -1) { scratch_buf = context->GetScratchBuffer(context, data.buffer_idx); } + esp_nn_set_depthwise_conv_scratch_buf(scratch_buf); + data_dims_t input_dims = { + .width = input_width, .height = input_height, + .channels = input_depth, 1 + }; + data_dims_t output_dims = { + .width = output_width, .height = output_height, + .channels = output_depth, 1 + }; + data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0}; + dw_conv_params_t conv_params = { + .in_offset = input_offset, .out_offset = output_offset, + .ch_mult = depth_multiplier, + .stride = {stride_width, stride_height}, + .padding = {pad_width, pad_height}, .dilation = {0, 0}, + .activation = {activation_min, activation_max} + }; + quant_data_t quant_data = { + .shift = data.op_data.per_channel_output_shift, + .mult = data.op_data.per_channel_output_multiplier + }; + for (int i_batch = 0; i_batch < batch_size; i_batch++) { - esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width, - input_height, input_depth, input_offset, - pad_width, pad_height, - stride_width, stride_height, depth_multiplier, - tflite::micro::GetTensorData(filter), - filter_width, filter_height, + esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size, + &filter_dims, tflite::micro::GetTensorData(filter), tflite::micro::GetTensorData(bias), - output_data + i_batch * output_size, - output_width, output_height, output_offset, - data.op_data.per_channel_output_shift, - data.op_data.per_channel_output_multiplier, - activation_min, activation_max); + &output_dims, output_data + i_batch * output_size, + &conv_params, &quant_data); } } else { reference_integer_ops::DepthwiseConvPerChannel( @@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { #if ESP_NN if (input->type == kTfLiteInt8) { + data_dims_t input_dims = { + .width = input_width, .height = input_height, + .channels = input->dims->data[3], 1 + }; + data_dims_t output_dims = { + .width = output_width, .height = output_height, + .channels = output->dims->data[3], 1 + }; + data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0}; + dw_conv_params_t conv_params = { + .in_offset = 0, .out_offset = 0, + .ch_mult = params.depth_multiplier, + .stride = {params.stride_width, params.stride_height}, + .padding = {data->op_data.padding.width, data->op_data.padding.height}, + .dilation = {0, 0}, .activation = {-128, 127} + }; + int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size( - input_width, input_height, input->dims->data[3], - params.depth_multiplier, filter_width, filter_height); + &input_dims, &filter_dims, &output_dims, &conv_params); if (scratch_buf_size > 0) { TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( context, scratch_buf_size, &data->buffer_idx)); @@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTypeGetName(input->type), input->type); return kTfLiteError; } - dc_total_time += esp_timer_get_time() - start_time; + long long time_this_instance = esp_timer_get_time() - start_time; + dc_total_time += time_this_instance; + // printf("time this instance: %llu\n", time_this_instance / 1000); + return kTfLiteOk; } } // namespace TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc index 5e1705da..484cffb6 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc @@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc index 0e8a82f4..02413f5c 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc @@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_MUL() { - return {/*init=*/MulInit, - /*free=*/nullptr, - /*prepare=*/MulPrepare, - /*invoke=*/MulEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc index d55bab82..b450929e 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc @@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } // namespace TfLiteRegistration Register_AVERAGE_POOL_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/PoolingPrepare, - /*invoke=*/AverageEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval); } TfLiteRegistration Register_MAX_POOL_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/PoolingPrepare, - /*invoke=*/MaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc new file mode 100644 index 00000000..9a967839 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc @@ -0,0 +1,208 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/softmax.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/softmax.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +#include "freertos/FreeRTOS.h" +#include + +#if ESP_NN +#include +#endif + +long long softmax_total_time = 0; + +namespace tflite { +namespace { +// Softmax parameter data that persists in user_data +const int kInt16LUTArraySize = 513; + +struct NodeData { + SoftmaxParams op_data; +#if ESP_NN + int buffer_idx; +#endif +}; + +static void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(NodeData)); +} + +void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input, + TfLiteEvalTensor* output, const NodeData* data) { + if (input->type == kTfLiteInt8) { + if (output->type == kTfLiteInt16) { + tflite::reference_ops::Softmax( + data->op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { +#if ESP_NN + const int32_t input_beta_multiplier = data->op_data.input_multiplier; + const int32_t input_beta_left_shift = data->op_data.input_left_shift; + const int diff_min = data->op_data.diff_min; + const RuntimeShape input_shape = tflite::micro::GetTensorShape(input); + const RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int8_t *in_ptr = tflite::micro::GetTensorData(input); + int8_t *out_ptr = tflite::micro::GetTensorData(output); + void *scratch_buf = NULL; + if (data->buffer_idx > -1) { + scratch_buf = context->GetScratchBuffer(context, data->buffer_idx); + } + esp_nn_set_softmax_scratch_buf(scratch_buf); + esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier, + input_beta_left_shift, diff_min, out_ptr); +#else + tflite::reference_ops::Softmax( + data->op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +#endif + } + } else { + tflite::reference_ops::SoftmaxInt16( + data->op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } +} + +static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + + TFLITE_DCHECK(node->user_data != nullptr); + NodeData data = *static_cast(node->user_data); + + long long start_time = esp_timer_get_time(); + switch (input->type) { + case kTfLiteFloat32: { + tflite::reference_ops::Softmax( + data.op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } + break; + case kTfLiteInt8: + case kTfLiteInt16: { + SoftmaxQuantized(context, input, output, &data); + } + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + softmax_total_time += esp_timer_get_time() - start_time; + return kTfLiteOk; +} + +static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + TF_LITE_ENSURE(context, output != nullptr); + + TF_LITE_ENSURE(context, node->user_data != nullptr); + NodeData* data = static_cast(node->user_data); + // Only allocate LUTs for KTfLiteInt16 data type + if (input->type == kTfLiteInt16) { + void* raw_exp_lut = context->AllocatePersistentBuffer( + context, sizeof(int16_t) * kInt16LUTArraySize); + TF_LITE_ENSURE(context, raw_exp_lut != nullptr); + data->op_data.exp_lut = reinterpret_cast(raw_exp_lut); + void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer( + context, sizeof(int16_t) * kInt16LUTArraySize); + TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr); + data->op_data.one_over_one_plus_x_lut = + reinterpret_cast(one_over_one_plus_x_lut); + } + + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE(context, + input->type == kTfLiteInt8 || input->type == kTfLiteInt16); + } else { + TF_LITE_ENSURE_EQ(context, input->type, output->type); + } + + // Populate LUT if required + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + // exp LUT only used on negative values + // we consider exp(-10.0) is insignificant to accumulation + gen_lut( + [](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f, + data->op_data.exp_lut); + gen_lut( + [](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f, + 1.0f, data->op_data.one_over_one_plus_x_lut); + data->op_data.zero_point = output->params.zero_point; + data->op_data.scale = output->params.scale; + } + + auto* params = static_cast(node->builtin_data); + auto ret_val = + CalculateSoftmaxParams(context, input, output, params, &data->op_data); + +#if ESP_NN + if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) { + const int32_t input_width = input->dims->data[1]; + const int32_t input_height = input->dims->data[2]; + int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width, + input_height); + if (scratch_buf_size > 0) { + TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( + context, scratch_buf_size, &data->buffer_idx)); + } + } +#endif + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + return ret_val; +} + +} // namespace + +TfLiteRegistration Register_SOFTMAX() { + return tflite::micro::RegisterOp(Init, Prepare, Eval); +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc index d1b0f6cb..ae26f636 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc @@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_EXP() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc index 6dcba4d5..4b105bf6 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc @@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_EXPAND_DIMS() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc index d8a2b09d..9f438b89 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc @@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FILL() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc index b8be1cf0..6b2a4cc2 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc @@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace floor TfLiteRegistration Register_FLOOR() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/floor::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc index d11e4969..333a1eba 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc @@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FLOOR_DIV() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc index 083bd5cb..9bb49497 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc @@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FLOOR_MOD() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc index c0be3814..a083edd7 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = micro_context->AllocateTempOutputTensor( node, kFullyConnectedOutputTensor); TF_LITE_ENSURE(context, output != nullptr); - TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); - TF_LITE_ENSURE_MSG(context, input->type == filter->type, - "Hybrid models are not supported on TFLite Micro."); TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected( context, params->activation, input->type, @@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } + case kTfLiteInt16: { + const int64_t* bias_data = + nullptr != bias ? tflite::micro::GetTensorData(bias) + : nullptr; + + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), bias_data, + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); @@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h index e1215da6..93026cd5 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() { } #endif + +#if defined(CMSIS_NN) +// Returns a TfLiteRegistration struct for kernel variant that only supports +// int16. +TfLiteRegistration Register_FULLY_CONNECTED_INT16(); + +#else +// Note that while this block gets used for both reference and optimized kernels +// that do not have any specialized implementations, the only goal here is to +// define fallback implementation that allow reference kernels to still be used +// from applications that call a more specific kernel variant. + +inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() { + return Register_FULLY_CONNECTED(); +} + +#endif + } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc index 0b7c23f9..6035efa7 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc @@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_GATHER() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc index c604ae15..eaa1abca 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc @@ -195,14 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_GATHER_ND() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc index 060dfc14..055e12e6 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc @@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_HARD_SWISH() { - return {/*init=*/HardSwishInit, - /*free=*/nullptr, - /*prepare=*/tflite::HardSwishPrepare, - /*invoke=*/HardSwishEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare, + HardSwishEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc index 050aeac4..39eca8b4 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc @@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_IF() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc index fd84dec1..341eec77 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include "tensorflow/lite/micro/micro_arena_constants.h" #include "tensorflow/lite/micro/micro_error_reporter.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/micro/test_helpers.h" namespace tflite { @@ -30,7 +30,7 @@ uint8_t KernelRunner::kKernelRunnerBuffer_[]; KernelRunner::KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, TfLiteIntArray* outputs, - void* builtin_data) + void* builtin_data, TfLiteIntArray* intermediates) : registration_(registration), allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(), kKernelRunnerBuffer_, @@ -54,6 +54,7 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration, node_.inputs = inputs; node_.outputs = outputs; node_.builtin_data = builtin_data; + node_.intermediates = intermediates; } bool KernelRunner::ValidateTempBufferDeallocated() { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h index 9dddde50..68722edb 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include "tensorflow/lite/micro/fake_micro_context.h" #include "tensorflow/lite/micro/mock_micro_graph.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" namespace tflite { namespace micro { @@ -35,7 +35,8 @@ class KernelRunner { public: KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, - TfLiteIntArray* outputs, void* builtin_data); + TfLiteIntArray* outputs, void* builtin_data, + TfLiteIntArray* intermediates = nullptr); // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any // exceptions will be DebugLog'd and returned as a status code. diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc index 14664b91..91c0bc91 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc @@ -36,6 +36,21 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index, } // namespace +TfLiteRegistration RegisterOp( + void* (*init)(TfLiteContext* context, const char* buffer, size_t length), + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node), + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) { + return {/*init=*/init, + /*free=*/nullptr, + /*prepare=*/prepare, + /*invoke=*/invoke, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0, + /*registration_external=*/nullptr}; +} + // Returns a mutable tensor for a given input index. is_variable must be checked // during prepare when the full TfLiteTensor is available. TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context, diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h index 6c5d7f18..d6f20c72 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h @@ -27,6 +27,11 @@ limitations under the License. namespace tflite { namespace micro { +TfLiteRegistration RegisterOp( + void* (*init)(TfLiteContext* context, const char* buffer, size_t length), + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node), + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)); + // Returns a mutable tensor for a given input index. is_variable must be checked // during prepare when the full TfLiteTensor is available. TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context, @@ -40,19 +45,33 @@ const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context, TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context, const TfLiteNode* node, int index); -// Returns data for a TfLiteEvalTensor struct. +// Returns data for a TfLiteEvalTensor struct that are expected to exist. template T* GetTensorData(TfLiteEvalTensor* tensor) { - return tensor != nullptr ? reinterpret_cast(tensor->data.raw) : nullptr; + TFLITE_DCHECK(tensor != nullptr); + return reinterpret_cast(tensor->data.raw); } -// Returns const data for a TfLiteEvalTensor struct. +// Returns const data for a TfLiteEvalTensor struct that are expected to exist. template const T* GetTensorData(const TfLiteEvalTensor* tensor) { TFLITE_DCHECK(tensor != nullptr); return reinterpret_cast(tensor->data.raw); } +// Returns data for a TfLiteEvalTensor struct that could be null. +template +T* GetOptionalTensorData(TfLiteEvalTensor* tensor) { + return tensor == nullptr ? nullptr : reinterpret_cast(tensor->data.raw); +} + +// Returns const data for a TfLiteEvalTensor struct that could be null. +template +const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) { + return tensor == nullptr ? nullptr + : reinterpret_cast(tensor->data.raw); +} + // Returns the shape of a TfLiteEvalTensor struct. const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor); diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc index 250cd3be..2b2a27bf 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc @@ -136,14 +136,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_L2_POOL_2D() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/L2Prepare, - /*invoke=*/L2Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, L2Prepare, L2Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc index 289e4de5..45858e78 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc @@ -137,14 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace l2norm TfLiteRegistration Register_L2NORM_REF() { - return {/*init=*/l2norm::Init, - /*free=*/nullptr, - /*prepare=*/l2norm::Prepare, - /*invoke=*/l2norm::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(l2norm::Init, l2norm::Prepare, l2norm::Eval); } TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc index 70ee3856..96c1b1b1 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc @@ -88,14 +88,8 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_LEAKY_RELU() { - return {/*init=*/LeakyReluInit, - /*free=*/nullptr, - /*prepare=*/LeakyReluPrepare, - /*invoke=*/LeakyReluEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(LeakyReluInit, LeakyReluPrepare, + LeakyReluEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc index 0af74def..5fd87612 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc @@ -142,14 +142,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_LOG_SOFTMAX() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/LogSoftmaxPrepare, - /*invoke=*/LogSoftmaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, LogSoftmaxPrepare, LogSoftmaxEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc index e2d2b5f8..c85e0c5b 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc @@ -34,29 +34,11 @@ TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_LOGICAL_OR() { - // Init, Free, Prepare, Eval are satisfying the Interface required by - // TfLiteRegistration. - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/LogicalOrEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, LogicalOrEval); } TfLiteRegistration Register_LOGICAL_AND() { - // Init, Free, Prepare, Eval are satisfying the Interface required by - // TfLiteRegistration. - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/LogicalAndEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, LogicalAndEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc index 77f94ec0..f8ac1c23 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc @@ -106,13 +106,6 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_LOGISTIC() { - return {/*init=*/LogisticInit, - /*free=*/nullptr, - /*prepare=*/LogisticPrepare, - /*invoke=*/LogisticEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(LogisticInit, LogisticPrepare, LogisticEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc new file mode 100644 index 00000000..f157a8d0 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc @@ -0,0 +1,2955 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/micro/kernels/lstm_eval.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h" +namespace tflite { +namespace { + +void ComputeRowSums( + int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums, + int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums, + int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums, + int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums, + int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums, + int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums, + int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell, + int n_input, int n_aux_input, int n_output, + const int8_t* input_to_input_weights_ptr, + const int8_t* input_to_forget_weights_ptr, + const int8_t* input_to_cell_weights_ptr, + const int8_t* input_to_output_weights_ptr, + const int8_t* aux_input_to_input_weights_ptr, + const int8_t* aux_input_to_forget_weights_ptr, + const int8_t* aux_input_to_cell_weights_ptr, + const int8_t* aux_input_to_output_weights_ptr, + const int8_t* recurrent_to_input_weights_ptr, + const int8_t* recurrent_to_forget_weights_ptr, + const int8_t* recurrent_to_cell_weights_ptr, + const int8_t* recurrent_to_output_weights_ptr, + const int8_t* projection_weights_ptr, bool use_cifg, + const float* aux_input_ptr) { + // Compute the row sums for dequantization + if (!use_cifg) { + micro_tensor_utils::ReductionSumVector( + input_to_input_weights_ptr, input_to_input_row_sums, n_cell, n_input); + } + micro_tensor_utils::ReductionSumVector( + input_to_forget_weights_ptr, input_to_forget_row_sums, n_cell, n_input); + micro_tensor_utils::ReductionSumVector( + input_to_cell_weights_ptr, input_to_cell_row_sums, n_cell, n_input); + micro_tensor_utils::ReductionSumVector( + input_to_output_weights_ptr, input_to_output_row_sums, n_cell, n_input); + + if (aux_input_ptr) { + if (!use_cifg) { + micro_tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr, + aux_input_to_input_row_sums, + n_cell, n_aux_input); + } + micro_tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr, + aux_input_to_forget_row_sums, n_cell, + n_aux_input); + micro_tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr, + aux_input_to_cell_row_sums, n_cell, + n_aux_input); + micro_tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr, + aux_input_to_output_row_sums, n_cell, + n_aux_input); + } + if (!use_cifg) { + micro_tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr, + recurrent_to_input_row_sums, n_cell, + n_output); + } + micro_tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr, + recurrent_to_forget_row_sums, n_cell, + n_output); + micro_tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr, + recurrent_to_cell_row_sums, n_cell, + n_output); + micro_tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr, + recurrent_to_output_row_sums, n_cell, + n_output); + + if (projection_weights_ptr != nullptr) { + micro_tensor_utils::ReductionSumVector( + projection_weights_ptr, projection_weights_row_sums, n_output, n_cell); + } +} + +// Calculates a single LSTM gate. +// +// Implements the following formula: (* is matrix multiply) +// gate = activate(W_input * input + W_aux * aux_input + +// W_peephole * cell + W_recurrent * prev_output + bias) +// with layer norm: +// gate = activate(W_norm * normalize(...) + bias) // not adding bias inside +// +// Activation is sigmoid except for the "cell" gate (configurable, usually tanh) +// +// Parameters: +// Input vectors (to LSTM): | Size: | Optional? +// input | n_input | +// aux_input | n_aux_input | y (bidir LSTM) +// Input vectors (persistent states): +// output_state | n_output | +// cell_state | n_cell | +// 'Constant' inputs: +// input_to_gate_weights | n_cell * n_input | +// aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM) +// recurrent_to_gate_weights | n_cell * n_output | +// cell_to_gate_weights | n_cell | y (peephole) +// gate_bias | n_cell | +// layer_norm_coefficients | n_cell | y (layer norm) +// Output vector: +// gate | n_cell | +// Scalar parameters: +// n_batch - batch size / number of vectors +// n_input, n_aux_input, n_output, n_cell - size of vectors. +// activation - activation to use. +// is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero. +// use_layer_norm - if doing layer norm LSTM. +inline void CalculateLstmGateFloat( + const float* input, const float* input_to_gate_weights, + const float* aux_input, const float* aux_input_to_gate_weights, + const float* output_state, const float* recurrent_to_gate_weights, + const float* cell_state, const float* cell_to_gate_weights, + const float* layer_norm_coefficients, const float* gate_bias, + const int n_batch, const int n_input, const int n_aux_input, + const int n_output, const int n_cell, + const TfLiteFusedActivation activation, float* gate, + const bool is_input_all_zeros, const bool is_aux_input_all_zeros) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (use_layer_norm) { + memset(gate, 0, n_cell * n_batch * sizeof(float)); + } else { + micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, + gate); + } + // For each batch and cell: compute input_weight * input. + // Skip if input is all zeros. + if (!is_input_all_zeros) { + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, n_cell, n_input, input, n_batch, gate); + } + // For each batch and cell: compute aux_input_weight * aux_input. + // Skip if auxiliary input is not available or all zeros. + if (!is_aux_input_all_zeros) { + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, n_batch, + gate); + } + // For each batch and cell: compute recurrent_weight * output_state. + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate); + // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) + if (use_peephole) { + micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_gate_weights, n_cell, cell_state, n_batch, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch); + micro_tensor_utils::VectorBatchVectorCwiseProduct( + layer_norm_coefficients, n_cell, gate, n_batch, gate); + micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate); + } + // Apply activation + micro_tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, + activation, gate); +} + +// Updates the LSTM cell state, used by both float and hybrid LSTM versions. +// +// Implements the following formula: +// cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate) +// +// With CIFG LSTM, input gate is replaced by (1-forget_gate). +// +// Parameters: +// - n_batch, n_cell: sizes of vectors +// - cell_state: input/output vector, size n_batch*n_cell +// - input_gate: input vector, size n_batch*n_cell. +// - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG +// - cell_gate: input vector, size n_batch*n_cell. +// - use_cifg: use 1-forget_gate instead of input_gate. +// - clip: if > 0, clip the resulting cell state to [-clip, +clip]. +void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state, + const float* input_gate, float* forget_gate, + const float* cell_gate, bool use_cifg, float clip) { + micro_tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state, + n_batch * n_cell, cell_state); + + if (use_cifg) { + // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as + // scratch, as input_gate array is not allocated in this case. (Be careful + // not to write to the scratch before reading the forget gate data.) + float* scratch = forget_gate; + micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch); + micro_tensor_utils::VectorVectorCwiseProductAccumulate( + cell_gate, scratch, n_batch * n_cell, cell_state); + } else { + micro_tensor_utils::VectorVectorCwiseProductAccumulate( + cell_gate, input_gate, n_batch * n_cell, cell_state); + } + if (clip > 0.0f) { + micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip); + } +} + +// Calculates the output state tensor of an LSTM step. +// +// Implements the following formula: +// output_no_projection = output_gate .* activate(cell_state) +// (elementwise vector product) +// If no projection is used: +// output = output_state = output_no_projection +// With projection: +// output = output_state = clip(W*output_no_projection + bias) +// +// Output might not have a different 'stride' than n_batch, so we need to copy. +// +// Parameters: +// - n_batch: batches: the number of distinct vectors in each array. +// - n_cell, n_output: sizes of vectors. +// - cell_state, output_gate: input vectors, size n_batch*n_cell. +// - projection_weights, projection_weights_scale, projection_bias: +// constant inputs, describing projection matrix and bias. +// - proj_clip: if > 0, clip the output of the projection. +// - output_state: output vector, size n_batch*n_output. Must be contigous. +// - scratch: scratch area, size n_batch*n_cell. +void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, + const float* cell_state, const float* output_gate, + TfLiteFusedActivation activation, + const float* projection_weights, + const float* projection_bias, + const float proj_clip, float* output_state, + float* scratch) { + micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, + activation, scratch); + micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, + n_batch * n_cell, scratch); + + const bool use_projection = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + + if (use_projection) { + if (use_projection_bias) { + micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, + n_batch, output_state); + } else { + memset(output_state, 0, n_batch * n_output * sizeof(float)); + } + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights, n_output, n_cell, scratch, n_batch, output_state); + if (proj_clip > 0.0f) { + micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output, + proj_clip); + } + } else { + std::memcpy(output_state, scratch, n_batch * n_output * sizeof(float)); + } +} + +// Calculates a single LSTM gate, hybrid version. +// Implements the same functionality as CalculateLstmGateFloat. +void CalculateLstmGateHybrid( + // Input and weights + const int8_t* input, const float* input_sf, const int32_t* input_zp, + const int8_t* input_to_gate_weights, + const uint8_t* input_to_gate_weights_ledger, + const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums, + // Aux input and weights + const int8_t* aux_input, const float* aux_input_sf, + const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights, + const float aux_input_to_gate_weights_scale, + int32_t* aux_input_to_gate_row_sums, + // Output state and weights + const int8_t* output_state, const float* output_state_sf, + const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights, + const uint8_t* recurrent_to_gate_weights_ledger, + const float recurrent_to_gate_weights_scale, + int32_t* recurrent_to_gate_row_sums, + // Cell state and weights (peephole LSTM) + const float* cell_state, const int8_t* cell_to_gate_weights, + const float cell_to_gate_weights_scale, + // Layer normalization coefficients (layer norm LSTM) + gate bias + const float* layer_norm_coefficients, const float* gate_bias, + // Array sizes + const int n_batch, const int n_input, const int n_aux_input, + const int n_output, const int n_cell, + const TfLiteFusedActivation activation, + // Output + float* gate, + // Parameters for performance optimizations + const bool is_input_all_zeros, const bool is_aux_input_all_zeros, + const bool is_output_state_all_zeros, bool* compute_row_sums, + // Scratch arrays + float* scratch0, // size: n_batch + float* scratch1, // size: n_cell, only used if peephole LSTM + float* scales, // size: n_batch + int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate +) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (use_layer_norm) { + memset(gate, 0, n_cell * n_batch * sizeof(float)); + } else { + micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, + gate); + } + // For each batch and cell: compute input_weight * input. + // Skip if input is all zeros. + if (!is_input_all_zeros) { + if (input_to_gate_weights_ledger != nullptr) { + for (int i = 0; i < n_batch; i++) { + scales[i] = input_to_gate_weights_scale * input_sf[i]; + } + micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input, + input, scales, n_batch, gate); + + } else { + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, n_cell, n_input, input, + input_to_gate_weights_scale, input_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch, + input_to_gate_row_sums, compute_row_sums, scratch0, nullptr); + } + } + // For each batch and cell: compute aux_input_weight * aux_input. + // Skip if auxiliary input is not available or all zeros. + if (!is_aux_input_all_zeros) { + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, + aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch, + aux_input_to_gate_row_sums, compute_row_sums, scratch0, nullptr); + } + // For each batch and cell: compute recurrent_weight * output_state. + // Skip if output state is all zeros. + if (!is_output_state_all_zeros) { + if (recurrent_to_gate_weights_ledger != nullptr) { + for (int i = 0; i < n_batch; i++) { + scales[i] = recurrent_to_gate_weights_scale * input_sf[i]; + } + micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell, + n_output, output_state, scales, n_batch, gate); + } else { + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, n_cell, n_output, output_state, + recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch, + recurrent_to_gate_row_sums, compute_row_sums, scratch0, nullptr); + } + } + // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) + if (use_peephole) { + float* recovered_cell_weights = scratch1; + micro_tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell, + cell_to_gate_weights_scale, + recovered_cell_weights); + micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state, n_batch, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch); + micro_tensor_utils::VectorBatchVectorCwiseProduct( + layer_norm_coefficients, n_cell, gate, n_batch, gate); + micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate); + } + // Apply activation + micro_tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, + activation, gate); +} + +// Calculates the output state tensor of an LSTM step. See Float version too. +// +// Parameters: +// - n_batch: batches: the number of distinct vectors in each array. +// - n_cell, n_output: sizes of vectors. +// - cell_state, output_gate: input vectors, size n_batch*n_cell. +// - projection_weights, projection_weights_scale, projection_bias: +// constant inputs, describing projection matrix and bias. +// - proj_clip: if > 0, clip the output of the projection. +// - output_state: output vector, size n_batch*n_output. Must be contigous. +// - asymmetric_quantize_inputs: parameter to control quantization. +// - projection_weights_row_sums, compute_row_sums: Data for optimized +// MatrixBatchVectorMultiplyAccumulate. +// - scratch0: scratch area of size n_batch*n_cell +// - scratch1: scratch area of size n_batch*n_cell +// - scratch2: scratch area of size n_batch +// - scratch3: scratch area of size n_batch +// - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate +// - scales: scratch area of size n_batch +void CalculateLstmOutputHybrid( + int n_batch, int n_cell, int n_output, const float* cell_state, + const float* output_gate, TfLiteFusedActivation activation, + const int8_t* projection_weights, const uint8_t* projection_weights_ledger, + float projection_weights_scale, const float* projection_bias, + const float proj_clip, float* output_state, bool asymmetric_quantize_inputs, + int32_t* projection_weights_row_sums, bool* compute_row_sums, + float* scratch0, int8_t* scratch1, float* scratch2, int32_t* scratch3, + int32_t* scratch4, float* scales) { + micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, + activation, scratch0); + micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0, + n_batch * n_cell, scratch0); + + const bool use_projection = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + + if (use_projection) { + if (use_projection_bias) { + micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, + n_batch, output_state); + } else { + memset(output_state, 0, n_batch * n_output * sizeof(float)); + } + if (!micro_tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero output. + micro_tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, + scratch1, scratch2, scratch3, + asymmetric_quantize_inputs); + if (projection_weights_ledger != nullptr) { + for (int i = 0; i < n_batch; i++) { + scales[i] = projection_weights_scale * scratch2[i]; + } + micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + projection_weights, projection_weights_ledger, n_output, n_cell, + scratch1, scales, n_batch, output_state); + } else { + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights, n_output, n_cell, scratch1, + projection_weights_scale, scratch2, n_batch, output_state, + /*per_channel_scale=*/nullptr, scratch3, scratch4, + projection_weights_row_sums, compute_row_sums, scratch2, nullptr); + } + } + if (proj_clip > 0.0f) { + micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output, + proj_clip); + } + } else { + std::memcpy(output_state, scratch0, n_batch * n_output * sizeof(float)); + } +} + +// Calculates a single LSTM gate, int8x8_16 version. +// Implements the same functionality as CalculateLstmGateFloat. +void CalculateLstmGateInteger8x8_16( + // Input and weights + const int8_t* input, const int8_t* input_to_gate_weights, + const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a, + const int32_t input_to_gate_scale_b, + // Output state and weights + const int8_t* output_state, const int8_t* recurrent_to_gate_weights, + const int32_t* recurrent_to_gate_bias, + const int32_t recurrent_to_gate_scale_a, + const int32_t recurrent_to_gate_scale_b, + // Cell state and weights + const int16_t* cell_state, const int16_t* cell_to_gate_weights, + const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b, + // Layer normalization parameters (layer norm LSTM) + const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias, + const int32_t layer_norm_input_scale_a, + const int32_t layer_norm_input_scale_b, + const int32_t layer_norm_variance_guard, + // Array sizes + const int n_batch, const int n_input, const int n_output, const int n_cell, + const TfLiteFusedActivation activation, + // Output + int16_t* gate, + // Parameters for performance optimizations + // Scratch arrays + int32_t* scratch5) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with zeros. Note that unlike float and hybrid + // versions, bias is only used in layer normalization. + memset(gate, 0, n_batch * n_cell * sizeof(int16_t)); + // For each batch and cell: compute input_weight * input. + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a, + input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate, + nullptr); + // Note: no aux_input. + + // For each batch and cell: compute recurrent_weight * output_state. + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + output_state, recurrent_to_gate_bias, recurrent_to_gate_weights, + recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output, + n_cell, 0, scratch5, gate, nullptr); + // For each batch and cell: compute cell_weight * cell_state (peephole LSTM) + if (use_peephole) { + micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_gate_weights, n_output, cell_state, n_batch, + cell_to_gate_scale_a, cell_to_gate_scale_b, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + micro_tensor_utils::ApplyLayerNorm( + gate, layer_norm_coefficients, layer_norm_bias, + layer_norm_input_scale_a, layer_norm_input_scale_b, + layer_norm_variance_guard, n_batch, n_cell, gate); + } + // Apply activation + switch (activation) { + case kTfLiteActSigmoid: + micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate); + break; + case kTfLiteActTanh: + micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate); + break; + default: + // Only Sigmoid or Tanh is used. + TFLITE_ASSERT_FALSE; + } +} + +// Updates the LSTM cell state, used by both integer LSTM versions. +// Also see UpdateLstmCellFloat. +// +// Parameters: +// - n_batch, n_cell: sizes of vectors +// - cell_state: input/output vector, size n_batch*n_cell +// - cell_state_scale: scaling factor of cell state. +// - input_gate: input vector, size n_batch*n_cell. +// - forget_gate: input/scratch vector, size n_batch*n_cell, always modified. +// - cell_gate: input vector, size n_batch*n_cell. +// - use_cifg: use 1-forget_gate instead of input_gate. +// - clip: if > 0, clip the resulting cell state to [-clip, +clip]. +void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state, + int32_t cell_state_scale, const int16_t* input_gate, + int16_t* forget_gate, const int16_t* cell_gate, + bool use_cifg, int16_t clip) { + // Use the forget_gate array as scratch, as input_gate array is not allocated + // in CIFG case. (Be careful not to write to the scratch before reading the + // forget gate data.) + int16_t* scratch = forget_gate; + + micro_tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15, + cell_state); + if (use_cifg) { + micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch); + micro_tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell, + 30 + cell_state_scale, scratch); + } else { + micro_tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell, + 30 + cell_state_scale, scratch); + } + micro_tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell, + cell_state); + + if (clip > 0) { + micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip); + } +} + +// Calculates the output state tensor of an LSTM step. See Float and hybrid +// versions as well. +// +// Parameters: +// - n_batch: batches: the number of distinct vectors in each array. +// - n_cell, n_output: sizes of vectors. +// - cell_state, output_gate: input vectors, size n_batch*n_cell. +// - cell_state_scale: scaling of cell_state. +// - hidden_scale_[a|b]: effective scale of cell_state.*output_gate +// - hidden_zp: zero_point for cell_state.*output_gate +// - projection_weights, proj_scale_[a|b], projection_bias: +// constant inputs, describing projection matrix and bias. +// - output_state_zp: zero point of output_state. (Input, calibrated value.) +// - quantized_proj_clip: if > 0, clip the output of the projection. +// - output_state: output vector, size n_batch*n_output. Must be contigous. +// - scratch0: scratch area of size n_batch*n_cell +// - scratch1: scratch area of size n_batch*n_cell +// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate +void CalculateLstmOutputInteger8x8_16( + int n_batch, int n_cell, int n_output, const int16_t* cell_state, + int32_t cell_state_scale, const int16_t* output_gate, + int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp, + const int8_t* projection_weights, int32_t proj_scale_a, + int32_t proj_scale_b, const int32_t* projection_bias, + int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state, + int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) { + // Note: unlike float/hybrid, the activation is always Tanh. + micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, + n_cell, scratch0); + micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, + hidden_scale_b, n_batch, n_cell, hidden_zp, + scratch1); + + const bool use_projection = (projection_weights != nullptr); + + if (use_projection) { + // Note: no bias like in float/hybrid + memset(output_state, 0, n_batch * n_output * sizeof(int8_t)); + micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate( + scratch1, projection_bias, projection_weights, proj_scale_a, + proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2, + output_state, nullptr); + if (quantized_proj_clip > 0) { + micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output, + quantized_proj_clip); + } + } else { + std::memcpy(output_state, scratch1, n_batch * n_output * sizeof(int8_t)); + } +} + +// Calculates a single LSTM gate, int8x8_8 version. +// Implements the same functionality as CalculateLstmGateFloat. +void CalculateLstmGateInteger8x8_8( + // Inputs and weights + const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight, + const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b, + const int32_t input_times_weights_scale_a, + const int32_t input_times_weights_scale_b, + const int32_t input_times_weights_zp, + // Output state and weights + const int8_t* output_state, const int32_t output_state_zp, + const int8_t* recurrent_to_gate_weight, + const int32_t recurrent_to_gate_scale_a, + const int32_t recurrent_to_gate_scale_b, + const int32_t output_state_times_weights_scale_a, + const int32_t output_state_times_weights_scale_b, + const int32_t output_state_times_weights_zp, + // Layer normalization parameters (layer norm LSTM) + const int16_t* layer_norm_gate_weight, + const int32_t layer_norm_gate_scale_a, + const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias, + // Array sizes + const int n_batch, const int n_input, const int n_output, const int n_cell, + const TfLiteFusedActivation activation, + // Output + int16_t* gate, + // Scratch arrays, both sized n_batch*n_cell + int8_t* scratch0, int8_t* scratch1) { + // Multiply input * input_weights => scratch0 + micro_tensor_utils::MatrixBatchVectorMultiply( + input, input_zp, input_to_gate_weight, input_to_gate_scale_a, + input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0, + input_times_weights_zp); + // Multiply output_state * recurrent_weights => scratch1 + micro_tensor_utils::MatrixBatchVectorMultiply( + output_state, output_state_zp, recurrent_to_gate_weight, + recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output, + n_cell, scratch1, output_state_times_weights_zp); + // Add scratch0 + scratch1 => gate + micro_tensor_utils::TwoGateSaturatingAdd( + scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp, + input_times_weights_scale_a, input_times_weights_scale_b, + output_state_times_weights_scale_a, output_state_times_weights_scale_b, + n_batch, n_cell, gate); + // Apply layer normalization. + micro_tensor_utils::ApplyLayerNormFloat( + gate, layer_norm_gate_weight, layer_norm_gate_scale_a, + layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate); + // Apply activation. + switch (activation) { + case kTfLiteActSigmoid: + micro_tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate); + break; + case kTfLiteActTanh: + micro_tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate); + break; + default: + // Only Sigmoid or Tanh is used. + TFLITE_ASSERT_FALSE; + } +} + +// Calculates the output state tensor of an LSTM step. See Float and hybrid +// versions as well. +// +// Parameters: +// - n_batch: batches: the number of distinct vectors in each array. +// - n_cell, n_output: sizes of vectors. +// - cell_state, output_gate: input vectors, size n_batch*n_cell. +// - projection_weights, proj_scale_[a|b], projection_bias: +// constant inputs, describing projection matrix and bias. +// - output_state_zp: zero point of the output state. +// - quantized_proj_clip: if > 0, clip the output of the projection. +// - output_state: output vector, size n_batch*n_output. Must be contigous. +// - scratch: scratch area of size n_batch*n_cell +void CalculateLstmOutputInteger8x8_8( + int n_batch, int n_cell, int n_output, const int16_t* cell_state, + const int16_t* output_gate, const int8_t* projection_weights, + int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias, + int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state, + int16_t* scratch) { + // Note: unlike float/hybrid, the activation is always Tanh. + micro_tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch); + micro_tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, + 15 + 15 - 15, scratch); + // Note: no bias like in float/hybrid + micro_tensor_utils::MatrixBatchVectorMultiply( + scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias, + n_batch, n_cell, n_output, output_state_zp, output_state); + if (quantized_proj_clip > 0) { + micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output, + quantized_proj_clip); + } +} + +// Performs an LSTM batch inference step for input specified by input_ptr. +// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and +// biases (*_bias_ptr), and buffers (*_scratch), along with additional +// parameters: +// - params: various LSTM params including activation, clipping, etc., +// - n_batch: size of batch, +// - n_cell: number of cells (or units), +// - n_input: the input size, +// - n_aux_input: the auxiliary input size. +// - n_output: the output size. +// - output_batch_leading_dim: the leading dimension of the output buffer. +// +// Input of size 'n_batch * n_input': +// input_ptr +// Input of size 'n_batch * n_aux_input': +// aux_input_ptr - optional (can be nullptr) +// +// LSTM weights: +// Input weights of size 'n_cell * n_input': +// input_to_input_weights - optional +// input_to_forget_weights +// input_to_cell_weights +// input_to_output_weights +// Auxiliary input weights of size 'n_cell * n_aux_input': +// aux_input_to_input_weights - optional +// aux_input_to_forget_weights - optional +// aux_input_to_cell_weights - optional +// aux_input_to_output_weights - optional +// Recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Layer norm coefficients of size 'n_cell', representing diagonal matrices. +// input_layer_norm_coefficients_ptr - optional +// forget_layer_norm_coefficients_ptr - optional +// cell_layer_norm_coefficients_ptr - optional +// output_layer_norm_coefficients_ptr - optional +// +// The pointers to the cell and output state and the output are updated. +// +// The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned +// in batch_major order, and each step processes batch_size many inputs from +// input_ptr, and updates batch_size many cell and output states. +// +// The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the +// output tensor, and in most cases will be equal to n_output. It is usually not +// when we want to store the LSTM output into a slice of the output tensor, e.g. +// for bidirectional LSTMs with merge_outputs. In this case, the batched +// operations cannot be used since they assume that the batched outputs are +// contiguous, and we manually loop over the batched outputs. +inline void LstmStepFloat( + const float* input_ptr, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, const float* aux_input_ptr, + const float* aux_input_to_input_weights_ptr, + const float* aux_input_to_forget_weights_ptr, + const float* aux_input_to_cell_weights_ptr, + const float* aux_input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, + const float* input_layer_norm_coefficients_ptr, + const float* forget_layer_norm_coefficients_ptr, + const float* cell_layer_norm_coefficients_ptr, + const float* output_layer_norm_coefficients_ptr, + const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, + const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, + const float* projection_weights_ptr, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_aux_input, int n_output, int output_batch_leading_dim, + float* output_state_ptr, float* cell_state_ptr, float* scratch0, + float* scratch1, float* scratch2, float* scratch3, float* output_ptr) { + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + + // Make named scratch buffers. + float* input_gate_scratch = scratch0; + float* forget_gate_scratch = scratch1; + float* cell_gate_scratch = scratch2; + float* output_gate_scratch = scratch3; + + // Check if inputs are all zeros so we can skip some computations. + const bool is_input_all_zeros = + micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); + const bool is_aux_input_all_zeros = + (aux_input_ptr == nullptr || + micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); + if (!use_cifg) { + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateFloat( + input_ptr, input_to_input_weights_ptr, aux_input_ptr, + aux_input_to_input_weights_ptr, output_state_ptr, + recurrent_to_input_weights_ptr, cell_state_ptr, + cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr, + input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, input_gate_scratch, + is_input_all_zeros, is_aux_input_all_zeros); + } + // Calculate the forget gate. + CalculateLstmGateFloat( + input_ptr, input_to_forget_weights_ptr, aux_input_ptr, + aux_input_to_forget_weights_ptr, output_state_ptr, + recurrent_to_forget_weights_ptr, cell_state_ptr, + cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr, + forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros); + // Calculate the cell update gate. + CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr, + aux_input_to_cell_weights_ptr, output_state_ptr, + recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr, + /*cell_to_gate_weights=*/nullptr, + cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, + n_batch, n_input, n_aux_input, n_output, n_cell, + params->activation, cell_gate_scratch, + is_input_all_zeros, is_aux_input_all_zeros); + // Update the cell state. + UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_gate_scratch, use_cifg, + params->cell_clip); + // Calculate output gate. + CalculateLstmGateFloat( + input_ptr, input_to_output_weights_ptr, aux_input_ptr, + aux_input_to_output_weights_ptr, output_state_ptr, + recurrent_to_output_weights_ptr, cell_state_ptr, + cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr, + output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros); + // Update the output state. + CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, + output_gate_scratch, params->activation, + projection_weights_ptr, projection_bias_ptr, + params->proj_clip, output_state_ptr, scratch2); + // Copy output state to the output. Note that the output's rows may not be + // contiguous (output_batch_leading_dim != n_output). + for (int b = 0; b < n_batch; b++) { + std::memcpy(output_ptr + b * output_batch_leading_dim, + output_state_ptr + b * n_output, n_output * sizeof(float)); + } +} + +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr +// Input of size 'n_batch * n_aux_input': +// aux_input_ptr - optional (can be nullptr) +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized auxiliary input weights of size 'n_cell * n_aux_input': +// aux_input_to_input_weights - optional +// aux_input_to_forget_weights - optional +// aux_input_to_cell_weights - optional +// aux_input_to_output_weights - optional +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// aux_input_to_input_weights_scale - optional +// aux_input_to_forget_weights_scale - optional +// aux_input_to_cell_weights_scale - optional +// aux_input_to_output_weights_scale - optional +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Layer norm coefficients of size 'n_cell', representing diagonal matrices. +// input_layer_norm_coefficients_ptr - optional +// forget_layer_norm_coefficients_ptr - optional +// cell_layer_norm_coefficients_ptr - optional +// output_layer_norm_coefficients_ptr - optional +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr (same size as input_ptr) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_output_scratch (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr - size 'n_batch * output_batch_leading_dim' +inline void LstmStepHybrid( + const float* input_ptr, const int8_t* input_to_input_weights_ptr, + const uint8_t* input_to_input_weights_ledger_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + const uint8_t* input_to_forget_weights_ledger_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, + const uint8_t* input_to_cell_weights_ledger_ptr, + float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + const uint8_t* input_to_output_weights_ledger_ptr, + float input_to_output_weights_scale, const float* aux_input_ptr, + const int8_t* aux_input_to_input_weights_ptr, + float aux_input_to_input_weights_scale, + const int8_t* aux_input_to_forget_weights_ptr, + float aux_input_to_forget_weights_scale, + const int8_t* aux_input_to_cell_weights_ptr, + float aux_input_to_cell_weights_scale, + const int8_t* aux_input_to_output_weights_ptr, + float aux_input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + const uint8_t* recurrent_to_input_weights_ledger_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + const uint8_t* recurrent_to_forget_weights_ledger_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + const uint8_t* recurrent_to_cell_weights_ledger_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + const uint8_t* recurrent_to_output_weights_ledger_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, + const float* input_layer_norm_coefficients_ptr, + const float* forget_layer_norm_coefficients_ptr, + const float* cell_layer_norm_coefficients_ptr, + const float* output_layer_norm_coefficients_ptr, + const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, + const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, + const int8_t* projection_weights_ptr, + const uint8_t* projection_weights_ledger_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_aux_input, int n_output, int output_batch_leading_dim, + float* scratch0, float* scratch1, float* scratch2, float* scratch3, + float* scales, float* input_sf, float* aux_input_sf, float* output_state_sf, + float* scaling_factors_scratch, float* recovered_cell_weights, + int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, + int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch, + float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, + float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp, + int32_t* output_state_zp, int32_t* row_sums, int row_sums_size, + bool* compute_row_sums, bool asymmetric_quantize_inputs) { + // Since we have already checked that weights are all there or none, we + // can check the existence of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + // Make named scratch buffers for the different gates. + float* input_gate_scratch = scratch0; + float* forget_gate_scratch = scratch1; + float* cell_gate_scratch = scratch2; + float* output_gate_scratch = scratch3; + + int32_t* input_to_input_row_sums = nullptr; + int32_t* input_to_forget_row_sums = nullptr; + int32_t* input_to_cell_row_sums = nullptr; + int32_t* input_to_output_row_sums = nullptr; + int32_t* aux_input_to_input_row_sums = nullptr; + int32_t* aux_input_to_forget_row_sums = nullptr; + int32_t* aux_input_to_cell_row_sums = nullptr; + int32_t* aux_input_to_output_row_sums = nullptr; + int32_t* recurrent_to_input_row_sums = nullptr; + int32_t* recurrent_to_forget_row_sums = nullptr; + int32_t* recurrent_to_cell_row_sums = nullptr; + int32_t* recurrent_to_output_row_sums = nullptr; + int32_t* projection_weights_row_sums = nullptr; + + if (asymmetric_quantize_inputs) { + int num_row_sums = use_cifg ? 6 : 8; + if (aux_input_ptr != nullptr) { + num_row_sums += use_cifg ? 3 : 4; + } + if (projection_weights_ptr != nullptr) { + num_row_sums += ceil(static_cast(n_output) / n_cell); + } + TFLITE_DCHECK(row_sums_size == num_row_sums); + input_to_input_row_sums = row_sums; + input_to_forget_row_sums = + use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell; + input_to_cell_row_sums = input_to_forget_row_sums + n_cell; + input_to_output_row_sums = input_to_cell_row_sums + n_cell; + if (aux_input_ptr != nullptr) { + aux_input_to_input_row_sums = input_to_output_row_sums + n_cell; + aux_input_to_forget_row_sums = use_cifg + ? aux_input_to_input_row_sums + : aux_input_to_input_row_sums + n_cell; + aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell; + aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell; + } + recurrent_to_input_row_sums = aux_input_ptr + ? aux_input_to_output_row_sums + n_cell + : input_to_output_row_sums + n_cell; + recurrent_to_forget_row_sums = use_cifg + ? recurrent_to_input_row_sums + : recurrent_to_input_row_sums + n_cell; + recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell; + recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell; + if (projection_weights_ptr != nullptr) { + projection_weights_row_sums = recurrent_to_output_row_sums + n_cell; + } + if (*compute_row_sums) { + ComputeRowSums( + input_to_input_row_sums, input_to_forget_row_sums, + input_to_cell_row_sums, input_to_output_row_sums, + aux_input_to_input_row_sums, aux_input_to_forget_row_sums, + aux_input_to_cell_row_sums, aux_input_to_output_row_sums, + recurrent_to_input_row_sums, recurrent_to_forget_row_sums, + recurrent_to_cell_row_sums, recurrent_to_output_row_sums, + projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input, + n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + projection_weights_ptr, use_cifg, aux_input_ptr); + *compute_row_sums = false; + } + } + + // Check if inputs are all zeros so we can skip some computations. + const bool is_input_all_zeros = + micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); + const bool is_aux_input_all_zeros = + (aux_input_ptr == nullptr || + micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); + const bool is_output_state_all_zeros = + micro_tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output); + // Quantize inputs. + if (!is_input_all_zeros) { + micro_tensor_utils::BatchQuantizeFloats( + input_ptr, n_batch, n_input, quantized_input_ptr, input_sf, input_zp, + asymmetric_quantize_inputs); + } + if (!is_aux_input_all_zeros) { + micro_tensor_utils::BatchQuantizeFloats( + aux_input_ptr, n_batch, n_aux_input, quantized_aux_input_ptr, + aux_input_sf, aux_input_zp, asymmetric_quantize_inputs); + } + if (!is_output_state_all_zeros) { + micro_tensor_utils::BatchQuantizeFloats( + output_state_ptr, n_batch, n_output, quantized_output_state_ptr, + output_state_sf, output_state_zp, asymmetric_quantize_inputs); + } + if (!use_cifg) { + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr, + input_to_input_weights_ledger_ptr, input_to_input_weights_scale, + input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_input_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr, + recurrent_to_input_weights_scale, recurrent_to_input_row_sums, + cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale, + input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch, + n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, + input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, + is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch, + recovered_cell_weights, scales, accum_scratch_ptr); + } + // Calculate the forget gate. + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr, + input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale, + input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr, + recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums, + cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch, + n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, + forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, + is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch, + recovered_cell_weights, scales, accum_scratch_ptr); + // Calculate the cell update gate. + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr, + input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale, + input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr, + recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums, + /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr, + /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr, + cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + params->activation, cell_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums, + scaling_factors_scratch, recovered_cell_weights, scales, + accum_scratch_ptr); + // Update the cell state. + UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_gate_scratch, use_cifg, + params->cell_clip); + // Calculate the output gate. + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr, + input_to_output_weights_ledger_ptr, input_to_output_weights_scale, + input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, aux_input_to_output_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr, + recurrent_to_output_weights_scale, recurrent_to_output_row_sums, + cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale, + output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch, + n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, + output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, + is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch, + recovered_cell_weights, scales, accum_scratch_ptr); + // Update the output state. + CalculateLstmOutputHybrid( + n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, + params->activation, projection_weights_ptr, projection_weights_ledger_ptr, + projection_weights_scale, projection_bias_ptr, params->proj_clip, + output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums, + compute_row_sums, scratch2, quantized_output_scratch, input_sf, input_zp, + accum_scratch_ptr, scales); + // Copy output state to the output. Note that the output's rows may not be + // contiguous (output_batch_leading_dim != n_output). + for (int b = 0; b < n_batch; b++) { + std::memcpy(output_ptr + b * output_batch_leading_dim, + output_state_ptr + b * n_output, n_output * sizeof(float)); + } +} + +// Fully quantized lstm kernel for 16 bit gate matmul output. +// +// Input tensor of size n_batch * n_input: +// input_ptr +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weight_ptr - optional +// input_to_forget_weight_ptr - optional +// input_to_cell_weight_ptr - optional +// input_to_output_weight_ptr - optional +// +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weight_ptr - optional +// recurrent_to_forget_weights_ptr +// recurrent_to_cell_weights_ptr +// recurrent_to_input_weights_ptr +// +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// +// Quantized projection weights of size 'n_output * n_cell' +// projection_weight_ptr - optional +// +// Weight scales (scalars) for each of the weights above. +// effective_input_to_input_scale_a - optional +// effective_input_to_input_scale_b - optional +// effective_input_to_forget_scale_a +// effective_input_to_forget_scale_b +// effective_input_to_cell_scale_a +// effective_input_to_cell_scale_b +// effective_input_to_output_scale_a +// effective_input_to_output_scale_b +// effective_recurrent_to_input_scale_a - optional +// effective_recurrent_to_input_scale_b - optional +// effective_recurrent_to_forget_scale_a +// effective_recurrent_to_forget_scale_b +// effective_recurrent_to_cell_scale_a +// effective_recurrent_to_cell_scale_b +// effective_recurrent_to_output_scale_a +// effective_recurrent_to_output_scale_b +// effective_proj_scale_a - optional +// effective_proj_scale_b - optional +// +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Layer norm coefficients of size 'n_cell', representing diagonal matrices. +// layer_norm_input_weight_ptr - optional +// layer_norm_forget_weight_ptr - optional +// layer_norm_cell_weight_ptr - optional +// layer_norm_output_weight_ptr - optional +// +// Layer norm scales of size 'n_cell'. +// layer_norm_input_scale_a - optional +// layer_norm_input_scale_b - optional +// layer_norm_forget_scale_a - optional +// layer_norm_forget_scale_b - optional +// layer_norm_cell_scale_a - optional +// layer_norm_cell_scale_b - optional +// layer_norm_output_scale_a - optional +// layer_norm_output_scale_b - optional +// +// Scalar values: +// quantized_cell_clip: quantized clip value for cell. +// quantized_proj_clip: quantized clip value for projection. +// cell_state_scale: the power of two scale for cell state. +// +// Zero points: +// output_state_zp: zero point of output state +// hidden_zp: zero point for hidden state. +// +// Temporary pre-allocated storage for the calculation. Each is of size n_cell * +// n_batch. +// scratch0 +// scratch1 +// scratch2 +// scratch3 +// scratch4 +// scratch5: this scratch buffer is created purely for optimizing the +// MatrixBatchVectorMultiplyAccumulate. +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr - size 'n_batch * n_output' +// TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then. +inline void LstmStepInteger8x8_16( + const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr, + int32_t effective_input_to_input_scale_a, + int32_t effective_input_to_input_scale_b, + const int8_t* input_to_forget_weight_ptr, + int32_t effective_input_to_forget_scale_a, + int32_t effective_input_to_forget_scale_b, + const int8_t* input_to_cell_weight_ptr, + int32_t effective_input_to_cell_scale_a, + int32_t effective_input_to_cell_scale_b, + const int8_t* input_to_output_weight_ptr, + int32_t effective_input_to_output_scale_a, + int32_t effective_input_to_output_scale_b, + const int8_t* recurrent_to_input_weight_ptr, + int32_t effective_recurrent_to_input_scale_a, + int32_t effective_recurrent_to_input_scale_b, + const int8_t* recurrent_to_forget_weight_ptr, + int32_t effective_recurrent_to_forget_scale_a, + int32_t effective_recurrent_to_forget_scale_b, + const int8_t* recurrent_to_cell_weight_ptr, + int32_t effective_recurrent_to_cell_scale_a, + int32_t effective_recurrent_to_cell_scale_b, + const int8_t* recurrent_to_output_weight_ptr, + int32_t effective_recurrent_to_output_scale_a, + int32_t effective_recurrent_to_output_scale_b, + const int16_t* cell_to_input_weight_ptr, + int32_t effective_cell_to_input_scale_a, + int32_t effective_cell_to_input_scale_b, + const int16_t* cell_to_forget_weight_ptr, + int32_t effective_cell_to_forget_scale_a, + int32_t effective_cell_to_forget_scale_b, + const int16_t* cell_to_output_weight_ptr, + int32_t effective_cell_to_output_scale_a, + int32_t effective_cell_to_output_scale_b, + const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a, + int32_t effective_proj_scale_b, int32_t hidden_zp, + int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b, + const int16_t* layer_norm_input_weight_ptr, + int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b, + const int16_t* layer_norm_forget_weight_ptr, + int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b, + const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a, + int32_t layer_norm_cell_scale_b, + const int16_t* layer_norm_output_weight_ptr, + int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, + const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr, + const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr, + int16_t quantized_cell_clip, int8_t quantized_proj_clip, + int32_t cell_state_scale, int32_t input_variance_guard, + int32_t forget_variance_guard, int32_t cell_variance_guard, + int32_t output_variance_guard, + const int32_t* input_to_forget_effective_bias, + const int32_t* recurrent_to_forget_effective_bias, + const int32_t* input_to_cell_effective_bias, + const int32_t* recurrent_to_cell_effective_bias, + const int32_t* input_to_output_effective_bias, + const int32_t* recurrent_to_output_effective_bias, + const int32_t* input_to_input_effective_bias, + const int32_t* recurrent_to_input_effective_bias, + const int32_t* projection_effective_bias, int n_batch, int n_cell, + int n_input, int n_output, int8_t* output_state_ptr, + int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr, + int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3, + int8_t* scratch4, int32_t* scratch5) { + // Make named scratch buffers for the different gates. + int16_t* input_gate_scratch = scratch0; + int16_t* forget_gate_scratch = scratch1; + int16_t* cell_gate_scratch = scratch2; + int16_t* output_gate_scratch = scratch3; + + // Since we have already checked that weights are all there or none, we + // can check the existence of only one to the get the condition. + const bool use_cifg = (input_to_input_weight_ptr == nullptr); + + // Check for nullptrs. + TFLITE_DCHECK(input_to_forget_effective_bias); + TFLITE_DCHECK(recurrent_to_forget_effective_bias); + TFLITE_DCHECK(input_to_cell_effective_bias); + TFLITE_DCHECK(recurrent_to_cell_effective_bias); + TFLITE_DCHECK(input_to_output_effective_bias); + TFLITE_DCHECK(recurrent_to_output_effective_bias); + if (!use_cifg) { + TFLITE_DCHECK(input_to_input_effective_bias); + TFLITE_DCHECK(recurrent_to_input_effective_bias); + } + const bool use_projection = (projection_weight_ptr != nullptr); + if (use_projection) { + TFLITE_DCHECK(projection_effective_bias); + } + if (!use_cifg) { + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias, + effective_input_to_input_scale_a, effective_input_to_input_scale_b, + output_state_ptr, recurrent_to_input_weight_ptr, + recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a, + effective_recurrent_to_input_scale_b, cell_state_ptr, + cell_to_input_weight_ptr, effective_cell_to_input_scale_a, + effective_cell_to_input_scale_b, layer_norm_input_weight_ptr, + input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b, + input_variance_guard, n_batch, n_input, n_output, n_cell, + kTfLiteActSigmoid, input_gate_scratch, scratch5); + } + // Calculate the forget gate. + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias, + effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, + output_state_ptr, recurrent_to_forget_weight_ptr, + recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a, + effective_recurrent_to_forget_scale_b, cell_state_ptr, + cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a, + effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr, + forget_gate_bias_ptr, layer_norm_forget_scale_a, + layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input, + n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5); + // Calculate the cell update gate. + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias, + effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, + output_state_ptr, recurrent_to_cell_weight_ptr, + recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a, + effective_recurrent_to_cell_scale_b, cell_state_ptr, + /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0, + /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr, + cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b, + cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh, + cell_gate_scratch, scratch5); + // Update the cell state. + UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale, + input_gate_scratch, forget_gate_scratch, + cell_gate_scratch, use_cifg, quantized_cell_clip); + // Calculate the output gate. + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias, + effective_input_to_output_scale_a, effective_input_to_output_scale_b, + output_state_ptr, recurrent_to_output_weight_ptr, + recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a, + effective_recurrent_to_output_scale_b, cell_state_ptr, + cell_to_output_weight_ptr, effective_cell_to_output_scale_a, + effective_cell_to_output_scale_b, layer_norm_output_weight_ptr, + output_gate_bias_ptr, layer_norm_output_scale_a, + layer_norm_output_scale_b, output_variance_guard, n_batch, n_input, + n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5); + // Update the output state. + CalculateLstmOutputInteger8x8_16( + n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale, + output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b, + hidden_zp, projection_weight_ptr, effective_proj_scale_a, + effective_proj_scale_b, projection_effective_bias, output_state_zp, + quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5); + // Copy output state to the output. Note that unlike float or hybrid, output + // is always contiguous. + std::memcpy(output_ptr, output_state_ptr, + n_batch * n_output * sizeof(int8_t)); +} + +// Fully quantized lstm kernel for 8 bit gate matmul output. +// +// Input tensor of size n_batch * n_input: +// input_ptr +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weight_ptr - optional +// input_to_forget_weight_ptr - optional +// input_to_cell_weight_ptr - optional +// input_to_output_weight_ptr - optional +// +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weight_ptr - optional +// recurrent_to_forget_weights_ptr +// recurrent_to_cell_weights_ptr +// recurrent_to_input_weights_ptr +// +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// +// Quantized projection weights of size 'n_output * n_cell' +// projection_weight_ptr - optional +// +// Weight scales (scalars) for each of the weights above. +// effective_input_to_input_scale_a - optional +// effective_input_to_input_scale_b - optional +// effective_input_to_forget_scale_a +// effective_input_to_forget_scale_b +// effective_input_to_cell_scale_a +// effective_input_to_cell_scale_b +// effective_input_to_output_scale_a +// effective_input_to_output_scale_b +// effective_recurrent_to_input_scale_a - optional +// effective_recurrent_to_input_scale_b - optional +// effective_recurrent_to_forget_scale_a +// effective_recurrent_to_forget_scale_b +// effective_recurrent_to_cell_scale_a +// effective_recurrent_to_cell_scale_b +// effective_recurrent_to_output_scale_a +// effective_recurrent_to_output_scale_b +// effective_proj_scale_a - optional +// effective_proj_scale_b - optional +// +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Layer norm coefficients of size 'n_cell', representing diagonal matrices. +// layer_norm_input_weight_ptr - optional +// layer_norm_forget_weight_ptr - optional +// layer_norm_cell_weight_ptr - optional +// layer_norm_output_weight_ptr - optional +// +// Layer norm scales of size 'n_cell'. +// layer_norm_input_scale_a - optional +// layer_norm_input_scale_b - optional +// layer_norm_forget_scale_a - optional +// layer_norm_forget_scale_b - optional +// layer_norm_cell_scale_a - optional +// layer_norm_cell_scale_b - optional +// layer_norm_output_scale_a - optional +// layer_norm_output_scale_b - optional +// +// Scalar values: +// quantized_cell_clip: quantized clip value for cell. +// quantized_proj_clip: quantized clip value for projection. +// cell_state_scale: the power of two scale for cell state. +// +// Zero points: +// input_zp: zero point for input tensor. +// output_state_zp: zero point of output state. +// hidden_zp: zero point for hidden state. +// +// Temporary pre-allocated storage for the calculation. Each is of size n_cell * +// n_batch. +// scratch0 +// scratch1 +// scratch2 +// scratch3 +// scratch4 +// scratch5 +// scratch6 +// scratch7 +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr - size 'n_batch * n_output' +// +// Can move zero point calculation into Prepare() for better perfomance. +// TODO(b/159947023): scratch5 is unused, remove. +inline void LstmStepInteger8x8_8( + const int8_t* input_ptr, int32_t input_zp, + const int8_t* input_to_input_weight_ptr, + int32_t effective_input_to_input_scale_a, + int32_t effective_input_to_input_scale_b, + const int8_t* input_to_forget_weight_ptr, + int32_t effective_input_to_forget_scale_a, + int32_t effective_input_to_forget_scale_b, + const int8_t* input_to_cell_weight_ptr, + int32_t effective_input_to_cell_scale_a, + int32_t effective_input_to_cell_scale_b, + const int8_t* input_to_output_weight_ptr, + int32_t effective_input_to_output_scale_a, + int32_t effective_input_to_output_scale_b, + const int8_t* recurrent_to_input_weight_ptr, + int32_t effective_recurrent_to_input_scale_a, + int32_t effective_recurrent_to_input_scale_b, + const int8_t* recurrent_to_forget_weight_ptr, + int32_t effective_recurrent_to_forget_scale_a, + int32_t effective_recurrent_to_forget_scale_b, + const int8_t* recurrent_to_cell_weight_ptr, + int32_t effective_recurrent_to_cell_scale_a, + int32_t effective_recurrent_to_cell_scale_b, + const int8_t* recurrent_to_output_weight_ptr, + int32_t effective_recurrent_to_output_scale_a, + int32_t effective_recurrent_to_output_scale_b, + const int8_t* cell_to_input_weight_ptr, + int32_t effective_cell_to_input_scale_a, + int32_t effective_cell_to_input_scale_b, + const int8_t* cell_to_forget_weight_ptr, + int32_t effective_cell_to_forget_scale_a, + int32_t effective_cell_to_forget_scale_b, + const int8_t* cell_to_output_weight_ptr, + int32_t effective_cell_to_output_scale_a, + int32_t effective_cell_to_output_scale_b, + const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a, + int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr, + int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b, + const int16_t* layer_norm_forget_weight_ptr, + int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b, + const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a, + int32_t layer_norm_cell_scale_b, + const int16_t* layer_norm_output_weight_ptr, + int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, + const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr, + const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr, + const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params, + const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b, + const int32_t* intermediate_zp, int16_t quantized_cell_clip, + int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input, + int n_output, int output_batch_leading_dim, int8_t* output_state_ptr, + int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr, + int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, + int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, + int16_t* scratch7) { + // TODO(b/159066113): scratch5 is unused, remove. + + // Make named scratch buffers for the different gates. + int16_t* forget_gate_scratch = scratch2; + int16_t* cell_gate_scratch = scratch3; + int16_t* output_gate_scratch = scratch4; + // no-CIFG is not supported here + + // Calculate the forget gate. + CalculateLstmGateInteger8x8_8( + input_ptr, input_zp, input_to_forget_weight_ptr, + effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, + intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4], + output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr, + effective_recurrent_to_forget_scale_a, + effective_recurrent_to_forget_scale_b, intermediate_scale_a[3], + intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr, + layer_norm_forget_scale_a, layer_norm_forget_scale_b, + forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell, + kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1); + // Calculate the cell update gate. + CalculateLstmGateInteger8x8_8( + input_ptr, input_zp, input_to_cell_weight_ptr, + effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, + intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7], + output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr, + effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b, + intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8], + layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, + layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output, + n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1); + // Update the cell state. + UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, + /*cell_state_scale=*/-15, /*input_gate=*/nullptr, + forget_gate_scratch, cell_gate_scratch, + /*use_cifg=*/true, quantized_cell_clip); + // Calculate the output gate. + CalculateLstmGateInteger8x8_8( + input_ptr, input_zp, input_to_output_weight_ptr, + effective_input_to_output_scale_a, effective_input_to_output_scale_b, + intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10], + output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr, + effective_recurrent_to_output_scale_a, + effective_recurrent_to_output_scale_b, intermediate_scale_a[11], + intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr, + layer_norm_output_scale_a, layer_norm_output_scale_b, + output_gate_bias_ptr, n_batch, n_input, n_output, n_cell, + kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1); + // Update the output state. + CalculateLstmOutputInteger8x8_8( + n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, + projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b, + projection_bias_ptr, output_state_zp, quantized_proj_clip, + output_state_ptr, scratch2); + // Copy output state to the output. Note that unlike float or hybrid, output + // is always contigous. + std::memcpy(output_ptr, output_state_ptr, + n_batch * n_output * sizeof(int8_t)); +} + +} // namespace + +TfLiteStatus EvalFloatLstm( + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* aux_input, + const TfLiteEvalTensor* aux_input_to_input_weights, + const TfLiteEvalTensor* aux_input_to_forget_weights, + const TfLiteEvalTensor* aux_input_to_cell_weights, + const TfLiteEvalTensor* aux_input_to_output_weights, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, int output_offset, + float* scratch_buffer, TfLiteEvalTensor* output_state, + TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output) { + TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3); + int max_time, n_batch; + if (input->dims->size == 3) { + max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; + n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; + } else { + max_time = 1; + n_batch = input->dims->data[0]; + } + const int n_input = input->dims->data[input->dims->size - 1]; + const int aux_input_size = + (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0; + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + float* input_gate_scratch = nullptr; + float* cell_gate_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_gate_scratch = scratch_buffer; + forget_gate_scratch = scratch_buffer + n_cell * n_batch; + output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer; + cell_gate_scratch = scratch_buffer + n_cell * n_batch; + forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch; + } + + const int output_batch_leading_dim = + output->dims->data[output->dims->size - 1]; + if (time_major) { + // Loop through the sequence. + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = + tflite::micro::GetTensorData(input) + t_rel * input_step; + const float* aux_input_ptr = nullptr; + if (aux_input) { + aux_input_ptr = + tflite::micro::GetTensorData(aux_input) + t_rel * input_step; + } + float* output_ptr = tflite::micro::GetTensorData(output) + + t_rel * output_step + output_offset; + + LstmStepFloat( + input_ptr, + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + aux_input_ptr, + aux_input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(aux_input_to_input_weights), + aux_input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_forget_weights), + aux_input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(aux_input_to_cell_weights), + aux_input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_output_weights), + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_input_weights), + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights), + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_cell_weights), + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights), + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + projection_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_bias), + params, n_batch, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, + tflite::micro::GetTensorData(output_state), + tflite::micro::GetTensorData(cell_state), input_gate_scratch, + forget_gate_scratch, cell_gate_scratch, output_gate_scratch, + output_ptr); + } + } else { + for (int b = 0; b < n_batch; b++) { + const int input_step = n_input; + const int output_step = output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const int time_offset = b * max_time + t_rel; + const float* input_ptr = tflite::micro::GetTensorData(input) + + time_offset * input_step; + const float* aux_input_ptr = nullptr; + if (aux_input) { + aux_input_ptr = tflite::micro::GetTensorData(aux_input) + + time_offset * input_step; + } + float* output_ptr = tflite::micro::GetTensorData(output) + + time_offset * output_step + output_offset; + + // Offset the {output,cell}_state pointers to the right batch. + float* output_state_ptr = + tflite::micro::GetTensorData(output_state) + + b * output_batch_leading_dim; + float* cell_state_ptr = + tflite::micro::GetTensorData(cell_state) + b * n_cell; + // Offset the scratch pointers to the right batch. + float* input_gate_scratch_ptr = + input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; + float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; + float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; + float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; + + LstmStepFloat( + input_ptr, + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + aux_input_ptr, + aux_input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_input_weights), + aux_input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_forget_weights), + aux_input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_cell_weights), + aux_input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_output_weights), + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights), + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights), + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_cell_weights), + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights), + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + projection_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_bias), + params, + /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, output_state_ptr, cell_state_ptr, + input_gate_scratch_ptr, forget_gate_scratch_ptr, + cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr); + } + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalHybridLstm( + const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_input_weights_ledger, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_forget_weights_ledger, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_cell_weights_ledger, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* input_to_output_weights_ledger, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_input_weights_ledger, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights_ledger, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights_ledger, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* recurrent_to_output_weights_ledger, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* aux_input, + const TfLiteEvalTensor* aux_input_to_input_weights, + const TfLiteEvalTensor* aux_input_to_forget_weights, + const TfLiteEvalTensor* aux_input_to_cell_weights, + const TfLiteEvalTensor* aux_input_to_output_weights, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_weights_ledger, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, int output_offset, + float* scratch_buffer, float* input_sf, float* aux_input_sf, + float* output_state_sf, float* prod_scaling_factors, + float* recovered_cell_weights, int8_t* input_quantized, + int8_t* aux_input_quantized, int8_t* output_state_quantized, + int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state, + TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer, + TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp, + int32_t* output_state_zp, int32_t* row_sums, int row_sums_size, + bool* compute_row_sums) { + TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3); + const int n_input = input->dims->data[input->dims->size - 1]; + int max_time, n_batch; + if (input->dims->size == 2) { + max_time = 1; + n_batch = input->dims->data[0]; + } else { + max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; + n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; + } + const int aux_input_size = + (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + + float* input_gate_scratch = nullptr; + float* cell_gate_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_gate_scratch = scratch_buffer; + forget_gate_scratch = scratch_buffer + n_cell * n_batch; + output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer; + cell_gate_scratch = scratch_buffer + n_cell * n_batch; + forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch; + } + + const int output_batch_leading_dim = + output->dims->data[output->dims->size - 1]; + + int32_t* input_zp_ptr = nullptr; + int32_t* aux_input_zp_ptr = nullptr; + int32_t* output_state_zp_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + input_zp_ptr = input_zp; + aux_input_zp_ptr = aux_input_zp; + output_state_zp_ptr = output_state_zp; + row_sums_ptr = row_sums; + } + + if (time_major) { + // Feed the sequence into the LSTM step-by-step. + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = + tflite::micro::GetTensorData(input) + t_rel * input_step; + const float* aux_input_ptr = nullptr; + if (aux_input) { + aux_input_ptr = + tflite::micro::GetTensorData(aux_input) + t_rel * input_step; + } + float* output_ptr = tflite::micro::GetTensorData(output) + + t_rel * output_step + output_offset; + LstmStepHybrid( + input_ptr, + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + input_to_input_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_input_weights_ledger), + hybrid_lstm_scales->input_to_input_weights_scale, + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + input_to_forget_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_forget_weights_ledger), + hybrid_lstm_scales->input_to_forget_weights_scale, + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + input_to_cell_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_cell_weights_ledger), + hybrid_lstm_scales->input_to_cell_weights_scale, + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + input_to_output_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_output_weights_ledger), + hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr, + aux_input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_input_weights), + hybrid_lstm_scales->aux_input_to_input_weights_scale, + aux_input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_forget_weights), + hybrid_lstm_scales->aux_input_to_forget_weights_scale, + aux_input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(aux_input_to_cell_weights), + hybrid_lstm_scales->aux_input_to_cell_weights_scale, + aux_input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_output_weights), + hybrid_lstm_scales->aux_input_to_output_weights_scale, + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights), + recurrent_to_input_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights_ledger), + hybrid_lstm_scales->recurrent_to_input_weights_scale, + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights), + recurrent_to_forget_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights_ledger), + hybrid_lstm_scales->recurrent_to_forget_weights_scale, + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_cell_weights), + recurrent_to_cell_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_cell_weights_ledger), + hybrid_lstm_scales->recurrent_to_cell_weights_scale, + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights), + recurrent_to_output_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights_ledger), + hybrid_lstm_scales->recurrent_to_output_weights_scale, + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + hybrid_lstm_scales->cell_to_input_weights_scale, + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + hybrid_lstm_scales->cell_to_forget_weights_scale, + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + hybrid_lstm_scales->cell_to_output_weights_scale, + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + projection_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + projection_weights_ledger), + hybrid_lstm_scales->projection_weights_scale, + projection_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_bias), + params, n_batch, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, input_gate_scratch, forget_gate_scratch, + cell_gate_scratch, output_gate_scratch, scales, input_sf, + aux_input_sf, output_state_sf, prod_scaling_factors, + recovered_cell_weights, input_quantized, aux_input_quantized, + output_state_quantized, cell_state_quantized, + tflite::micro::GetTensorData(output_state), + tflite::micro::GetTensorData(cell_state), + output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr, + output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums, + params->asymmetric_quantize_inputs); + } + } else { + for (int b = 0; b < n_batch; b++) { + const int input_step = n_input; + const int output_step = output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const int time_offset = b * max_time + t_rel; + const float* input_ptr = tflite::micro::GetTensorData(input) + + time_offset * input_step; + const float* aux_input_ptr = nullptr; + if (aux_input) { + aux_input_ptr = tflite::micro::GetTensorData(aux_input) + + time_offset * input_step; + } + float* output_ptr = tflite::micro::GetTensorData(output) + + time_offset * output_step + output_offset; + + // Offset the {output,cell}_state pointers to the right batch. + float* output_state_ptr = + tflite::micro::GetTensorData(output_state) + + b * output_batch_leading_dim; + float* cell_state_ptr = + tflite::micro::GetTensorData(cell_state) + b * n_cell; + // Offset the scratch pointers to the right batch. + float* input_gate_scratch_ptr = + input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; + float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; + float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; + float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; + + LstmStepHybrid( + input_ptr, + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + input_to_input_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_input_weights_ledger), + hybrid_lstm_scales->input_to_input_weights_scale, + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + input_to_forget_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_forget_weights_ledger), + hybrid_lstm_scales->input_to_forget_weights_scale, + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + input_to_cell_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_cell_weights_ledger), + hybrid_lstm_scales->input_to_cell_weights_scale, + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + input_to_output_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_to_output_weights_ledger), + hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr, + aux_input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_input_weights), + hybrid_lstm_scales->aux_input_to_input_weights_scale, + aux_input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_forget_weights), + hybrid_lstm_scales->aux_input_to_forget_weights_scale, + aux_input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_cell_weights), + hybrid_lstm_scales->aux_input_to_cell_weights_scale, + aux_input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + aux_input_to_output_weights), + hybrid_lstm_scales->aux_input_to_output_weights_scale, + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights), + recurrent_to_input_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights_ledger), + hybrid_lstm_scales->recurrent_to_input_weights_scale, + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights), + recurrent_to_forget_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights_ledger), + hybrid_lstm_scales->recurrent_to_forget_weights_scale, + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_cell_weights), + recurrent_to_cell_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_cell_weights_ledger), + hybrid_lstm_scales->recurrent_to_cell_weights_scale, + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights), + recurrent_to_output_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights_ledger), + hybrid_lstm_scales->recurrent_to_output_weights_scale, + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + hybrid_lstm_scales->cell_to_input_weights_scale, + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + hybrid_lstm_scales->cell_to_forget_weights_scale, + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + hybrid_lstm_scales->cell_to_output_weights_scale, + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + projection_weights_ledger == nullptr + ? nullptr + : tflite::micro::GetTensorData( + projection_weights_ledger), + hybrid_lstm_scales->projection_weights_scale, + projection_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_bias), + params, + /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, input_gate_scratch_ptr, + forget_gate_scratch_ptr, cell_gate_scratch_ptr, + output_gate_scratch_ptr, scales, input_sf, aux_input_sf, + output_state_sf, prod_scaling_factors, recovered_cell_weights, + input_quantized, aux_input_quantized, output_state_quantized, + cell_state_quantized, output_state_ptr, cell_state_ptr, + output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr, + output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums, + params->asymmetric_quantize_inputs); + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus EvalInteger8x8_16Lstm( + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, + const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp, + TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state, + TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1, + int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5) { + TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3); + const int n_input = input->dims->data[input->dims->size - 1]; + int max_time, n_batch; + if (input->dims->size == 2) { + max_time = 1; + n_batch = input->dims->data[0]; + } else { + max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; + n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; + } + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Get params for time/batch/sequence. + const int output_batch_leading_dim = + output->dims->data[output->dims->size - 1]; + + if (time_major) { + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + const int t_rel = t; + int8_t* output_ptr = + tflite::micro::GetTensorData(output) + t_rel * output_step; + const int8_t* input_ptr = + tflite::micro::GetTensorData(input) + t_rel * input_step; + LstmStepInteger8x8_16( + input_ptr, + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + integer_lstm_param->effective_input_to_input_scale_a, + integer_lstm_param->effective_input_to_input_scale_b, + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + integer_lstm_param->effective_input_to_forget_scale_a, + integer_lstm_param->effective_input_to_forget_scale_b, + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + integer_lstm_param->effective_input_to_cell_scale_a, + integer_lstm_param->effective_input_to_cell_scale_b, + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + integer_lstm_param->effective_input_to_output_scale_a, + integer_lstm_param->effective_input_to_output_scale_b, + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights), + integer_lstm_param->effective_recurrent_to_input_scale_a, + integer_lstm_param->effective_recurrent_to_input_scale_b, + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights), + integer_lstm_param->effective_recurrent_to_forget_scale_a, + integer_lstm_param->effective_recurrent_to_forget_scale_b, + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_cell_weights), + integer_lstm_param->effective_recurrent_to_cell_scale_a, + integer_lstm_param->effective_recurrent_to_cell_scale_b, + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights), + integer_lstm_param->effective_recurrent_to_output_scale_a, + integer_lstm_param->effective_recurrent_to_output_scale_b, + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + integer_lstm_param->effective_cell_to_input_scale_a, + integer_lstm_param->effective_cell_to_input_scale_b, + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + integer_lstm_param->effective_cell_to_forget_scale_a, + integer_lstm_param->effective_cell_to_forget_scale_b, + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + integer_lstm_param->effective_cell_to_output_scale_a, + integer_lstm_param->effective_cell_to_output_scale_b, + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + integer_lstm_param->effective_proj_scale_a, + integer_lstm_param->effective_proj_scale_b, + integer_lstm_param->hidden_zp, + integer_lstm_param->effective_hidden_scale_a, + integer_lstm_param->effective_hidden_scale_b, + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + integer_lstm_param->layer_norm_input_scale_a, + integer_lstm_param->layer_norm_input_scale_b, + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + integer_lstm_param->layer_norm_forget_scale_a, + integer_lstm_param->layer_norm_forget_scale_b, + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + integer_lstm_param->layer_norm_cell_scale_a, + integer_lstm_param->layer_norm_cell_scale_b, + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + integer_lstm_param->layer_norm_output_scale_a, + integer_lstm_param->layer_norm_output_scale_b, + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + integer_lstm_param->quantized_cell_clip, + integer_lstm_param->quantized_proj_clip, + integer_lstm_param->cell_scale, + integer_lstm_param->input_variance_guard, + integer_lstm_param->forget_variance_guard, + integer_lstm_param->cell_variance_guard, + integer_lstm_param->output_variance_guard, + integer_lstm_param->input_to_forget_effective_bias, + integer_lstm_param->recurrent_to_forget_effective_bias, + integer_lstm_param->input_to_cell_effective_bias, + integer_lstm_param->recurrent_to_cell_effective_bias, + integer_lstm_param->input_to_output_effective_bias, + integer_lstm_param->recurrent_to_output_effective_bias, + integer_lstm_param->input_to_input_effective_bias, + integer_lstm_param->recurrent_to_input_effective_bias, + integer_lstm_param->projection_effective_bias, n_batch, n_cell, + n_input, n_output, tflite::micro::GetTensorData(output_state), + output_state_zp, tflite::micro::GetTensorData(cell_state), + output_ptr, scratch0, scratch1, scratch2, scratch3, scratch4, + scratch5); + } + } else { + for (int b = 0; b < n_batch; b++) { + const int input_step = n_input; + const int output_step = output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const int time_offset = b * max_time + t_rel; + const int8_t* input_ptr = tflite::micro::GetTensorData(input) + + time_offset * input_step; + int8_t* output_ptr = tflite::micro::GetTensorData(output) + + time_offset * output_step; + + // Offset the {output,cell}_state pointers to the right batch. + int8_t* output_state_ptr = + tflite::micro::GetTensorData(output_state) + + b * output_batch_leading_dim; + int16_t* cell_state_ptr = + tflite::micro::GetTensorData(cell_state) + b * n_cell; + + LstmStepInteger8x8_16( + input_ptr, + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + integer_lstm_param->effective_input_to_input_scale_a, + integer_lstm_param->effective_input_to_input_scale_b, + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + integer_lstm_param->effective_input_to_forget_scale_a, + integer_lstm_param->effective_input_to_forget_scale_b, + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + integer_lstm_param->effective_input_to_cell_scale_a, + integer_lstm_param->effective_input_to_cell_scale_b, + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + integer_lstm_param->effective_input_to_output_scale_a, + integer_lstm_param->effective_input_to_output_scale_b, + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_input_weights), + integer_lstm_param->effective_recurrent_to_input_scale_a, + integer_lstm_param->effective_recurrent_to_input_scale_b, + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_forget_weights), + integer_lstm_param->effective_recurrent_to_forget_scale_a, + integer_lstm_param->effective_recurrent_to_forget_scale_b, + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_cell_weights), + integer_lstm_param->effective_recurrent_to_cell_scale_a, + integer_lstm_param->effective_recurrent_to_cell_scale_b, + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData( + recurrent_to_output_weights), + integer_lstm_param->effective_recurrent_to_output_scale_a, + integer_lstm_param->effective_recurrent_to_output_scale_b, + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + integer_lstm_param->effective_cell_to_input_scale_a, + integer_lstm_param->effective_cell_to_input_scale_b, + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + integer_lstm_param->effective_cell_to_forget_scale_a, + integer_lstm_param->effective_cell_to_forget_scale_b, + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + integer_lstm_param->effective_cell_to_output_scale_a, + integer_lstm_param->effective_cell_to_output_scale_b, + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + integer_lstm_param->effective_proj_scale_a, + integer_lstm_param->effective_proj_scale_b, + integer_lstm_param->hidden_zp, + integer_lstm_param->effective_hidden_scale_a, + integer_lstm_param->effective_hidden_scale_b, + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + integer_lstm_param->layer_norm_input_scale_a, + integer_lstm_param->layer_norm_input_scale_b, + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + integer_lstm_param->layer_norm_forget_scale_a, + integer_lstm_param->layer_norm_forget_scale_b, + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + integer_lstm_param->layer_norm_cell_scale_a, + integer_lstm_param->layer_norm_cell_scale_b, + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + integer_lstm_param->layer_norm_output_scale_a, + integer_lstm_param->layer_norm_output_scale_b, + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + integer_lstm_param->quantized_cell_clip, + integer_lstm_param->quantized_proj_clip, + integer_lstm_param->cell_scale, + integer_lstm_param->input_variance_guard, + integer_lstm_param->forget_variance_guard, + integer_lstm_param->cell_variance_guard, + integer_lstm_param->output_variance_guard, + integer_lstm_param->input_to_forget_effective_bias, + integer_lstm_param->recurrent_to_forget_effective_bias, + integer_lstm_param->input_to_cell_effective_bias, + integer_lstm_param->recurrent_to_cell_effective_bias, + integer_lstm_param->input_to_output_effective_bias, + integer_lstm_param->recurrent_to_output_effective_bias, + integer_lstm_param->input_to_input_effective_bias, + integer_lstm_param->recurrent_to_input_effective_bias, + integer_lstm_param->projection_effective_bias, /*n_batch=*/1, + n_cell, n_input, n_output, output_state_ptr, output_state_zp, + cell_state_ptr, output_ptr, scratch0, scratch1, scratch2, scratch3, + scratch4, scratch5); + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus EvalInteger8x8_8Lstm( + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state, + TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param, + int32_t input_zp, int32_t output_state_zp, int8_t* scratch0, + int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4, + int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) { + TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3); + const int n_input = input->dims->data[input->dims->size - 1]; + int max_time, n_batch; + if (input->dims->size == 2) { + max_time = 1; + n_batch = input->dims->data[0]; + } else { + max_time = input->dims->data[0]; + n_batch = input->dims->data[1]; + } + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Get params for time/batch/sequence. + const int output_batch_leading_dim = + output->dims->data[output->dims->size - 1]; + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + + for (int t = 0; t < max_time; t++) { + const int t_rel = t; + int8_t* output_ptr = + tflite::micro::GetTensorData(output) + t_rel * output_step; + // Input can be int8 asymmetric or int16 symmetric. + const int8_t* input_ptr = + tflite::micro::GetTensorData(input) + t_rel * input_step; + LstmStepInteger8x8_8( + input_ptr, input_zp, + + input_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_input_weights), + integer_lstm_param->effective_input_to_input_scale_a, + integer_lstm_param->effective_input_to_input_scale_b, + + input_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_forget_weights), + integer_lstm_param->effective_input_to_forget_scale_a, + integer_lstm_param->effective_input_to_forget_scale_b, + + input_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_cell_weights), + integer_lstm_param->effective_input_to_cell_scale_a, + integer_lstm_param->effective_input_to_cell_scale_b, + + input_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_to_output_weights), + integer_lstm_param->effective_input_to_output_scale_a, + integer_lstm_param->effective_input_to_output_scale_b, + + recurrent_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_input_weights), + integer_lstm_param->effective_recurrent_to_input_scale_a, + integer_lstm_param->effective_recurrent_to_input_scale_b, + + recurrent_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_forget_weights), + integer_lstm_param->effective_recurrent_to_forget_scale_a, + integer_lstm_param->effective_recurrent_to_forget_scale_b, + + recurrent_to_cell_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_cell_weights), + integer_lstm_param->effective_recurrent_to_cell_scale_a, + integer_lstm_param->effective_recurrent_to_cell_scale_b, + + recurrent_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(recurrent_to_output_weights), + integer_lstm_param->effective_recurrent_to_output_scale_a, + integer_lstm_param->effective_recurrent_to_output_scale_b, + + cell_to_input_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_input_weights), + integer_lstm_param->effective_cell_to_input_scale_a, + integer_lstm_param->effective_cell_to_input_scale_b, + + cell_to_forget_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_forget_weights), + integer_lstm_param->effective_cell_to_forget_scale_a, + integer_lstm_param->effective_cell_to_forget_scale_b, + + cell_to_output_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_to_output_weights), + integer_lstm_param->effective_cell_to_output_scale_a, + integer_lstm_param->effective_cell_to_output_scale_b, + + projection_weights == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_weights), + integer_lstm_param->effective_proj_scale_a, + integer_lstm_param->effective_proj_scale_b, + + input_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + input_layer_norm_coefficients), + integer_lstm_param->layer_norm_input_scale_a, + integer_lstm_param->layer_norm_input_scale_b, + + forget_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + forget_layer_norm_coefficients), + integer_lstm_param->layer_norm_forget_scale_a, + integer_lstm_param->layer_norm_forget_scale_b, + + cell_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + cell_layer_norm_coefficients), + integer_lstm_param->layer_norm_cell_scale_a, + integer_lstm_param->layer_norm_cell_scale_b, + + output_layer_norm_coefficients == nullptr + ? nullptr + : tflite::micro::GetTensorData( + output_layer_norm_coefficients), + integer_lstm_param->layer_norm_output_scale_a, + integer_lstm_param->layer_norm_output_scale_b, + + input_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(input_gate_bias), + forget_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(forget_gate_bias), + cell_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(cell_gate_bias), + output_gate_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(output_gate_bias), + projection_bias == nullptr + ? nullptr + : tflite::micro::GetTensorData(projection_bias), + + params, integer_lstm_param->intermediate_scale_a, + integer_lstm_param->intermediate_scale_b, + integer_lstm_param->intermediate_zp, + integer_lstm_param->quantized_cell_clip, + integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input, + n_output, output_batch_leading_dim, + tflite::micro::GetTensorData(output_state), output_state_zp, + tflite::micro::GetTensorData(cell_state), output_ptr, scratch0, + scratch1, scratch2, scratch3, scratch4, scratch5, scratch6, scratch7); + } + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h new file mode 100644 index 00000000..218b4938 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h @@ -0,0 +1,250 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_ + +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" + +namespace tflite { + +// Pamameters for integer LSTM. +// Consider split this into two Integer Parameters if more fields are added. +struct IntegerLstmParameter { + int32_t effective_input_to_input_scale_a; + int32_t effective_input_to_input_scale_b; + int32_t effective_recurrent_to_input_scale_a; + int32_t effective_recurrent_to_input_scale_b; + int32_t effective_cell_to_input_scale_a; + int32_t effective_cell_to_input_scale_b; + int32_t effective_input_to_forget_scale_a; + int32_t effective_input_to_forget_scale_b; + int32_t effective_recurrent_to_forget_scale_a; + int32_t effective_recurrent_to_forget_scale_b; + int32_t effective_cell_to_forget_scale_a; + int32_t effective_cell_to_forget_scale_b; + int32_t effective_input_to_cell_scale_a; + int32_t effective_input_to_cell_scale_b; + int32_t effective_recurrent_to_cell_scale_a; + int32_t effective_recurrent_to_cell_scale_b; + int32_t effective_input_to_output_scale_a; + int32_t effective_input_to_output_scale_b; + int32_t effective_recurrent_to_output_scale_a; + int32_t effective_recurrent_to_output_scale_b; + int32_t effective_cell_to_output_scale_a; + int32_t effective_cell_to_output_scale_b; + int32_t effective_proj_scale_a; + int32_t effective_proj_scale_b; + int32_t effective_hidden_scale_a; + int32_t effective_hidden_scale_b; + int32_t layer_norm_input_scale_a; + int32_t layer_norm_input_scale_b; + int32_t layer_norm_forget_scale_a; + int32_t layer_norm_forget_scale_b; + int32_t layer_norm_cell_scale_a; + int32_t layer_norm_cell_scale_b; + int32_t layer_norm_output_scale_a; + int32_t layer_norm_output_scale_b; + // Quantized clip value for cell and projection. Zero value means no clipping. + int16_t quantized_cell_clip; + int8_t quantized_proj_clip; + int32_t hidden_zp; + int32_t cell_scale; + + int32_t input_variance_guard; + int32_t forget_variance_guard; + int32_t cell_variance_guard; + int32_t output_variance_guard; + + // Pre-calculate bias + zero_point * weight. + int32_t* input_to_forget_effective_bias; + int32_t* recurrent_to_forget_effective_bias; + int32_t* input_to_cell_effective_bias; + int32_t* recurrent_to_cell_effective_bias; + int32_t* input_to_output_effective_bias; + int32_t* recurrent_to_output_effective_bias; + int32_t* input_to_input_effective_bias; + int32_t* recurrent_to_input_effective_bias; + int32_t* projection_effective_bias; + + // Scale and zero point for intermediate tensors. + // Used only in the 8x8_8 case. + int32_t intermediate_scale_a[8]; + int32_t intermediate_scale_b[8]; + int32_t intermediate_zp[12]; +}; + +// Scales for hybrid op with integer inputs and float weights +struct HybridLstmScales { + float input_to_input_weights_scale; + float input_to_forget_weights_scale; + float input_to_cell_weights_scale; + float input_to_output_weights_scale; + float aux_input_to_input_weights_scale; + float aux_input_to_forget_weights_scale; + float aux_input_to_cell_weights_scale; + float aux_input_to_output_weights_scale; + float recurrent_to_input_weights_scale; + float recurrent_to_forget_weights_scale; + float recurrent_to_cell_weights_scale; + float recurrent_to_output_weights_scale; + float cell_to_input_weights_scale; + float cell_to_forget_weights_scale; + float cell_to_output_weights_scale; + float projection_weights_scale; +}; + +TfLiteStatus EvalFloatLstm( + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* aux_input, + const TfLiteEvalTensor* aux_input_to_input_weights, + const TfLiteEvalTensor* aux_input_to_forget_weights, + const TfLiteEvalTensor* aux_input_to_cell_weights, + const TfLiteEvalTensor* aux_input_to_output_weights, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, int output_offset, + float* scratch_buffer, TfLiteEvalTensor* output_state, + TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output); + +TfLiteStatus EvalHybridLstm( + const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_input_weights_ledger, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_forget_weights_ledger, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_cell_weights_ledger, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* input_to_output_weights_ledger, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_input_weights_ledger, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights_ledger, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights_ledger, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* recurrent_to_output_weights_ledger, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* aux_input, + const TfLiteEvalTensor* aux_input_to_input_weights, + const TfLiteEvalTensor* aux_input_to_forget_weights, + const TfLiteEvalTensor* aux_input_to_cell_weights, + const TfLiteEvalTensor* aux_input_to_output_weights, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_weights_ledger, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, int output_offset, + float* scratch_buffer, float* input_sf, float* aux_input_sf, + float* output_state_sf, float* prod_scaling_factors, + float* recovered_cell_weights, int8_t* input_quantized, + int8_t* aux_input_quantized, int8_t* output_state_quantized, + int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state, + TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer, + TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp, + int32_t* output_state_zp, int32_t* row_sums, int row_sums_size, + bool* compute_row_sums); + +TfLiteStatus EvalInteger8x8_16Lstm( + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, + const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp, + TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state, + TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1, + int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5); + +TfLiteStatus EvalInteger8x8_8Lstm( + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* input_to_input_weights, + const TfLiteEvalTensor* input_to_forget_weights, + const TfLiteEvalTensor* input_to_cell_weights, + const TfLiteEvalTensor* input_to_output_weights, + const TfLiteEvalTensor* recurrent_to_input_weights, + const TfLiteEvalTensor* recurrent_to_forget_weights, + const TfLiteEvalTensor* recurrent_to_cell_weights, + const TfLiteEvalTensor* recurrent_to_output_weights, + const TfLiteEvalTensor* cell_to_input_weights, + const TfLiteEvalTensor* cell_to_forget_weights, + const TfLiteEvalTensor* cell_to_output_weights, + const TfLiteEvalTensor* input_layer_norm_coefficients, + const TfLiteEvalTensor* forget_layer_norm_coefficients, + const TfLiteEvalTensor* cell_layer_norm_coefficients, + const TfLiteEvalTensor* output_layer_norm_coefficients, + const TfLiteEvalTensor* input_gate_bias, + const TfLiteEvalTensor* forget_gate_bias, + const TfLiteEvalTensor* cell_gate_bias, + const TfLiteEvalTensor* output_gate_bias, + const TfLiteEvalTensor* projection_weights, + const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params, + TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state, + TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param, + int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, + int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7); + +} // namespace tflite +#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_shared.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_shared.h new file mode 100644 index 00000000..ee34b848 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_shared.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_ + +namespace tflite { + +// Input Tensors of size {n_batch, n_input} +constexpr int kLstmInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kLstmInputToInputWeightsTensor = 1; // Optional +constexpr int kLstmInputToForgetWeightsTensor = 2; +constexpr int kLstmInputToCellWeightsTensor = 3; +constexpr int kLstmInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kLstmRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kLstmRecurrentToForgetWeightsTensor = 6; +constexpr int kLstmRecurrentToCellWeightsTensor = 7; +constexpr int kLstmRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kLstmCellToInputWeightsTensor = 9; // Optional +constexpr int kLstmCellToForgetWeightsTensor = 10; // Optional +constexpr int kLstmCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kLstmInputGateBiasTensor = 12; // Optional +constexpr int kLstmForgetGateBiasTensor = 13; +constexpr int kLstmCellGateBiasTensor = 14; +constexpr int kLstmOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kLstmProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kLstmProjectionBiasTensor = 17; // Optional + +// These state tensors are defined as variable tensors, and will be modified by +// this op. +constexpr int kLstmOutputStateTensor = 18; +constexpr int kLstmCellStateTensor = 19; + +// Layer norm coefficient tensors of size {n_cell}, representing a diagonal +// matrix. +constexpr int kLstmInputLayerNormCoefficientsTensor = 20; // Optional +constexpr int kLstmForgetLayerNormCoefficientsTensor = 21; // Optional +constexpr int kLstmCellLayerNormCoefficientsTensor = 22; // Optional +constexpr int kLstmOutputLayerNormCoefficientsTensor = 23; // Optional + +// Output tensors. +constexpr int kLstmOutputTensor = 0; + +} // namespace tflite +#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/maximum_minimum.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/maximum_minimum.cc index a6d358fb..7964f1e6 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/maximum_minimum.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/maximum_minimum.cc @@ -115,29 +115,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace maximum_minimum TfLiteRegistration Register_MAXIMUM() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/ - maximum_minimum::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, nullptr, + maximum_minimum::Eval); } TfLiteRegistration Register_MINIMUM() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/ - maximum_minimum::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + nullptr, nullptr, + maximum_minimum::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_ops.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_ops.h index 17064fee..c4dec92d 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_ops.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_ops.h @@ -76,11 +76,14 @@ TfLiteRegistration Register_SHAPE(); TfLiteRegistration Register_SLICE(); TfLiteRegistration Register_SPACE_TO_BATCH_ND(); TfLiteRegistration Register_SPACE_TO_DEPTH(); +TfLiteRegistration Register_SQUARED_DIFFERENCE(); TfLiteRegistration Register_SQUEEZE(); TfLiteRegistration Register_SUB(); TfLiteRegistration Register_SVDF(); TfLiteRegistration Register_TRANSPOSE(); TfLiteRegistration Register_TRANSPOSE_CONV(); +// TODO(b/230666079): resolve conflict with xtensa implementation +TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration Register_VAR_HANDLE(); TfLiteRegistration Register_WHILE(); TfLiteRegistration Register_ZEROS_LIKE(); @@ -103,14 +106,12 @@ TfLiteRegistration Register_LESS_EQUAL(); TfLiteRegistration Register_LOG(); TfLiteRegistration Register_LOGICAL_NOT(); TfLiteRegistration Register_MAXIMUM(); -TfLiteRegistration Register_MEAN(); TfLiteRegistration Register_MINIMUM(); TfLiteRegistration Register_NEG(); TfLiteRegistration Register_NOT_EQUAL(); TfLiteRegistration Register_PACK(); TfLiteRegistration Register_PAD(); TfLiteRegistration Register_PADV2(); -TfLiteRegistration Register_REDUCE_MAX(); TfLiteRegistration Register_RESHAPE(); TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR(); TfLiteRegistration Register_ROUND(); @@ -121,7 +122,6 @@ TfLiteRegistration Register_SPLIT_V(); TfLiteRegistration Register_SQRT(); TfLiteRegistration Register_SQUARE(); TfLiteRegistration Register_STRIDED_SLICE(); -TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration Register_UNPACK(); TfLiteRegistration Register_L2_NORMALIZATION(); TfLiteRegistration Register_TANH(); diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_tensor_utils.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_tensor_utils.cc new file mode 100644 index 00000000..88b097c7 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_tensor_utils.cc @@ -0,0 +1,809 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "fixedpoint/fixedpoint.h" // from @gemmlowp +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace micro_tensor_utils { + +namespace { +const int32_t kInt16Max = std::numeric_limits::max(); +const int32_t kInt16Min = std::numeric_limits::min(); +} // namespace + +void PortableSymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { + auto minmax = std::minmax_element(values, values + size); + *min_value = *minmax.first; + *max_value = *minmax.second; + + PortableSymmetricQuantizeFloats(values, size, quantized_values, *min_value, + *max_value, scaling_factor); +} + +void PortableSymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float min_value, + float max_value, float* scaling_factor) { + const int32_t kScale = 127; + const float range = std::max(std::abs(min_value), std::abs(max_value)); + if (range == 0) { + memset(quantized_values, 0, size * sizeof(int8_t)); + *scaling_factor = 1; + return; + } + *scaling_factor = range / kScale; + const float scaling_factor_inv = kScale / range; + for (int i = 0; i < size; ++i) { + const int32_t quantized_value = + static_cast(TfLiteRound(values[i] * scaling_factor_inv)); + // Clamp: just in case some odd numeric offset. + quantized_values[i] = static_cast( + std::min(kScale, std::max(-kScale, quantized_value))); + } +} + +void PortableAsymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, + float* scaling_factor, int32_t* offset) { + const int32_t kMinScale = -128; + const int32_t kMaxScale = 127; + const double qmin_double = kMinScale; + const double qmax_double = kMaxScale; + const auto minmax = std::minmax_element(values, values + size); + const double rmin = static_cast(std::min(0.0f, *minmax.first)); + const double rmax = static_cast(std::max(0.0f, *minmax.second)); + if (rmin == rmax) { + memset(quantized_values, 0, size * sizeof(int8_t)); + *scaling_factor = 1; + *offset = 0; + return; + } else { + double scale = (rmax - rmin) / (qmax_double - qmin_double); + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + int8_t nudged_zero_point = 0; + if (zero_point_double <= qmin_double) { + nudged_zero_point = kMinScale; + } else if (zero_point_double >= qmax_double) { + nudged_zero_point = kMaxScale; + } else { + nudged_zero_point = static_cast(round(zero_point_double)); + } + *scaling_factor = scale; + *offset = nudged_zero_point; + } + const float scaling_factor_inv = 1.0f / *scaling_factor; + for (int i = 0; i < size; ++i) { + const int32_t quantized_value = static_cast( + TfLiteRound(*offset + values[i] * scaling_factor_inv)); + quantized_values[i] = + std::min(kMaxScale, std::max(kMinScale, quantized_value)); + } +} + +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result) { + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + for (int r = 0; r < m_rows; r++) { + float dot_prod = 0.0f; + const float* vector_in_batch = vector + b * m_cols; + for (int c = 0; c < m_cols; c++) { + dot_prod += *matrix_ptr++ * *vector_in_batch++; + } + *result_in_batch += dot_prod; + ++result_in_batch; + } + } +} + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result) { + for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) { + const float batch_scaling_factor = scaling_factors[batch]; + // Get the address of the first row. + const int8_t* row_ptr = matrix; + for (int row = 0; row < m_rows; ++row) { + // Initialize the dot product sum for the row to 0. + int32_t dotprod = 0; + // TODO(b/230666277): remove this +#if defined(__GNUC__) + // Prefetch the row to cache. + __builtin_prefetch(row_ptr, 0 /* prefetch for read */, + 3 /* temporal locality */); +#endif + for (int col = 0; col < m_cols; ++col, ++row_ptr) { + dotprod += (*row_ptr) * (vectors[col]); + } // for col + *result += dotprod * batch_scaling_factor; + ++result; + } // for row + } // for batch +} + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, CpuBackendContext* context) { + if (input_offset == nullptr) { + PortableMatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); + return; + } + if (!compute_row_sums || *compute_row_sums) { + PortableReductionSumVector(matrix, row_sums, m_rows, m_cols); + if (compute_row_sums) { + *compute_row_sums = false; + } + } + + for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) { + const float batch_scaling_factor = scaling_factors[batch]; + const int32_t batch_offset = input_offset[batch]; + const int8_t* row_ptr = matrix; + for (int row = 0; row < m_rows; ++row) { + int32_t dotprod = 0; + float scale = batch_scaling_factor; + if (per_channel_scale) { + scale *= per_channel_scale[row]; + } +#if defined(__GNUC__) + // Prefetch the row to cache. + __builtin_prefetch(row_ptr, 0 /* prefetch for read */, + 3 /* temporal locality */); +#endif + for (int col = 0; col < m_cols; ++col, ++row_ptr) { + dotprod += (*row_ptr) * vectors[col]; + } // for col + dotprod -= row_sums[row] * batch_offset; + *result += dotprod * scale; + ++result; + } // for row + } // for batch +} + +void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + const int kBlockSize = 4; + TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0); + for (int batch = 0; batch < n_batch; batch++) { + const float* matrix_ptr = matrix; + for (int row = 0; row < m_rows; row++) { + float dot_prod = 0.0f; + const float* vector_in_batch = vector + batch * m_cols; + for (int i = segments[row]; i < segments[row + 1]; i++) { + const int block_start_index = indices[i] * kBlockSize; + const float* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++; + } + } + result[batch * m_rows + row] += dot_prod; + } + } +} + +void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( + const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, + int n_batch, const int32_t input_offset, const int32_t output_multiplier, + const int32_t output_shift, const int32_t output_offset, + const int32_t output_activation_min, const int32_t output_activation_max, + int8_t* __restrict__ result) { + const int kBlockSize = 16; + TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0); + for (int batch = 0; batch < n_batch; ++batch) { + const int8_t* matrix_ptr = matrix; + for (int row = 0; row < m_rows; ++row) { + int32_t dot_prod = 0; + const int8_t* vector_in_batch = vector + batch * m_cols; + for (int i = segments[row]; i < segments[row + 1]; ++i) { + const int block_start_index = indices[i] * kBlockSize; + const int8_t* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dot_prod += *matrix_ptr * *vector_block_in_batch_ptr++; + dot_prod += *matrix_ptr++ * input_offset; + } + } + const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0; + dot_prod = MultiplyByQuantizedMultiplier(dot_prod + bias_value, + output_multiplier, output_shift); + dot_prod += output_offset; + result[batch * m_rows + row] = + static_cast(ActivationFunctionWithMinMax( + dot_prod, output_activation_min, output_activation_max)); + } + } +} + +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, + int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, + float* __restrict__ result) { + const int kBlockSize = 16; + TFLITE_DCHECK_EQ( // NOLINT + m_cols % kBlockSize, 0); + for (int batch = 0; batch < n_batch; batch++) { + const float* matrix_ptr = matrix; + const uint8_t* ledger_ptr = ledger; + for (int row = 0; row < m_rows; row++) { + float dot_prod = 0.0f; + int num_nonzero_blocks = *ledger_ptr++; + if (num_nonzero_blocks > 0) { + const float* vector_in_batch = vector + batch * m_cols; + for (int i = 0; i < num_nonzero_blocks; i++) { + const int block_start_index = *ledger_ptr++ * kBlockSize; + const float* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++; + } + } + } + result[batch * m_rows + row] += dot_prod; + } + } +} + +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result) { + static const int kBlockSize = 16; + TFLITE_DCHECK_EQ( // NOLINT + m_cols % kBlockSize, 0); + for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) { + const float batch_scaling_factor = scaling_factors[batch]; + const uint8_t* ledger_ptr = ledger; + // Get the address of the first row. + const int8_t* row_ptr = matrix; + for (int row = 0; row < m_rows; ++row) { + // Initialize the dot product sum for the row to 0. + int32_t dotprod = 0; +#if defined(__GNUC__) + // Prefetch the row to cache. + __builtin_prefetch(row_ptr, 0 /* prefetch for read */, + 3 /* temporal locality */); +#endif + int num_nonzero_blocks = *ledger_ptr++; + for (int i = 0; i < num_nonzero_blocks; i++) { + const int block_start_index = *ledger_ptr++ * kBlockSize; + const int8_t* vector_block_ptr = vectors + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dotprod += (*row_ptr++) * (*vector_block_ptr++); + } // for block + } // for num_nonzero_blocks + result[batch * m_rows + row] += dotprod * batch_scaling_factor; + } // for row + } // for batch +} + +template +void PortableMatrixBatchVectorMultiplyAccumulateImpl( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + T* output) { + const int16_t output_max = std::numeric_limits::max(); + const int16_t output_min = std::numeric_limits::min(); + for (int batch = 0; batch < n_batch; ++batch) { + for (int row = 0; row < n_output; ++row) { + int32_t acc = bias[row]; + for (int col = 0; col < n_input; ++col) { + int8_t input_val = input[batch * n_input + col]; + int8_t weights_val = input_to_gate_weights[row * n_input + col]; + acc += input_val * weights_val; + } + acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift); + acc += output_zp; + acc += output[batch * n_output + row]; + if (acc > output_max) { + acc = output_max; + } + if (acc < output_min) { + acc = output_min; + } + output[batch * n_output + row] = static_cast(acc); + } + } +} + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + int32_t* scratch, int16_t* output, CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulateImpl( + input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input, + n_output, output_zp, output); +} + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + int32_t* scratch, int8_t* output, CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulateImpl( + input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input, + n_output, output_zp, output); +} + +void PortableMatrixBatchVectorMultiply(const int8_t* input, + int32_t input_zeropoint, + const int8_t* input_to_gate_weights, + int32_t input_to_gate_effective_scale_a, + int32_t input_to_gate_effective_scale_b, + int32_t n_batch, int32_t n_input, + int32_t n_cell, int8_t* gate_output, + int8_t gate_output_zp) { + const int32_t int8_max = std::numeric_limits::max(); + const int32_t int8_min = std::numeric_limits::min(); + for (int batch = 0; batch < n_batch; ++batch) { + for (int row = 0; row < n_cell; ++row) { + int32_t acc = 0; + for (int col = 0; col < n_input; ++col) { + int32_t input_val = input[batch * n_input + col]; + int8_t weights_val = input_to_gate_weights[row * n_input + col]; + acc += (input_val - input_zeropoint) * weights_val; + } + acc = MultiplyByQuantizedMultiplier(acc, input_to_gate_effective_scale_a, + input_to_gate_effective_scale_b); + acc += gate_output_zp; + if (acc > int8_max) { + acc = int8_max; + } + if (acc < int8_min) { + acc = int8_min; + } + gate_output[batch * n_cell + row] = static_cast(acc); + } + } +} + +void PortableMatrixBatchVectorMultiply( + const int16_t* hidden, const int8_t* hidden_to_output_weights, + int32_t proj_effective_scale_a, int32_t proj_effective_scale_b, + const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden, + int32_t n_output, int32_t output_zp, int8_t* proj_output) { + const int16_t int8_max = std::numeric_limits::max(); + const int16_t int8_min = std::numeric_limits::min(); + for (int batch = 0; batch < n_batch; ++batch) { + for (int row = 0; row < n_output; ++row) { + int64_t acc = gate_bias[row]; + for (int col = 0; col < n_hidden; ++col) { + int16_t input_val = hidden[batch * n_hidden + col]; + int8_t weights_val = hidden_to_output_weights[row * n_hidden + col]; + int64_t curr = acc; + acc += input_val * weights_val; + if (input_val * weights_val > 0 && acc < curr) { + acc = std::numeric_limits::max(); + } + if (input_val * weights_val < 0 && acc > curr) { + acc = std::numeric_limits::min(); + } + } + acc = MultiplyByQuantizedMultiplier(acc, proj_effective_scale_a, + proj_effective_scale_b); + acc += output_zp; + if (acc > int8_max) { + acc = int8_max; + } + if (acc < int8_min) { + acc = int8_min; + } + proj_output[batch * n_output + row] = acc; + } + } +} + +void PortableApplyLayerNorm(const int16_t* input, + const int16_t* layer_norm_weights, + const int32_t* bias, int32_t layer_norm_scale_a, + int32_t layer_norm_scale_b, int32_t variance_limit, + int n_batch, int n_input, int16_t* output) { + // The square of std::pow(2, 10), which is the extra factor that makes sure + // normalized values has enough resolution. + static const int kTwoToPower20 = 1 << 20; + for (int i = 0; i < n_batch; ++i) { + int64_t sum = 0; + int64_t sum_sq = 0; + for (int j = 0; j < n_input; ++j) { + const int32_t index = i * n_input + j; + int32_t val = static_cast(input[index]); + sum += val; + sum_sq += val * val; + } + int32_t mean = + static_cast(static_cast(sum) * 1024 / n_input); + // TODO(b/173994730): Avoids overflow but only works for POT n_input. + int32_t temp = kTwoToPower20 / n_input; + int64_t variance = + sum_sq * temp - static_cast(mean) * static_cast(mean); + int32_t variance2 = static_cast(variance / kTwoToPower20); + if (variance2 < 1) { + variance2 = variance_limit; + } + int32_t stddev_inverse_a; + int stddev_inverse_b; + GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1, + &stddev_inverse_a, &stddev_inverse_b); + + for (int j = 0; j < n_input; ++j) { + const int32_t index = i * n_input + j; + int32_t val = static_cast(input[index]); + int32_t shifted = 1024 * val - mean; + int32_t rescaled = MultiplyByQuantizedMultiplier( + shifted, stddev_inverse_a, stddev_inverse_b); + int64_t val3 = rescaled * layer_norm_weights[j] + bias[j]; + int32_t val4 = + static_cast((val3 > 0 ? val3 + 512 : val3 - 512) / 1024); + int32_t val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a, + layer_norm_scale_b + 12); + val5 = std::min(std::max(kInt16Min, val5), kInt16Max); + output[index] = static_cast(val5); + } + } +} + +void PortableApplyLayerNormFloat(const int16_t* input, + const int16_t* layer_norm_weights, + int32_t layer_norm_scale_a, + int32_t layer_norm_scale_b, + const int32_t* bias, int n_batch, int n_input, + int16_t* output) { + const int32_t int16_max = std::numeric_limits::max(); + const int32_t int16_min = std::numeric_limits::min(); + const float layer_norm_scale = + layer_norm_scale_a * + std::pow(2.0, static_cast(layer_norm_scale_b - 31)); + const float bias_scale = + static_cast(std::pow(2.0, -10)) * layer_norm_scale; + + for (int batch = 0; batch < n_batch; ++batch) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + const float value = static_cast(input[index]); + sum += value; + sum_sq += value * value; + } + const float mean = sum / n_input; + float stddev_inv = 0.0f; + const float variance = sum_sq / n_input - mean * mean; + if (variance == 0) { + stddev_inv = 1.0f / std::sqrt(1e-8f); + } else { + stddev_inv = 1.0f / std::sqrt(variance); + } + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + const float normalized_value = + (static_cast(input[index]) - mean) * stddev_inv; + const float weighted_normalized_value = + normalized_value * layer_norm_weights[i] * layer_norm_scale + + bias[i] * bias_scale; + const int32_t quant_output = static_cast(round( + weighted_normalized_value * static_cast(std::pow(2, 12)))); + output[index] = std::min(int16_max, std::max(int16_min, quant_output)); + } + } +} + +void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix, + int32_t scalar, int32_t n_row, + int32_t n_col, int32_t* output) { + for (int i = 0; i < n_row; ++i) { + int32_t row_sum = 0; + for (int j = 0; j < n_col; ++j) { + row_sum += *matrix++; + } + output[i] += row_sum * scalar; + } +} + +void PortableApplySigmoid(const int16_t* input, int32_t n_batch, + int32_t n_input, int16_t* output) { + for (int batch = 0; batch < n_batch; ++batch) { + for (int c = 0; c < n_input; c++) { + using F3 = gemmlowp::FixedPoint; + using F0 = gemmlowp::FixedPoint; + const int index = batch * n_input + c; + F3 sigmoid_input = F3::FromRaw(input[index]); + F0 sigmoid_output = gemmlowp::logistic(sigmoid_input); + output[index] = sigmoid_output.raw(); + } + } +} + +void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch, + int32_t n_input, int16_t* output) { + const int32_t int16_max = std::numeric_limits::max(); + const int32_t int16_min = std::numeric_limits::min(); + for (int batch = 0; batch < n_batch; ++batch) { + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + const float float_input = + input[index] * static_cast(std::pow(2, -12)); + const float float_output = 1.0f / (1.0f + std::exp(-float_input)); + const int32_t quant_output = static_cast( + float_output * static_cast(std::pow(2, 15))); + const int32_t quant_output_clamped = + std::min(int16_max, std::max(int16_min, quant_output)); + output[index] = static_cast(quant_output_clamped); + } + } +} + +template +void PortableApplyTanhImpl(const int16_t* input, int32_t n_batch, + int32_t n_input, int16_t* output) { + using FX = gemmlowp::FixedPoint; + using F0 = gemmlowp::FixedPoint; + for (int batch = 0; batch < n_batch; ++batch) { + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + FX tanh_input = FX::FromRaw(input[index]); + F0 tanh_output = gemmlowp::tanh(tanh_input); + output[index] = tanh_output.raw(); + } + } +} + +void PortableApplyTanh(int32_t integer_bits, const int16_t* input, + int32_t n_batch, int32_t n_input, int16_t* output) { + if (integer_bits > 6) { + TFLITE_ASSERT_FALSE; + } +#define DISPATCH_TANH(i) \ + case i: \ + PortableApplyTanhImpl(input, n_batch, n_input, output); \ + break; + switch (integer_bits) { + DISPATCH_TANH(0); + DISPATCH_TANH(1); + DISPATCH_TANH(2); + DISPATCH_TANH(3); + DISPATCH_TANH(4); + DISPATCH_TANH(5); + DISPATCH_TANH(6); + default: + return; + } +#undef DISPATCH_TANH +} + +void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch, + int32_t n_input, int32_t integer_bits, + int16_t* output) { + const int32_t int16_max = std::numeric_limits::max(); + const int32_t int16_min = std::numeric_limits::min(); + const double two = 2.0; + for (int batch = 0; batch < n_batch; ++batch) { + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + const float float_input = + input[index] * std::pow(two, static_cast(integer_bits)); + const float float_output = std::tanh(float_input); + const int32_t quant_output = static_cast( + float_output * static_cast(std::pow(2, 15))); + const int32_t quant_output_clamped = + std::min(int16_max, std::max(int16_min, quant_output)); + output[index] = static_cast(quant_output_clamped); + } + } +} + +void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2, + int n_batch, int n_input, int shift, int16_t* output) { + for (int batch = 0; batch < n_batch; ++batch) { + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + const int16_t a = input_1[index]; + const int16_t b = input_2[index]; + const int32_t value = static_cast(a) * static_cast(b); + output[index] = + static_cast(gemmlowp::RoundingDivideByPOT(value, shift)); + } + } +} + +void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2, + int32_t multiplier, int32_t shift, int32_t n_batch, + int32_t n_input, int32_t output_zp, int8_t* output) { + for (int batch = 0; batch < n_batch; ++batch) { + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + const int16_t a = input_1[index]; + const int16_t b = input_2[index]; + int32_t value = static_cast(a) * static_cast(b); + value = MultiplyByQuantizedMultiplier(value, multiplier, shift); + value -= output_zp; + value = std::min(std::max(static_cast(-128), value), + static_cast(127)); + + output[index] = static_cast(value); + } + } +} + +void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2, + int n_batch, int n_input, int16_t* output) { + for (int batch = 0; batch < n_batch; ++batch) { + for (int i = 0; i < n_input; ++i) { + const int index = batch * n_input + i; + int32_t sum = input_1[index] + input_2[index]; + const int32_t sum_clamped = std::min(kInt16Max, std::max(kInt16Min, sum)); + output[index] = static_cast(sum_clamped); + } + } +} + +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + float result = 0.0; + for (int v = 0; v < v_size; v++) { + result += *vector1++ * *vector2++; + } + return result; +} + +namespace { +inline int32_t VectorVectorDotProduct(const int16_t* vector1, + const int16_t* vector2, int v_size) { + int32_t result = 0; + for (int v = 0; v < v_size; v++) { + result += *vector1++ * *vector2++; + } + return result; +} +} // namespace + +void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1, + const int16_t* vector2, + int v_size, int n_batch, + int32_t* result) { + for (int b = 0; b < n_batch; b++) { + result[b] = VectorVectorDotProduct(vector1, vector2, v_size); + vector1 += v_size; + vector2 += v_size; + } +} + +void PortableVectorBatchVectorCwiseProductAccumulate( + const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch, + int32_t multiplier, int shift, int16_t* result) { + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < v_size; v++) { + int32_t prod = vector[v] * *batch_vector++; + prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift); + int32_t output = prod + *result; + output = std::max(std::min(static_cast(32767), output), + static_cast(-32768)); + *result++ = output; + } + } +} + +void PortableSub1Vector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = 1.0f - *vector++; + } +} + +void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result) { + static const int16_t kOne = 32767; + for (int v = 0; v < v_size; v++) { + *result++ = kOne - *vector++; + } +} + +void PortableVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + for (int v = 0; v < v_size; ++v) { + *result++ = scale * *vector++; + } +} + +void PortableMeanStddevNormalization(const float* __restrict__ input_vector, + float* __restrict__ output_vector, + int v_size, int n_batch) { + for (int batch = 0; batch < n_batch; ++batch) { + float sum = 0.0f; + for (int i = 0; i < v_size; ++i) { + sum += input_vector[i]; + } + const float mean = sum / v_size; + float sum_diff_sq = 0.0f; + for (int i = 0; i < v_size; ++i) { + const float diff = input_vector[i] - mean; + sum_diff_sq += diff * diff; + } + const float variance = sum_diff_sq / v_size; + constexpr float kNormalizationConstant = 1e-8f; + const float stddev_inv = + 1.0f / std::sqrt(variance + kNormalizationConstant); + for (int i = 0; i < v_size; ++i) { + output_vector[i] = (input_vector[i] - mean) * stddev_inv; + } + input_vector += v_size; + output_vector += v_size; + } +} + +void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp, + const int8_t* recurrent, int8_t recurrent_zp, + int32_t input_effective_scale_a, + int32_t input_effective_scale_b, + int32_t recurrent_effective_scale_a, + int32_t recurrent_effective_scale_b, + int32_t n_batch, int32_t n_cell, + int16_t* output) { + const int32_t int16_max = std::numeric_limits::max(); + const int32_t int16_min = std::numeric_limits::min(); + for (int i = 0; i < n_batch * n_cell; ++i) { + int32_t x = static_cast(input[i]) - static_cast(input_zp); + int32_t h = + static_cast(recurrent[i]) - static_cast(recurrent_zp); + int32_t x_scaled = MultiplyByQuantizedMultiplier(x, input_effective_scale_a, + input_effective_scale_b); + int32_t h_scaled = MultiplyByQuantizedMultiplier( + h, recurrent_effective_scale_a, recurrent_effective_scale_b); + int32_t y = h_scaled + x_scaled; + if (y > int16_max) { + y = int16_max; + } + if (y < int16_min) { + y = int16_min; + } + output[i] = static_cast(y); + } +} + +} // namespace micro_tensor_utils +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_tensor_utils.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_tensor_utils.h new file mode 100644 index 00000000..673ba6a3 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/micro_tensor_utils.h @@ -0,0 +1,874 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file and the associated .cc file is branched from +// tensorflow/lite/kernels/internal/reference/portable_tensor_utils* +// TFLM needs to create its own because the original files are coupled with +// the tensor_utils module, which we cannot reuse due to its use of the +// Eigen library. + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" + +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + +namespace tflite { + +// Not all backends support CpuBackendContext usage, so forward declare to avoid +// pulling in its implementation. +// TODO(b/230666277): consider removing this since micro does not utilize it +class CpuBackendContext; + +namespace micro_tensor_utils { + +template +inline bool PortableIsZeroVector(const T* vector, int v_size) { + for (int i = 0; i < v_size; ++i) { + if (vector[i] != 0) { + return false; + } + } + return true; +} + +void PortableSymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); + +void PortableSymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float min_value, + float max_value, float* scaling_factor); + +void PortableAsymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, + float* scaling_factor, int32_t* offset); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result); + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result); + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, CpuBackendContext* context); + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vector, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + CpuBackendContext* context); + +void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result); + +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, + int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, + float* __restrict__ result); + +void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( + const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, + int n_batch, const int32_t input_offset, const int32_t output_multiplier, + const int32_t output_shift, const int32_t output_offset, + const int32_t output_activation_min, const int32_t output_activation_max, + int8_t* __restrict__ result); + +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1, + const int16_t* vector2, + int v_size, int n_batch, + int32_t* result); + +void PortableVectorBatchVectorCwiseProductAccumulate( + const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch, + int32_t multiplier, int shift, int16_t* result); + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + int32_t* scratch, int16_t* output, CpuBackendContext* context); + +void PortableMatrixBatchVectorMultiplyAccumulate( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + int32_t* scratch, int8_t* output, CpuBackendContext* context); + +void PortableMatrixBatchVectorMultiply(const int8_t* input, + int32_t input_zeropoint, + const int8_t* input_to_gate_weights, + int32_t input_to_gate_effective_scale_a, + int32_t input_to_gate_effective_scale_b, + int32_t n_batch, int32_t n_input, + int32_t n_cell, int8_t* gate_output, + int8_t gate_output_zp); + +void PortableMatrixBatchVectorMultiply( + const int16_t* hidden, const int8_t* hidden_to_output_weights, + int32_t proj_effective_scale_a, int32_t proj_effective_scale_b, + const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden, + int32_t n_output, int32_t output_zp, int8_t* proj_output); + +void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix, + int32_t scalar, int32_t n_row, + int32_t n_col, int32_t* output); + +void PortableApplyLayerNorm(const int16_t* input, + const int16_t* layer_norm_weights, + const int32_t* bias, int32_t layer_norm_scale_a, + int32_t layer_norm_scale_b, int32_t variance_limit, + int n_batch, int n_input, int16_t* output); + +void PortableApplyLayerNormFloat(const int16_t* input, + const int16_t* layer_norm_weights, + int32_t layer_norm_scale_a, + int32_t layer_norm_scale_b, + const int32_t* bias, int n_batch, int n_input, + int16_t* output); + +void PortableApplySigmoid(const int16_t* input, int32_t n_batch, + int32_t n_input, int16_t* output); + +void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch, + int32_t n_input, int16_t* output); + +void PortableApplyTanh(int32_t integer_bits, const int16_t* input, + int32_t n_batch, int32_t n_input, int16_t* output); + +void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch, + int32_t n_input, int32_t integer_bits, + int16_t* output); + +void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2, + int n_batch, int n_input, int shift, int16_t* output); + +void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2, + int32_t multiplier, int32_t shift, int32_t n_batch, + int32_t n_input, int32_t output_zp, int8_t* output); + +void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2, + int n_batch, int n_input, int16_t* output); + +template +inline void PortableCwiseClipping(T* vector, const int v_size, + const T& clipping_value) { + for (int i = 0; i < v_size; i++) { + vector[i] = std::max(std::min(clipping_value, vector[i]), + static_cast(-clipping_value)); + } +} + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); + +void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result); + +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + +// Reduce-sum on a vector: +// input_vector: pointer to input vector. +// output_vector: pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +template +inline void PortableReductionSumVector(const INPUT* input_vector, + OUTPUT* output_vector, int output_size, + int reduction_size) { + for (int o = 0; o < output_size; o++) { + OUTPUT result = 0; + for (int r = 0; r < reduction_size; r++) { + result += input_vector[r]; + } + output_vector[o] = result; + input_vector += reduction_size; + } +} + +// Layer norm for each batch. +void PortableMeanStddevNormalization(const float* __restrict__ input_vector, + float* __restrict__ output_vector, + int v_size, int n_batch); + +// Saturate Add. +void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp, + const int8_t* recurrent, int8_t recurrent_zp, + int32_t input_effective_scale_a, + int32_t input_effective_scale_b, + int32_t recurrent_effective_scale_a, + int32_t recurrent_effective_scale_b, + int32_t n_batch, int32_t n_cell, + int16_t* output); + +// Add another vector for each batch in the batch vector. +template +inline void VectorBatchVectorAdd(const T* vector, int v_size, int n_batch, + T* batch_vector) { + for (int b = 0; b < n_batch; b++) { + for (int i = 0; i < v_size; ++i) { + batch_vector[i] += vector[i]; + } + batch_vector += v_size; + } +} + +// Cwise product of two vectors. +template +inline void VectorVectorCwiseProduct(const T* vector1, const T* vector2, + int v_size, T* result) { + for (int v = 0; v < v_size; v++) { + *result++ = *vector1++ * *vector2++; + } +} + +// Cwise product of a vector and a batch-vector. +template +inline void VectorBatchVectorCwiseProduct(const T* vector, int v_size, + const T* batch_vector, int n_batch, + T* result) { + for (int b = 0; b < n_batch; b++) { + VectorVectorCwiseProduct(vector, batch_vector, v_size, result); + // Update the pointers. + result += v_size; + batch_vector += v_size; + } +} + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +inline void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +// Same as above but input/output is 32 bit integer. +inline void ReductionSumVector(const int32_t* input_vector, + int32_t* output_vector, int output_size, + int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +// Same as above but input is 8 bit integer. +inline void ReductionSumVector(const int8_t* input_vector, + int32_t* output_vector, int output_size, + int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +// Cwise product and accumulate of two vectors. Since it's a MAC operation, the +// assumption here is that result array is initialized to valid values. +template +inline void VectorVectorCwiseProductAccumulate(const T* __restrict__ vector1, + const T* __restrict__ vector2, + int v_size, + T* __restrict__ result) { + for (int v = 0; v < v_size; v++) { + *result++ += *vector1++ * *vector2++; + } +} + +// Batch vector initialization with another vector. +template +inline void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch, + T* batch_vector) { + for (int b = 0; b < n_batch; b++) { + std::copy_n(vector, v_size, batch_vector + b * v_size); + } +} + +inline void SymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float* min, + float* max, float* scaling_factor) { + PortableSymmetricQuantizeFloats(values, size, quantized_values, min, max, + scaling_factor); +} + +inline void SymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float min_value, + float max_value, float* scaling_factor) { + PortableSymmetricQuantizeFloats(values, size, quantized_values, min_value, + max_value, scaling_factor); +} + +inline void AsymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, + float* scaling_factor, int32_t* offset) { + PortableAsymmetricQuantizeFloats(values, size, quantized_values, + scaling_factor, offset); +} + +// Helper function to quantize floats. +// float_data_ptr input float vectors +// n_batch number of input vectors +// n_data size of a single input vector +// quantized_data_ptr (out) vector with quantized data +// scaling_factors (out) scaling factors (one per vector) +// zero_points (out) zero points (one per vector) +// do_asymmetric controls if the quantization should be asymmetric. +inline void BatchQuantizeFloats(const float* float_data_ptr, int n_batch, + int n_data, int8_t* quantized_data_ptr, + float* scaling_factors, int32_t* zero_points, + bool do_asymmetric) { + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_data; + if (do_asymmetric) { + AsymmetricQuantizeFloats(float_data_ptr + offset, n_data, + quantized_data_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + SymmetricQuantizeFloats(float_data_ptr + offset, n_data, + quantized_data_ptr + offset, &unused_min, + &unused_max, &scaling_factors[b]); + } + } +} + +// Check if all entries of a vector are zero for float. +inline bool IsZeroVector(const float* vector, int v_size) { + return PortableIsZeroVector(vector, v_size); +} + +// Check if all entries of a vector are zero for int8_t. +inline bool IsZeroVector(const int8_t* vector, int v_size) { + return PortableIsZeroVector(vector, v_size); +} + +// Apply Layer Normalization (https://arxiv.org/abs/1607.06450) to a Quantized +// vector. +// Parameters: +// - input: batch vector of size n_batch * n_input; 16 bit. +// - layer_norm_weights: the quantized layer normalization weights. +// - bias: the bias for the layer normalization. +// - layer_norm_scale_a: multiplier for scale factor. +// - layer_norm_scale_b: shift for scale factor. +// - variance_limit: the guard to make sure the inverse does not overflow. +// - n_batch: the number of batches. +// - n_input: the size for input and output. +// - output: the 16 bit output +inline void ApplyLayerNorm(const int16_t* input, + const int16_t* layer_norm_weights, + const int32_t* bias, int32_t layer_norm_scale_a, + int32_t layer_norm_scale_b, int32_t variance_limit, + int n_batch, int n_input, int16_t* output) { + PortableApplyLayerNorm(input, layer_norm_weights, bias, layer_norm_scale_a, + layer_norm_scale_b, variance_limit, n_batch, n_input, + output); +} + +// Same as above but the internal calculation is done in float. +inline void ApplyLayerNormFloat(const int16_t* input, + const int16_t* layer_norm_weights, + int32_t layer_norm_scale_a, + int32_t layer_norm_scale_b, const int32_t* bias, + int n_batch, int n_input, int16_t* output) { + PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a, + layer_norm_scale_b, bias, n_batch, n_input, + output); +} + +// Apply Sigmoid to a quantized vector. +// Parameters: +// - input: batch vector of size n_batch * n_input; 16 bit. +// - n_batch: the number of batches. +// - n_input: the size for input and output. +// - output: the 16 bit output +// The input is in Q3.12 format and the output is in Q0.15 format. +inline void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input, + int16_t* output) { + PortableApplySigmoid(input, n_batch, n_input, output); +} + +// Same as above but the internal calcualtion is float. +inline void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, + int32_t n_input, int16_t* output) { + PortableApplySigmoidFloat(input, n_batch, n_input, output); +} + +// Apply Tanh to a quantized vector. +// Parameters: +// - integer_bits: the integer bits of the input. +// Currently supports 0, 1, 2, 3, 4, 5, 6. +// - input: batch vector of size n_batch * n_input; 16 bit. +// - n_batch: the number of batches. +// - n_input: the size for input and output. +// - output: the 16 bit output +// The input is in Qm.15-m format and the output is in Q0.15 format. +inline void ApplyTanh(int32_t integer_bits, const int16_t* input, + int32_t n_batch, int32_t n_input, int16_t* output) { + PortableApplyTanh(integer_bits, input, n_batch, n_input, output); +} + +// Apply Tanh to a quantized vector. Tbe internal calculation is in float. +// - Input has 2^(integer_bits) as scale. +// - Output has Q0.15 as scale. +inline void ApplyTanhFloat(const int16_t* input, int32_t n_batch, + int32_t n_input, int32_t integer_bits, + int16_t* output) { + PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output); +} + +// Element-wise multiplication of two quantized vectors. +// Parameters: +// - input_1: batch vector of size n_batch * n_input; 16 bit. +// - input_2: batch vector of size n_batch * n_input; 16 bit. +// - n_batch: the number of batches. +// - n_input: the size for input and output. +// - shift: the shift needed to produce the output. +// - output: the 16 bit output of size n_batch * n_input. +// Output does not need to be initialized. +inline void CwiseMul(const int16_t* input_1, const int16_t* input_2, + int n_batch, int n_input, int shift, int16_t* output) { + PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output); +} + +// Element-wise multiplication of two quantized vectors with rescaling. +// Parameters: +// - input_1: batch vector of size n_batch * n_input; 16 bit. +// - input_2: batch vector of size n_batch * n_input; 16 bit. +// - multiplier: the multiplier part of scale. +// - shift: the shift part of scale. +// - n_batch: the number of batches. +// - n_input: the size for input and output. +// - output: the 8 bit output of size n_batch * n_input. +// - output_zp: the zero point of output. +// Output does not need to be initialized. +// Multiplier ("m") and shift ("s") are connected to scale ("s") with s = m * +// 2^(s - 31). +inline void CwiseMul(const int16_t* input_1, const int16_t* input_2, + int32_t multiplier, int32_t shift, int32_t n_batch, + int32_t n_input, int32_t output_zp, int8_t* output) { + PortableCwiseMul(input_1, input_2, multiplier, shift, n_batch, n_input, + output_zp, output); +} + +// Element-wise in-place clipping of a vector. Overloaded for float, int16_t, +// int8_t. Parameters: +// - vector: vector of size v_size. +// - v_size: the size of the vector. +// - clipping_value: the value used for clipping. +inline void CwiseClipping(float* vector, const int v_size, + const float clipping_value) { + PortableCwiseClipping(vector, v_size, clipping_value); +} + +inline void CwiseClipping(int16_t* vector, const int v_size, + const int16_t clipping_value) { + PortableCwiseClipping(vector, v_size, clipping_value); +} + +inline void CwiseClipping(int8_t* vector, const int v_size, + const int8_t clipping_value) { + PortableCwiseClipping(vector, v_size, clipping_value); +} + +// Element-wise saturating addition of two quantized vectors without rescaling. +// Parameters: +// - input_1: batch vector of size n_batch * n_input; 16 bit. +// - input_2: batch vector of size n_batch * n_input; 16 bit. +// - n_batch: the number of batches. +// - n_input: the size for input and output. +// - output: the 8 bit output of size n_batch * n_input. +// Output does not need to be initialized. +inline void CwiseAdd(const int16_t* input_1, const int16_t* input_2, + int n_batch, int n_input, int16_t* output) { + PortableCwiseAdd(input_1, input_2, n_batch, n_input, output); +} + +inline void MeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch) { + PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch); +} + +inline void Sub1Vector(const float* vector, int v_size, float* result) { + PortableSub1Vector(vector, v_size, result); +} + +inline void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) { + PortableSub1Vector(vector, v_size, result); +} + +// Multiply all elements of vector with a scalar. +inline void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + PortableVectorScalarMultiply(vector, v_size, scale, result); +} + +// Saturate Add with rescale on both inputs. +inline void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp, + const int8_t* recurrent, int8_t recurrent_zp, + int32_t input_effective_scale_a, + int32_t input_effective_scale_b, + int32_t recurrent_effective_scale_a, + int32_t recurrent_effective_scale_b, + int32_t n_batch, int32_t n_cell, + int16_t* output) { + PortableTwoGateSaturatingAdd( + input, input_zp, recurrent, recurrent_zp, input_effective_scale_a, + input_effective_scale_b, recurrent_effective_scale_a, + recurrent_effective_scale_b, n_batch, n_cell, output); +} + +// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch +// dimension composed by input vectors independent from each other). The result +// of the multiplication is accumulated to the passed result buffer. +// More specifically, for a matrix M of shape [n, i] and a batched-vector +// of shape [i, batch] it will first compute the product of shape [n, batch]. +// This product will be accumulated to the result buffer. +inline void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + n_batch, result); +} + +// Same as the function above, but the matrix is a sparse tensor with block +// pattern 1x4. +// This function assumes that m_cols is a multiple of the block size (4 in this +// case) so that there's no incomplete block. +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vector, const float* scaling_factors, + int n_batch, float* __restrict__ result) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + scaling_factors, n_batch, result); +} + +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + per_channel_scale, input_offset, scratch, row_sums, compute_row_sums, + context); +} + +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vector, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + scaling_factors, n_batch, result); +} + +// Same as the function above, but the matrix is a sparse tensor with block +// pattern 1x4. +// This function assumes that m_cols is a multiple of the block size (4 in this +// case) so that there's no incomplete block. +inline void SparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( + matrix, segments, indices, m_rows, m_cols, vector, n_batch, result); +} + +// Same as the function above, but the matrix is stored in block compressed +// sparse row format with block pattern 1x16 which consists of two arrays: +// 1. A matrix array stores non-zero blocks of the matrix in row major. +// 2. A ledger array stores nrows groups, one group per row. Each group starts +// with an integer representing the number of non-zero blocks for the +// corresponding row and follows with column indexes of the first element +// of each non-zero block. +// This function assumes that +// 1. m_cols is a multiple of 16 so that all blocks are full blocks. +// 2. m_cols < 254 * 16 so that block index can be represented by uint8. +inline void SparseMatrixBatchVectorMultiplyAccumulate( + const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, + int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, + float* __restrict__ result) { + PortableSparseMatrixBatchVectorMultiplyAccumulate( + matrix, ledger, m_rows, m_cols, vector, n_batch, result); +} + +// Same as the function above, but the matrix is a sparse tensor with block +// pattern 1x16. +// This function assumes that m_cols is a multiple of the block size (16 in this +// case) so that there's no incomplete block. Also, it assumes all offsets of +// input, output and filter are zero. +inline void SparseMatrixBatchVectorMultiplyAccumulate1x16( + const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, + int n_batch, const int32_t input_offset, const int32_t output_multiplier, + const int32_t output_shift, const int32_t output_offset, + const int32_t output_activation_min, const int32_t output_activation_max, + int8_t* __restrict__ result) { + PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( + matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch, + input_offset, output_multiplier, output_shift, output_offset, + output_activation_min, output_activation_max, result); +} + +// Same as the function above, but the matrix is stored in block compressed +// sparse row format with block pattern 1x16 which consists of two arrays: +// 1. A matrix array stores non-zero blocks of the matrix in row major. +// 2. A ledger array stores nrows groups, one group per row. Each group starts +// with an integer representing the number of non-zero blocks for the +// corresponding row followed by column index of the first element of +// each non-zero block. +// This function assumes that +// 1. m_cols is a multiple of 16 so that all blocks are full blocks. +// 2. m_cols < 254 * 16 so that block index can be represented by uint8. +inline void SparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result) { + PortableSparseMatrixBatchVectorMultiplyAccumulate( + matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch, + result); +} + +// Same as the above 8, 8, 8 integer matmul except for the presence of zero +// point and non-accumulative. +// TODO(b/148688698): remove this function by folding zero point calculation in +// prepare() function. +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + int32_t* scratch, int16_t* output, CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulate( + input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input, + n_output, output_zp, scratch, output, context); +} + +// Same as above but has 16 bit and 8 bit input and 8 bit output. +// Used in projection when hidden is 16bit. +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* input, const int32_t* bias, + const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, + int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp, + int32_t* scratch, int8_t* output, CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulate( + input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input, + n_output, output_zp, scratch, output, context); +} + +// Same as the function above, but provides separate scaling factor for the +// matrix and the vectors. The scaling factors are multiplied in the +// scaling_factor_scratch buffer. +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float matrix_scaling_factor, + const float* vector_scaling_factors, int n_batch, + float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, float* scaling_factor_scratch, + CpuBackendContext* context) { + for (int b = 0; b < n_batch; ++b) { + scaling_factor_scratch[b] = + vector_scaling_factors[b] * matrix_scaling_factor; + } + MatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factor_scratch, n_batch, result, + per_channel_scale, input_offset, scratch, + row_sums, compute_row_sums, context); +} + +// Multiplies a matrix with a scalar and reduce the result on each row to a +// scalar. +// Parameters: +// - matrix: matrix of size n_row * n_col +// - scalar: the scalar that is multiplied to each element in the matrix +// - n_row: the row count of the matrix +// - n_col: the column count of the matrix +// - output: the 32bit output +// Note: We do not need saturation because the int8 * int8 is safe from overflow +// in (2^31-1) / (2^14) = 131072, which is bigger than the n_row. Non-zero +// initial output value is not exceptionally large. +inline void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, + int32_t n_row, int32_t n_col, + int32_t* output) { + PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output); +} + +// Same as the above 8, 8, 8 integer matmul except for the presence of zero +// point and non-accumulative. +// TODO(b/148688698): remove this function by folding zero point calculation in +// prepare() function. +inline void MatrixBatchVectorMultiply(const int8_t* input, + int32_t input_zeropoint, + const int8_t* input_to_gate_weights, + int32_t input_to_gate_effective_scale_a, + int32_t input_to_gate_effective_scale_b, + int32_t n_batch, int32_t n_input, + int32_t n_cell, int8_t* gate_output, + int8_t gate_output_zp) { + PortableMatrixBatchVectorMultiply( + input, input_zeropoint, input_to_gate_weights, + input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch, + n_input, n_cell, gate_output, gate_output_zp); +} + +// Same as above but has 16 bit and 8 bit input and 8 bit output. +// Used in projection when hidden is 16bit. +inline void MatrixBatchVectorMultiply(const int16_t* hidden, + const int8_t* hidden_to_output_weights, + int32_t proj_effective_scale_a, + int32_t proj_effective_scale_b, + const int32_t* gate_bias, int32_t n_batch, + int32_t n_hidden, int32_t n_output, + int32_t output_zp, int8_t* proj_output) { + PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights, + proj_effective_scale_a, + proj_effective_scale_b, gate_bias, n_batch, + n_hidden, n_output, output_zp, proj_output); +} + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +template +inline void VectorBatchVectorCwiseProductAccumulate(const T* vector, int v_size, + const T* batch_vector, + int n_batch, T* result) { + for (int b = 0; b < n_batch; b++) { + VectorVectorCwiseProductAccumulate(vector, batch_vector, v_size, result); + // Update the pointers. + result += v_size; + batch_vector += v_size; + } +} + +// Same as above, but inputs are 16bit integer and output is 16bit integer. +inline void VectorBatchVectorCwiseProductAccumulate( + const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch, + int32_t multiplier, int shift, int16_t* result) { + PortableVectorBatchVectorCwiseProductAccumulate( + vector, v_size, batch_vector, n_batch, multiplier, shift, result); +} + +// Apply Rectified Linear to elements of a vector. +inline void ApplyReluToVector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::max(0.0f, vector[v]); + } +} + +// Apply Rectified Linear 1 (cap to [-1;1]) to elements of a vector +inline void ApplyRelu1ToVector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::max(-1.0f, std::min(vector[v], 1.0f)); + } +} + +// Apply Rectified Linear 6 (cap to [0;6]) to elements of a vector +inline void ApplyRelu6ToVector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::max(0.0f, std::min(vector[v], 6.0f)); + } +} + +// Apply tanh to elements of a vector +inline void ApplyTanhToVector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::tanh(vector[v]); + } +} + +// Apply signbit to elements of a vector +inline void ApplySignbitToVector(const float* vector, int v_size, + float* result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::signbit(vector[v]); + } +} + +// Apply sigmoid to elements of a vector. +inline void ApplySigmoidToVector(const float* vector, int v_size, + float* result) { + for (int v = 0; v < v_size; v++) { + result[v] = 1.0f / (1.0f + std::exp(-vector[v])); + } +} + +// Apply appropriate activation function to elements of a vector. +inline void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result) { + switch (activation) { + case kTfLiteActNone: + return; + case kTfLiteActRelu: + return ApplyReluToVector(vector, v_size, result); + case kTfLiteActReluN1To1: + return ApplyRelu1ToVector(vector, v_size, result); + case kTfLiteActRelu6: + return ApplyRelu6ToVector(vector, v_size, result); + case kTfLiteActTanh: + return ApplyTanhToVector(vector, v_size, result); + case kTfLiteActSignBit: + return ApplySignbitToVector(vector, v_size, result); + case kTfLiteActSigmoid: + return ApplySigmoidToVector(vector, v_size, result); + } +} + +} // namespace micro_tensor_utils + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_MICRO_TENSOR_UTILS_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/mirror_pad.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/mirror_pad.cc index a19561f6..90d3bd9e 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/mirror_pad.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/mirror_pad.cc @@ -209,14 +209,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_MIRROR_PAD() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/mul.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/mul.cc index e8295197..59f006b0 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/mul.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/mul.cc @@ -61,14 +61,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_MUL() { - return {/*init=*/MulInit, - /*free=*/nullptr, - /*prepare=*/MulPrepare, - /*invoke=*/MulEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/neg.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/neg.cc index 74a95ca3..59dd8cb8 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/neg.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/neg.cc @@ -51,14 +51,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace neg TfLiteRegistration Register_NEG() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/neg::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, neg::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/pack.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/pack.cc index 098a0482..56f3b96e 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/pack.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/pack.cc @@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace pack TfLiteRegistration Register_PACK() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/pack::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, pack::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/pad.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/pad.cc index 1428b16e..b645f983 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/pad.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/pad.cc @@ -223,26 +223,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace pad TfLiteRegistration Register_PAD() { - return {/*init=*/pad::Init, - /*free=*/nullptr, - /*prepare=*/pad::Prepare, - /*invoke=*/pad::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval); } // Also register Pad as PadV2. TfLiteRegistration Register_PADV2() { - return {/*init=*/pad::Init, - /*free=*/nullptr, - /*prepare=*/pad::Prepare, - /*invoke=*/pad::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/pooling.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/pooling.cc index b3781636..a2ef8b62 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/pooling.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/pooling.cc @@ -88,25 +88,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } // namespace TfLiteRegistration Register_AVERAGE_POOL_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/PoolingPrepare, - /*invoke=*/AverageEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval); } TfLiteRegistration Register_MAX_POOL_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/PoolingPrepare, - /*invoke=*/MaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/prelu.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/prelu.cc index dc0c32c0..54cc0e02 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/prelu.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/prelu.cc @@ -69,14 +69,7 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_PRELU() { - return {/*init=*/PreluInit, - /*free=*/nullptr, - /*prepare=*/PreluPrepare, - /*invoke=*/PreluEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(PreluInit, PreluPrepare, PreluEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize.cc index 97f5a004..b5eb9c3c 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize.cc @@ -34,14 +34,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } // namespace TfLiteRegistration Register_QUANTIZE() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/PrepareQuantizeReference, - /*invoke=*/EvalQuantizeReference, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, PrepareQuantizeReference, + EvalQuantizeReference); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize_common.cc index cca3489d..94220529 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize_common.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/quantize_common.cc @@ -53,15 +53,19 @@ TfLiteStatus PrepareQuantizeReference(TfLiteContext* context, TF_LITE_ENSURE(context, affine_quantization->scale); TF_LITE_ENSURE(context, affine_quantization->scale->size == 1); - TF_LITE_ENSURE(context, - input->type == kTfLiteFloat32 || input->type == kTfLiteInt32 || - input->type == kTfLiteInt16 || input->type == kTfLiteInt8); + TF_LITE_ENSURE( + context, input->type == kTfLiteFloat32 || input->type == kTfLiteInt32 || + input->type == kTfLiteInt16 || input->type == kTfLiteInt8 || + input->type == kTfLiteUInt8); TF_LITE_ENSURE(context, output->type == kTfLiteInt8 || output->type == kTfLiteInt16 || - output->type == kTfLiteInt32); + output->type == kTfLiteInt32 || + output->type == kTfLiteUInt8); if ((input->type == kTfLiteInt16 && output->type == kTfLiteInt8) || (input->type == kTfLiteInt8 && output->type == kTfLiteInt8) || + (input->type == kTfLiteInt8 && output->type == kTfLiteUInt8) || + (input->type == kTfLiteUInt8 && output->type == kTfLiteInt8) || (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) || (input->type == kTfLiteInt8 && output->type == kTfLiteInt32) || (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) || @@ -109,9 +113,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(output)); return kTfLiteOk; default: - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); return kTfLiteError; } } else if (input->type == kTfLiteInt32) { @@ -132,9 +136,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(output)); break; default: - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); return kTfLiteError; } } else if (input->type == kTfLiteInt16) { @@ -162,9 +166,9 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(output)); return kTfLiteOk; default: - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); return kTfLiteError; } } else if (input->type == kTfLiteInt8) { @@ -179,6 +183,13 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) { data->input_zero_point, data->quantization_params.zero_point, tflite::micro::GetTensorData(output)); break; + case kTfLiteUInt8: + reference_ops::Requantize( + tflite::micro::GetTensorData(input), size, + data->requantize_output_multiplier, data->requantize_output_shift, + data->input_zero_point, data->quantization_params.zero_point, + tflite::micro::GetTensorData(output)); + break; case kTfLiteInt16: reference_ops::Requantize( tflite::micro::GetTensorData(input), size, @@ -194,15 +205,31 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(output)); break; default: - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else if (input->type == kTfLiteUInt8) { + size_t size = ElementCount(*input->dims); + switch (output->type) { + case kTfLiteInt8: + reference_ops::Requantize( + tflite::micro::GetTensorData(input), size, + data->requantize_output_multiplier, data->requantize_output_shift, + data->input_zero_point, data->quantization_params.zero_point, + tflite::micro::GetTensorData(output)); + break; + default: + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); return kTfLiteError; } } else { - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); return kTfLiteError; } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/read_variable.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/read_variable.cc index f9124f04..422c0384 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/read_variable.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/read_variable.cc @@ -81,14 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_READ_VARIABLE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.cc index 40aed2d5..7e862ba1 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,331 +23,41 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/reduce.h" #include "tensorflow/lite/micro/micro_utils.h" namespace tflite { -namespace ops { -namespace micro { -namespace reduce { - -constexpr int kMaxNumberOfAxis = 4; -constexpr int kMaxNumberOfReducedAxis = 2; - -struct OpData { - int32_t multiplier; - int shift; - int temp_buffer_idx; - int resolved_axis_idx; - int input_zp; - float input_scale; - int output_zp; - float output_scale; - int num_output_elements; -}; void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) { - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { - MicroContext* micro_context = GetMicroContext(context); - - // Inputs Tensor (dtype depends on quantization): - // [0] = Input - // [1] = Axis - TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); - - // Outputs Tensor (dtype depends on quantization): - // [0] = Output - - // Validate number of inputs and outputs - TF_LITE_ENSURE_EQ(context, node->inputs->size, 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Validate axis type - TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1); - TF_LITE_ENSURE(context, axis != nullptr); - TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32); - - if (input->type == kTfLiteInt8) { - OpData* data = static_cast(node->user_data); - TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); - const double real_multiplier = static_cast(input->params.scale) / - static_cast(output->params.scale); - QuantizeMultiplier(real_multiplier, &data->multiplier, &data->shift); - micro_context->DeallocateTempTfLiteTensor(output); - } - micro_context->DeallocateTempTfLiteTensor(axis); - micro_context->DeallocateTempTfLiteTensor(input); - return kTfLiteOk; + return context->AllocatePersistentBuffer(context, sizeof(OpDataReduce)); } TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); - - MicroContext* micro_context = GetMicroContext(context); - OpData* op_data = static_cast(node->user_data); - TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); - TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); - TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1); - - op_data->input_scale = input->params.scale; - op_data->output_scale = output->params.scale; - op_data->num_output_elements = NumElements(output); - - context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size, - &op_data->temp_buffer_idx); - context->RequestScratchBufferInArena( - context, sizeof(int) * static_cast(ElementCount(*axis->dims)), - &op_data->resolved_axis_idx); - - micro_context->DeallocateTempTfLiteTensor(input); - micro_context->DeallocateTempTfLiteTensor(output); - micro_context->DeallocateTempTfLiteTensor(axis); - return kTfLiteOk; + return PrepareMaxHelper(context, node, + static_cast(node->user_data)); } TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) { - MicroContext* micro_context = GetMicroContext(context); - TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); - OpData* op_data = reinterpret_cast(node->user_data); - TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); - if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { - const double real_multiplier = static_cast(input->params.scale) / - static_cast(output->params.scale); - QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift); - } - - int output_size = NumElements(output); - if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { - context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t), - &op_data->temp_buffer_idx); - op_data->input_zp = input->params.zero_point; - op_data->input_scale = input->params.scale; - op_data->output_zp = output->params.zero_point; - op_data->output_scale = output->params.scale; - } - - TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); - // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018) - micro_context->DeallocateTempTfLiteTensor(input); - micro_context->DeallocateTempTfLiteTensor(output); - return kTfLiteOk; -} - -void ResolveAxis(const int* axis_data, int axis_count, - tflite::MeanParams* op_params) { - int i = 0; - for (; i < axis_count; ++i) { - op_params->axis[i] = static_cast(axis_data[i]); - } - for (; i < 4; ++i) { - op_params->axis[i] = 1; - } - op_params->axis_count = axis_count; + return PrepareMeanOrSumHelper(context, node, + static_cast(node->user_data)); } TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { - const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); - const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1); - TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); - TfLiteReducerParams* params = - reinterpret_cast(node->builtin_data); - OpData* op_data = reinterpret_cast(node->user_data); - - int num_axis = static_cast(ElementCount(*axis->dims)); - int temp_index[kMaxNumberOfAxis]; - int resolved_axis[kMaxNumberOfReducedAxis]; - - tflite::MeanParams op_params; - ResolveAxis(tflite::micro::GetTensorData(axis), num_axis, &op_params); - - // Special case mean implementation exists for 4D mean across axes 1 and 2. - bool special_case_4d_axes_1_and_2 = - input->dims->size == 4 && op_params.axis_count == 2 && - ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || - (op_params.axis[0] == 2 && op_params.axis[1] == 1)); - - switch (input->type) { - case kTfLiteFloat32: { - // Defer to specialized implementation for 4D Mean across axes 1 & 2. - if (params->keep_dims && special_case_4d_axes_1_and_2) { - reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - } else { - TF_LITE_ENSURE( - context, - reference_ops::Mean( - tflite::micro::GetTensorData(input), input->dims->data, - input->dims->size, tflite::micro::GetTensorData(output), - output->dims->data, output->dims->size, - tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_index, resolved_axis, - tflite::micro::GetTensorData(output))); - } - } break; - case kTfLiteInt8: { - // Defer to specialized implementation for 4D Mean across axes 1 & 2. - if (params->keep_dims && special_case_4d_axes_1_and_2) { - reference_integer_ops::Mean( - op_params, op_data->multiplier, op_data->shift, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), op_data->input_zp, - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), op_data->output_zp); - } else if (op_data->input_zp == op_data->output_zp && - op_data->input_scale == op_data->output_scale) { - int32_t* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - TF_LITE_ENSURE( - context, - reference_ops::Mean( - tflite::micro::GetTensorData(input), input->dims->data, - input->dims->size, tflite::micro::GetTensorData(output), - output->dims->data, output->dims->size, - tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_index, resolved_axis, temp_buffer)); - } else { - int32_t* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - TF_LITE_ENSURE( - context, - reference_ops::QuantizedMeanOrSum( - tflite::micro::GetTensorData(input), op_data->input_zp, - op_data->input_scale, input->dims->data, input->dims->size, - tflite::micro::GetTensorData(output), - op_data->output_zp, op_data->output_scale, output->dims->data, - output->dims->size, tflite::micro::GetTensorData(axis), - num_axis, params->keep_dims, temp_index, resolved_axis, - temp_buffer, false)); - } - } break; - case kTfLiteInt16: { - // Defer to specialized implementation for 4D Mean across axes 1 & 2. - if (params->keep_dims && special_case_4d_axes_1_and_2) { - reference_integer_ops::Mean( - op_params, op_data->multiplier, op_data->shift, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), op_data->input_zp, - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), op_data->output_zp); - } else if (op_data->input_zp == op_data->output_zp && - op_data->input_scale == op_data->output_scale) { - int32_t* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - TF_LITE_ENSURE( - context, - reference_ops::Mean(tflite::micro::GetTensorData(input), - input->dims->data, input->dims->size, - tflite::micro::GetTensorData(output), - output->dims->data, output->dims->size, - tflite::micro::GetTensorData(axis), - num_axis, params->keep_dims, temp_index, - resolved_axis, temp_buffer)); - } else { - int32_t* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - TF_LITE_ENSURE( - context, - reference_ops::QuantizedMeanOrSum( - tflite::micro::GetTensorData(input), op_data->input_zp, - op_data->input_scale, input->dims->data, input->dims->size, - tflite::micro::GetTensorData(output), - op_data->output_zp, op_data->output_scale, output->dims->data, - output->dims->size, tflite::micro::GetTensorData(axis), - num_axis, params->keep_dims, temp_index, resolved_axis, - temp_buffer, false)); - } - } break; - default: - TF_LITE_ENSURE_MSG(context, false, - "Currently, only float32, int8 or uint8 input type " - "is supported."); - } - return kTfLiteOk; + return EvalMeanHelper(context, node, + static_cast(node->user_data)); } TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) { - const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); - const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1); - TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); - TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); - TfLiteReducerParams* params = - static_cast(node->builtin_data); - OpData* op_data = static_cast(node->user_data); - - // Interpret an axis tensor with null dimensions as a scalar - int num_axis = static_cast(ElementCount(*axis->dims)); - int* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - int* resolved_axis = static_cast( - context->GetScratchBuffer(context, op_data->resolved_axis_idx)); - switch (input->type) { - case kTfLiteFloat32: - TF_LITE_ENSURE( - context, - reference_ops::ReduceGeneric( - tflite::micro::GetTensorData(input), input->dims->data, - input->dims->size, tflite::micro::GetTensorData(output), - output->dims->data, output->dims->size, - tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_buffer, resolved_axis, - std::numeric_limits::lowest(), - [](const float current, const float in) -> float { - return (in > current) ? in : current; - })); - break; - case kTfLiteInt8: - TF_LITE_ENSURE_EQ(context, static_cast(op_data->input_scale), - static_cast(op_data->output_scale)); - TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp); - TF_LITE_ENSURE( - context, - reference_ops::ReduceGeneric( - tflite::micro::GetTensorData(input), input->dims->data, - input->dims->size, tflite::micro::GetTensorData(output), - output->dims->data, output->dims->size, - tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_buffer, resolved_axis, - std::numeric_limits::lowest(), - [](const int8_t current, const int8_t in) -> int8_t { - return (in > current) ? in : current; - })); - break; - default: - TF_LITE_KERNEL_LOG(context, - "Only float32 and int8 types are supported.\n"); - return kTfLiteError; - } - return kTfLiteOk; + OpDataReduce* op_data = static_cast(node->user_data); + return EvalMaxHelper(context, node, op_data); } -} // namespace reduce - TfLiteRegistration Register_MEAN() { - return {/*init=*/reduce::InitReduce, - /*free=*/nullptr, - /*prepare=*/reduce::PrepareMeanOrSum, - /*invoke=*/reduce::EvalMean, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalMean); } TfLiteRegistration Register_REDUCE_MAX() { - return {/*init=*/reduce::InitReduce, - /*free=*/nullptr, - /*prepare=*/reduce::PrepareMax, - /*invoke=*/reduce::EvalMax, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(InitReduce, PrepareMax, EvalMax); } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.h new file mode 100644 index 00000000..cd94b3f5 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce.h @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { + +extern const int kMaxNumberOfAxis; +extern const int kMaxNumberOfReducedAxis; + +struct OpDataReduce { + int32_t multiplier; + int shift; + int temp_buffer_idx; + int resolved_axis_idx; + int input_zp; + float input_scale; + int output_zp; + float output_scale; + int num_output_elements; +}; + +TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data); + +TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data); + +TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data); +TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data); + +void ReduceResolveAxis(const int* axis_data, int axis_count, + MeanParams* op_params); + +TfLiteRegistration Register_MEAN(); +TfLiteRegistration Register_REDUCE_MAX(); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_REDUCE_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce_common.cc new file mode 100644 index 00000000..97452300 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reduce_common.cc @@ -0,0 +1,311 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h" +#include "tensorflow/lite/kernels/internal/reference/reduce.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/reduce.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { + +const int kMaxNumberOfAxis = 4; +const int kMaxNumberOfReducedAxis = 2; + +TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node, + int32_t* multiplier, int* shift) { + MicroContext* micro_context = GetMicroContext(context); + + // Inputs Tensor (dtype depends on quantization): + // [0] = Input + // [1] = Axis + TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + + // Outputs Tensor (dtype depends on quantization): + // [0] = Output + + // Validate number of inputs and outputs + TF_LITE_ENSURE_EQ(context, node->inputs->size, 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Validate axis type + TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1); + TF_LITE_ENSURE(context, axis != nullptr); + TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32); + + if (input->type == kTfLiteInt8) { + TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + const double real_multiplier = static_cast(input->params.scale) / + static_cast(output->params.scale); + QuantizeMultiplier(real_multiplier, multiplier, shift); + micro_context->DeallocateTempTfLiteTensor(output); + } + micro_context->DeallocateTempTfLiteTensor(axis); + micro_context->DeallocateTempTfLiteTensor(input); + return kTfLiteOk; +} + +TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data) { + TF_LITE_ENSURE_OK(context, PrepareSimple(context, node, &op_data->multiplier, + &op_data->shift)); + + MicroContext* micro_context = GetMicroContext(context); + TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1); + + op_data->input_scale = input->params.scale; + op_data->output_scale = output->params.scale; + op_data->num_output_elements = NumElements(output); + + context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size, + &op_data->temp_buffer_idx); + context->RequestScratchBufferInArena( + context, sizeof(int) * static_cast(ElementCount(*axis->dims)), + &op_data->resolved_axis_idx); + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + micro_context->DeallocateTempTfLiteTensor(axis); + return kTfLiteOk; +} + +TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data) { + MicroContext* micro_context = GetMicroContext(context); + TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { + const double real_multiplier = static_cast(input->params.scale) / + static_cast(output->params.scale); + QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift); + } + + int output_size = NumElements(output); + if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { + context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t), + &op_data->temp_buffer_idx); + op_data->input_zp = input->params.zero_point; + op_data->input_scale = input->params.scale; + op_data->output_zp = output->params.zero_point; + op_data->output_scale = output->params.scale; + } + + TF_LITE_ENSURE_OK( + context, + PrepareSimple(context, node, &(op_data->multiplier), &(op_data->shift))); + // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018) + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + return kTfLiteOk; +} + +void ResolveAxis(const int* axis_data, int axis_count, + tflite::MeanParams* op_params) { + int i = 0; + for (; i < axis_count; ++i) { + op_params->axis[i] = static_cast(axis_data[i]); + } + for (; i < 4; ++i) { + op_params->axis[i] = 1; + } + op_params->axis_count = axis_count; +} + +TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data) { + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + TfLiteReducerParams* params = + reinterpret_cast(node->builtin_data); + + int num_axis = static_cast(ElementCount(*axis->dims)); + int temp_index[kMaxNumberOfAxis]; + int resolved_axis[kMaxNumberOfReducedAxis]; + + tflite::MeanParams op_params; + ResolveAxis(tflite::micro::GetTensorData(axis), num_axis, &op_params); + + // Special case mean implementation exists for 4D mean across axes 1 and 2. + bool special_case_4d_axes_1_and_2 = + input->dims->size == 4 && op_params.axis_count == 2 && + ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + + switch (input->type) { + case kTfLiteFloat32: { + // Defer to specialized implementation for 4D Mean across axes 1 & 2. + if (params->keep_dims && special_case_4d_axes_1_and_2) { + reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { + TF_LITE_ENSURE( + context, + reference_ops::Mean( + tflite::micro::GetTensorData(input), input->dims->data, + input->dims->size, tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), num_axis, + params->keep_dims, temp_index, resolved_axis, + tflite::micro::GetTensorData(output))); + } + } break; + case kTfLiteInt8: { + // Defer to specialized implementation for 4D Mean across axes 1 & 2. + if (params->keep_dims && special_case_4d_axes_1_and_2) { + reference_integer_ops::Mean( + op_params, op_data->multiplier, op_data->shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), op_data->input_zp, + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), op_data->output_zp); + } else if (op_data->input_zp == op_data->output_zp && + op_data->input_scale == op_data->output_scale) { + int32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::Mean( + tflite::micro::GetTensorData(input), input->dims->data, + input->dims->size, tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), num_axis, + params->keep_dims, temp_index, resolved_axis, temp_buffer)); + } else { + int32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum( + tflite::micro::GetTensorData(input), op_data->input_zp, + op_data->input_scale, input->dims->data, input->dims->size, + tflite::micro::GetTensorData(output), + op_data->output_zp, op_data->output_scale, output->dims->data, + output->dims->size, tflite::micro::GetTensorData(axis), + num_axis, params->keep_dims, temp_index, resolved_axis, + temp_buffer, false)); + } + } break; + case kTfLiteInt16: { + // Defer to specialized implementation for 4D Mean across axes 1 & 2. + if (params->keep_dims && special_case_4d_axes_1_and_2) { + reference_integer_ops::Mean( + op_params, op_data->multiplier, op_data->shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), op_data->input_zp, + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), op_data->output_zp); + } else if (op_data->input_zp == op_data->output_zp && + op_data->input_scale == op_data->output_scale) { + int32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::Mean(tflite::micro::GetTensorData(input), + input->dims->data, input->dims->size, + tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), + num_axis, params->keep_dims, temp_index, + resolved_axis, temp_buffer)); + } else { + int32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum( + tflite::micro::GetTensorData(input), op_data->input_zp, + op_data->input_scale, input->dims->data, input->dims->size, + tflite::micro::GetTensorData(output), + op_data->output_zp, op_data->output_scale, output->dims->data, + output->dims->size, tflite::micro::GetTensorData(axis), + num_axis, params->keep_dims, temp_index, resolved_axis, + temp_buffer, false)); + } + } break; + default: + TF_LITE_ENSURE_MSG(context, false, + "Currently, only float32, int8 or uint8 input type " + "is supported."); + } + return kTfLiteOk; +} + +TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node, + OpDataReduce* op_data) { + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); + TfLiteReducerParams* params = + static_cast(node->builtin_data); + + // Interpret an axis tensor with null dimensions as a scalar + int num_axis = static_cast(ElementCount(*axis->dims)); + int* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + int* resolved_axis = static_cast( + context->GetScratchBuffer(context, op_data->resolved_axis_idx)); + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE( + context, + reference_ops::ReduceGeneric( + tflite::micro::GetTensorData(input), input->dims->data, + input->dims->size, tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), num_axis, + params->keep_dims, temp_buffer, resolved_axis, + std::numeric_limits::lowest(), + [](const float current, const float in) -> float { + return (in > current) ? in : current; + })); + break; + case kTfLiteInt8: + TF_LITE_ENSURE_EQ(context, static_cast(op_data->input_scale), + static_cast(op_data->output_scale)); + TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp); + TF_LITE_ENSURE( + context, + reference_ops::ReduceGeneric( + tflite::micro::GetTensorData(input), input->dims->data, + input->dims->size, tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), num_axis, + params->keep_dims, temp_buffer, resolved_axis, + std::numeric_limits::lowest(), + [](const int8_t current, const int8_t in) -> int8_t { + return (in > current) ? in : current; + })); + break; + default: + MicroPrintf("Only float32 and int8 types are supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/reshape.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reshape.cc index d14ed82e..832ba261 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/reshape.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/reshape.cc @@ -110,14 +110,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace reshape TfLiteRegistration Register_RESHAPE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/reshape::Prepare, - /*invoke=*/reshape::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, reshape::Prepare, reshape::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_bilinear.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_bilinear.cc index 55c23846..a90057b9 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_bilinear.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_bilinear.cc @@ -111,14 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_RESIZE_BILINEAR() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc index a02159af..ce507445 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc @@ -117,14 +117,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace resize_nearest_neighbor TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/resize_nearest_neighbor::Prepare, - /*invoke=*/resize_nearest_neighbor::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, resize_nearest_neighbor::Prepare, + resize_nearest_neighbor::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/round.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/round.cc index 76d8e6bf..0bda8783 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/round.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/round.cc @@ -68,14 +68,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace round TfLiteRegistration Register_ROUND() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/round::Prepare, - /*invoke=*/round::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, round::Prepare, round::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/shape.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/shape.cc index df962f62..02f663a8 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/shape.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/shape.cc @@ -60,14 +60,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_SHAPE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/slice.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/slice.cc index 40d9fdd7..212cf47f 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/slice.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/slice.cc @@ -151,14 +151,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_SLICE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.cc index f6a30010..c2cee3c5 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.cc @@ -83,14 +83,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_SOFTMAX() { - return {/*init=*/SoftmaxInit, - /*free=*/nullptr, - /*prepare=*/SoftmaxPrepare, - /*invoke=*/SoftmaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(SoftmaxInit, SoftmaxPrepare, SoftmaxEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.h index 8d605eab..7096d202 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,13 @@ namespace tflite { void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length); +// Common helper function to SoftmaxPrepare. +TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + SoftmaxParams* op_data); + TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node); // This is the most generic TfLiteRegistration. The actual supported types may @@ -30,7 +37,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node); // (reference or optimized) must define this function. TfLiteRegistration Register_SOFTMAX(); -#if defined(XTENSA) +#if defined(XTENSA) || defined(CMSIS_NN) // Returns a TfLiteRegistration struct for kernel variant that only supports // int8 input and int16 output. TfLiteRegistration Register_SOFTMAX_INT8_INT16(); @@ -40,6 +47,23 @@ inline TfLiteRegistration Register_SOFTMAX_INT8_INT16() { } #endif +#if defined(CMSIS_NN) +// Returns a TfLiteRegistration struct for kernel variant that only supports +// int8 input/output and uses the latency optimized implementations. +TfLiteRegistration Register_SOFTMAX_INT8(); + +// Returns a TfLiteRegistration struct for kernel variant that only supports +// int16 input/output and uses the latency optimized implementations. +TfLiteRegistration Register_SOFTMAX_INT16(); + +#else +inline TfLiteRegistration Register_SOFTMAX_INT8() { return Register_SOFTMAX(); } + +inline TfLiteRegistration Register_SOFTMAX_INT16() { + return Register_SOFTMAX(); +} +#endif + } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_KERNELS_SOFTMAX_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax_common.cc index d93f5f26..b5378dae 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax_common.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/softmax_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,11 +28,59 @@ namespace { // Softmax parameter data that persists in user_data const int kInt16LUTArraySize = 513; +TfLiteStatus InitializeLutForInt16(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + SoftmaxParams* op_data) { + // Only allocate LUTs for KTfLiteInt16 data type + if (input->type == kTfLiteInt16) { + void* raw_exp_lut = context->AllocatePersistentBuffer( + context, sizeof(int16_t) * kInt16LUTArraySize); + TF_LITE_ENSURE(context, raw_exp_lut != nullptr); + op_data->exp_lut = reinterpret_cast(raw_exp_lut); + void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer( + context, sizeof(int16_t) * kInt16LUTArraySize); + TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr); + op_data->one_over_one_plus_x_lut = + reinterpret_cast(one_over_one_plus_x_lut); + } + + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE(context, + input->type == kTfLiteInt8 || input->type == kTfLiteInt16); + } else { + TF_LITE_ENSURE_EQ(context, input->type, output->type); + } + + // Populate LUT if required + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + // exp LUT only used on negative values + // we consider exp(-10.0) is insignificant to accumulation + gen_lut( + [](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f, + op_data->exp_lut); + gen_lut( + [](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f, + 1.0f, op_data->one_over_one_plus_x_lut); + op_data->zero_point = output->params.zero_point; + op_data->scale = output->params.scale; + } + + return kTfLiteOk; +} + +} // namespace + TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, const TfLiteSoftmaxParams* params, SoftmaxParams* op_data) { + if (InitializeLutForInt16(context, input, output, op_data) != kTfLiteOk) { + return kTfLiteError; + } + if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { if (input->type == kTfLiteInt16) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -83,8 +131,6 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, return kTfLiteOk; } -} // namespace - void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams)); @@ -103,40 +149,6 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, node->user_data != nullptr); SoftmaxParams* op_data = static_cast(node->user_data); - // Only allocate LUTs for KTfLiteInt16 data type - if (input->type == kTfLiteInt16) { - void* raw_exp_lut = context->AllocatePersistentBuffer( - context, sizeof(int16_t) * kInt16LUTArraySize); - TF_LITE_ENSURE(context, raw_exp_lut != nullptr); - op_data->exp_lut = reinterpret_cast(raw_exp_lut); - void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer( - context, sizeof(int16_t) * kInt16LUTArraySize); - TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr); - op_data->one_over_one_plus_x_lut = - reinterpret_cast(one_over_one_plus_x_lut); - } - - if (output->type == kTfLiteInt16) { - TF_LITE_ENSURE(context, - input->type == kTfLiteInt8 || input->type == kTfLiteInt16); - } else { - TF_LITE_ENSURE_EQ(context, input->type, output->type); - } - - // Populate LUT if required - if (input->type == kTfLiteInt16) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - // exp LUT only used on negative values - // we consider exp(-10.0) is insignificant to accumulation - gen_lut( - [](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f, - op_data->exp_lut); - gen_lut( - [](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f, - 1.0f, op_data->one_over_one_plus_x_lut); - op_data->zero_point = output->params.zero_point; - op_data->scale = output->params.scale; - } auto* params = static_cast(node->builtin_data); auto ret_val = diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_batch_nd.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_batch_nd.cc index 4e01becb..21f81312 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_batch_nd.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_batch_nd.cc @@ -114,14 +114,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_SPACE_TO_BATCH_ND() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_depth.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_depth.cc index 9c0cc445..30519b27 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_depth.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/space_to_depth.cc @@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_SPACE_TO_DEPTH() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/split.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/split.cc index 5d90d983..06584d45 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/split.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/split.cc @@ -120,14 +120,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace split TfLiteRegistration Register_SPLIT() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/split::Prepare, - /*invoke=*/split::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, split::Prepare, split::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/split_v.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/split_v.cc index c1c41124..3ea35130 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/split_v.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/split_v.cc @@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace split_v TfLiteRegistration Register_SPLIT_V() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/split_v::Prepare, - /*invoke=*/split_v::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, split_v::Prepare, split_v::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/squared_difference.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/squared_difference.cc new file mode 100644 index 00000000..ca924e26 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/squared_difference.cc @@ -0,0 +1,247 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/binary_function.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" + +namespace tflite { +namespace { +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +struct OpData { + bool requires_broadcast; + ArithmeticParams arithmetic_params; +}; + +template +T SquaredDifference(T input1, T input2) { + const T difference = input1 - input2; + return difference * difference; +} + +void* SquaredDifferenceInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus SquaredDifferencePrepare(TfLiteContext* context, + TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = reinterpret_cast(node->user_data); + data->requires_broadcast = false; + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MicroContext* micro_context = GetMicroContext(context); + + TfLiteTensor* input1 = + micro_context->AllocateTempInputTensor(node, kInputTensor1); + TF_LITE_ENSURE(context, input1 != nullptr); + TfLiteTensor* input2 = + micro_context->AllocateTempInputTensor(node, kInputTensor2); + TF_LITE_ENSURE(context, input2 != nullptr); + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + + TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); + output->type = input2->type; + + // Ensure the quantization parameters are equivalent. + if (input1->type == kTfLiteInt8) { + const auto& input1_quantization_params = input1->params; + const auto& input2_quantization_params = input2->params; + const auto& output_quantization_params = output->params; + const int32_t integer_type_min = std::numeric_limits::min(); + const int32_t integer_type_max = std::numeric_limits::max(); + TF_LITE_ENSURE(context, + input1_quantization_params.zero_point >= integer_type_min); + TF_LITE_ENSURE(context, + input1_quantization_params.zero_point <= integer_type_max); + TF_LITE_ENSURE(context, + input2_quantization_params.zero_point >= integer_type_min); + TF_LITE_ENSURE(context, + input2_quantization_params.zero_point <= integer_type_max); + TF_LITE_ENSURE(context, + output_quantization_params.zero_point >= integer_type_min); + TF_LITE_ENSURE(context, + output_quantization_params.zero_point <= integer_type_max); + data->arithmetic_params.input1_offset = + -input1_quantization_params.zero_point; + data->arithmetic_params.input2_offset = + -input2_quantization_params.zero_point; + data->arithmetic_params.output_offset = + output_quantization_params.zero_point; + + // shift to make integer for scales. + // 7 is selected so that maximum shifted result 255^2 * (1 << (7 * 2 )) + // does not overflow signed 32-bit integer + data->arithmetic_params.left_shift = 7; + const double twice_max_input_scale = + 2.0 * static_cast(std::max(input1_quantization_params.scale, + input2_quantization_params.scale)); + const double real_input1_multiplier = + static_cast(input1_quantization_params.scale) / + twice_max_input_scale; + double real_input2_multiplier = + static_cast(input2_quantization_params.scale) / + twice_max_input_scale; + const double real_output_multiplier = + (twice_max_input_scale * twice_max_input_scale) / + static_cast((1 << data->arithmetic_params.left_shift * 2) * + output_quantization_params.scale); + QuantizeMultiplierSmallerThanOneExp( + real_input1_multiplier, &data->arithmetic_params.input1_multiplier, + &data->arithmetic_params.input1_shift); + QuantizeMultiplierSmallerThanOneExp( + real_input2_multiplier, &data->arithmetic_params.input2_multiplier, + &data->arithmetic_params.input2_shift); + QuantizeMultiplierSmallerThanOneExp( + real_output_multiplier, &data->arithmetic_params.output_multiplier, + &data->arithmetic_params.output_shift); + data->arithmetic_params.quantized_activation_min = + std::numeric_limits::min(); + data->arithmetic_params.quantized_activation_max = + std::numeric_limits::max(); + } + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + micro_context->DeallocateTempTfLiteTensor(input1); + micro_context->DeallocateTempTfLiteTensor(input2); + micro_context->DeallocateTempTfLiteTensor(output); + return kTfLiteOk; +} + +inline int8_t SquaredDifference(int8_t x, int8_t y, + const ArithmeticParams& params) { + const int32_t input1_val = params.input1_offset + x; + const int32_t input2_val = params.input2_offset + y; + const int32_t shifted_input1_val = input1_val * (1 << params.left_shift); + const int32_t shifted_input2_val = input2_val * (1 << params.left_shift); + const int32_t scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, params.input1_shift); + const int32_t scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, params.input2_shift); + const int32_t raw_diff = scaled_input1_val - scaled_input2_val; + + // Max of this is 255^2 * (1 << 14), so won't overflow 32 bits. + const int32_t squared_raw_diff = raw_diff * raw_diff; + const int32_t raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + squared_raw_diff, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32_t clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + return static_cast(clamped_output); +} + +template +void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node, + const OpData* data, + const TfLiteEvalTensor* input1, + const TfLiteEvalTensor* input2, + TfLiteEvalTensor* output) { + const auto* op_data = static_cast(node->user_data); + if (data->requires_broadcast) { + reference_integer_ops::BroadcastBinaryFunction4DSlow( + op_data->arithmetic_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + reference_integer_ops::CheckArithmeticParams, SquaredDifference); + } else { + const int flat_size = tflite::micro::GetTensorShape(input1).FlatSize(); + reference_integer_ops::ElementWise( + flat_size, op_data->arithmetic_params, + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorData(output), + reference_integer_ops::CheckArithmeticParams, SquaredDifference); + } +} + +template +void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node, + const OpData* data, const TfLiteEvalTensor* input1, + const TfLiteEvalTensor* input2, + TfLiteEvalTensor* output) { + if (data->requires_broadcast) { + reference_ops::BroadcastBinaryFunction4DSlow( + tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), SquaredDifference); + } else { + reference_ops::BinaryFunction( + tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), SquaredDifference); + } +} + +TfLiteStatus SquaredDifferenceEval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalSquaredDifference(context, node, data, input1, input2, output); + } else if (output->type == kTfLiteInt32) { + EvalSquaredDifference(context, node, data, input1, input2, output); + } else if (output->type == kTfLiteInt8) { + EvalQuantizedSquaredDifference(context, node, data, input1, input2, + output); + } else { + MicroPrintf( + "SquaredDifference only supports FLOAT32, INT32 and INT8 now, got %d.", + output->type); + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration Register_SQUARED_DIFFERENCE() { + return tflite::micro::RegisterOp( + SquaredDifferenceInit, SquaredDifferencePrepare, SquaredDifferenceEval); +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/squeeze.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/squeeze.cc index 0eb767db..e81b5b56 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/squeeze.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/squeeze.cc @@ -111,14 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_SQUEEZE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/strided_slice.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/strided_slice.cc index d5b73b8f..832e2ccd 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/strided_slice.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/strided_slice.cc @@ -193,14 +193,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace strided_slice TfLiteRegistration Register_STRIDED_SLICE() { - return {/*init=*/strided_slice::Init, - /*free=*/nullptr, - /*prepare=*/strided_slice::Prepare, - /*invoke=*/strided_slice::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(strided_slice::Init, strided_slice::Prepare, + strided_slice::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/sub.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/sub.cc index de99149f..40bddbad 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/sub.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/sub.cc @@ -162,14 +162,7 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration Register_SUB() { - return {/*init=*/SubInit, - /*free=*/nullptr, - /*prepare=*/SubPrepare, - /*invoke=*/SubEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(SubInit, SubPrepare, SubEval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/svdf.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/svdf.cc index f8a2bed2..5994db94 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/svdf.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/svdf.cc @@ -100,14 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_SVDF() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/PrepareSvdf, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, PrepareSvdf, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/tanh.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/tanh.cc index a9f01c71..e97a9035 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/tanh.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/tanh.cc @@ -195,14 +195,8 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { } // namespace activations TfLiteRegistration Register_TANH() { - return {/*init=*/activations::TanhInit, - /*free=*/nullptr, - /*prepare=*/activations::TanhPrepare, - /*invoke=*/activations::TanhEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp( + activations::TanhInit, activations::TanhPrepare, activations::TanhEval); } } // namespace micro } // namespace ops diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose.cc index ba3d6e94..9f77e04d 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose.cc @@ -116,13 +116,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_TRANSPOSE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose_conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose_conv.cc index dcf007c5..0b2afd5b 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose_conv.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/transpose_conv.cc @@ -266,7 +266,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -282,7 +282,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -293,7 +293,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { context->GetScratchBuffer(context, data.scratch_buffer_index)); // TODO(b/192090531): Remove this once all 8x16 transpose conv models use // 64-bit biases. - if (bias->type == kTfLiteInt16) { + if (bias != nullptr && bias->type == kTfLiteInt16) { std::int64_t* bias_converted_buffer = static_cast(context->GetScratchBuffer( context, data.bias_converted_buffer_index)); @@ -319,7 +319,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -337,14 +337,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_TRANSPOSE_CONV() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.cc new file mode 100644 index 00000000..7f3c50e4 --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.cc @@ -0,0 +1,1696 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/lstm_eval.h" +#include "tensorflow/lite/micro/kernels/lstm_shared.h" +#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" + +namespace tflite { + +namespace { + +constexpr int scratch_index_size = 12; + +struct UnidirectionalSequenceLstmOpData { + // If the lstm is layer norm. + bool use_layer_norm; + // The scratch index. + int scratch_index[scratch_index_size]; + + int32_t row_sums_size; + int32_t* row_sums; + bool compute_row_sums = false; + + int32_t input_zero_point; + int32_t output_state_zero_point; + + IntegerLstmParameter integer_lstm_param; + HybridLstmScales hybrid_lstm_scales; +}; + +TfLiteStatus PopulateQuantizedLstmParams8x8_16( + TfLiteContext* context, TfLiteNode* node, + IntegerLstmParameter* integer_lstm_param) { + MicroContext* micro_context = GetMicroContext(context); + + // Calculate quantized clip for projection and cell. + const auto* params = + static_cast(node->builtin_data); + const float cell_clip = params->cell_clip; + const float proj_clip = params->proj_clip; + + TfLiteTensor* cell_state = + micro_context->AllocateTempInputTensor(node, kLstmCellStateTensor); + TF_LITE_ENSURE(context, cell_state != nullptr); + TF_LITE_ENSURE(context, cell_state->is_variable); + TfLiteTensor* output_tensor = + micro_context->AllocateTempOutputTensor(node, kLstmOutputTensor); + + TF_LITE_ENSURE(context, + cell_state->quantization.type != kTfLiteNoQuantization); + auto* cell_state_params = + static_cast(cell_state->quantization.params); + TF_LITE_ENSURE(context, + output_tensor->quantization.type != kTfLiteNoQuantization); + auto* proj_params = static_cast( + output_tensor->quantization.params); + if (cell_clip > 0.0f) { + integer_lstm_param->quantized_cell_clip = static_cast(std::min( + std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f), + 32767.0f)); + } else { + integer_lstm_param->quantized_cell_clip = 0; + } + if (proj_clip > 0.0f) { + integer_lstm_param->quantized_proj_clip = static_cast(std::min( + std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f)); + } else { + integer_lstm_param->quantized_proj_clip = 0; + } + + // Calculate effective scales. + UnidirectionalSequenceLstmOpData* op_data = + static_cast(node->user_data); + const bool use_layer_norm = op_data->use_layer_norm; + + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kLstmInputTensor); + + TfLiteTensor* input_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + micro_context->AllocateTempInputTensor(node, + kLstmRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToOutputWeightsTensor); + + TfLiteTensor* input_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmInputLayerNormCoefficientsTensor); + TfLiteTensor* forget_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmForgetLayerNormCoefficientsTensor); + TfLiteTensor* cell_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmCellLayerNormCoefficientsTensor); + TfLiteTensor* output_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmOutputLayerNormCoefficientsTensor); + + TfLiteTensor* projection_weights = micro_context->AllocateTempInputTensor( + node, kLstmProjectionWeightsTensor); + + TfLiteTensor* output_state = + micro_context->AllocateTempInputTensor(node, kLstmOutputStateTensor); + TF_LITE_ENSURE(context, output_state != nullptr); + TF_LITE_ENSURE(context, output_state->is_variable); + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + const bool use_projection = (projection_weights != nullptr); + + // Get intermediate scales and zero points. + float intermediate_scale[5]; + int32_t intermediate_zp[5]; + for (int i = 0; i < 4; ++i) { + if (use_layer_norm) { + TfLiteTensor* intermediate = + micro_context->AllocateTempIntermediateTensor(node, i); + TF_LITE_ENSURE(context, + intermediate->quantization.type != kTfLiteNoQuantization); + auto* params_intermediate = static_cast( + intermediate->quantization.params); + intermediate_scale[i] = params_intermediate->scale->data[0]; + intermediate_zp[i] = params_intermediate->zero_point->data[0]; + if (intermediate != nullptr) { + micro_context->DeallocateTempTfLiteTensor(intermediate); + } + } else { + // Q3.12 for activation functions. + intermediate_scale[i] = std::pow(2.0f, -12.0f); + intermediate_zp[i] = 0; + } + } + // In the absence of projection, hidden becomes otuput and this intermediate + // is ignored. + TfLiteTensor* hidden = micro_context->AllocateTempIntermediateTensor(node, 4); + TF_LITE_ENSURE(context, hidden->quantization.type != kTfLiteNoQuantization); + auto* hidden_params = + static_cast(hidden->quantization.params); + intermediate_scale[4] = hidden_params->scale->data[0]; + intermediate_zp[4] = hidden_params->zero_point->data[0]; + if (hidden != nullptr) { + micro_context->DeallocateTempTfLiteTensor(hidden); + } + + // Scales. + const float default_scale = 1.0; + float input_scale = default_scale; + float input_to_input_weight_scale = default_scale; + float recurrent_to_input_weight_scale = default_scale; + float cell_to_input_weight_scale = default_scale; + float input_to_forget_weight_scale = default_scale; + float recurrent_to_forget_weight_scale = default_scale; + float cell_to_forget_weight_scale = default_scale; + float input_to_cell_weight_scale = default_scale; + float recurrent_to_cell_weight_scale = default_scale; + float input_to_output_weight_scale = default_scale; + float recurrent_to_output_weight_scale = default_scale; + float cell_to_output_weight_scale = default_scale; + float projection_weight_scale = default_scale; + float layer_norm_input_scale = default_scale; + float layer_norm_forget_scale = default_scale; + float layer_norm_cell_scale = default_scale; + float layer_norm_output_scale = default_scale; + float output_state_scale = default_scale; + int cell_scale = 1; + + // Effective scales. + float effective_input_to_input_scale = default_scale; + float effective_recurrent_to_input_scale = default_scale; + float effective_cell_to_input_scale = default_scale; + float effective_input_to_forget_scale = default_scale; + float effective_recurrent_to_forget_scale = default_scale; + float effective_cell_to_forget_scale = default_scale; + float effective_input_to_cell_scale = default_scale; + float effective_recurrent_to_cell_scale = default_scale; + float effective_input_to_output_scale = default_scale; + float effective_recurrent_to_output_scale = default_scale; + float effective_cell_to_output_scale = default_scale; + float effective_proj_scale = default_scale; + float effective_hidden_scale = default_scale; + + // Populate scales. + if (!use_cifg) { + input_to_input_weight_scale = input_to_input_weights->params.scale; + recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale; + } + + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weight_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weight_scale = cell_to_forget_weights->params.scale; + cell_to_output_weight_scale = cell_to_output_weights->params.scale; + } + + if (use_layer_norm) { + if (!use_cifg) { + layer_norm_input_scale = input_layer_norm_coefficients->params.scale; + } + layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale; + layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale; + layer_norm_output_scale = output_layer_norm_coefficients->params.scale; + } + + if (use_projection) { + projection_weight_scale = projection_weights->params.scale; + } + output_state_scale = output_state->params.scale; + + input_to_forget_weight_scale = input_to_forget_weights->params.scale; + input_to_cell_weight_scale = input_to_cell_weights->params.scale; + input_to_output_weight_scale = input_to_output_weights->params.scale; + recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale; + recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale; + recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale; + + // Check cell state (already used above) + TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale)); + // TF_LITE_ENSURE(context, cell_scale <= -9); + integer_lstm_param->cell_scale = cell_scale; + input_scale = input->params.scale; + + // Calculate effective scales. + if (!use_cifg) { + effective_input_to_input_scale = + input_to_input_weight_scale * input_scale / intermediate_scale[0]; + effective_recurrent_to_input_scale = recurrent_to_input_weight_scale * + output_state_scale / + intermediate_scale[0]; + } + effective_input_to_forget_scale = + input_to_forget_weight_scale * input_scale / intermediate_scale[1]; + effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale * + output_state_scale / + intermediate_scale[1]; + + effective_input_to_cell_scale = + input_to_cell_weight_scale * input_scale / intermediate_scale[2]; + effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale * + output_state_scale / + intermediate_scale[2]; + + effective_input_to_output_scale = + input_to_output_weight_scale * input_scale / intermediate_scale[3]; + effective_recurrent_to_output_scale = recurrent_to_output_weight_scale * + output_state_scale / + intermediate_scale[3]; + + effective_hidden_scale = + std::pow(2.0f, -15.0f) / intermediate_scale[4] * std::pow(2.0f, -15.0f); + + effective_proj_scale = + projection_weight_scale * intermediate_scale[4] / output_state_scale; + + if (use_peephole) { + if (!use_cifg) { + effective_cell_to_input_scale = + std::pow(2.0f, static_cast(cell_scale)) * + cell_to_input_weight_scale / intermediate_scale[0]; + } + effective_cell_to_forget_scale = + std::pow(2.0f, static_cast(cell_scale)) * + cell_to_forget_weight_scale / intermediate_scale[1]; + effective_cell_to_output_scale = + std::pow(2.0f, static_cast(cell_scale)) * + cell_to_output_weight_scale / intermediate_scale[3]; + } + + // Decompose scales. + int shift_output; + QuantizeMultiplier(static_cast(effective_input_to_input_scale), + &integer_lstm_param->effective_input_to_input_scale_a, + &shift_output); + integer_lstm_param->effective_input_to_input_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_recurrent_to_input_scale), + &integer_lstm_param->effective_recurrent_to_input_scale_a, + &shift_output); + integer_lstm_param->effective_recurrent_to_input_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_cell_to_input_scale), + &integer_lstm_param->effective_cell_to_input_scale_a, + &shift_output); + integer_lstm_param->effective_cell_to_input_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_input_to_forget_scale), + &integer_lstm_param->effective_input_to_forget_scale_a, + &shift_output); + integer_lstm_param->effective_input_to_forget_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_recurrent_to_forget_scale), + &integer_lstm_param->effective_recurrent_to_forget_scale_a, + &shift_output); + integer_lstm_param->effective_recurrent_to_forget_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_cell_to_forget_scale), + &integer_lstm_param->effective_cell_to_forget_scale_a, + &shift_output); + integer_lstm_param->effective_cell_to_forget_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_input_to_cell_scale), + &integer_lstm_param->effective_input_to_cell_scale_a, + &shift_output); + integer_lstm_param->effective_input_to_cell_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_recurrent_to_cell_scale), + &integer_lstm_param->effective_recurrent_to_cell_scale_a, + &shift_output); + integer_lstm_param->effective_recurrent_to_cell_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_input_to_output_scale), + &integer_lstm_param->effective_input_to_output_scale_a, + &shift_output); + integer_lstm_param->effective_input_to_output_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_recurrent_to_output_scale), + &integer_lstm_param->effective_recurrent_to_output_scale_a, + &shift_output); + integer_lstm_param->effective_recurrent_to_output_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_cell_to_output_scale), + &integer_lstm_param->effective_cell_to_output_scale_a, + &shift_output); + integer_lstm_param->effective_cell_to_output_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_proj_scale), + &integer_lstm_param->effective_proj_scale_a, + &shift_output); + integer_lstm_param->effective_proj_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(effective_hidden_scale), + &integer_lstm_param->effective_hidden_scale_a, + &shift_output); + integer_lstm_param->effective_hidden_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(layer_norm_input_scale), + &integer_lstm_param->layer_norm_input_scale_a, + &shift_output); + integer_lstm_param->layer_norm_input_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(layer_norm_forget_scale), + &integer_lstm_param->layer_norm_forget_scale_a, + &shift_output); + integer_lstm_param->layer_norm_forget_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(layer_norm_cell_scale), + &integer_lstm_param->layer_norm_cell_scale_a, + &shift_output); + integer_lstm_param->layer_norm_cell_scale_b = + static_cast(shift_output); + QuantizeMultiplier(static_cast(layer_norm_output_scale), + &integer_lstm_param->layer_norm_output_scale_a, + &shift_output); + integer_lstm_param->layer_norm_output_scale_b = + static_cast(shift_output); + + integer_lstm_param->hidden_zp = intermediate_zp[4]; + + // 10000 is used to make sure the kernel logic does not overflow. + if (!use_cifg) { + integer_lstm_param->input_variance_guard = + std::max(1, static_cast(10000 * layer_norm_input_scale)); + } + integer_lstm_param->forget_variance_guard = + std::max(1, static_cast(10000 * layer_norm_forget_scale)); + integer_lstm_param->cell_variance_guard = + std::max(1, static_cast(10000 * layer_norm_cell_scale)); + integer_lstm_param->output_variance_guard = + std::max(1, static_cast(10000 * layer_norm_output_scale)); + + if (cell_state != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_state); + } + if (output_tensor != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_tensor); + } + if (input != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input); + } + if (input_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_input_weights); + } + if (input_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_forget_weights); + } + if (input_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_cell_weights); + } + if (input_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_output_weights); + } + if (recurrent_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_input_weights); + } + if (recurrent_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_forget_weights); + } + if (recurrent_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_cell_weights); + } + if (recurrent_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_output_weights); + } + if (cell_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_input_weights); + } + if (cell_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_forget_weights); + } + if (cell_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_output_weights); + } + if (input_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_layer_norm_coefficients); + } + if (forget_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(forget_layer_norm_coefficients); + } + if (cell_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_layer_norm_coefficients); + } + if (output_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_layer_norm_coefficients); + } + if (projection_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_weights); + } + if (output_state != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_state); + } + + return kTfLiteOk; +} + +// Temporary buffers used for hybrid mode +enum HybridTempBuffer { + kPrimaryScratchBuffer = 0, + kInputQuantized = 1, + kOutputStateQuantized = 2, + kCellStateQuantized = 3, + kInputScalingFactors = 4, + kOutputStateScalingFactors = 5, + kProductScalingFactors = 6, + kRecoveredCellWeights = 7, + kAccumScratch = 8, + kInputZeroPoints = 9, + kOutputStateZeroPoints = 10, + kScales = 11, + kNumHybridTempBuffers = 12, +}; + +void* UnidirectionalSequenceLstmInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer( + context, sizeof(UnidirectionalSequenceLstmOpData)); +} + +// Check that input tensor dimensions matches with each other. +TfLiteStatus SetHybridScales(TfLiteContext* context, TfLiteNode* node) { + UnidirectionalSequenceLstmOpData* op_data = + reinterpret_cast(node->user_data); + MicroContext* micro_context = GetMicroContext(context); + + TfLiteTensor* input_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToInputWeightsTensor); + op_data->hybrid_lstm_scales.input_to_input_weights_scale = + (input_to_input_weights != nullptr) ? input_to_input_weights->params.scale + : 1.0f; + + TfLiteTensor* input_to_forget_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToForgetWeightsTensor); + op_data->hybrid_lstm_scales.input_to_forget_weights_scale = + (input_to_forget_weights != nullptr) + ? input_to_forget_weights->params.scale + : 1.0f; + + TfLiteTensor* input_to_cell_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToCellWeightsTensor); + op_data->hybrid_lstm_scales.input_to_cell_weights_scale = + (input_to_cell_weights != nullptr) ? input_to_cell_weights->params.scale + : 1.0f; + + TfLiteTensor* input_to_output_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToOutputWeightsTensor); + op_data->hybrid_lstm_scales.input_to_output_weights_scale = + (input_to_output_weights != nullptr) + ? input_to_output_weights->params.scale + : 1.0f; + + op_data->hybrid_lstm_scales.aux_input_to_input_weights_scale = 1.0f; + op_data->hybrid_lstm_scales.aux_input_to_forget_weights_scale = 1.0f; + op_data->hybrid_lstm_scales.aux_input_to_cell_weights_scale = 1.0f; + op_data->hybrid_lstm_scales.aux_input_to_output_weights_scale = 1.0f; + + TfLiteTensor* recurrent_to_input_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToInputWeightsTensor); + op_data->hybrid_lstm_scales.recurrent_to_input_weights_scale = + (recurrent_to_input_weights != nullptr) + ? recurrent_to_input_weights->params.scale + : 1.0f; + + TfLiteTensor* recurrent_to_forget_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToForgetWeightsTensor); + op_data->hybrid_lstm_scales.recurrent_to_forget_weights_scale = + (recurrent_to_forget_weights != nullptr) + ? recurrent_to_forget_weights->params.scale + : 1.0f; + + TfLiteTensor* recurrent_to_cell_weights = + micro_context->AllocateTempInputTensor(node, + kLstmRecurrentToCellWeightsTensor); + op_data->hybrid_lstm_scales.recurrent_to_cell_weights_scale = + (recurrent_to_cell_weights != nullptr) + ? recurrent_to_cell_weights->params.scale + : 1.0f; + + TfLiteTensor* recurrent_to_output_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToOutputWeightsTensor); + op_data->hybrid_lstm_scales.recurrent_to_output_weights_scale = + (recurrent_to_output_weights != nullptr) + ? recurrent_to_output_weights->params.scale + : 1.0f; + + TfLiteTensor* cell_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToInputWeightsTensor); + op_data->hybrid_lstm_scales.cell_to_input_weights_scale = + (cell_to_input_weights != nullptr) ? cell_to_input_weights->params.scale + : 1.0f; + + TfLiteTensor* cell_to_forget_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToForgetWeightsTensor); + op_data->hybrid_lstm_scales.cell_to_forget_weights_scale = + (cell_to_forget_weights != nullptr) ? cell_to_forget_weights->params.scale + : 1.0f; + + TfLiteTensor* cell_to_output_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToOutputWeightsTensor); + op_data->hybrid_lstm_scales.cell_to_output_weights_scale = + (cell_to_output_weights != nullptr) ? cell_to_output_weights->params.scale + : 1.0f; + + TfLiteTensor* projection_weights = micro_context->AllocateTempInputTensor( + node, kLstmProjectionWeightsTensor); + op_data->hybrid_lstm_scales.projection_weights_scale = + (projection_weights != nullptr) ? projection_weights->params.scale : 1.0f; + + if (input_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_input_weights); + } + + if (input_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_forget_weights); + } + + if (input_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_cell_weights); + } + + if (input_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_output_weights); + } + + if (recurrent_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_input_weights); + } + + if (recurrent_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_forget_weights); + } + + if (recurrent_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_cell_weights); + } + + if (recurrent_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_output_weights); + } + + if (cell_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_input_weights); + } + + if (cell_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_forget_weights); + } + + if (cell_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_output_weights); + } + + if (projection_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_weights); + } + + return kTfLiteOk; +} + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell, + bool use_layer_norm, bool is_integer) { + MicroContext* micro_context = GetMicroContext(context); + + const auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToInputWeightsTensor); + if (input_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + micro_context->AllocateTempInputTensor(node, + kLstmRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToInputWeightsTensor); + if (cell_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_TYPES_EQ( + context, cell_to_input_weights->type, + is_integer ? kTfLiteInt16 : input_to_forget_weights->type); + } + + TfLiteTensor* cell_to_forget_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToForgetWeightsTensor); + if (cell_to_forget_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_TYPES_EQ( + context, cell_to_forget_weights->type, + is_integer ? kTfLiteInt16 : input_to_forget_weights->type); + } + + TfLiteTensor* cell_to_output_weights = micro_context->AllocateTempInputTensor( + node, kLstmCellToOutputWeightsTensor); + if (cell_to_output_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_TYPES_EQ( + context, cell_to_output_weights->type, + is_integer ? kTfLiteInt16 : input_to_forget_weights->type); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + micro_context->AllocateTempInputTensor(node, kLstmInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32); + } + } + + TfLiteTensor* forget_gate_bias = + micro_context->AllocateTempInputTensor(node, kLstmForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); + } + + TfLiteTensor* cell_gate_bias = + micro_context->AllocateTempInputTensor(node, kLstmCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32); + } + + TfLiteTensor* output_gate_bias = + micro_context->AllocateTempInputTensor(node, kLstmOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32); + } + + TfLiteTensor* projection_weights = micro_context->AllocateTempInputTensor( + node, kLstmProjectionWeightsTensor); + if (projection_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + micro_context->AllocateTempInputTensor(node, kLstmProjectionBiasTensor); + if (projection_bias != nullptr) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32); + } + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + if (use_layer_norm) { + TfLiteTensor* input_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmInputLayerNormCoefficientsTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + } else { + TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], + n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, + kTfLiteFloat32); + } + } + + TfLiteTensor* forget_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmForgetLayerNormCoefficientsTensor); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], + n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, + kTfLiteFloat32); + } + + TfLiteTensor* cell_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmCellLayerNormCoefficientsTensor); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], + n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, + kTfLiteFloat32); + } + + TfLiteTensor* output_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmOutputLayerNormCoefficientsTensor); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], + n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, + kTfLiteFloat32); + } + if (input_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_layer_norm_coefficients); + } + if (forget_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(forget_layer_norm_coefficients); + } + if (cell_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_layer_norm_coefficients); + } + if (output_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_layer_norm_coefficients); + } + } + + if (input_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_input_weights); + } + if (input_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_forget_weights); + } + if (input_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_cell_weights); + } + if (recurrent_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_input_weights); + } + if (recurrent_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_forget_weights); + } + micro_context->DeallocateTempTfLiteTensor(recurrent_to_cell_weights); + if (cell_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_input_weights); + } + if (cell_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_forget_weights); + } + if (cell_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_to_output_weights); + } + if (input_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_gate_bias); + } + if (forget_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(forget_gate_bias); + } + if (cell_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_gate_bias); + } + if (output_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_gate_bias); + } + if (projection_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_weights); + } + if (projection_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_bias); + } + + return kTfLiteOk; +} + +TfLiteStatus PrecomputeZeroPointTimesWeightWithBias( + TfLiteContext* context, int32_t zero_point, + const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor, + int32_t** output) { + if (weight_tensor == nullptr) { + return kTfLiteOk; + } + + const RuntimeShape& weight_shape = GetTensorShape(weight_tensor); + TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2); + const int row = weight_shape.Dims(0); + const int col = weight_shape.Dims(1); + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + *output = static_cast( + context->AllocatePersistentBuffer(context, row * sizeof(int32_t))); + + if (bias_tensor == nullptr) { + memset(*output, 0, row * sizeof(int32_t)); + } else { + const int32_t* bias = GetTensorData(bias_tensor); + memcpy(*output, bias, row * sizeof(int32_t)); + } + if (zero_point != 0) { + const int8_t* weight = GetTensorData(weight_tensor); + micro_tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, + col, *output); + } + return kTfLiteOk; +} + +TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias( + TfLiteContext* context, UnidirectionalSequenceLstmOpData* op_data, + TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kLstmInputTensor); + TfLiteTensor* output_state = + micro_context->AllocateTempInputTensor(node, kLstmOutputStateTensor); + TF_LITE_ENSURE(context, output_state != nullptr); + TF_LITE_ENSURE(context, output_state->is_variable); + + const int32_t input_zero_point = -input->params.zero_point; + const int32_t output_state_zero_point = -output_state->params.zero_point; + + TfLiteTensor* input_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + micro_context->AllocateTempInputTensor(node, + kLstmRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToOutputWeightsTensor); + + TfLiteTensor* projection_weights = micro_context->AllocateTempInputTensor( + node, kLstmProjectionWeightsTensor); + TfLiteTensor* projection_bias = + micro_context->AllocateTempInputTensor(node, kLstmProjectionBiasTensor); + + IntegerLstmParameter* integer_lstm_params = &op_data->integer_lstm_param; + + TfLiteTensor* intermediate = + micro_context->AllocateTempIntermediateTensor(node, 4); + TF_LITE_ENSURE(context, + intermediate->quantization.type != kTfLiteNoQuantization); + const auto* params = + static_cast(intermediate->quantization.params); + const int32_t hidden_zp = params->zero_point->data[0]; + + // Get bias and perform zero point calculation. + // When there is layer normalization, the gate bias does not apply to matmul + // directly: + // y = ln(w * x + w * r + w * c) + b. + const bool is_layer_norm = op_data->use_layer_norm; + + // Forget gate. + TfLiteTensor* forget_gate_bias = is_layer_norm + ? nullptr + : micro_context->AllocateTempInputTensor( + node, kLstmForgetGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_forget_weights, forget_gate_bias, + &(integer_lstm_params->input_to_forget_effective_bias))); + + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_forget_weights, + nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias))); + + // Modulation gate. + TfLiteTensor* cell_gate_bias = is_layer_norm + ? nullptr + : micro_context->AllocateTempInputTensor( + node, kLstmCellGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_cell_weights, cell_gate_bias, + &(integer_lstm_params->input_to_cell_effective_bias))); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_cell_weights, nullptr, + &(integer_lstm_params->recurrent_to_cell_effective_bias))); + + // Output gate. + TfLiteTensor* output_gate_bias = is_layer_norm + ? nullptr + : micro_context->AllocateTempInputTensor( + node, kLstmOutputGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_output_weights, output_gate_bias, + &(integer_lstm_params->input_to_output_effective_bias))); + + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_output_weights, + nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias))); + + // Input gate. The calculation is only meaningful for non-cifg case. + TfLiteTensor* input_gate_bias = is_layer_norm + ? nullptr + : micro_context->AllocateTempInputTensor( + node, kLstmInputGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_input_weights, input_gate_bias, + &(integer_lstm_params->input_to_input_effective_bias))); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_input_weights, nullptr, + &(integer_lstm_params->recurrent_to_input_effective_bias))); + + // Projection bias. The calculation is only meaningful for with projection. + TF_LITE_ENSURE_OK(context, + PrecomputeZeroPointTimesWeightWithBias( + context, hidden_zp, projection_weights, projection_bias, + &(integer_lstm_params->projection_effective_bias))); + + if (input != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input); + } + if (output_state != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_state); + } + if (input_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_input_weights); + } + if (input_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_forget_weights); + } + if (input_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_cell_weights); + } + if (input_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_output_weights); + } + if (recurrent_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_input_weights); + } + if (recurrent_to_forget_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_forget_weights); + } + if (recurrent_to_cell_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_cell_weights); + } + if (recurrent_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_output_weights); + } + if (projection_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_weights); + } + if (projection_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_bias); + } + if (forget_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(forget_gate_bias); + } + if (cell_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_gate_bias); + } + if (output_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_gate_bias); + } + if (input_gate_bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_gate_bias); + } + + if (intermediate != nullptr) { + micro_context->DeallocateTempTfLiteTensor(intermediate); + } + + return kTfLiteOk; +} + +// Resize the output and state tensors based on the sizes of the input tensors. +// Allocate a temporary scratch tensor. Also check that the sizes of the input +// tensors match each other. +TfLiteStatus UnidirectionalSequenceLstmPrepare(TfLiteContext* context, + TfLiteNode* node) { + UnidirectionalSequenceLstmOpData* op_data = + reinterpret_cast(node->user_data); + + MicroContext* micro_context = GetMicroContext(context); + + // Check we have all the inputs and outputs we need. + bool use_layer_norm = false; + if (node->inputs->size == 24) { + TfLiteTensor* forget_layer_norm_coefficients = + micro_context->AllocateTempInputTensor( + node, kLstmForgetLayerNormCoefficientsTensor); + if (forget_layer_norm_coefficients == nullptr) { + use_layer_norm = false; + } else { + use_layer_norm = true; + } + if (forget_layer_norm_coefficients != nullptr) { + micro_context->DeallocateTempTfLiteTensor(forget_layer_norm_coefficients); + } + } else if (node->inputs->size == 20) { + // This is deprecated and is only kept here for backward compatibility. + use_layer_norm = false; + } else { + MicroPrintf("The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + op_data->use_layer_norm = use_layer_norm; + + // Inferring batch size, number of outputs and sequence length and + // number of cells from the input tensors. + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kLstmInputTensor); + op_data->input_zero_point = input->params.zero_point; + const bool is_integer = input->type == kTfLiteInt8; + TF_LITE_ENSURE(context, input->dims->size > 1); + const auto* params = + reinterpret_cast( + node->builtin_data); + const bool time_major = params->time_major; + const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; + const int n_input = input->dims->data[2]; + + TfLiteTensor* input_to_output_weights = + micro_context->AllocateTempInputTensor(node, + kLstmInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + micro_context->AllocateTempInputTensor( + node, kLstmRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_output, + n_cell, use_layer_norm, is_integer)); + + // Get the pointer to output, output_state and cell_state buffer tensors. + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kLstmOutputTensor); + + TfLiteTensor* output_state = + micro_context->AllocateTempInputTensor(node, kLstmOutputStateTensor); + TF_LITE_ENSURE(context, output_state != nullptr); + TF_LITE_ENSURE(context, output_state->is_variable); + op_data->output_state_zero_point = output_state->params.zero_point; + TfLiteTensor* cell_state = + micro_context->AllocateTempInputTensor(node, kLstmCellStateTensor); + TF_LITE_ENSURE(context, cell_state != nullptr); + TF_LITE_ENSURE(context, cell_state->is_variable); + + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + + // Check the shape of output tensor against that of input tensor + TF_LITE_ENSURE_EQ(context, output->dims->size, 3); + TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]); + TF_LITE_ENSURE_EQ(context, input->dims->data[1], output->dims->data[1]); + TF_LITE_ENSURE_EQ(context, output->dims->data[2], n_output); + + if (is_integer) { + const int num_intermediate_tensors = node->intermediates->size; + TF_LITE_ENSURE(context, num_intermediate_tensors == 5); + } + + TfLiteTensor* input_to_input_weights = micro_context->AllocateTempInputTensor( + node, kLstmInputToInputWeightsTensor); + + const bool use_cifg = (input_to_input_weights == nullptr); + + // Create a primary scratch buffer for hybrid and float + // If is_integer, primary scratch buffer has a different size + if (!is_integer) { + int scratch_buffer_size[2]; + scratch_buffer_size[0] = n_batch; + + if (use_cifg) { + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size[1] = n_cell * 3; + } else { + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size[1] = n_cell * 4; + } + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, + scratch_buffer_size[0] * scratch_buffer_size[1] * + TfLiteTypeGetSize(input->type), + &(op_data->scratch_index[kPrimaryScratchBuffer]))); + } + + if (IsHybridOp(input, input_to_output_weights)) { + TF_LITE_ENSURE(context, kNumHybridTempBuffers <= scratch_index_size); + + TF_LITE_ENSURE_OK(context, SetHybridScales(context, node)); + + op_data->compute_row_sums = true; + + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, + GetTensorShape(input).FlatSize() * + TfLiteTypeGetSize(input_to_output_weights->type), + &(op_data->scratch_index[kInputQuantized]))); + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, + GetTensorShape(output_state).FlatSize() * + TfLiteTypeGetSize(input_to_output_weights->type), + &(op_data->scratch_index[kOutputStateQuantized]))); + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, + GetTensorShape(cell_state).FlatSize() * + TfLiteTypeGetSize(input_to_output_weights->type), + &(op_data->scratch_index[kCellStateQuantized]))); + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, n_batch * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kScales]))); + + // Allocate temporary buffers to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, n_batch * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kInputScalingFactors]))); + + TF_LITE_ENSURE_OK( + context, context->RequestScratchBufferInArena( + context, n_batch * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kOutputStateScalingFactors]))); + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, n_batch * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kProductScalingFactors]))); + + // Allocate a temporary buffer to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, n_cell * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kRecoveredCellWeights]))); + + // Allocate a temporary buffer to store the accumulated int32 values. + TF_LITE_ENSURE_OK( + context, + context->RequestScratchBufferInArena( + context, n_cell * n_batch * TfLiteTypeGetSize(kTfLiteInt32), + &(op_data->scratch_index[kAccumScratch]))); + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, n_batch * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kInputZeroPoints]))); + + TF_LITE_ENSURE_OK(context, + context->RequestScratchBufferInArena( + context, n_batch * TfLiteTypeGetSize(kTfLiteFloat32), + &(op_data->scratch_index[kOutputStateZeroPoints]))); + + int row_sums_rows = use_cifg ? 6 : 8; + TfLiteTensor* projection_weights = micro_context->AllocateTempInputTensor( + node, kLstmProjectionWeightsTensor); + if (projection_weights != nullptr) { + row_sums_rows += ceil(static_cast(n_output) / n_cell); + } + op_data->row_sums_size = row_sums_rows; + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + op_data->row_sums = static_cast(context->AllocatePersistentBuffer( + context, row_sums_rows * n_cell * sizeof(int32_t))); + if (projection_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(projection_weights); + } + } + + if (is_integer) { + // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16. + // This code path needs 5 intermediate tensors per Op. + // Populate quantization parameters. + PopulateQuantizedLstmParams8x8_16(context, node, + &op_data->integer_lstm_param); + // Allocate scratch buffer. Need 4 16-bit buffer with size n_batch * n_cell + // and 1 8-bit buffer with size n_batch * n_cell. For integer + // UnidirectionalSequenceLSTM, we do not need the extra 32-bit buffer. + for (int i = 0; i < 5; ++i) { + TfLiteType buffer_type = kTfLiteInt16; + + if (i == 4) { + buffer_type = kTfLiteInt8; + } + + TF_LITE_ENSURE_OK( + context, + context->RequestScratchBufferInArena( + context, n_batch * n_cell * TfLiteTypeGetSize(buffer_type), + &(op_data->scratch_index[i]))); + } + + // Populate precomputed zp * weight. + TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias( + context, op_data, node)); + } + + if (input != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input); + } + if (input_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_output_weights); + } + if (recurrent_to_output_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(recurrent_to_output_weights); + } + if (output != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output); + } + if (output_state != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output_state); + } + if (cell_state != nullptr) { + micro_context->DeallocateTempTfLiteTensor(cell_state); + } + + if (input_to_input_weights != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input_to_input_weights); + } + return kTfLiteOk; +} + +TfLiteStatus UnidirectionalSequenceLstmEval(TfLiteContext* context, + TfLiteNode* node) { + TFLITE_DCHECK(context->GetScratchBuffer != nullptr); + + const auto* params = + reinterpret_cast( + node->builtin_data); + const UnidirectionalSequenceLstmOpData* op_data = + reinterpret_cast(node->user_data); + const bool use_layer_norm = op_data->use_layer_norm; + const bool time_major = params->time_major; + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kLstmInputTensor); + + const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput( + context, node, kLstmInputToInputWeightsTensor); + + const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput( + context, node, kLstmInputToForgetWeightsTensor); + + const TfLiteEvalTensor* input_to_cell_weights = + tflite::micro::GetEvalInput(context, node, kLstmInputToCellWeightsTensor); + + const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput( + context, node, kLstmInputToOutputWeightsTensor); + + const TfLiteEvalTensor* recurrent_to_input_weights = + tflite::micro::GetEvalInput(context, node, + kLstmRecurrentToInputWeightsTensor); + + const TfLiteEvalTensor* recurrent_to_forget_weights = + tflite::micro::GetEvalInput(context, node, + kLstmRecurrentToForgetWeightsTensor); + + const TfLiteEvalTensor* recurrent_to_cell_weights = + tflite::micro::GetEvalInput(context, node, + kLstmRecurrentToCellWeightsTensor); + + const TfLiteEvalTensor* recurrent_to_output_weights = + tflite::micro::GetEvalInput(context, node, + kLstmRecurrentToOutputWeightsTensor); + + const TfLiteEvalTensor* cell_to_input_weights = + tflite::micro::GetEvalInput(context, node, kLstmCellToInputWeightsTensor); + + const TfLiteEvalTensor* cell_to_forget_weights = tflite::micro::GetEvalInput( + context, node, kLstmCellToForgetWeightsTensor); + + const TfLiteEvalTensor* cell_to_output_weights = tflite::micro::GetEvalInput( + context, node, kLstmCellToOutputWeightsTensor); + + const TfLiteEvalTensor* input_gate_bias = + tflite::micro::GetEvalInput(context, node, kLstmInputGateBiasTensor); + + const TfLiteEvalTensor* forget_gate_bias = + tflite::micro::GetEvalInput(context, node, kLstmForgetGateBiasTensor); + + const TfLiteEvalTensor* cell_gate_bias = + tflite::micro::GetEvalInput(context, node, kLstmCellGateBiasTensor); + + const TfLiteEvalTensor* output_gate_bias = + tflite::micro::GetEvalInput(context, node, kLstmOutputGateBiasTensor); + + const TfLiteEvalTensor* projection_weights = + tflite::micro::GetEvalInput(context, node, kLstmProjectionWeightsTensor); + + const TfLiteEvalTensor* projection_bias = + tflite::micro::GetEvalInput(context, node, kLstmProjectionBiasTensor); + + TfLiteEvalTensor* output_state = + tflite::micro::GetMutableEvalInput(context, node, kLstmOutputStateTensor); + + TfLiteEvalTensor* cell_state = + tflite::micro::GetMutableEvalInput(context, node, kLstmCellStateTensor); + + TFLITE_DCHECK(cell_state != nullptr); + + const TfLiteEvalTensor* input_layer_norm_coefficients = + use_layer_norm ? tflite::micro::GetEvalInput( + context, node, kLstmInputLayerNormCoefficientsTensor) + : nullptr; + const TfLiteEvalTensor* forget_layer_norm_coefficients = + use_layer_norm + ? tflite::micro::GetEvalInput(context, node, + kLstmForgetLayerNormCoefficientsTensor) + : nullptr; + const TfLiteEvalTensor* cell_layer_norm_coefficients = + use_layer_norm ? tflite::micro::GetEvalInput( + context, node, kLstmCellLayerNormCoefficientsTensor) + : nullptr; + const TfLiteEvalTensor* output_layer_norm_coefficients = + use_layer_norm + ? tflite::micro::GetEvalInput(context, node, + kLstmOutputLayerNormCoefficientsTensor) + : nullptr; + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kLstmOutputTensor); + + // Copy out the LSTM specific params so they can be passed in the function. + TfLiteLSTMParams lstm_params; + lstm_params.activation = params->activation; + lstm_params.cell_clip = params->cell_clip; + lstm_params.proj_clip = params->proj_clip; + lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs; + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + // Index the scratch buffers pointers to the global scratch buffer. + return EvalFloatLstm( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_layer_norm_coefficients, forget_layer_norm_coefficients, + cell_layer_norm_coefficients, output_layer_norm_coefficients, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_gate_bias, output_gate_bias, + projection_weights, projection_bias, &lstm_params, + /*forward_sequence=*/true, time_major, + /*output_offset=*/0, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kPrimaryScratchBuffer])), + output_state, cell_state, output); + } break; + case kTfLiteUInt8: + case kTfLiteInt8: { + const bool is_hybrid = input->type == kTfLiteFloat32; + if (is_hybrid) { + // Index the scratch buffers pointers to the global scratch buffer. + UnidirectionalSequenceLstmOpData* op_data_rw = + reinterpret_cast( + node->user_data); + return EvalHybridLstm( + &(op_data->hybrid_lstm_scales), input, input_to_input_weights, + /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, + /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, + /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, + /*input_to_output_weights_ledger*/ nullptr, + recurrent_to_input_weights, + /*recurrent_to_input_weights_ledger*/ nullptr, + recurrent_to_forget_weights, + /*recurrent_to_forget_weights_ledger*/ nullptr, + recurrent_to_cell_weights, + /*recurrent_to_cell_weights_ledger*/ nullptr, + recurrent_to_output_weights, + /*recurrent_to_output_weights_ledger*/ nullptr, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_coefficients, + forget_layer_norm_coefficients, cell_layer_norm_coefficients, + output_layer_norm_coefficients, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_gate_bias, output_gate_bias, + projection_weights, /*projection_weights_ledger*/ nullptr, + projection_bias, &lstm_params, + /*forward_sequence=*/true, time_major, + /*output_offset=*/0, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kPrimaryScratchBuffer])), + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kInputScalingFactors])), + /*aux_input_sf=*/nullptr, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kOutputStateScalingFactors])), + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kProductScalingFactors])), + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kRecoveredCellWeights])), + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kInputQuantized])), + /*aux_input_quantized=*/nullptr, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kOutputStateQuantized])), + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kCellStateQuantized])), + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kScales])), + output_state, cell_state, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kAccumScratch])), + output, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kInputZeroPoints])), + /*aux_input_zp=*/nullptr, + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scratch_index[kOutputStateZeroPoints])), + op_data_rw->row_sums, op_data_rw->row_sums_size, + &op_data_rw->compute_row_sums); + } else { + return EvalInteger8x8_16Lstm( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_coefficients, + forget_layer_norm_coefficients, cell_layer_norm_coefficients, + output_layer_norm_coefficients, input_gate_bias, forget_gate_bias, + cell_gate_bias, output_gate_bias, projection_weights, + projection_bias, &lstm_params, /*forward_sequence=*/true, + time_major, &op_data->integer_lstm_param, + op_data->output_state_zero_point, output_state, cell_state, output, + reinterpret_cast( + context->GetScratchBuffer(context, op_data->scratch_index[0])), + reinterpret_cast( + context->GetScratchBuffer(context, op_data->scratch_index[1])), + reinterpret_cast( + context->GetScratchBuffer(context, op_data->scratch_index[2])), + reinterpret_cast( + context->GetScratchBuffer(context, op_data->scratch_index[3])), + reinterpret_cast( + context->GetScratchBuffer(context, op_data->scratch_index[4])), + nullptr); + } + } break; + default: + MicroPrintf("Type %s is not currently supported.", + TfLiteTypeGetName(input_to_output_weights->type)); + return kTfLiteError; + } +} + +} // namespace + +TfLiteRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { + return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit, + UnidirectionalSequenceLstmPrepare, + UnidirectionalSequenceLstmEval); +} + +} // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test_config.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test_config.h new file mode 100644 index 00000000..e37c0efd --- /dev/null +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test_config.h @@ -0,0 +1,244 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_ + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace testing { + +// TODO(b/230666079) enable below tests for xtensa when the xtensa +// kernel is reconciled with reference kernel +#if !defined(XTENSA) + +typedef struct LstmIntegerTestConfig { + const int n_batch; + const int n_input; + const int n_cell; + const int n_output; + const int sequence_length; + const bool time_major; + const bool use_cifg; + const bool use_peephole; + const bool use_projection_weights; + const bool use_projection_bias; + const bool use_layer_norm; + const bool use_8x8_8_implementation; + float intermediate_scale[5][2]; + int intermediate_zp[5][2]; + TfLiteAffineQuantization* intermediate_qparam; + + const float* input; + int8_t* input_quant; + + const float* input_to_input_weights; + int8_t* lstm_i2i_quant; + const float* input_to_forget_weights; + int8_t* lstm_i2f_quant; + const float* input_to_cell_weights; + int8_t* lstm_i2c_quant; + const float* input_to_output_weights; + int8_t* lstm_i2o_quant; + + const float* recurrent_to_input_weights; + int8_t* lstm_r2i_quant; + const float* recurrent_to_forget_weights; + int8_t* lstm_r2f_quant; + const float* recurrent_to_cell_weights; + int8_t* lstm_r2c_quant; + const float* recurrent_to_output_weights; + int8_t* lstm_r2o_quant; + + const float* cell_to_input_weights; + int16_t* lstm_c2i_quant; + const float* cell_to_forget_weights; + int16_t* lstm_c2f_quant; + const float* cell_to_output_weights; + int16_t* lstm_c2o_quant; + + const float* input_gate_bias; + int32_t* lstm_igate_bias_quant; + const float* forget_gate_bias; + int32_t* lstm_fgate_bias_quant; + const float* cell_gate_bias; + int32_t* lstm_cgate_bias_quant; + const float* output_gate_bias; + int32_t* lstm_ogate_bias_quant; + + const float* projection_weights; + int8_t* lstm_proj_w_quant; + const float* projection_bias; + int32_t* projection_bias_quant; + + int16_t* output_state; + int16_t* cell_state; + + const float* input_layer_norm_coefficients; + int16_t* lstm_input_layer_norm_coeff_quant; + const float* forget_layer_norm_coefficients; + int16_t* lstm_forget_layer_norm_coeff_quant; + const float* cell_layer_norm_coefficients; + int16_t* lstm_cell_layer_norm_coeff_quant; + const float* output_layer_norm_coefficients; + int16_t* lstm_output_layer_norm_coeff_quant; + + int8_t* output; + const int8_t* expected_output; + + bool asymmetric_quantize_inputs; + const float ranges[25][2]; +} LstmIntegerTestConfig; + +typedef struct LstmFloatTestConfig { + const int n_batch; + const int n_input; + const int n_cell; + const int n_output; + const int sequence_length; + const bool time_major; + const bool use_cifg; + const bool use_peephole; + const bool use_projection_weights; + const bool use_projection_bias; + const bool use_layer_norm; + const float cell_clip; + const float proj_clip; + + const float* input_original; + float* input; + + const float* input_to_input_weights; + const float* input_to_forget_weights; + const float* input_to_cell_weights; + const float* input_to_output_weights; + + const float* recurrent_to_input_weights; + const float* recurrent_to_forget_weights; + const float* recurrent_to_cell_weights; + const float* recurrent_to_output_weights; + + const float* cell_to_input_weights; + const float* cell_to_forget_weights; + const float* cell_to_output_weights; + + const float* input_gate_bias; + const float* forget_gate_bias; + const float* cell_gate_bias; + const float* output_gate_bias; + + const float* projection_weights; + const float* projection_bias; + + float* output_state; + float* cell_state; + + const float* input_layer_norm_coefficients; + const float* forget_layer_norm_coefficients; + const float* cell_layer_norm_coefficients; + const float* output_layer_norm_coefficients; + + float* output; + const float* expected_output_original; + float* expected_output; +} LstmFloatTestConfig; + +typedef struct LstmWeightQuantizationBuffers { + int8_t* lstm_i2i_quant; + float* lstm_i2i_scale; + int* lstm_i2i_zp; + TfLiteAffineQuantization* lstm_i2i_qparam; + + int8_t* lstm_i2f_quant; + float* lstm_i2f_scale; + int* lstm_i2f_zp; + TfLiteAffineQuantization* lstm_i2f_qparam; + + int8_t* lstm_i2c_quant; + float* lstm_i2c_scale; + int* lstm_i2c_zp; + TfLiteAffineQuantization* lstm_i2c_qparam; + + int8_t* lstm_i2o_quant; + float* lstm_i2o_scale; + int* lstm_i2o_zp; + TfLiteAffineQuantization* lstm_i2o_qparam; + + int8_t* lstm_r2i_quant; + float* lstm_r2i_scale; + int* lstm_r2i_zp; + TfLiteAffineQuantization* lstm_r2i_qparam; + + int8_t* lstm_r2f_quant; + float* lstm_r2f_scale; + int* lstm_r2f_zp; + TfLiteAffineQuantization* lstm_r2f_qparam; + + int8_t* lstm_r2c_quant; + float* lstm_r2c_scale; + int* lstm_r2c_zp; + TfLiteAffineQuantization* lstm_r2c_qparam; + + int8_t* lstm_r2o_quant; + float* lstm_r2o_scale; + int* lstm_r2o_zp; + TfLiteAffineQuantization* lstm_r2o_qparam; + + int8_t* lstm_c2i_quant; + float* lstm_c2i_scale; + int* lstm_c2i_zp; + TfLiteAffineQuantization* lstm_c2i_qparam; + + int8_t* lstm_c2f_quant; + float* lstm_c2f_scale; + int* lstm_c2f_zp; + TfLiteAffineQuantization* lstm_c2f_qparam; + + int8_t* lstm_c2o_quant; + float* lstm_c2o_scale; + int* lstm_c2o_zp; + TfLiteAffineQuantization* lstm_c2o_qparam; + + int8_t* lstm_proj_w_quant; + float* lstm_proj_w_scale; + int* lstm_proj_w_zp; + TfLiteAffineQuantization* lstm_proj_w_qparam; +} LstmWeightQuantizationBuffers; + +extern LstmIntegerTestConfig lstm_integer_no_peephole_config; + +extern LstmIntegerTestConfig lstm_integer_peephole_config; + +extern LstmFloatTestConfig lstm_no_cifg_no_peephole_no_proj_config; + +extern LstmFloatTestConfig lstm_cifg_peephole_no_proj_config; + +extern LstmFloatTestConfig lstm_no_cifg_peephole_proj_config; + +extern LstmFloatTestConfig lstm_no_cifg_peephole_proj_bias_config; + +extern LstmWeightQuantizationBuffers lstm_no_cifg_no_peephole_no_proj_buffers; + +extern LstmWeightQuantizationBuffers lstm_cifg_peephole_no_proj_buffers; + +extern LstmWeightQuantizationBuffers lstm_no_cifg_peephole_proj_buffers; + +extern LstmFloatTestConfig cifg_peephole_no_proj_config_layer_norm; + +#endif // !defined(XTENSA) +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_TEST_CONFIG_H_ diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/unpack.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/unpack.cc index 13bb7dcf..d199add0 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/unpack.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/unpack.cc @@ -103,14 +103,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace unpack TfLiteRegistration Register_UNPACK() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/unpack::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, nullptr, unpack::Eval); } } // namespace micro diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/var_handle.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/var_handle.cc index 8354c918..db044f3f 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/var_handle.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/var_handle.cc @@ -87,14 +87,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_VAR_HANDLE() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/while.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/while.cc index 576c19b0..811c9eae 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/while.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/while.cc @@ -127,14 +127,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace. TfLiteRegistration Register_WHILE() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(Init, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/zeros_like.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/zeros_like.cc index 733564c9..fd6e6612 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/zeros_like.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/zeros_like.cc @@ -81,14 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_ZEROS_LIKE() { - return {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.cc b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.cc index ab313e66..edab2b83 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h" +#include "tensorflow/lite/micro/micro_error_reporter.h" namespace tflite { @@ -148,6 +150,30 @@ TfLiteStatus AllocationInfoBuilder::FreeAllocationInfo() { return kTfLiteOk; } +TfLiteStatus AllocationInfoBuilder::ValidateSubgraph( + const SubGraph* subgraph, TfLiteEvalTensor* eval_tensors) { + uint32_t operators_size = NumSubgraphOperators(subgraph); + + for (uint32_t i = 0; i < operators_size; i++) { + const auto op = subgraph->operators()->Get(i); + for (size_t n = 0; + op->intermediates() != nullptr && n < op->intermediates()->size(); + n++) { + const int tensor_index = op->intermediates()->Get(n); + size_t tensor_size = -1; + TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength( + &eval_tensors[tensor_index], &tensor_size)); + if (tensor_size != 0) { + MicroPrintf( + "Does not support intermediate tensor with non-zero size: %d", + tensor_size); + return kTfLiteError; + } + } + } + return kTfLiteOk; +} + TfLiteStatus AllocationInfoBuilder::InitializeAllocationInfo( const int32_t* offline_offsets, SubgraphAllocations* allocations) { AllocationInfo* allocation_info = info_.allocation_info; @@ -158,6 +184,10 @@ TfLiteStatus AllocationInfoBuilder::InitializeAllocationInfo( TfLiteEvalTensor* eval_tensors = allocations[subgraph_idx].tensors; AllocationInfo* subgraph_allocation_info = &allocation_info[info_.subgraph_offsets[subgraph_idx]]; + + // Ensure constraints are met. + TF_LITE_ENSURE_STATUS(ValidateSubgraph(subgraph, eval_tensors)); + for (size_t i = 0; i < subgraph->tensors()->size(); ++i) { AllocationInfo* current = &subgraph_allocation_info[i]; current->output_ptr = &(eval_tensors[i].data.data); @@ -167,8 +197,10 @@ TfLiteStatus AllocationInfoBuilder::InitializeAllocationInfo( current->first_created = kUninitializedLifetime; current->last_used = kUninitializedLifetime; - current->needs_allocating = (eval_tensors[i].data.data == nullptr) && - (!subgraph->tensors()->Get(i)->is_variable()); + current->needs_allocating = + (eval_tensors[i].data.data == nullptr) && + (!subgraph->tensors()->Get(i)->is_variable()) && + (current->bytes != 0); if (offline_offsets) { current->offline_offset = offline_offsets[i]; } else { @@ -181,8 +213,8 @@ TfLiteStatus AllocationInfoBuilder::InitializeAllocationInfo( &allocation_info[info_.scratch_offset]; for (size_t i = 0; i < info_.scratch_buffer_count; i++) { AllocationInfo* current = &scratch_allocation_info[i]; - current->first_created = -1; - current->last_used = -1; + current->first_created = kUninitializedLifetime; + current->last_used = kUninitializedLifetime; current->needs_allocating = true; current->offline_offset = kOnlinePlannedBuffer; } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.h b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.h index af303307..bc6825ef 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocation_info.h @@ -128,6 +128,10 @@ class AllocationInfoBuilder { // count monotonically increases through the lifetime marking process. void UpdateLastUsed(AllocationInfo* current, int allocation_scope_count); + // Validate if a subgraph satisfies assumptions. + TfLiteStatus ValidateSubgraph(const SubGraph* subgraph, + TfLiteEvalTensor* eval_tensors); + const tflite::Model* model_ = nullptr; INonPersistentBufferAllocator* non_persistent_allocator_ = nullptr; ErrorReporter* reporter_ = nullptr; diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.cc index b71c7502..7e5192cf 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/flatbuffer_utils.h" #include "tensorflow/lite/micro/memory_helpers.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/lite/micro/micro_allocation_info.h" #include "tensorflow/lite/micro/micro_arena_constants.h" #include "tensorflow/lite/micro/micro_error_reporter.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.h index d2967c21..35b07f16 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_allocator.h @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/flatbuffer_utils.h" #include "tensorflow/lite/micro/memory_planner/micro_memory_planner.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_context.cc b/code/components/tflite-lib/tensorflow/lite/micro/micro_context.cc index 1526b976..9ec694b8 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_context.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_context.cc @@ -80,6 +80,16 @@ TfLiteTensor* MicroContext::AllocateTempOutputTensor(const TfLiteNode* node, return AllocateTempTfLiteTensor(tensor_index); } +TfLiteTensor* MicroContext::AllocateTempIntermediateTensor( + const TfLiteNode* node, int index) { + const int tensor_index = GetTensorIndex(index, node->intermediates->size, + node->intermediates->data); + if (tensor_index < 0) { + return nullptr; + } + return AllocateTempTfLiteTensor(tensor_index); +} + void MicroContext::DeallocateTempTfLiteTensor(TfLiteTensor* tensor) { return allocator_.DeallocateTempTfLiteTensor(tensor); } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_context.h b/code/components/tflite-lib/tensorflow/lite/micro/micro_context.h index 1db2575e..e7be6544 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_context.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_context.h @@ -73,6 +73,13 @@ class MicroContext { virtual TfLiteTensor* AllocateTempOutputTensor(const TfLiteNode* node, int index); + // Returns a temporary TfLiteTensor struct for the specified intermediate + // tensor of a given mode. This is the recommended API over the deprecated + // GetIntermediates/GetIntermediatesSafe to get a temp intermediate tensor. + // The returned tensor shall be freed via calling DeallocateTempTfLiteTensor. + virtual TfLiteTensor* AllocateTempIntermediateTensor(const TfLiteNode* node, + int index); + // Deallocates a temp TfLiteTensor. // Virtual so that it can be faked for kernel tests. virtual void DeallocateTempTfLiteTensor(TfLiteTensor* tensor); diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_mutable_op_resolver.h b/code/components/tflite-lib/tensorflow/lite/micro/micro_mutable_op_resolver.h index 8676189d..237bd595 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,9 +25,11 @@ limitations under the License. #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/kernels/conv.h" +#include "tensorflow/lite/micro/kernels/depthwise_conv.h" #include "tensorflow/lite/micro/kernels/ethosu.h" #include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/kernels/reduce.h" #include "tensorflow/lite/micro/kernels/softmax.h" #include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -119,8 +121,8 @@ class MicroMutableOpResolver : public MicroOpResolver { ParseAbs); } - TfLiteStatus AddAdd() { - return AddBuiltin(BuiltinOperator_ADD, tflite::Register_ADD(), ParseAdd); + TfLiteStatus AddAdd(const TfLiteRegistration& registration = Register_ADD()) { + return AddBuiltin(BuiltinOperator_ADD, registration, ParseAdd); } TfLiteStatus AddAddN() { @@ -207,9 +209,10 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::Register_DEPTH_TO_SPACE(), ParseDepthToSpace); } - TfLiteStatus AddDepthwiseConv2D() { - return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, - Register_DEPTHWISE_CONV_2D(), ParseDepthwiseConv2D); + TfLiteStatus AddDepthwiseConv2D( + const TfLiteRegistration& registration = Register_DEPTHWISE_CONV_2D()) { + return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, registration, + ParseDepthwiseConv2D); } TfLiteStatus AddDequantize() { @@ -372,8 +375,7 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddMean() { - return AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN(), - ParseReducer); + return AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(), ParseReducer); } TfLiteStatus AddMinimum() { @@ -426,8 +428,8 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddReduceMax() { - return AddBuiltin(BuiltinOperator_REDUCE_MAX, - tflite::ops::micro::Register_REDUCE_MAX(), ParseReducer); + return AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX(), + ParseReducer); } TfLiteStatus AddRelu() { @@ -554,10 +556,9 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddUnidirectionalSequenceLSTM() { - return AddBuiltin( - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - tflite::ops::micro::Register_UNIDIRECTIONAL_SEQUENCE_LSTM(), - ParseUnidirectionalSequenceLSTM); + return AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + Register_UNIDIRECTIONAL_SEQUENCE_LSTM(), + ParseUnidirectionalSequenceLSTM); } TfLiteStatus AddVarHandle() { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.cc b/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.cc index d8a86c6b..72f3d37f 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/micro/micro_profiler.h" +#include #include #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -38,7 +39,7 @@ void MicroProfiler::EndEvent(uint32_t event_handle) { end_ticks_[event_handle] = GetCurrentTimeTicks(); } -int32_t MicroProfiler::GetTotalTicks() const { +uint32_t MicroProfiler::GetTotalTicks() const { int32_t ticks = 0; for (int i = 0; i < num_events_; ++i) { ticks += end_ticks_[i] - start_ticks_[i]; @@ -49,8 +50,9 @@ int32_t MicroProfiler::GetTotalTicks() const { void MicroProfiler::Log() const { #if !defined(TF_LITE_STRIP_ERROR_STRINGS) for (int i = 0; i < num_events_; ++i) { - int32_t ticks = end_ticks_[i] - start_ticks_[i]; - MicroPrintf("%s took %d ticks (%d ms).", tags_[i], ticks, TicksToMs(ticks)); + uint32_t ticks = end_ticks_[i] - start_ticks_[i]; + MicroPrintf("%s took %" PRIu32 " ticks (%d ms).", tags_[i], ticks, + TicksToMs(ticks)); } #endif } @@ -59,8 +61,8 @@ void MicroProfiler::LogCsv() const { #if !defined(TF_LITE_STRIP_ERROR_STRINGS) MicroPrintf("\"Event\",\"Tag\",\"Ticks\""); for (int i = 0; i < num_events_; ++i) { - int32_t ticks = end_ticks_[i] - start_ticks_[i]; - MicroPrintf("%d,%s,%d", i, tags_[i], ticks); + uint32_t ticks = end_ticks_[i] - start_ticks_[i]; + MicroPrintf("%d,%s,%" PRIu32, i, tags_[i], ticks); } #endif } diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.h b/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.h index 8a1ba5de..41f41a35 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_profiler.h @@ -51,7 +51,7 @@ class MicroProfiler { // Returns the sum of the ticks taken across all the events. This number // is only meaningful if all of the events are disjoint (the end time of // event[i] <= start time of event[i+1]). - int32_t GetTotalTicks() const; + uint32_t GetTotalTicks() const; // Prints the profiling information of each of the events in human readable // form. @@ -68,8 +68,8 @@ class MicroProfiler { static constexpr int kMaxEvents = 1024; const char* tags_[kMaxEvents]; - int32_t start_ticks_[kMaxEvents]; - int32_t end_ticks_[kMaxEvents]; + uint32_t start_ticks_[kMaxEvents]; + uint32_t end_ticks_[kMaxEvents]; int num_events_ = 0; TF_LITE_REMOVE_VIRTUAL_DELETE; diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_time.cc b/code/components/tflite-lib/tensorflow/lite/micro/micro_time.cc index bbe3f1a8..2d74fdba 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_time.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_time.cc @@ -38,21 +38,21 @@ namespace tflite { // for a platform to support Tensorflow Lite for Microcontrollers profiling. // This returns 0 by default because timing is an optional feature that builds // without errors on platforms that do not need it. -int32_t ticks_per_second() { return 0; } +uint32_t ticks_per_second() { return 0; } // Reference implementation of the GetCurrentTimeTicks() function that's // required for a platform to support Tensorflow Lite for Microcontrollers // profiling. This returns 0 by default because timing is an optional feature // that builds without errors on platforms that do not need it. -int32_t GetCurrentTimeTicks() { return 0; } +uint32_t GetCurrentTimeTicks() { return 0; } #else // defined(TF_LITE_USE_CTIME) // For platforms that support ctime, we implment the micro_time interface in // this central location. -int32_t ticks_per_second() { return CLOCKS_PER_SEC; } +uint32_t ticks_per_second() { return CLOCKS_PER_SEC; } -int32_t GetCurrentTimeTicks() { return clock(); } +uint32_t GetCurrentTimeTicks() { return clock(); } #endif } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/micro_time.h b/code/components/tflite-lib/tensorflow/lite/micro/micro_time.h index fac9069b..7a8ab455 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/micro_time.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/micro_time.h @@ -21,14 +21,14 @@ namespace tflite { // These functions should be implemented by each target platform, and provide an // accurate tick count along with how many ticks there are per second. -int32_t ticks_per_second(); +uint32_t ticks_per_second(); // Return time in ticks. The meaning of a tick varies per platform. -int32_t GetCurrentTimeTicks(); +uint32_t GetCurrentTimeTicks(); -inline int32_t TicksToMs(int32_t ticks) { - return static_cast(1000.0f * static_cast(ticks) / - static_cast(ticks_per_second())); +inline uint32_t TicksToMs(int32_t ticks) { + return static_cast(1000.0f * static_cast(ticks) / + static_cast(ticks_per_second())); } } // namespace tflite diff --git a/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.cc index 53b3806d..fd84370a 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.cc @@ -17,12 +17,12 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h" #include "tensorflow/lite/micro/micro_allocator.h" #include "tensorflow/lite/micro/micro_error_reporter.h" -#include "tensorflow/lite/micro/recording_simple_memory_allocator.h" namespace tflite { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.h index 6b039c03..0667287f 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.h +++ b/code/components/tflite-lib/tensorflow/lite/micro/recording_micro_allocator.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_RECORDING_MICRO_ALLOCATOR_H_ #define TENSORFLOW_LITE_MICRO_RECORDING_MICRO_ALLOCATOR_H_ +#include "tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/micro_allocator.h" -#include "tensorflow/lite/micro/recording_simple_memory_allocator.h" namespace tflite { diff --git a/code/components/tflite-lib/tensorflow/lite/micro/test_helpers.cc b/code/components/tflite-lib/tensorflow/lite/micro/test_helpers.cc index 2adea777..2411bbf8 100644 --- a/code/components/tflite-lib/tensorflow/lite/micro/test_helpers.cc +++ b/code/components/tflite-lib/tensorflow/lite/micro/test_helpers.cc @@ -110,7 +110,9 @@ class ModelBuilder { // Adds a node to the model with given input and output Tensors. Node AddNode(Operator op, std::initializer_list inputs, - std::initializer_list outputs); + std::initializer_list outputs, + std::initializer_list intermediates = + std::initializer_list{}); void AddMetadata(const char* description_string, const int32_t* metadata_buffer_data, size_t num_elements); @@ -165,12 +167,17 @@ ModelBuilder::Operator ModelBuilder::RegisterOp(BuiltinOperator op, ModelBuilder::Node ModelBuilder::AddNode( ModelBuilder::Operator op, std::initializer_list inputs, - std::initializer_list outputs) { + std::initializer_list outputs, + std::initializer_list intermediates) { TFLITE_DCHECK(next_operator_id_ <= kMaxOperators); operators_[next_operator_id_] = tflite::CreateOperator( *builder_, op, builder_->CreateVector(inputs.begin(), inputs.size()), builder_->CreateVector(outputs.begin(), outputs.size()), - BuiltinOptions_NONE); + BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs =*/0, + builder_->CreateVector(intermediates.begin(), intermediates.size())); next_operator_id_++; return next_operator_id_ - 1; } @@ -274,9 +281,12 @@ const Model* BuildSimpleStatefulModel() { const int median_tensor = model_builder.AddTensor(TensorType_INT8, {3}); const int invoke_count_tensor = model_builder.AddTensor(TensorType_INT32, {1}); + const int intermediate_tensor = + model_builder.AddTensor(TensorType_FLOAT32, {0}); model_builder.AddNode(op_id, {input_tensor}, - {median_tensor, invoke_count_tensor}); + {median_tensor, invoke_count_tensor}, + {intermediate_tensor}); return model_builder.BuildModel({input_tensor}, {median_tensor, invoke_count_tensor}); } diff --git a/code/components/tflite-lib/tensorflow/lite/schema/schema_generated.h b/code/components/tflite-lib/tensorflow/lite/schema/schema_generated.h index e5ce189f..d30dbfe8 100644 --- a/code/components/tflite-lib/tensorflow/lite/schema/schema_generated.h +++ b/code/components/tflite-lib/tensorflow/lite/schema/schema_generated.h @@ -397,6 +397,9 @@ struct GeluOptionsT; struct DynamicUpdateSliceOptions; struct DynamicUpdateSliceOptionsT; +struct UnsortedSegmentProdOptions; +struct UnsortedSegmentProdOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -875,11 +878,13 @@ enum BuiltinOperator { BuiltinOperator_MULTINOMIAL = 149, BuiltinOperator_GELU = 150, BuiltinOperator_DYNAMIC_UPDATE_SLICE = 151, + BuiltinOperator_RELU_0_TO_1 = 152, + BuiltinOperator_UNSORTED_SEGMENT_PROD = 153, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_DYNAMIC_UPDATE_SLICE + BuiltinOperator_MAX = BuiltinOperator_UNSORTED_SEGMENT_PROD }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[152] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[154] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -1032,13 +1037,15 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[152] { BuiltinOperator_RANDOM_UNIFORM, BuiltinOperator_MULTINOMIAL, BuiltinOperator_GELU, - BuiltinOperator_DYNAMIC_UPDATE_SLICE + BuiltinOperator_DYNAMIC_UPDATE_SLICE, + BuiltinOperator_RELU_0_TO_1, + BuiltinOperator_UNSORTED_SEGMENT_PROD }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[153] = { + static const char * const names[155] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1191,13 +1198,15 @@ inline const char * const *EnumNamesBuiltinOperator() { "MULTINOMIAL", "GELU", "DYNAMIC_UPDATE_SLICE", + "RELU_0_TO_1", + "UNSORTED_SEGMENT_PROD", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_DYNAMIC_UPDATE_SLICE)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_UNSORTED_SEGMENT_PROD)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1321,11 +1330,12 @@ enum BuiltinOptions { BuiltinOptions_BucketizeOptions = 115, BuiltinOptions_GeluOptions = 116, BuiltinOptions_DynamicUpdateSliceOptions = 117, + BuiltinOptions_UnsortedSegmentProdOptions = 118, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_DynamicUpdateSliceOptions + BuiltinOptions_MAX = BuiltinOptions_UnsortedSegmentProdOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[118] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[119] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1444,13 +1454,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[118] { BuiltinOptions_RandomOptions, BuiltinOptions_BucketizeOptions, BuiltinOptions_GeluOptions, - BuiltinOptions_DynamicUpdateSliceOptions + BuiltinOptions_DynamicUpdateSliceOptions, + BuiltinOptions_UnsortedSegmentProdOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[119] = { + static const char * const names[120] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1569,13 +1580,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "BucketizeOptions", "GeluOptions", "DynamicUpdateSliceOptions", + "UnsortedSegmentProdOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_DynamicUpdateSliceOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_UnsortedSegmentProdOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -2052,6 +2064,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -3020,6 +3036,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_DynamicUpdateSliceOptions ? reinterpret_cast(value) : nullptr; } + tflite::UnsortedSegmentProdOptionsT *AsUnsortedSegmentProdOptions() { + return type == BuiltinOptions_UnsortedSegmentProdOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnsortedSegmentProdOptionsT *AsUnsortedSegmentProdOptions() const { + return type == BuiltinOptions_UnsortedSegmentProdOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -10659,6 +10683,60 @@ inline flatbuffers::Offset CreateDynamicUpdateSliceOp flatbuffers::Offset CreateDynamicUpdateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const DynamicUpdateSliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct UnsortedSegmentProdOptionsT : public flatbuffers::NativeTable { + typedef UnsortedSegmentProdOptions TableType; + int32_t num_segments; + UnsortedSegmentProdOptionsT() + : num_segments(0) { + } +}; + +struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UnsortedSegmentProdOptionsT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NUM_SEGMENTS = 4 + }; + int32_t num_segments() const { + return GetField(VT_NUM_SEGMENTS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_SEGMENTS) && + verifier.EndTable(); + } + UnsortedSegmentProdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnsortedSegmentProdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnsortedSegmentProdOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_num_segments(int32_t num_segments) { + fbb_.AddElement(UnsortedSegmentProdOptions::VT_NUM_SEGMENTS, num_segments, 0); + } + explicit UnsortedSegmentProdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UnsortedSegmentProdOptionsBuilder &operator=(const UnsortedSegmentProdOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUnsortedSegmentProdOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_segments = 0) { + UnsortedSegmentProdOptionsBuilder builder_(_fbb); + builder_.add_num_segments(num_segments); + return builder_.Finish(); +} + +flatbuffers::Offset CreateUnsortedSegmentProdOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; int8_t deprecated_builtin_code; @@ -11160,6 +11238,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::DynamicUpdateSliceOptions *builtin_options_as_DynamicUpdateSliceOptions() const { return builtin_options_type() == tflite::BuiltinOptions_DynamicUpdateSliceOptions ? static_cast(builtin_options()) : nullptr; } + const tflite::UnsortedSegmentProdOptions *builtin_options_as_UnsortedSegmentProdOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnsortedSegmentProdOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -11664,6 +11745,10 @@ template<> inline const tflite::DynamicUpdateSliceOptions *Operator::builtin_opt return builtin_options_as_DynamicUpdateSliceOptions(); } +template<> inline const tflite::UnsortedSegmentProdOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnsortedSegmentProdOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -15773,6 +15858,32 @@ inline flatbuffers::Offset CreateDynamicUpdateSliceOp _fbb); } +inline UnsortedSegmentProdOptionsT *UnsortedSegmentProdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new UnsortedSegmentProdOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void UnsortedSegmentProdOptions::UnPackTo(UnsortedSegmentProdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_segments(); _o->num_segments = _e; } +} + +inline flatbuffers::Offset UnsortedSegmentProdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnsortedSegmentProdOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateUnsortedSegmentProdOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnsortedSegmentProdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_segments = _o->num_segments; + return tflite::CreateUnsortedSegmentProdOptions( + _fbb, + _num_segments); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -16716,6 +16827,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -17202,6 +17317,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -17676,6 +17795,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateDynamicUpdateSliceOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnsortedSegmentProdOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -18150,6 +18273,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::DynamicUpdateSliceOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_UnsortedSegmentProdOptions: { + value = new tflite::UnsortedSegmentProdOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -18742,6 +18869,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/code/components/tflite-lib_20220716.zip b/code/components/tflite-lib_20220716.zip new file mode 100644 index 00000000..38814b7a Binary files /dev/null and b/code/components/tflite-lib_20220716.zip differ diff --git a/code/main/version.cpp b/code/main/version.cpp index a6add0f6..34b61171 100644 --- a/code/main/version.cpp +++ b/code/main/version.cpp @@ -1,4 +1,4 @@ -const char* GIT_REV="058e943"; +const char* GIT_REV="0b039e8"; const char* GIT_TAG=""; -const char* GIT_BRANCH="rolling"; -const char* BUILD_TIME="2022-07-16 07:55"; \ No newline at end of file +const char* GIT_BRANCH="espressif-latest"; +const char* BUILD_TIME="2022-07-16 20:42"; \ No newline at end of file diff --git a/code/platformio.ini b/code/platformio.ini index abf5731f..c26b91de 100644 --- a/code/platformio.ini +++ b/code/platformio.ini @@ -14,7 +14,7 @@ src_dir = main [env:esp32cam] -platform = espressif32@4.4.0 +platform = espressif32@4.4 ;platform = espressif32 board = esp32cam framework = espidf diff --git a/code/version.cpp b/code/version.cpp index 0b90033d..34b61171 100644 --- a/code/version.cpp +++ b/code/version.cpp @@ -1,4 +1,4 @@ -const char* GIT_REV="058e943"; +const char* GIT_REV="0b039e8"; const char* GIT_TAG=""; -const char* GIT_BRANCH="rolling"; -const char* BUILD_TIME="2022-07-16 07:54"; \ No newline at end of file +const char* GIT_BRANCH="espressif-latest"; +const char* BUILD_TIME="2022-07-16 20:42"; \ No newline at end of file diff --git a/firmware/bootloader.bin b/firmware/bootloader.bin index d13bbe16..3705e66f 100644 Binary files a/firmware/bootloader.bin and b/firmware/bootloader.bin differ diff --git a/firmware/firmware.bin b/firmware/firmware.bin index 149d64f4..edb42ff0 100644 Binary files a/firmware/firmware.bin and b/firmware/firmware.bin differ