diff --git a/README.md b/README.md
index deb6f8a8..c6561a19 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,16 @@ In other cases you can contact the developer via email:
-
+#include "esp_nn_defs.h"
/************************** Basic math functions ****************************/
/**
@@ -81,28 +80,15 @@ void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
* optimization notes: Though input_offset is int32 type,
* offset values are contained in 8 bits [-128, 127]
*/
-void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
- const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max);
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data);
/**
* @brief 2d-convolution channelwise
@@ -112,43 +98,26 @@ void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
* inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
-void esp_nn_conv_s8_ansi(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
+void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const uint16_t out_channels,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max);
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data);
-int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_ch,
- const uint16_t out_ch,
- const uint16_t filter_wd,
- const uint16_t filter_ht);
+int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const conv_params_t *conv_params);
void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
-int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const uint16_t ch_mult,
- const uint16_t filter_wd,
- const uint16_t filter_ht);
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const dw_conv_params_t *conv_params);
void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
/************************** Activation functions *****************************/
@@ -252,9 +221,6 @@ int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t h
*/
void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
-/* ANSI C function to be hooked up when optimised version needed */
-void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
-
/**
* @brief reference softmax function
*
@@ -268,6 +234,66 @@ void esp_nn_softmax_s8_ansi(const int8_t *input_data,
const int32_t diff_min,
int8_t *output_data);
+
+//////////////////////////// Generic optimisations /////////////////////////////
+
+/************************** Convolution functions *****************************/
+
+/**
+ * @brief 2d-convolution channelwise optimized version
+ *
+ * @note operation: result += (input + offset) * filter
+ *
+ * inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ const data_dims_t *output_dims,
+ int8_t *out_data,
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data);
+
+/**
+ * @brief depthwise convolution per channel optimized version
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * Version used in tflite is per channel.
+ * This version follows the same footsprints.
+ * Meaning, it has per out_channel shift and multiplier for
+ * requantization
+ *
+ * optimization notes: Though input_offset is int32 type,
+ * offset values are contained in 8 bits [-128, 127]
+ */
+void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ const data_dims_t *output_dims,
+ int8_t *out_data,
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const conv_params_t *conv_params);
+void esp_nn_set_conv_scratch_buf_opt(const void *buf);
+
+int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const dw_conv_params_t *conv_params);
+void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf);
+
+/* ANSI C function to be hooked up when optimised version needed */
+void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
+
/**
* @brief optimised version of softmax function
*
diff --git a/code/components/esp-nn/include/esp_nn_defs.h b/code/components/esp-nn/include/esp_nn_defs.h
new file mode 100644
index 00000000..756d8e6f
--- /dev/null
+++ b/code/components/esp-nn/include/esp_nn_defs.h
@@ -0,0 +1,83 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+
+/**
+ * @brief structure to club data dims
+ * this structure can be used for input, output and filter
+ */
+typedef struct data_dims {
+ int32_t width;
+ int32_t height;
+ int32_t channels;
+
+ int32_t extra; // can be used as batch or any other param
+} data_dims_t;
+
+/**
+ * @brief 2d data structure (width, height)
+ *
+ */
+typedef struct data_2d {
+ int32_t width;
+ int32_t height;
+} data_2d_t;
+
+/**
+ * @brief min/max activation
+ */
+typedef struct act_params {
+ int32_t min;
+ int32_t max;
+} act_params_t;
+
+/**
+ * @brief per channel quant data
+ *
+ * @note number of shift and mult elements are equal to output channels
+ */
+typedef struct quant_data {
+ int32_t *shift;
+ int32_t *mult;
+} quant_data_t;
+
+/**
+ * @brief params specific to convolution 2d
+ *
+ */
+typedef struct conv_params {
+ int32_t in_offset;
+ int32_t out_offset;
+ data_2d_t stride;
+ data_2d_t padding;
+ data_2d_t dilation;
+ act_params_t activation;
+} conv_params_t;
+
+/**
+ * @brief params specific to depthwise convolution 2d
+ *
+ */
+typedef struct dw_conv_params {
+ int32_t in_offset;
+ int32_t out_offset;
+ int32_t ch_mult; // channel multiplier. (in_ch * ch_mult = out_ch)
+ data_2d_t stride;
+ data_2d_t padding;
+ data_2d_t dilation;
+ act_params_t activation;
+} dw_conv_params_t;
diff --git a/code/components/esp-nn/include/esp_nn_esp32s3.h b/code/components/esp-nn/include/esp_nn_esp32s3.h
index 58b544e4..0f52c943 100644
--- a/code/components/esp-nn/include/esp_nn_esp32s3.h
+++ b/code/components/esp-nn/include/esp_nn_esp32s3.h
@@ -19,7 +19,7 @@
#pragma once
-#include
+#include "esp_nn_defs.h"
#include "esp_nn_ansi_headers.h"
/************************** Basic math functions *****************************/
@@ -85,28 +85,15 @@ void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
* optimization notes: Though input_offset is int32 type,
* offset values are contained in 8 bits [-128, 127]
*/
-void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
- const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
- int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max);
+ const data_dims_t *output_dims,
+ int8_t *output_data,
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data);
/**
* @brief 2d - convolution channelwise
@@ -116,43 +103,26 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
* inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
-void esp_nn_conv_s8_esp32s3(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
+void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
- int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const uint16_t out_channels,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max);
+ const data_dims_t *output_dims,
+ int8_t *output_data,
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data);
-int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_ch,
- const uint16_t out_ch,
- const uint16_t filter_wd,
- const uint16_t filter_ht);
+int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const conv_params_t *conv_params);
void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
-int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const uint16_t ch_mult,
- const uint16_t filter_wd,
- const uint16_t filter_ht);
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const dw_conv_params_t *conv_params);
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
/************************** Pooling functions *****************************/
diff --git a/code/components/esp-nn/include/esp_nn_esp32.h b/code/components/esp-nn/include/esp_nn_generic_opt.h
similarity index 77%
rename from code/components/esp-nn/include/esp_nn_esp32.h
rename to code/components/esp-nn/include/esp_nn_generic_opt.h
index 03fd8216..136cba5d 100644
--- a/code/components/esp-nn/include/esp_nn_esp32.h
+++ b/code/components/esp-nn/include/esp_nn_generic_opt.h
@@ -13,28 +13,27 @@
// limitations under the License.
/**
- * @file Header definitions to include for esp_nn optimized functions for
- * the ESP32 platform.
- * We are hooking up just the C versions for now.
- * The file hence is exactly same as `esp_nn_ansi_c.h`
+ * @file Header definitions to include for esp_nn generic optimisations
+ * For functions which not having optimisations, _ansi versions are picked.
*/
#pragma once
+#include "esp_nn_defs.h"
#include "esp_nn_ansi_headers.h"
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
-#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
+#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_opt
-#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
+#define esp_nn_conv_s8 esp_nn_conv_s8_opt
-#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
-#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
+#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_opt
+#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_opt
-#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
-#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
+#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_opt
+#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_opt
#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
diff --git a/code/components/esp-nn/src/common/common_functions.h b/code/components/esp-nn/src/common/common_functions.h
index 9a5f0dcc..0a74eca4 100644
--- a/code/components/esp-nn/src/common/common_functions.h
+++ b/code/components/esp-nn/src/common/common_functions.h
@@ -41,15 +41,39 @@
__NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
{
+#if CONFIG_IDF_TARGET_ARCH_XTENSA
__asm__ volatile("nsau %0, %0" : "+r" (in));
return in;
-}
-
-__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
-{
- int32_t sign = (int32_t) (val64 >> 63);
- int32_t to_add = sign & ((1ul << 31) - 1);
- return (int32_t) ((int64_t) (val64 + to_add) >> 31);
+#elif defined(__GNUC__)
+ return __builtin_clz(in);
+#else
+ int32_t count = 32;
+ uint32_t x = in, y = in >> 16;
+ if (y != 0) {
+ count -= 16;
+ x = y;
+ }
+ y = x >> 8;
+ if (y != 0) {
+ count -= 8;
+ x = y;
+ }
+ y = x >> 4;
+ if (y != 0) {
+ count -= 4;
+ x = y;
+ }
+ y = x >> 2;
+ if (y != 0) {
+ count -= 2;
+ x = y;
+ }
+ y = x >> 1;
+ if (y != 0) {
+ return count - 2;
+ }
+ return count - x;
+#endif
}
/**
@@ -57,8 +81,19 @@ __NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
*/
__NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
{
+#if CONFIG_IDF_TARGET_ARCH_XTENSA
__asm__ volatile("clamps %0, %0, 7" : "+a"(in));
return in;
+#else
+ return max(INT8_MIN, min(in, INT8_MAX));
+#endif
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
+{
+ int32_t sign = (int32_t) (val64 >> 63);
+ int32_t to_add = sign & ((1ul << 31) - 1);
+ return (int32_t) ((int64_t) (val64 + to_add) >> 31);
}
__NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
@@ -144,7 +179,7 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
const uint16_t pad_ht)
{
/* memset with pad_val */
- memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels * 2);
+ memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels);
dst += (pad_wd + input_wd + pad_wd) * channels;
for (int i = 0; i < input_ht; i++) {
@@ -156,7 +191,6 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
}
}
-#if 0
static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
const uint16_t input_wd,
const uint16_t input_ht,
@@ -169,13 +203,16 @@ static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
for (int j = 0; j < input_wd * channels; j++) {
*dst++ = *src++;
}
- memset(dst, pad_val, pad_wd * channels);
- dst += pad_wd * channels;
+ if (pad_wd) {
+ memset(dst, pad_val, pad_wd * channels);
+ dst += pad_wd * channels;
+ }
}
/* pad end `pad_ht` lines at end */
- memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+ if (pad_ht) {
+ memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+ }
}
-#endif
/**
* @brief convert 8 bit input data to 16 bit
diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c b/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
index d04f78e1..677c0ad8 100644
--- a/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
+++ b/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
@@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include
+#include
#include
-int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_ch,
- const uint16_t out_ch,
- const uint16_t filter_wd,
- const uint16_t filter_ht)
+int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const conv_params_t *conv_params)
{
return 0;
}
@@ -108,29 +106,35 @@ void esp_nn_conv_u8_ansi(const uint8_t *input_data,
* Assumption 2: Pointers are valid
* Assumption 3: dialation width = 1
*/
-void esp_nn_conv_s8_ansi(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
+void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const uint16_t out_channels,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max)
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t in_channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const uint16_t out_channels = output_dims->channels;
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c b/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
index ea8fdfa5..e13129b2 100644
--- a/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
+++ b/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
@@ -12,30 +12,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include
#include
+#include
#include
static int16_t *scratch_buffer = NULL;
-extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_channels,
- const int32_t input_offset,
- const int16_t *filter_data,
- const int32_t *bias,
- int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const uint16_t out_channels,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max,
- void *buffer /* scratch buffer */);
+extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const int8_t *filter_aligned,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ void *buffer /* scratch buffer */);
extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
@@ -81,34 +81,40 @@ extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int1
extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
-static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
+static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const uint16_t out_channels,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max)
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t in_ch = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const uint16_t out_ch = output_dims->channels;
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
- for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) {
int32_t conv_out = 0;
const int32_t base_y = stride_ht * out_y - pad_ht;
@@ -124,10 +130,10 @@ static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t in_row = base_y + filter_y_idx;
const int32_t in_col = base_x + filter_x_idx;
- int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
- int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
- (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
- for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+ int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch;
+ int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd +
+ (filter_y_idx * filter_wd + filter_x_idx) * in_ch;
+ for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) {
conv_out +=
(input_data[input_base_offset + in_ch_idx] + input_offset) *
filter_data[filter_base_offset + in_ch_idx];
@@ -332,18 +338,35 @@ static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
}
}
-int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t in_ch,
- const uint16_t out_ch,
- const uint16_t filter_wd,
- const uint16_t filter_ht)
+int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const conv_params_t *conv_params)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t in_ch = input_dims->channels;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_ch = output_dims->channels;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+
int filter_size = filter_wd * filter_ht * in_ch * out_ch;
int input_size = input_wd * input_ht * in_ch;
- int transpose_buf_size = 8 * in_ch; /* to store intermediate data */
+
+ int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */
+ if (input_wd * input_ht < 8) {
+ transpose_buf_size = 0; // not using this for leftover
+ }
int align_buf_size = 32; /* extra buffer for alignment */
- return 2 * (filter_size + input_size + transpose_buf_size) + align_buf_size;
+ if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
+ pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
+ return filter_size + transpose_buf_size + align_buf_size;
+ }
+ return 2 * (filter_size + input_size) + transpose_buf_size + align_buf_size;
}
void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
@@ -351,29 +374,35 @@ void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
scratch_buffer = (int16_t *) buf;
}
-void esp_nn_conv_s8_esp32s3(const int8_t *input,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
+void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
+ const int8_t *input,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const uint16_t out_channels,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max)
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const uint16_t out_channels = output_dims->channels;
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
int filter_size = filter_wd * filter_ht * channels * out_channels;
int input_size = input_wd * input_ht * channels;
int align_len = 16 - (filter_size & 15);
@@ -387,15 +416,16 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
- int scratch_offset = (int) (filter_data16 + filter_size);
+ int8_t *filter_aligned = (int8_t *) scratch_buffer;
+ int scratch_offset = (int) (filter_aligned + filter_size);
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
- esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
- esp_nn_conv_s16_mult8_1x1_esp32s3(
- input, input_wd, input_ht, channels, input_offset, filter_data16,
+ memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address
+ esp_nn_conv_s8_mult8_1x1_esp32s3(
+ input, input_wd, input_ht, channels, input_offset, filter_aligned,
bias, out_data, out_wd, out_ht, out_channels, out_offset,
out_shift, out_mult, activation_min, activation_max, scratch_buf);
} else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
- (input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */
+ (input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
int scratch_offset = (int) (input_data16 + input_size);
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
@@ -427,10 +457,7 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
}
} else {
/* Basic unrolled version */
- esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset,
- pad_wd, pad_ht, stride_wd, stride_ht,
- filter_data, filter_wd, filter_ht, bias,
- out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
- out_mult, activation_min, activation_max);
+ esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data,
+ bias, output_dims, out_data, conv_params, quant_data);
}
}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_opt.c b/code/components/esp-nn/src/convolution/esp_nn_conv_opt.c
new file mode 100644
index 00000000..be96430e
--- /dev/null
+++ b/code/components/esp-nn/src/convolution/esp_nn_conv_opt.c
@@ -0,0 +1,179 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const conv_params_t *conv_params)
+{
+ return 0;
+}
+
+void esp_nn_set_conv_scratch_buf_opt(const void *buf)
+{
+
+}
+
+__attribute__ ((noinline))
+static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ const data_dims_t *output_dims,
+ int8_t *out_data,
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data)
+{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t in_channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const uint16_t out_channels = output_dims->channels;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
+ for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) {
+ for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) {
+ const int32_t *out_mult = quant_data->mult;
+ const int32_t *out_shift = quant_data->shift;
+ const int8_t *filter_ptr = filter_data;
+ const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels;
+ int32_t out_ch_idx = 0;
+ for (; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+
+ const int8_t *input_ptr = input_base_ptr;
+
+ int32_t in_ch_idx = 0;
+ for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ }
+ for (; in_ch_idx < in_channels; in_ch_idx ++) {
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ }
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
+
+/**
+ * Assumption 1: i/p channels == o/p channels
+ * Assumption 2: Pointers are valid
+ * Assumption 3: dialation width = 1
+ */
+void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ const data_dims_t *output_dims,
+ int8_t *out_data,
+ const conv_params_t *conv_params,
+ const quant_data_t *quant_data)
+{
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+
+ if (filter_wd == 1 && filter_ht == 1) {
+ esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias,
+ output_dims, out_data, conv_params, quant_data);
+ return;
+ }
+
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t in_channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const uint16_t out_channels = output_dims->channels;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
+ int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx;
+
+ for (out_y = 0; out_y < out_ht; out_y++) {
+ for (out_x = 0; out_x < out_wd; out_x++) {
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+
+ const int32_t base_y = stride_ht * out_y - pad_ht;
+ const int32_t base_x = stride_wd * out_x - pad_wd;
+
+ const int32_t filter_y_start = max(0, -base_y);
+ const int32_t filter_x_start = max(0, -base_x);
+
+ const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+ const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t in_row = base_y + filter_y_idx;
+ const int32_t in_col = base_x + filter_x_idx;
+
+ const int8_t *input_ptr = input_data +
+ (in_row * input_wd + in_col) * in_channels;
+ const int8_t *filter_ptr = filter_data +
+ out_ch_idx * in_channels * filter_ht * filter_wd +
+ (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
+ int32_t in_ch_idx = 0;
+ for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ }
+ for (; in_ch_idx < in_channels; in_ch_idx ++) {
+ conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+ }
+ }
+ }
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
index 9cac6cef..1cd02e0f 100644
--- a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
+++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
@@ -12,16 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include
-
+#include
#include
-int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const uint16_t ch_mult,
- const uint16_t filter_wd,
- const uint16_t filter_ht)
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const dw_conv_params_t *conv_params)
{
return 0;
}
@@ -31,29 +28,35 @@ void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf)
}
-void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
- const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max)
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+ const uint16_t ch_mult = conv_params->ch_mult;
+
int out_idx = 0;
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c
new file mode 100644
index 00000000..4afea3f3
--- /dev/null
+++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c
@@ -0,0 +1,291 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+
+int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const dw_conv_params_t *conv_params)
+{
+ return 0;
+}
+
+void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf)
+{
+
+}
+
+/* common channel multiplier == 1 case */
+__attribute__ ((noinline))
+static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ const data_dims_t *output_dims,
+ int8_t *out_data,
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data)
+{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
+ int out_idx = 0;
+ for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+ const int16_t base_y = (out_y * stride_ht) - pad_ht;
+ for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+ const int16_t base_x = (out_x * stride_wd) - pad_wd;
+
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ int ch_idx = 0;
+ for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop
+ int32_t result0 = 0;
+ int32_t result1 = 0;
+ int32_t result2 = 0;
+ int32_t result3 = 0;
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
+ int32_t input_val0 = input_data[input_index + 0] + input_offset;
+ int32_t input_val1 = input_data[input_index + 1] + input_offset;
+ int32_t input_val2 = input_data[input_index + 2] + input_offset;
+ int32_t input_val3 = input_data[input_index + 3] + input_offset;
+ int32_t filter_val0 = filter_data[filter_index + 0];
+ int32_t filter_val1 = filter_data[filter_index + 1];
+ int32_t filter_val2 = filter_data[filter_index + 2];
+ int32_t filter_val3 = filter_data[filter_index + 3];
+ result0 += input_val0 * filter_val0;
+ result1 += input_val1 * filter_val1;
+ result2 += input_val2 * filter_val2;
+ result3 += input_val3 * filter_val3;
+ }
+ }
+ if (bias) {
+ result0 += bias[ch_idx + 0];
+ result1 += bias[ch_idx + 1];
+ result2 += bias[ch_idx + 2];
+ result3 += bias[ch_idx + 3];
+ }
+ result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
+ result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
+ result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
+ result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
+
+ result0 += out_offset;
+ result1 += out_offset;
+ result2 += out_offset;
+ result3 += out_offset;
+
+ result0 = max(result0, activation_min);
+ result1 = max(result1, activation_min);
+ result2 = max(result2, activation_min);
+ result3 = max(result3, activation_min);
+
+ result0 = min(result0, activation_max);
+ result1 = min(result1, activation_max);
+ result2 = min(result2, activation_max);
+ result3 = min(result3, activation_max);
+
+ out_data[out_idx++] = result0;
+ out_data[out_idx++] = result1;
+ out_data[out_idx++] = result2;
+ out_data[out_idx++] = result3;
+ }
+ for (; ch_idx < channels; ch_idx++) {//channel_loop
+ int32_t result = 0;
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val = filter_data[filter_index];
+ result += input_val * filter_val;
+ }
+ }
+ if (bias) {
+ result += bias[ch_idx];
+ }
+ result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ out_data[out_idx++] = result;
+ }
+ }
+ }
+}
+
+void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ const data_dims_t *output_dims,
+ int8_t *out_data,
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data)
+{
+ const uint16_t ch_mult = conv_params->ch_mult;
+ if (ch_mult == 1) {
+ esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data,
+ bias, output_dims, out_data, conv_params, quant_data);
+ return;
+ }
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+
+ int out_idx = 0;
+ for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+ const int16_t base_y = (out_y * stride_ht) - pad_ht;
+ for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+ const int16_t base_x = (out_x * stride_wd) - pad_wd;
+
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
+ int ch_mult_idx = 0;
+ for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
+ int32_t result0 = 0;
+ int32_t result1 = 0;
+ int32_t result2 = 0;
+ int32_t result3 = 0;
+ const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val0 = filter_data[filter_index + 0];
+ int32_t filter_val1 = filter_data[filter_index + 1];
+ int32_t filter_val2 = filter_data[filter_index + 2];
+ int32_t filter_val3 = filter_data[filter_index + 3];
+ result0 += input_val * filter_val0;
+ result1 += input_val * filter_val1;
+ result2 += input_val * filter_val2;
+ result3 += input_val * filter_val3;
+ }
+ }
+ if (bias) {
+ result0 += bias[out_ch_idx + 0];
+ result1 += bias[out_ch_idx + 1];
+ result2 += bias[out_ch_idx + 2];
+ result3 += bias[out_ch_idx + 3];
+ }
+ result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
+ result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
+ result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
+ result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
+
+ result0 += out_offset;
+ result1 += out_offset;
+ result2 += out_offset;
+ result3 += out_offset;
+
+ result0 = max(result0, activation_min);
+ result1 = max(result1, activation_min);
+ result2 = max(result2, activation_min);
+ result3 = max(result3, activation_min);
+ result0 = min(result0, activation_max);
+ result1 = min(result1, activation_max);
+ result2 = min(result2, activation_max);
+ result3 = min(result3, activation_max);
+
+ out_data[out_idx++] = result0;
+ out_data[out_idx++] = result1;
+ out_data[out_idx++] = result2;
+ out_data[out_idx++] = result3;
+ }
+ for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
+ int32_t result = 0;
+ const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val = filter_data[filter_index];
+ result += input_val * filter_val;
+ }
+ }
+ if (bias) {
+ result += bias[out_ch_idx];
+ }
+ result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ out_data[out_idx++] = result;
+ }
+ }
+ }
+ }
+}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
index c588c48f..9167a43f 100644
--- a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
+++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include
#include
+#include
#include
@@ -353,17 +353,59 @@ void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data,
}
}
-int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const uint16_t ch_mult,
- const uint16_t filter_wd,
- const uint16_t filter_ht)
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+ const data_dims_t *filter_dims,
+ const data_dims_t *output_dims,
+ const dw_conv_params_t *conv_params)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t channels = input_dims->channels;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t ch_mult = conv_params->ch_mult;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+
int filter_size = filter_wd * filter_ht * channels * ch_mult;
- int padding_used = ((filter_wd == 3) && (filter_ht == 3)) * 2;
- int input_size = (input_wd + padding_used) * (input_ht + padding_used) * channels;
- return 2 * (filter_size + input_size) + 16; //16 for alignment
+ int pad_width = 0, pad_height = 0;
+
+ if ((ch_mult == 1) && (channels % 8 == 0) && (filter_wd == 3) && (filter_ht == 3)) {
+ if (channels % 16 == 0) {
+ if (pad_wd || pad_ht) {
+ pad_width = pad_wd * 2;
+ pad_height = pad_ht * 2;
+ } else {
+ // check if we need to pad additionally
+ pad_width = (out_wd * stride_wd + filter_wd - 1) - input_wd;
+ pad_height = (out_ht * stride_ht + filter_ht - 1) - input_ht;
+ // printf("in(%d %d %d), out(%d %d), filter (%d %d) stride (%d %d), pad (%d %d)",
+ // input_wd, input_ht, channels, out_wd, out_ht, filter_wd, filter_ht,
+ // stride_wd, stride_ht, pad_wd, pad_ht);
+ }
+ if (pad_width || pad_height) {
+ int input_size = (input_wd + pad_width) * (input_ht + pad_height) * channels;
+ // printf("ask1 %d\n", filter_size + input_size + 16);
+ return filter_size + input_size + 16; // 16 for alignment
+ } else {
+ // printf("ask2 %d\n", filter_size + 16);
+ return filter_size + 16; // 16 for alignment
+ }
+ } else {
+ int input_size = input_wd * input_ht * channels;
+ // printf("ask3 %d\n", 2 * (filter_size + input_size) + 16);
+ return 2 * (filter_size + input_size) + 16; // 16 for alignment
+ }
+ } else if (ch_mult % 4 == 0) {
+ int input_size = input_wd * input_ht * channels;
+ // printf("ask4 %d\n", 2 * (filter_size + input_size) + 16);
+ return 2 * (filter_size + input_size) + 16; // 16 for alignment
+ }
+ return 32; // just few bytes
}
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
@@ -376,29 +418,38 @@ void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
* Assumption 2: Pointers are valid
* Assumption 3: dialation width = 1
*/
-void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
- const uint16_t input_wd,
- const uint16_t input_ht,
- const uint16_t channels,
- const int32_t input_offset,
- const uint16_t pad_wd,
- const uint16_t pad_ht,
- const uint16_t stride_wd,
- const uint16_t stride_ht,
- const uint16_t ch_mult,
+
+
+
+void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
+ const int8_t *input_data,
+ const data_dims_t *filter_dims,
const int8_t *filter_data,
- const uint16_t filter_wd,
- const uint16_t filter_ht,
const int32_t *bias,
+ const data_dims_t *output_dims,
int8_t *out_data,
- const uint16_t out_wd,
- const uint16_t out_ht,
- const int32_t out_offset,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t activation_min,
- const int32_t activation_max)
+ const dw_conv_params_t *conv_params,
+ const quant_data_t *quant_data)
{
+ const uint16_t input_wd = input_dims->width;
+ const uint16_t input_ht = input_dims->height;
+ const uint16_t channels = input_dims->channels;
+ const int32_t input_offset = conv_params->in_offset;
+ const int32_t out_offset = conv_params->out_offset;
+ const uint16_t pad_wd = conv_params->padding.width;
+ const uint16_t pad_ht = conv_params->padding.height;
+ const uint16_t stride_wd = conv_params->stride.width;
+ const uint16_t stride_ht = conv_params->stride.height;
+ const uint16_t filter_wd = filter_dims->width;
+ const uint16_t filter_ht = filter_dims->height;
+ const uint16_t out_wd = output_dims->width;
+ const uint16_t out_ht = output_dims->height;
+ const int32_t *out_shift = quant_data->shift;
+ const int32_t *out_mult = quant_data->mult;
+ const int32_t activation_min = conv_params->activation.min;
+ const int32_t activation_max = conv_params->activation.max;
+ const uint16_t ch_mult = conv_params->ch_mult;
+
int filter_size = filter_wd * filter_ht * channels * ch_mult;
int align_len = 16 - (filter_size & 15);
int input_size = input_wd * input_ht * channels;
@@ -423,18 +474,27 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
stride_wd, stride_ht, filter_aligned, bias,
out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
- } else if ((pad_wd == 0) && (pad_ht == 0) &&
- // because this does not handle padding offset cases yet, run just for stride (1, 1).
- // end padding of input with `-input_offset` should solve this
- (stride_wd == 1) && (stride_ht == 1)) {
+ } else if ((channels % 16 == 0) && (pad_wd == 0) && (pad_ht == 0)) {
/* process in 8 bits */
int8_t *filter_aligned = (int8_t *) scratch_buffer;
+ int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
+
+ // check if we need to pad additionally
+ int pad_right = (out_wd * stride_wd + filter_wd - 1) - input_wd;
+ int pad_bottom = (out_ht * stride_ht + filter_ht - 1) - input_ht;
+ if (pad_right || pad_bottom) { // pad right and bottom
+ esp_nn_aligned_s8_pad_end_with_value(input_data, input_padded, input_wd, input_ht,
+ channels, -input_offset, pad_right, pad_bottom);
+ } else {
+ input_padded = (int8_t *) input_data;
+ }
memcpy(filter_aligned, filter_data, filter_size);
- esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_data, input_wd, input_ht, channels, input_offset,
- stride_wd, stride_ht, filter_aligned,
- bias, out_data, out_wd, out_ht, out_offset, out_shift,
+ esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + pad_right,
+ input_ht + pad_bottom, channels, input_offset,
+ stride_wd, stride_ht, filter_aligned, bias,
+ out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
- } else { /* (channels % 8) == 0 && pad_wd == 1 && pad_ht == 1 */
+ } else { /* (channels % 8) == 0 */
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
diff --git a/code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3 b/code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3
new file mode 100644
index 00000000..1adc4b01
--- /dev/null
+++ b/code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3
@@ -0,0 +1,8 @@
+# Default configurations for ESP32-S3
+
+CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y
+CONFIG_ESP32S3_SPIRAM_SUPPORT=y
+
+CONFIG_ESP32S3_DATA_CACHE_64KB=y
+CONFIG_ESP32S3_DATA_CACHE_8WAYS=y
+CONFIG_ESP32S3_DATA_CACHE_LINE_64B=y
diff --git a/code/components/esp-nn/tests/src/basic_math_test.c b/code/components/esp-nn/tests/src/basic_math_test.c
index 5b96b990..715d7c78 100644
--- a/code/components/esp-nn/tests/src/basic_math_test.c
+++ b/code/components/esp-nn/tests/src/basic_math_test.c
@@ -23,7 +23,9 @@
#include "test_utils.h"
#if CONFIG_IDF_CMAKE
+#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
#define IDF_HEAP_CAPS 1
+#endif
#if IDF_HEAP_CAPS
#include "esp_heap_caps.h"
@@ -138,6 +140,11 @@ void esp_nn_add_elementwise_s8_test()
out_c_orig = out_data_c;
out_opt_orig = out_data_opt;
#endif
+ if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
+ out_opt_orig == NULL) {
+ printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto elementwise_add_test_cleanup;
+ }
for (int i = 0; i < size; ++i) {
input1[i] = rand() % 256 - 128;
@@ -194,10 +201,10 @@ elementwise_add_test_cleanup:
if (input2_orig) {
free(input2_orig);
}
- if (out_data_c) {
+ if (out_c_orig) {
free(out_c_orig);
}
- if (out_data_opt) {
+ if (out_opt_orig) {
free(out_opt_orig);
}
}
@@ -282,6 +289,11 @@ void esp_nn_mul_elementwise_s8_test()
out_c_orig = out_data_c;
out_opt_orig = out_data_opt;
#endif
+ if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
+ out_opt_orig == NULL) {
+ printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto elementwise_mult_test_cleanup;
+ }
for (int i = 0; i < size; ++i) {
input1[i] = rand() % 256 - 128;
@@ -333,10 +345,10 @@ elementwise_mult_test_cleanup:
if (input2_orig) {
free(input2_orig);
}
- if (out_data_c) {
+ if (out_c_orig) {
free(out_c_orig);
}
- if (out_data_opt) {
+ if (out_opt_orig) {
free(out_opt_orig);
}
}
diff --git a/code/components/esp-nn/tests/src/convolution_test.c b/code/components/esp-nn/tests/src/convolution_test.c
index f3802257..c86bdbab 100644
--- a/code/components/esp-nn/tests/src/convolution_test.c
+++ b/code/components/esp-nn/tests/src/convolution_test.c
@@ -22,8 +22,9 @@
#include "test_utils.h"
#if CONFIG_IDF_CMAKE
+#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
#define IDF_HEAP_CAPS 1
-
+#endif
#if IDF_HEAP_CAPS
#include "esp_heap_caps.h"
#endif
@@ -44,8 +45,8 @@ void esp_nn_depthwise_conv_s8_test()
uint16_t filter_ht, filter_wd, ch_mult;
uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
- // run for 10 iterations
- for (int itr = 0; itr < 10; itr++) {
+ // run for 15 iterations
+ for (int itr = 0; itr < 15; itr++) {
/* prepare data */
switch (itr) {
case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
@@ -144,22 +145,52 @@ void esp_nn_depthwise_conv_s8_test()
stride_wd = 2;
stride_ht = 2;
break;
- default:
- input_wd = 4;
- input_ht = 4;
+ case 8: // same as case 7, with large parameters
+ input_wd = 58;
+ input_ht = 58;
filter_ht = 3;
filter_wd = 3;
- ch_mult = 4;
- channels = 4;
- pad_wd = 1;
- pad_ht = 1;
- stride_wd = 1;
- stride_ht = 1;
+ ch_mult = 1;
+ channels = 128;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 2;
+ stride_ht = 2;
+ break;
+ case 9: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0) stride (2,2)
+ input_wd = 6;
+ input_ht = 6;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 16;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 2;
+ stride_ht = 2;
+ break;
+ default:
+ input_wd = 6;
+ input_ht = 6;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 16;
+ stride_wd = rand() % 2 + 1;
+ stride_ht = stride_wd;
+ pad_wd = stride_wd == 1 ? 0 : rand() % 2;
+ pad_ht = pad_wd;
+ printf("stride(%d), pad (%d)\t", stride_wd, pad_wd);
break;
}
uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
+ if (itr == 9) {
+ // expect the function to handle this gracefully
+ out_wd += 1;
+ out_ht += 1;
+ }
int in_size = input_wd * input_ht * channels;
int out_size = out_wd * out_ht * channels * ch_mult;
int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
@@ -210,9 +241,16 @@ void esp_nn_depthwise_conv_s8_test()
out_mult[i] = 0x7eb0e200 + rand() % 50;
}
- int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(input_wd, input_ht,
- channels, ch_mult,
- filter_wd, filter_ht);
+ data_dims_t input_dims = {.width = input_wd, .height = input_ht, .channels = channels, 1};
+ data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = channels * ch_mult, 1};
+ data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
+ dw_conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset, .ch_mult = ch_mult,
+ .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
+ .dilation = {0, 0}, .activation = {activation_min, activation_max}};
+ quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
+
+ int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(&input_dims, &filter_dims,
+ &output_dims, &conv_params);
if (scratch_buf_size > 0) {
#if IDF_HEAP_CAPS
scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
@@ -234,11 +272,8 @@ void esp_nn_depthwise_conv_s8_test()
}
/* C function */
- esp_nn_depthwise_conv_s8_ansi(input, input_wd, input_ht, channels, input_offset,
- pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
- filter_data + 4, filter_wd, filter_ht,
- bias + 1, out_data_c, out_wd, out_ht, out_offset, out_shift,
- out_mult, activation_min, activation_max);
+ esp_nn_depthwise_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 4,
+ bias + 1, &output_dims, out_data_c, &conv_params, &quant_data);
if (itr == 0) {
profile_c_end();
@@ -246,11 +281,8 @@ void esp_nn_depthwise_conv_s8_test()
}
/* Optimized function */
- esp_nn_depthwise_conv_s8(input, input_wd, input_ht, channels, input_offset,
- pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
- filter_data + 4, filter_wd, filter_ht,
- bias + 1, out_data_opt, out_wd, out_ht, out_offset, out_shift,
- out_mult, activation_min, activation_max);
+ esp_nn_depthwise_conv_s8(&input_dims, input, &filter_dims, filter_data + 4,
+ bias + 1, &output_dims, out_data_opt, &conv_params, &quant_data);
if (itr == 0) {
/* disable profiler */
@@ -479,8 +511,16 @@ void esp_nn_conv_s8_test()
out_mult[i] = 0x7f67f4f8 + rand() % 50;
}
- int scratch_buf_size = esp_nn_get_conv_scratch_size(in_wd, in_ht, in_channels,
- out_channels, filter_wd, filter_ht);
+ data_dims_t input_dims = {.width = in_wd, .height = in_ht, .channels = in_channels, 1};
+ data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = out_channels, 1};
+ data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
+ conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset,
+ .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
+ .dilation = {0, 0}, .activation = {activation_min, activation_max}};
+ quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
+
+ int scratch_buf_size = esp_nn_get_conv_scratch_size(&input_dims, &filter_dims,
+ &output_dims, &conv_params);
if (scratch_buf_size > 0) {
#if IDF_HEAP_CAPS
void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
@@ -502,11 +542,8 @@ void esp_nn_conv_s8_test()
}
/* C function */
- esp_nn_conv_s8_ansi(input, in_wd, in_ht, in_channels, input_offset,
- pad_wd, pad_ht, stride_wd, stride_ht,
- filter_data + 2, filter_wd, filter_ht, bias,
- out_data_c, out_wd, out_ht, out_channels, out_offset, out_shift,
- out_mult, activation_min, activation_max);
+ esp_nn_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 2,
+ bias, &output_dims, out_data_c, &conv_params, &quant_data);
if (itr == 0) {
profile_c_end();
@@ -514,11 +551,8 @@ void esp_nn_conv_s8_test()
}
/* Optimized function */
- esp_nn_conv_s8(input, in_wd, in_ht, in_channels, input_offset,
- pad_wd, pad_ht, stride_wd, stride_ht,
- filter_data + 2, filter_wd, filter_ht, bias,
- out_data_opt, out_wd, out_ht, out_channels, out_offset, out_shift,
- out_mult, activation_min, activation_max);
+ esp_nn_conv_s8(&input_dims, input, &filter_dims, filter_data + 2,
+ bias, &output_dims, out_data_opt, &conv_params, &quant_data);
if (itr == 0) {
/* disable profiler */
diff --git a/code/components/esp-nn_20220724.zip b/code/components/esp-nn_20220724.zip
deleted file mode 100644
index 2bac7498..00000000
Binary files a/code/components/esp-nn_20220724.zip and /dev/null differ
diff --git a/code/components/esp-nn_20220716.zip b/code/components/esp-nn_20220827.zip
similarity index 97%
rename from code/components/esp-nn_20220716.zip
rename to code/components/esp-nn_20220827.zip
index 53c7bef2..43f16002 100644
Binary files a/code/components/esp-nn_20220716.zip and b/code/components/esp-nn_20220827.zip differ
diff --git a/code/components/esp32-camera-master.zip b/code/components/esp32-camera-master.zip
deleted file mode 100644
index 8706b3d8..00000000
Binary files a/code/components/esp32-camera-master.zip and /dev/null differ
diff --git a/code/components/esp32-camera-master_20220724.zip b/code/components/esp32-camera-master_20220724.zip
deleted file mode 100644
index 64f2f896..00000000
Binary files a/code/components/esp32-camera-master_20220724.zip and /dev/null differ
diff --git a/code/components/jomjol_controlcamera/ClassControllCamera.cpp b/code/components/jomjol_controlcamera/ClassControllCamera.cpp
index df42a2d7..ff410da3 100644
--- a/code/components/jomjol_controlcamera/ClassControllCamera.cpp
+++ b/code/components/jomjol_controlcamera/ClassControllCamera.cpp
@@ -263,6 +263,9 @@ void CCamera::EnableAutoExposure(int flashdauer)
ESP_LOGE(TAGCAMERACLASS, "Camera Capture Failed");
LEDOnOff(false);
LightOnOff(false);
+ LogFile.SwitchOnOff(true);
+ LogFile.WriteToFile("Camera Capture Failed (Procedure 'EnableAutoExposure') --> Reboot"
+ "Check that your camera module is working and connected properly.");
doReboot();
}
esp_camera_fb_return(fb);
@@ -313,7 +316,7 @@ esp_err_t CCamera::CaptureToBasisImage(CImageBasis *_Image, int delay)
LightOnOff(false);
LogFile.SwitchOnOff(true);
- LogFile.WriteToFile("Camera is not working anymore - most propably hardware problem (instablility, ...). "
+ LogFile.WriteToFile("Camera is not working anymore (CCamera::CaptureToBasisImage) - most propably hardware problem (instablility, ...). "
"System will reboot.");
doReboot();
@@ -410,6 +413,9 @@ esp_err_t CCamera::CaptureToFile(std::string nm, int delay)
ESP_LOGE(TAGCAMERACLASS, "CaptureToFile: Camera Capture Failed");
LEDOnOff(false);
LightOnOff(false);
+ LogFile.SwitchOnOff(true);
+ LogFile.WriteToFile("Camera Capture Failed (CCamera::CaptureToFile) --> Reboot"
+ "Check that your camera module is working and connected properly.");
doReboot();
return ESP_FAIL;
diff --git a/code/components/jomjol_fileserver_ota/server_ota.cpp b/code/components/jomjol_fileserver_ota/server_ota.cpp
index 88d301ae..c9b11b08 100644
--- a/code/components/jomjol_fileserver_ota/server_ota.cpp
+++ b/code/components/jomjol_fileserver_ota/server_ota.cpp
@@ -416,6 +416,8 @@ void task_reboot(void *pvParameter)
}
void doReboot(){
+ LogFile.SwitchOnOff(true);
+ LogFile.WriteToFile("Reboot triggert by Software (5s).");
ESP_LOGI(TAGPARTOTA, "Reboot in 5sec");
LogFile.WriteToFile("Reboot in 5sec");
xTaskCreate(&task_reboot, "reboot", configMINIMAL_STACK_SIZE * 64, NULL, 10, NULL);
@@ -435,7 +437,7 @@ esp_err_t handler_reboot(httpd_req_t *req)
LogFile.WriteToFile("handler_reboot");
ESP_LOGI(TAGPARTOTA, "!!! System will restart within 5 sec!!!");
- const char* resp_str = "!!! System will restart within 5 sec!!!";
+ const char* resp_str = " ";
httpd_resp_send(req, resp_str, strlen(resp_str));
doReboot();
diff --git a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
index 83243e30..da6ddfd5 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
+++ b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
@@ -42,7 +42,7 @@ string ClassFlowCNNGeneral::getReadout(int _analog = 0, bool _extendedResolution
int ergebnis_nachkomma = ((int) floor(zahl * 10) + 10) % 10;
prev = ZeigerEvalAnalogNeu(GENERAL[_analog]->ROI[GENERAL[_analog]->ROI.size() - 1]->result_float, prev);
- if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::getReadout(analog) zahl=" + std::to_string(zahl) + ", ergebnis_nachkomma=" + std::to_string(ergebnis_nachkomma) + ", prev=" + std::to_string(prev));
+// if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::getReadout(analog) zahl=" + std::to_string(zahl) + ", ergebnis_nachkomma=" + std::to_string(ergebnis_nachkomma) + ", prev=" + std::to_string(prev));
result = std::to_string(prev);
if (_extendedResolution && (CNNType != Digital))
@@ -82,8 +82,6 @@ string ClassFlowCNNGeneral::getReadout(int _analog = 0, bool _extendedResolution
result = std::to_string(ergebnis_vorkomma) + std::to_string(ergebnis_nachkomma);
prev = ergebnis_vorkomma;
if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::getReadout(dig100-ext) ergebnis_vorkomma=" + std::to_string(ergebnis_vorkomma) + ", ergebnis_nachkomma=" + std::to_string(ergebnis_nachkomma) + ", prev=" + std::to_string(prev));
-
-
}
else
{
@@ -129,6 +127,7 @@ string ClassFlowCNNGeneral::getReadout(int _analog = 0, bool _extendedResolution
return result;
}
+/*
int ClassFlowCNNGeneral::ZeigerEvalHybrid(float zahl, float zahl_vorgaenger, int eval_vorgaenger)
{
if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalHybrid( " + std::to_string(zahl) + ", " + std::to_string(zahl_vorgaenger) + ", " + std::to_string(eval_vorgaenger) + ")");
@@ -179,6 +178,7 @@ int ClassFlowCNNGeneral::ZeigerEvalHybrid(float zahl, float zahl_vorgaenger, int
return -1;
}
+*/
int ClassFlowCNNGeneral::ZeigerEvalHybridNeu(float zahl, float zahl_vorgaenger, int eval_vorgaenger, bool AnalogerVorgaenger)
{
@@ -200,28 +200,17 @@ int ClassFlowCNNGeneral::ZeigerEvalHybridNeu(float zahl, float zahl_vorgaenger,
if (AnalogerVorgaenger)
{
- if (zahl_vorgaenger <= DigitalAnalogerVorgaengerUebergangsbereich) // Nulldurchgang hat stattgefunden
- {
- result = (int) ((int) round(zahl) + 10) % 10;
- if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalHybridNeu - Analoger Vorgänger, Nulldurchgang stattgefunden = " + std::to_string(result) +
- " zahl: " + std::to_string(zahl) + " zahl_vorgaenger = " + std::to_string(zahl_vorgaenger)+ " eval_vorgaenger = " + std::to_string(eval_vorgaenger) + " DigitalUnschaerfe = " + std::to_string(DigitalUnschaerfe));
- return result;
- }
-
- if ((ergebnis_nachkomma <= 2) || (ergebnis_nachkomma >= 8)) // Band um die Ziffer --> Runden, da Ziffer im Rahmen Ungenauigkeit erreicht
- result = ((int) round(zahl) + 10) % 10;
- else
- result = ((int) trunc(zahl) + 10) % 10;
-
- if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalHybridNeu - Analoger Vorgänger, Nulldurchgang NICHT stattgefunden = " + std::to_string(result) +
+// result = ZeigerEvalAnalogToDigitNeu(zahl, eval_vorgaenger);
+ result = ZeigerEvalAnalogToDigitNeu(zahl, zahl_vorgaenger);
+ if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalHybridNeu - Analoger Vorgänger, Bewertung über ZeigerEvalAnalogNeu = " + std::to_string(result) +
" zahl: " + std::to_string(zahl) + " zahl_vorgaenger = " + std::to_string(zahl_vorgaenger)+ " eval_vorgaenger = " + std::to_string(eval_vorgaenger) + " DigitalUnschaerfe = " + std::to_string(DigitalUnschaerfe));
return result;
}
if ((zahl_vorgaenger >= DigitalUebergangsbereichVorgaenger ) && (zahl_vorgaenger <= (10.0 - DigitalUebergangsbereichVorgaenger)))
{
- // kein Ziffernwechsel, da Vorkomma weit genug weg ist (0+/-DigitalUebergangsbereichVorgaenger) --> zahl wird gerundet
- if ((ergebnis_nachkomma <= 2) || (ergebnis_nachkomma >= 8)) // Band um die Ziffer --> Runden, da Ziffer im Rahmen Ungenauigkeit erreicht
+ // kein Ziffernwechsel, da Vorgänger weit genug weg ist (0+/-DigitalUebergangsbereichVorgaenger) --> zahl wird gerundet
+ if ((ergebnis_nachkomma <= DigitalBand) || (ergebnis_nachkomma >= (10-DigitalBand))) // Band um die Ziffer --> Runden, da Ziffer im Rahmen Ungenauigkeit erreicht
result = ((int) round(zahl) + 10) % 10;
else
result = ((int) trunc(zahl) + 10) % 10;
@@ -256,6 +245,57 @@ int ClassFlowCNNGeneral::ZeigerEvalHybridNeu(float zahl, float zahl_vorgaenger,
}
+int ClassFlowCNNGeneral::ZeigerEvalAnalogToDigitNeu(float zahl, float ziffer_vorgaenger)
+{
+ int result;
+ int ergebnis_nachkomma = ((int) floor(zahl * 10)) % 10;
+ int ergebnis_vorkomma = ((int) floor(zahl) + 10) % 10;
+
+ if (ziffer_vorgaenger < 0)
+ {
+ result = (int) floor(zahl);
+ if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalAnalogToDigitNeu - kein Vorgänger - Ergebnis = " + std::to_string(result) +
+ " zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " AnalogFehler = " + std::to_string(AnalogFehler));
+ return result;
+ }
+
+ if ((ziffer_vorgaenger >= DigitalUebergangsbereichVorgaengerAnalogToDigit ) && (ziffer_vorgaenger <= (10.0 - DigitalUebergangsbereichVorgaengerAnalogToDigit)))
+ {
+ // kein Ziffernwechsel, da Vorgänger weit genug weg ist (0+/-DigitalUebergangsbereichVorgaenger) --> zahl wird gerundet
+ if ((ergebnis_nachkomma <= 2) || (ergebnis_nachkomma >= 8)) // Band um die Ziffer --> Runden, da Ziffer im Rahmen Ungenauigkeit erreicht
+ result = ((int) round(zahl) + 10) % 10;
+ else
+ result = ((int) trunc(zahl) + 10) % 10;
+
+ if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalAnalogToDigitNeu - kein Ziffernwechsel, da Vorkomma weit genug weg = " + std::to_string(result) +
+ " zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " DigitalUnschaerfe = " + std::to_string(DigitalUnschaerfe));
+ return result;
+ }
+
+ if (ziffer_vorgaenger <= 1) // Nulldurchgang hat stattgefunden (!Bewertung über Prev_value und nicht Zahl!) --> hier aufrunden (2.8 --> 3, aber auch 3.1 --> 3)
+ {
+ if (ergebnis_nachkomma > 5)
+ result = (ergebnis_vorkomma + 1) % 10;
+ else
+ result = ergebnis_vorkomma;
+ if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalAnalogToDigitNeu - Nulldurchgang hat stattgefunden = " + std::to_string(result) +
+ " zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " DigitalUnschaerfe = " + std::to_string(DigitalUnschaerfe));
+ return result;
+ }
+
+ // bleibt nur >= 9.5 --> noch kein Nulldurchgang --> 2.8 --> 2, und 3.1 --> 2
+ // hier auf 4 reduziert, da erst ab Vorgänder 9 anfängt umzustellen. Bei 9.5 Vorgänger kann die aktuelle
+ // Zahl noch x.4 - x.5 sein.
+ if (ergebnis_nachkomma >= 4)
+ result = ergebnis_vorkomma;
+ else
+ result = (ergebnis_vorkomma - 1 + 10) % 10;
+
+ if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalAnalogToDigitNeu - 9.0 --> noch kein Nulldurchgang = " + std::to_string(result) +
+ " zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " DigitalUnschaerfe = " + std::to_string(DigitalUnschaerfe));
+ return result;
+}
+
int ClassFlowCNNGeneral::ZeigerEvalAnalogNeu(float zahl, int ziffer_vorgaenger)
{
float zahl_min, zahl_max;
@@ -281,10 +321,13 @@ int ClassFlowCNNGeneral::ZeigerEvalAnalogNeu(float zahl, int ziffer_vorgaenger)
" zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " AnalogFehler = " + std::to_string(AnalogFehler));
return result;
}
- result = ((int) floor(zahl_min) + 10) % 10;
- if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalAnalogNeu - Zahl uneindeutig, Korrektur nach unten - Ergebnis = " + std::to_string(result) +
- " zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " AnalogFehler = " + std::to_string(AnalogFehler));
- return result;
+ if (ziffer_vorgaenger >= 10 - AnalogFehler)
+ {
+ result = ((int) floor(zahl_min) + 10) % 10;
+ if (debugdetailgeneral) LogFile.WriteToFile("ClassFlowCNNGeneral::ZeigerEvalAnalogNeu - Zahl uneindeutig, Korrektur nach unten - Ergebnis = " + std::to_string(result) +
+ " zahl: " + std::to_string(zahl) + " ziffer_vorgaenger = " + std::to_string(ziffer_vorgaenger) + " AnalogFehler = " + std::to_string(AnalogFehler));
+ return result;
+ }
}
@@ -296,7 +339,7 @@ int ClassFlowCNNGeneral::ZeigerEvalAnalogNeu(float zahl, int ziffer_vorgaenger)
}
-
+/*
int ClassFlowCNNGeneral::ZeigerEval(float zahl, int ziffer_vorgaenger)
{
int ergebnis_nachkomma = ((int) floor(zahl * 10) + 10) % 10;
@@ -327,6 +370,7 @@ int ClassFlowCNNGeneral::ZeigerEval(float zahl, int ziffer_vorgaenger)
ergebnis = (ergebnis + 10) % 10;
return ergebnis;
}
+*/
bool ClassFlowCNNGeneral::ReadParameter(FILE* pfile, string& aktparamgraph)
{
diff --git a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h
index 2424ef5f..fd58153c 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h
+++ b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h
@@ -25,9 +25,12 @@ protected:
std::vector GENERAL;
float CNNGoodThreshold;
float AnalogFehler = 3.0;
+ float AnalogToDigtalFehler = 0.8;
float DigitalUnschaerfe = 0.2;
+ int DigitalBand = 3;
float DigitalAnalogerVorgaengerUebergangsbereich = 2;
- float DigitalUebergangsbereichVorgaenger = 0.7;
+ float DigitalUebergangsbereichVorgaengerAnalogToDigit = 1; // war vorher 2
+ float DigitalUebergangsbereichVorgaenger = 0.9;
string cnnmodelfile;
int modelxsize, modelysize, modelchannel;
@@ -38,9 +41,10 @@ protected:
bool SaveAllFiles;
// bool extendedResolution;
- int ZeigerEval(float zahl, int ziffer_vorgaenger);
- int ZeigerEvalHybrid(float zahl, float zahl_vorgaenger, int eval_vorgaenger);
+// int ZeigerEval(float zahl, int ziffer_vorgaenger);
+// int ZeigerEvalHybrid(float zahl, float zahl_vorgaenger, int eval_vorgaenger);
int ZeigerEvalAnalogNeu(float zahl, int ziffer_vorgaenger);
+ int ZeigerEvalAnalogToDigitNeu(float zahl, float ziffer_vorgaenger);
int ZeigerEvalHybridNeu(float zahl, float zahl_vorgaenger, int eval_vorgaenger, bool AnalogerVorgaenger = false);
diff --git a/code/components/jomjol_flowcontroll/ClassFlowControll.cpp b/code/components/jomjol_flowcontroll/ClassFlowControll.cpp
index 0351dc17..db72aedf 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowControll.cpp
+++ b/code/components/jomjol_flowcontroll/ClassFlowControll.cpp
@@ -305,6 +305,7 @@ bool ClassFlowControll::doFlow(string time)
if (i) i -= 1; // vorheriger Schritt muss wiederholt werden (vermutlich Bilder aufnehmen)
result = false;
if (repeat > 5) {
+ LogFile.SwitchOnOff(true);
LogFile.WriteToFile("Wiederholung 5x nicht erfolgreich --> reboot");
doReboot();
// Schritt wurde 5x wiederholt --> reboot
@@ -493,6 +494,8 @@ bool ClassFlowControll::ReadParameter(FILE* pfile, string& aktparamgraph)
// reboot notwendig damit die neue wlan.ini auch benutzt wird !!!
fclose(pfile);
printf("do reboot\n");
+ LogFile.SwitchOnOff(true);
+ LogFile.WriteToFile("Reboot to activate new HOSTNAME.");
esp_restart();
hard_restart();
doReboot();
diff --git a/code/components/jomjol_logfile/ClassLogFile.cpp b/code/components/jomjol_logfile/ClassLogFile.cpp
index 04a4df4d..f9a790fe 100644
--- a/code/components/jomjol_logfile/ClassLogFile.cpp
+++ b/code/components/jomjol_logfile/ClassLogFile.cpp
@@ -73,7 +73,7 @@ void ClassLogFile::WriteToDedicatedFile(std::string _fn, std::string info, bool
// pFile = OpenFileAndWait(_fn.c_str(), "a");
pFile = fopen(_fn.c_str(), "a+");
- printf("Logfile opened: %s\n", _fn.c_str());
+// printf("Logfile opened: %s\n", _fn.c_str());
if (pFile!=NULL) {
if (_time)
diff --git a/code/components/tflite-lib/CMakeLists.txt b/code/components/tflite-lib/CMakeLists.txt
index eed31a57..aaf56231 100644
--- a/code/components/tflite-lib/CMakeLists.txt
+++ b/code/components/tflite-lib/CMakeLists.txt
@@ -25,7 +25,8 @@ list(REMOVE_ITEM srcs_kernels
"${tfmicro_kernels_dir}/depthwise_conv.cc"
"${tfmicro_kernels_dir}/fully_connected.cc"
"${tfmicro_kernels_dir}/mul.cc"
- "${tfmicro_kernels_dir}/pooling.cc")
+ "${tfmicro_kernels_dir}/pooling.cc"
+ "${tfmicro_kernels_dir}/softmax.cc")
FILE(GLOB esp_nn_kernels
"${tfmicro_kernels_dir}/esp_nn/*.cc")
@@ -38,6 +39,10 @@ set(lib_srcs
"${tflite_dir}/kernels/kernel_util.cc"
"${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
"${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
+ "${tflite_dir}/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc"
+ "${tflite_dir}/micro/arena_allocator/persistent_arena_buffer_allocator.cc"
+ "${tflite_dir}/micro/arena_allocator/recording_single_arena_buffer_allocator.cc"
+ "${tflite_dir}/micro/arena_allocator/single_arena_buffer_allocator.cc"
"${tflite_dir}/c/common.cc"
"${tflite_dir}/core/api/error_reporter.cc"
"${tflite_dir}/core/api/flatbuffer_conversions.cc"
diff --git a/code/components/tflite-lib/tensorflow/lite/builtin_ops.h b/code/components/tflite-lib/tensorflow/lite/builtin_ops.h
index 19ce3e2c..01156c39 100644
--- a/code/components/tflite-lib/tensorflow/lite/builtin_ops.h
+++ b/code/components/tflite-lib/tensorflow/lite/builtin_ops.h
@@ -179,6 +179,12 @@ typedef enum {
kTfLiteBuiltinMultinomial = 149,
kTfLiteBuiltinGelu = 150,
kTfLiteBuiltinDynamicUpdateSlice = 151,
+ kTfLiteBuiltinRelu0To1 = 152,
+ kTfLiteBuiltinUnsortedSegmentProd = 153,
+ kTfLiteBuiltinUnsortedSegmentMax = 154,
+ kTfLiteBuiltinUnsortedSegmentSum = 155,
+ kTfLiteBuiltinAtan2 = 156,
+ kTfLiteBuiltinUnsortedSegmentMin = 157,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h b/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
index d2524969..d947213b 100644
--- a/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
+++ b/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
@@ -113,7 +113,13 @@ typedef struct TfLiteQuantizationParams {
} TfLiteQuantizationParams;
// --------------------------------------------------------------------------
-// Opaque types used by c_api_opaque.h.
+// Opaque types used by c_api.h, c_api_opaque.h and common.h.
+
+// TfLiteOpaqueContext is an opaque version of TfLiteContext;
+typedef struct TfLiteOpaqueContext TfLiteOpaqueContext;
+
+// TfLiteOpaqueNode is an opaque version of TfLiteNode;
+typedef struct TfLiteOpaqueNode TfLiteOpaqueNode;
// TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;
diff --git a/code/components/tflite-lib/tensorflow/lite/c/common.cc b/code/components/tflite-lib/tensorflow/lite/c/common.cc
index 956e9d69..ae5c44b5 100644
--- a/code/components/tflite-lib/tensorflow/lite/c/common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/c/common.cc
@@ -14,7 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
+
#include "tensorflow/lite/c/c_api_types.h"
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+#include "tensorflow/lite/tensorflow_profiler_logger.h"
+#endif
#ifndef TF_LITE_STATIC_MEMORY
#include
@@ -99,7 +103,12 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
void TfLiteTensorDataFree(TfLiteTensor* t) {
if (t->allocation_type == kTfLiteDynamic ||
t->allocation_type == kTfLitePersistentRo) {
- free(t->data.raw);
+ if (t->data.raw) {
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+ tflite::OnTfLiteTensorDealloc(t);
+#endif
+ free(t->data.raw);
+ }
}
t->data.raw = nullptr;
}
@@ -161,7 +170,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
t->dims = nullptr;
if (t->dims_signature) {
- TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
+ TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature);
}
t->dims_signature = nullptr;
@@ -191,16 +200,12 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
}
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
- if (!src || !dst)
- return kTfLiteOk;
- if (src->bytes != dst->bytes)
- return kTfLiteError;
- if (src == dst)
- return kTfLiteOk;
+ if (!src || !dst) return kTfLiteOk;
+ if (src->bytes != dst->bytes) return kTfLiteError;
+ if (src == dst) return kTfLiteOk;
dst->type = src->type;
- if (dst->dims)
- TfLiteIntArrayFree(dst->dims);
+ if (dst->dims) TfLiteIntArrayFree(dst->dims);
dst->dims = TfLiteIntArrayCopy(src->dims);
memcpy(dst->data.raw, src->data.raw, src->bytes);
dst->buffer_handle = src->buffer_handle;
@@ -218,8 +223,17 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
// TODO(b/145340303): Tensor data should be aligned.
if (!tensor->data.raw) {
tensor->data.raw = (char*)malloc(num_bytes);
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+ tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
+#endif
} else if (num_bytes > tensor->bytes) {
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+ tflite::OnTfLiteTensorDealloc(tensor);
+#endif
tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+ tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
+#endif
}
tensor->bytes = num_bytes;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/c/common.h b/code/components/tflite-lib/tensorflow/lite/c/common.h
index 6a109e1e..cc856f9a 100644
--- a/code/components/tflite-lib/tensorflow/lite/c/common.h
+++ b/code/components/tflite-lib/tensorflow/lite/c/common.h
@@ -173,9 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
} \
} while (false)
#else // TF_LITE_STRIP_ERROR_STRINGS
-#define UNUSED(...) (void)sizeof(#__VA_ARGS__)
-#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
-#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
+#define ARGS_UNUSED(...) (void)sizeof(#__VA_ARGS__)
+#define TF_LITE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
+#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
#endif // TF_LITE_STRIP_ERROR_STRINGS
// Check whether value is true, and if not return kTfLiteError from
@@ -842,6 +842,12 @@ typedef struct TfLiteContext {
size_t* bytes);
} TfLiteContext;
+// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
+// for C API which doesn't use internal types (such as `TfLiteContext`) but only
+// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each
+// field is the exactly the same as with `TfLiteRegistration`.
+typedef struct TfLiteRegistrationExternal TfLiteRegistrationExternal;
+
typedef struct TfLiteRegistration {
// Initializes the op from serialized data.
// Called only *once* for the lifetime of the op, so any one-time allocations
@@ -903,8 +909,31 @@ typedef struct TfLiteRegistration {
// Note: It is the responsibility of the registration binder to set this
// properly.
int version;
+
+ // The external version of `TfLiteRegistration`. Since we can't use internal
+ // types (such as `TfLiteContext`) for C API to maintain ABI stability.
+ // C API user will provide `TfLiteRegistrationExternal` to implement custom
+ // ops. We keep it inside of `TfLiteRegistration` and use it to route
+ // callbacks properly.
+ TfLiteRegistrationExternal* registration_external;
} TfLiteRegistration;
+// Old version of `TfLiteRegistration` to maintain binary backward
+// compatibility.
+// WARNING: This structure is deprecated / not an official part of the API.
+// It should be only used for binary backward compatibility.
+typedef struct TfLiteRegistration_V1 {
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+ void (*free)(TfLiteContext* context, void* buffer);
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+ int32_t builtin_code;
+ const char* custom_name;
+ int version;
+} TfLiteRegistration_V1;
+
// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
// values should be 1, 2, 4, 8, ...etc.
typedef enum TfLiteDelegateFlags {
diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
index e92d754f..1ecefa47 100644
--- a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -493,6 +493,11 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseSquare(op, error_reporter, allocator, builtin_data);
}
+ case BuiltinOperator_SQUARED_DIFFERENCE: {
+ return ParseSquaredDifference(op, error_reporter, allocator,
+ builtin_data);
+ }
+
case BuiltinOperator_SQUEEZE: {
return ParseSqueeze(op, error_reporter, allocator, builtin_data);
}
@@ -840,14 +845,25 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
// ok for now, since there is no call implementation either.
case BuiltinOperator_CALL:
+ case BuiltinOperator_COMPLEX_ABS:
case BuiltinOperator_CONCAT_EMBEDDINGS:
case BuiltinOperator_COS:
case BuiltinOperator_CUSTOM:
+ case BuiltinOperator_DENSIFY:
+ case BuiltinOperator_DYNAMIC_UPDATE_SLICE:
case BuiltinOperator_EMBEDDING_LOOKUP:
case BuiltinOperator_EQUAL:
+ case BuiltinOperator_HASHTABLE_FIND:
+ case BuiltinOperator_HASHTABLE_IMPORT:
+ case BuiltinOperator_HASHTABLE_SIZE:
+ case BuiltinOperator_IMAG:
case BuiltinOperator_MATRIX_DIAG:
case BuiltinOperator_MATRIX_SET_DIAG:
+ case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
+ case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
case BuiltinOperator_RELU_N1_TO_1:
+ case BuiltinOperator_RELU_0_TO_1:
+ case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_SELECT:
case BuiltinOperator_SELECT_V2:
case BuiltinOperator_SLICE:
@@ -855,23 +871,17 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_TRANSPOSE:
case BuiltinOperator_RANGE:
- case BuiltinOperator_SQUARED_DIFFERENCE:
- case BuiltinOperator_REVERSE_V2:
- case BuiltinOperator_WHERE:
case BuiltinOperator_RANK:
- case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
- case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
- case BuiltinOperator_SCATTER_ND:
- case BuiltinOperator_DENSIFY:
- case BuiltinOperator_SEGMENT_SUM:
- case BuiltinOperator_RFFT2D:
- case BuiltinOperator_IMAG:
case BuiltinOperator_REAL:
- case BuiltinOperator_COMPLEX_ABS:
- case BuiltinOperator_HASHTABLE_FIND:
- case BuiltinOperator_HASHTABLE_IMPORT:
- case BuiltinOperator_HASHTABLE_SIZE:
- case BuiltinOperator_DYNAMIC_UPDATE_SLICE:
+ case BuiltinOperator_RFFT2D:
+ case BuiltinOperator_SEGMENT_SUM:
+ case BuiltinOperator_REVERSE_V2:
+ case BuiltinOperator_UNSORTED_SEGMENT_MAX:
+ case BuiltinOperator_UNSORTED_SEGMENT_MIN:
+ case BuiltinOperator_UNSORTED_SEGMENT_PROD:
+ case BuiltinOperator_UNSORTED_SEGMENT_SUM:
+ case BuiltinOperator_ATAN2:
+ case BuiltinOperator_WHERE:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;
@@ -2189,6 +2199,14 @@ TfLiteStatus ParseSquare(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
return kTfLiteOk;
}
+// We have this parse function instead of directly returning kTfLiteOk from the
+// switch-case in ParseOpData because this function is used as part of the
+// selective registration for the OpResolver implementation in micro.
+TfLiteStatus ParseSquaredDifference(const Operator*, ErrorReporter*,
+ BuiltinDataAllocator*, void**) {
+ return kTfLiteOk;
+}
+
TfLiteStatus ParseStridedSlice(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h
index cd6637bc..ed317b81 100644
--- a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -356,6 +356,11 @@ TfLiteStatus ParseSqrt(const Operator* op, ErrorReporter* error_reporter,
TfLiteStatus ParseSquare(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
+TfLiteStatus ParseSquaredDifference(const Operator* op,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator,
+ void** builtin_data);
+
TfLiteStatus ParseStridedSlice(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h b/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h
index 49ac778e..cec1f2dd 100644
--- a/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h
+++ b/code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h
@@ -23,6 +23,16 @@ limitations under the License.
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_generated.h"
+// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate.
+// This is used for cases (e.g. when using "TF Lite with Google Play Services")
+// where the TF Lite runtime might be built using a newer (or older)
+// version of the TF Lite sources than the app, and hence might have a
+// different definition of the TfLiteDelegate type. TF Lite APIs use
+// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to
+// refer to a delegate defined with that potentially different version
+// of the TfLiteDelegate type.
+struct TfLiteOpaqueDelegateStruct;
+
namespace tflite {
/// Abstract interface that returns TfLiteRegistrations given op codes or custom
@@ -37,8 +47,10 @@ class OpResolver {
virtual const TfLiteRegistration* FindOp(const char* op,
int version) const = 0;
+ // Represents a sequence of delegates.
using TfLiteDelegatePtrVector =
std::vector>;
+
// Returns optional delegates for resolving and handling ops in the flatbuffer
// model. This may be used in addition to the standard TfLiteRegistration
// lookup for graph resolution.
@@ -47,16 +59,55 @@ class OpResolver {
return {};
}
- // Represent a function that creates a TfLite delegate instance.
+ // Represents a function that creates a TfLite delegate instance.
using TfLiteDelegateCreator =
std::function(
int /*num_threads*/)>;
+
+ // Represents a sequence of delegate creator functions.
using TfLiteDelegateCreators = std::vector;
+
// Returns a vector of delegate creators to create optional delegates for
// resolving and handling ops in the flatbuffer model. This may be used in
// addition to the standard TfLiteRegistration lookup for graph resolution.
+ //
+ // Note that this method is not used (will not be called) if you are using
+ // TF Lite in Google Play Services; the GetOpaqueDelegateCreators method
+ // (see below) is used for that case.
virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
+ // TODO(b/202712825): it would be nice if we could avoid the need for separate
+ // "opaque" types & methods for use only with TF Lite in Google Play Services.
+
+ // Represents an opaque delegate instance.
+ // WARNING: Experimental interface, subject to change.
+ using TfLiteOpaqueDelegatePtr =
+ std::unique_ptr;
+
+ // Represents a function that creates an opaque delegate instance.
+ // WARNING: Experimental interface, subject to change.
+ using TfLiteOpaqueDelegateCreator =
+ std::function;
+
+ // Represents a sequence of opaque delegate creator functions.
+ // WARNING: Experimental interface, subject to change.
+ using TfLiteOpaqueDelegateCreators = std::vector;
+
+ // Returns a vector of opaque delegate creators to create optional opaque
+ // delegates for resolving and handling ops in the flatbuffer model. This may
+ // be used in addition to the standard TfLiteRegistration lookup for graph
+ // resolution.
+ //
+ // Note that this method will be called only if you are using TF Lite in
+ // Google Play Services; if you are using regular TF Lite, GetDelegateCreators
+ // (see above) is used instead.
+ //
+ // WARNING: Experimental interface, subject to change.
+ virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const {
+ return {};
+ }
+
virtual ~OpResolver() {}
private:
diff --git a/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft.cc b/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft.cc
index 62442fba..bcdd9cc0 100644
--- a/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft.cc
+++ b/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft.cc
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/microfrontend/lib/fft.h"
-#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
#include
+#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
void FftCompute(struct FftState* state, const int16_t* input,
int input_scale_shift) {
@@ -37,9 +37,9 @@ void FftCompute(struct FftState* state, const int16_t* input,
// Apply the FFT.
kissfft_fixed16::kiss_fftr(
- reinterpret_cast(state->scratch),
- state->input,
- reinterpret_cast(state->output));
+ reinterpret_cast(state->scratch),
+ state->input,
+ reinterpret_cast(state->output));
}
void FftInit(struct FftState* state) {
diff --git a/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc b/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc
index 81efe14d..ed3dc8fb 100644
--- a/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc
+++ b/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h"
-#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
#include
+#include "tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h"
+
int FftPopulateState(struct FftState* state, size_t input_size) {
state->input_size = input_size;
state->fft_size = 1;
diff --git a/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h b/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h
index 9abe686b..beee99aa 100644
--- a/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h
+++ b/code/components/tflite-lib/tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h
@@ -31,4 +31,3 @@ namespace kissfft_fixed16 {
#undef KISS_FFT_H
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_KISS_FFT_INT16_H_
-
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/common.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/common.h
index 5e8778f1..205294fd 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/common.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/common.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
+#include
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/compatibility.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/compatibility.h
index 61becad3..7ba66ed8 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/compatibility.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/compatibility.h
@@ -86,6 +86,16 @@ using int32 = std::int32_t;
using uint32 = std::uint32_t;
#endif // !defined(TF_LITE_STATIC_MEMORY)
+// Allow for cross-compiler usage of function signatures - currently used for
+// specifying named RUY profiler regions in templated methods.
+#if defined(_MSC_VER)
+#define TFLITE_PRETTY_FUNCTION __FUNCSIG__
+#elif defined(__GNUC__)
+#define TFLITE_PRETTY_FUNCTION __PRETTY_FUNCTION__
+#else
+#define TFLITE_PRETTY_FUNCTION __func__
+#endif
+
// TFLITE_DEPRECATED()
//
// Duplicated from absl/base/macros.h to avoid pulling in that library.
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h
index ab0c8f96..122a0dc2 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h
@@ -324,7 +324,7 @@ void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
// - n_input: the size for input and output.
// - output: the 16 bit output
// The input is in Qm.15-m format and the output is in Q0.15 format.
-void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
+void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/add.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/add.h
index 57fa13d8..1f521316 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/add.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/add.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
+#include
#include
#include "fixedpoint/fixedpoint.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/concatenation.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/concatenation.h
index 998bb093..9d2ecbec 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/concatenation.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/concatenation.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h
index ac5f04f6..3a53e06e 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/div.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/div.h
new file mode 100644
index 00000000..df8da1b1
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/div.h
@@ -0,0 +1,247 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
+
+#include
+
+#include "tensorflow/lite/kernels/internal/common.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+template
+inline void DivCheckArithmeticParams(const ArithmeticParams& params) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ // Input offset is negative input zero point. Activation tensors are
+ // asymmetric quantized so they span the full int8 range.
+ constexpr int32_t max_value =
+ static_cast(std::numeric_limits::max());
+ TFLITE_DCHECK_GE(params.input1_offset, -max_value);
+ TFLITE_DCHECK_LE(params.input1_offset, max_value);
+ TFLITE_DCHECK_GE(params.input2_offset, -max_value);
+ TFLITE_DCHECK_LE(params.input2_offset, max_value);
+ TFLITE_DCHECK_GE(params.output_offset, -max_value);
+ TFLITE_DCHECK_LE(params.output_offset, max_value);
+}
+
+// Element-wise div that can often be used for inner loop of broadcast Div as
+// well as the non-broadcast Div.
+template
+inline void DivElementwise(int size, const ArithmeticParams& params,
+ const T* input1_data, const T* input2_data,
+ T* output_data) {
+ DivCheckArithmeticParams(params);
+
+ for (int i = 0; i < size; ++i) {
+ int32_t input1_val = params.input1_offset + input1_data[i];
+ int32_t input2_val = params.input2_offset + input2_data[i];
+ TFLITE_DCHECK_NE(input2_val, 0);
+ if (input2_val < 0) {
+ // Invert signs to avoid a negative input2_val as input2_inv needs to be
+ // positive to be used as multiplier of MultiplyByQuantizedMultiplier.
+ input1_val = -input1_val;
+ input2_val = -input2_val;
+ }
+ int recip_shift;
+ const int32_t input2_inv = GetReciprocal(input2_val, 31, &recip_shift);
+ const int headroom = CountLeadingSignBits(input1_val);
+ const int32_t unscaled_quotient =
+ MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv,
+ headroom);
+ const int total_shift = params.output_shift - recip_shift - headroom;
+ const int32_t unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ unscaled_quotient, params.output_multiplier, total_shift);
+ const int32_t clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast(clamped_output);
+ }
+}
+
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8_t* input1_data,
+ const RuntimeShape& input2_shape, const uint8_t* input2_data,
+ const RuntimeShape& output_shape, uint8_t* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ const int flat_size =
+ MatchingElementsSize(input1_shape, input2_shape, output_shape);
+
+ DivElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int8_t* input1_data,
+ const RuntimeShape& input2_shape, const int8_t* input2_data,
+ const RuntimeShape& output_shape, int8_t* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ const int flat_size =
+ MatchingElementsSize(input1_shape, input2_shape, output_shape);
+
+ DivElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+template
+inline void BroadcastDivSlowQuantized(
+ const ArithmeticParams& params, const RuntimeShape& unextended_input1_shape,
+ const T* input1_data, const RuntimeShape& unextended_input2_shape,
+ const T* input2_data, const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
+
+ NdArrayDesc desc1;
+ NdArrayDesc desc2;
+ NdArrayDesc output_desc;
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+ CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
+ &output_desc);
+
+ DivCheckArithmeticParams(params);
+
+ auto div_func = [&](int indexes[N]) {
+ int32_t input1_val =
+ params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
+ int32_t input2_val =
+ params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
+ TFLITE_DCHECK_NE(input2_val, 0);
+ if (input2_val < 0) {
+ // Invert signs to avoid a negative input2_val as input2_inv needs to be
+ // positive to be used as multiplier of MultiplyByQuantizedMultiplier.
+ input1_val = -input1_val;
+ input2_val = -input2_val;
+ }
+ int recip_shift;
+ const int32_t input2_inv = GetReciprocal(input2_val, 31, &recip_shift);
+ const int headroom = CountLeadingSignBits(input1_val);
+ const int32_t unscaled_quotient =
+ MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv,
+ headroom);
+ const int total_shift = params.output_shift - recip_shift - headroom;
+ const int32_t unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ unscaled_quotient, params.output_multiplier, total_shift);
+ const int32_t clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[SubscriptToIndex(output_desc, indexes)] =
+ static_cast(clamped_output);
+ };
+ NDOpsHelper(output_desc, div_func);
+}
+
+template
+inline void BroadcastDivSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const uint8_t* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const uint8_t* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ uint8_t* output_data) {
+ BroadcastDivSlowQuantized(
+ params, unextended_input1_shape, input1_data, unextended_input2_shape,
+ input2_data, unextended_output_shape, output_data);
+}
+
+template
+inline void BroadcastDivSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const int8_t* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const int8_t* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ int8_t* output_data) {
+ BroadcastDivSlowQuantized(
+ params, unextended_input1_shape, input1_data, unextended_input2_shape,
+ input2_data, unextended_output_shape, output_data);
+}
+
+// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+template
+void BroadcastDivSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
+
+ NdArrayDesc desc1;
+ NdArrayDesc desc2;
+ NdArrayDesc output_desc;
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+ CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
+ &output_desc);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest
+ // stride, typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+
+ auto div_func = [&](int indexes[N]) {
+ output_data[SubscriptToIndex(output_desc, indexes)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, indexes)] /
+ input2_data[SubscriptToIndex(desc2, indexes)],
+ output_activation_min, output_activation_max);
+ };
+ NDOpsHelper(output_desc, div_func);
+}
+
+template
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingElementsSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] / input2_data[i], output_activation_min,
+ output_activation_max);
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/fully_connected.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/fully_connected.h
index 9bf2e5df..ba51cbcf 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/fully_connected.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/fully_connected.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+#include
+
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h
index cda1b5cf..d9fe32e9 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
+#include
+
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
@@ -23,9 +25,9 @@ namespace tflite {
namespace reference_ops {
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
- int32_t result = static_cast(value) * (1 << amount);
- result = std::min(result, std::numeric_limits::max());
- result = std::max(result, std::numeric_limits::min());
+ int64_t result = static_cast(value) * (1 << amount);
+ result = std::min(result, std::numeric_limits::max());
+ result = std::max(result, std::numeric_limits::min());
return result;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/add.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/add.h
index 10bee904..8d9b318c 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/add.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/add.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_
+#include
#include
#include "tensorflow/lite/kernels/internal/common.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
index 3f869a3a..ffd4978e 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h
index f0ca09c7..7676fce0 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h
index 42920d16..634f0bff 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h
@@ -15,11 +15,101 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_integer_ops {
+// For per-channel functions, since it is defined in quantization spec that
+// weights are symmetric
+// (https://www.tensorflow.org/lite/performance/quantization_spec#symmetric_vs_asymmetric),
+// zero_point (params.weights_offset) is always 0.
+// However, for per-tensor functions, params.weights_offset is still applied for
+// backward compatibility.
+
+inline void FullyConnectedPerChannel(
+ const FullyConnectedParams& params, const int32_t* output_multiplier,
+ const int* output_shift, const RuntimeShape& input_shape,
+ const int8_t* input_data, const RuntimeShape& filter_shape,
+ const int8_t* filter_data, const RuntimeShape& bias_shape,
+ const int32_t* bias_data, const RuntimeShape& output_shape,
+ int8_t* output_data) {
+ const int32_t input_offset = params.input_offset;
+ const int32_t output_offset = params.output_offset;
+ const int32_t output_activation_min = params.quantized_activation_min;
+ const int32_t output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = output_shape.Dims(0);
+ const int output_depth = output_shape.Dims(1);
+ TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ int32_t acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32_t input_val = input_data[b * accum_depth + d];
+ int32_t filter_val = filter_data[out_c * accum_depth + d];
+ acc += filter_val * (input_val + input_offset);
+ }
+ if (bias_data) {
+ acc += bias_data[out_c];
+ }
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier[out_c],
+ output_shift[out_c]);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast(acc);
+ }
+ }
+}
+
+template
+inline void FullyConnectedPerChannel(
+ const FullyConnectedParams& params, const int32_t* output_multiplier,
+ const int* output_shift, const RuntimeShape& input_shape,
+ const int16_t* input_data, const RuntimeShape& filter_shape,
+ const int8_t* filter_data, const RuntimeShape& bias_shape,
+ const AccumScalar* bias_data, const RuntimeShape& output_shape,
+ int16_t* output_data) {
+ const int32_t output_activation_min = params.quantized_activation_min;
+ const int32_t output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = output_shape.Dims(output_dim_count - 1);
+ TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ AccumScalar acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32_t input_val = input_data[b * accum_depth + d];
+ int32_t filter_val = filter_data[out_c * accum_depth + d];
+ acc += filter_val * input_val;
+ }
+ if (bias_data) {
+ acc += bias_data[out_c];
+ }
+ int32_t acc_scaled = MultiplyByQuantizedMultiplier(
+ acc, output_multiplier[out_c], output_shift[out_c]);
+ acc_scaled = std::max(acc_scaled, output_activation_min);
+ acc_scaled = std::min(acc_scaled, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast(acc_scaled);
+ }
+ }
+}
+
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& filter_shape,
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h
index 31f2de98..164a8367 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h
index 95697ec9..16eff133 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h
@@ -15,7 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_
+#include
#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h
index bd484270..09d37b72 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
index b80838aa..22e89740 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
+#include
+
#include "fixedpoint/fixedpoint.h"
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h
index 2cb4dada..4dc31d9e 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h
@@ -15,7 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_POOLING_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_POOLING_H_
+#include
#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h
index 63e40936..7b1e003b 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_
+#include
#include
#include "fixedpoint/fixedpoint.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h
index 3397f869..92919a71 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/mul.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/mul.h
index 273a8a98..b977104c 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/mul.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/mul.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/pooling.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/pooling.h
index ee30b840..fe17484c 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/pooling.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/pooling.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_POOLING_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_POOLING_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/prelu.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/prelu.h
index 02db5174..aa9901d6 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/prelu.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/prelu.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h
index 40f779c5..bda27693 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
+#include
+
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/reduce.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/reduce.h
index 348e170e..341b3a08 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/reduce.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/reduce.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_
+#include
+
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/requantize.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/requantize.h
index d1e67785..f35f6fc8 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/requantize.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/requantize.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_
+#include
+
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h
index 95550abc..bf0b757e 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
+#include
#include
#include "tensorflow/lite/kernels/internal/cppmath.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/softmax.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/softmax.h
index c2bddcf7..9f4b6398 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/softmax.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/softmax.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+#include
#include
#include "fixedpoint/fixedpoint.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/transpose_conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/transpose_conv.h
index 6e9cb1f9..ac91f379 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/transpose_conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/transpose_conv.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
+#include
+
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h
index 13693643..c2678b57 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h
@@ -27,6 +27,11 @@ class RuntimeShape {
public:
RuntimeShape& operator=(RuntimeShape const&) = delete;
+ // RuntimeShape in TFLM supports up to 5 dimensions.
+ // The name kMaxSmallSize comes from the same file of the upstream
+ // tensorflow lite repo and need to be kept the same for max reuse.
+ static constexpr int kMaxSmallSize = 5;
+
RuntimeShape() : size_(0) {}
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
@@ -104,11 +109,9 @@ class RuntimeShape {
sizeof(int32_t) * shape.DimensionsCount());
}
- // A maximum of 4 dimensions are supported on TFLM.
- static constexpr int kMaxSize = 5;
int32_t size_;
union {
- int32_t dims_[kMaxSize];
+ int32_t dims_[kMaxSmallSize];
};
};
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h
index 77644bc0..c44ba48e 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h
@@ -974,11 +974,11 @@ struct StridedSliceParams {
int8_t strides_count;
int32_t strides[5];
- int16_t begin_mask;
- int16_t ellipsis_mask;
- int16_t end_mask;
- int16_t new_axis_mask;
- int16_t shrink_axis_mask;
+ uint16_t begin_mask;
+ uint16_t ellipsis_mask;
+ uint16_t end_mask;
+ uint16_t new_axis_mask;
+ uint16_t shrink_axis_mask;
};
struct TanhParams {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
index 22689436..06874422 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
@@ -177,6 +177,14 @@ inline int64_t NumElements(const TfLiteTensor* t) {
return NumElements(t->dims);
}
+inline int64_t NumElements(const int* dims, int num_dims) {
+ int64_t count = 1;
+ for (int i = 0; i < num_dims; ++i) {
+ count *= dims[i];
+ }
+ return count;
+}
+
// Determines whether tensor is constant.
// TODO(b/138199592): Introduce new query which checks for constant OR
// persistent-read-only, which would be useful for most tensor kernels that
@@ -308,7 +316,7 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const TfLiteTensor* input3,
TfLiteIntArray** output_shape);
-// Return the size of given type in bytes. Return 0 in in case of string.
+// Return the size of given type in bytes. Return 0 in case of string.
int TfLiteTypeGetSize(TfLiteType type);
// Whether the current platform is mobile (Android or iOS).
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc b/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc
index 6fa1b31b..abbe34e7 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc
@@ -43,6 +43,7 @@ AllOpsResolver::AllOpsResolver() {
AddDepthwiseConv2D();
AddDequantize();
AddDetectionPostprocess();
+ AddDiv();
AddElu();
AddEqual();
AddEthosU();
@@ -104,6 +105,7 @@ AllOpsResolver::AllOpsResolver() {
AddSqueeze();
AddStridedSlice();
AddSub();
+ AddSum();
AddSvdf();
AddTanh();
AddTranspose();
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h
similarity index 95%
rename from code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h
rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h
index 3767cb9f..b92d6b2d 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
#include
#include
@@ -97,4 +97,4 @@ class INonPersistentBufferAllocator {
} // namespace tflite
-#endif // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc
new file mode 100644
index 00000000..6389da40
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc
@@ -0,0 +1,170 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h"
+
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+namespace tflite {
+
+NonPersistentArenaBufferAllocator::NonPersistentArenaBufferAllocator(
+ uint8_t* buffer, size_t buffer_size)
+ : buffer_head_(buffer),
+ buffer_tail_(buffer + buffer_size),
+ head_temp_(buffer),
+ next_temp_(buffer) {}
+
+NonPersistentArenaBufferAllocator::~NonPersistentArenaBufferAllocator() {}
+
+// Allocates a temporary buffer. This buffer is not resizable.
+uint8_t* NonPersistentArenaBufferAllocator::AllocateTemp(size_t size,
+ size_t alignment) {
+ uint8_t* const aligned_result = AlignPointerUp(next_temp_, alignment);
+ const size_t available_memory = buffer_tail_ - aligned_result;
+ if (available_memory < size) {
+ MicroPrintf(
+ "Failed to allocate temp memory. Requested: %u, "
+ "available %u, missing: %u",
+ size, available_memory, size - available_memory);
+ return nullptr;
+ }
+ next_temp_ = aligned_result + size;
+ temp_buffer_ptr_check_sum_ ^= reinterpret_cast(aligned_result);
+ temp_buffer_count_++;
+ return aligned_result;
+}
+
+// Signals that a temporary buffer is no longer needed.
+void NonPersistentArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
+ temp_buffer_ptr_check_sum_ ^= reinterpret_cast(temp_buf);
+ temp_buffer_count_--;
+}
+
+// Returns true if all temporary buffers are already deallocated.
+bool NonPersistentArenaBufferAllocator::IsAllTempDeallocated() {
+ if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
+ MicroPrintf(
+ "Number of allocated temp buffers: %d. Checksum passing status: %d",
+ temp_buffer_count_, !temp_buffer_ptr_check_sum_);
+ return false;
+ }
+ return true;
+}
+
+// Signals that all temporary allocations can be reclaimed. TFLM calls this
+// API when it knows that all temporary buffers that it requested has been
+// deallocated. The goal of API is to facilitate implementations of
+// INonPersistentBufferAllocator can reuse buffer with some reasonable
+// complexity.
+TfLiteStatus NonPersistentArenaBufferAllocator::ResetTempAllocations() {
+ if (!IsAllTempDeallocated()) {
+ MicroPrintf(
+ "All temp buffers must be freed before calling ResetTempAllocations()");
+ return kTfLiteError;
+ }
+ next_temp_ = head_temp_;
+ return kTfLiteOk;
+}
+
+// Returns a buffer that is resizable viable ResizeBuffer().
+uint8_t* NonPersistentArenaBufferAllocator::AllocateResizableBuffer(
+ size_t size, size_t alignment) {
+ // Only supports one resizable buffer, which starts at the buffer head.
+ uint8_t* expected_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+
+ if (resizable_buffer_allocated_) {
+ MicroPrintf(
+ "Cannot allocate a new resizable buffer when one is already allocated");
+ return nullptr;
+ }
+
+ if (ResizeBuffer(expected_resizable_buf, size, alignment) == kTfLiteOk) {
+ resizable_buffer_allocated_ = true;
+ return expected_resizable_buf;
+ }
+ return nullptr;
+}
+
+// Resizes a buffer that is previously returned by the AllocateResizableBuffer.
+// Note that ResizeBuffer(old_resizable_buf, 0, 1) effectively deallocates
+// a previous allocated resizable buffer.
+TfLiteStatus NonPersistentArenaBufferAllocator::ResizeBuffer(
+ uint8_t* resizable_buf, size_t size, size_t alignment) {
+ // Only supports one resizable buffer, which starts at the buffer head.
+ uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+ if (resizable_buf != expect_resizable_buf) {
+ MicroPrintf("Internal error: buffer is not resizable");
+ return kTfLiteError;
+ }
+ if (head_temp_ != next_temp_) {
+ MicroPrintf("ResetTempAllocations() is not called before ResizeBuffer().");
+ return kTfLiteError;
+ }
+
+ const size_t available_memory = buffer_tail_ - expect_resizable_buf;
+ if (available_memory < size) {
+ MicroPrintf(
+ "Failed to resize buffer. Requested: %u, available %u, missing: %u",
+ size, available_memory, size - available_memory);
+ return kTfLiteError;
+ }
+ head_temp_ = expect_resizable_buf + size;
+ next_temp_ = head_temp_;
+
+ return kTfLiteOk;
+}
+
+// Frees up the memory occupied by the resizable buffer.
+TfLiteStatus NonPersistentArenaBufferAllocator::DeallocateResizableBuffer(
+ uint8_t* resizable_buf) {
+ TfLiteStatus status = ResizeBuffer(resizable_buf, 0, 1);
+ if (status == kTfLiteOk) {
+ resizable_buffer_allocated_ = false;
+ }
+ return status;
+}
+
+// Returns a pointer pointing to the start of the overlay memory, which is
+// used for activation tensors and scratch buffers by kernels at Invoke stage.
+uint8_t* NonPersistentArenaBufferAllocator::GetOverlayMemoryAddress() const {
+ return buffer_head_;
+}
+
+// Reserves the size of the overlay memory. This overlay is reserved for the
+// kernels at Invoke stage. This is referred to as the overlay because before
+// Invoket state, the same memory can be used for temp buffers. The layout of
+// the memory is planned by the memory planner separately at Invoke stage.
+TfLiteStatus
+NonPersistentArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
+ size_t size, size_t alignment) {
+ uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+ return ResizeBuffer(expect_resizable_buf, size, alignment);
+}
+
+// Returns the size of non-persistent buffer in use.
+size_t NonPersistentArenaBufferAllocator::GetNonPersistentUsedBytes() const {
+ return (next_temp_ - buffer_head_);
+}
+
+// Returns the number of bytes available with a given alignment. This number
+// takes in account any temporary allocations.
+size_t NonPersistentArenaBufferAllocator::GetAvailableMemory(
+ size_t alignment) const {
+ uint8_t* const aligned_temp = AlignPointerUp(next_temp_, alignment);
+ uint8_t* const aligned_tail = AlignPointerDown(buffer_tail_, alignment);
+ return aligned_tail - aligned_temp;
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h
new file mode 100644
index 00000000..9eb4efeb
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h
@@ -0,0 +1,105 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+
+#include
+#include
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
+#include "tensorflow/lite/micro/compatibility.h"
+
+namespace tflite {
+
+// Implement INonPersistentBufferAllocator on an arena that is dedicated for
+// non-persistent buffers.
+class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator {
+ public:
+ NonPersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
+ virtual ~NonPersistentArenaBufferAllocator();
+
+ // Allocates a temporary buffer. This buffer is not resizable.
+ uint8_t* AllocateTemp(size_t size, size_t alignment) override;
+
+ // Signals that a temporary buffer is no longer needed.
+ void DeallocateTemp(uint8_t* buf) override;
+
+ // Returns true if all temporary buffers are already deallocated.
+ bool IsAllTempDeallocated() override;
+
+ // Signals that all temporary allocations can be reclaimed. TFLM calls this
+ // API when it knows that all temporary buffers that it requested has been
+ // deallocated.
+ TfLiteStatus ResetTempAllocations() override;
+
+ // Returns a buffer that is resizable viable ResizeBuffer().
+ uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) override;
+
+ // Resizes a buffer that is previously returned by the
+ // AllocateResizableBuffer.
+ TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size,
+ size_t alignment) override;
+
+ // Frees up the memory occupied by the resizable buffer.
+ TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) override;
+
+ // Returns a pointer pointing to the start of the overlay memory, which is
+ // used for activation tensors and scratch buffers by kernels at Invoke stage.
+ uint8_t* GetOverlayMemoryAddress() const override;
+
+ // Reserves the size of the overlay memory. This overlay is reserved for the
+ // kernels at Invoke stage. This is referred to as the overlay because before
+ // Invoket state, the same memory can be used for temp buffers. The layout of
+ // the memory is planned by the memory planner separately at Invoke stage.
+ TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size,
+ size_t alignment) override;
+
+ // Returns the size of non-persistent buffer in use.
+ size_t GetNonPersistentUsedBytes() const override;
+
+ // Returns the number of bytes available with a given alignment. This number
+ // takes in account any temporary allocations.
+ size_t GetAvailableMemory(size_t alignment) const override;
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+
+ private:
+ // The memory arena that this allocator manages.
+ uint8_t* const buffer_head_;
+ uint8_t* const buffer_tail_;
+
+ // The whole region is split into two parts:
+ // buffer_head_ to head_temp_ - 1 belongs to the only resizable buffer.
+ // head_temp_ to buffer_tail_ can be used for (non-resizable) temp buffers.
+ uint8_t* head_temp_;
+
+ // next_temp_ points to the next available temp buffer allocation address and
+ // its range is between head_temp_ and buffer_tail_
+ uint8_t* next_temp_;
+
+ // XOR Check sum for outstanding temp buffers.
+ // If all temp buffers are deallocated OR no temp buffers are allocated,
+ // temp_buffer_ptr_check_sum_ == nullptr.
+ intptr_t temp_buffer_ptr_check_sum_ = 0;
+ // Count of outstanding temp buffers.
+ int temp_buffer_count_ = 0;
+ bool resizable_buffer_allocated_ = false;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc
new file mode 100644
index 00000000..0ccc8fb1
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc
@@ -0,0 +1,52 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h"
+
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+namespace tflite {
+
+PersistentArenaBufferAllocator::PersistentArenaBufferAllocator(
+ uint8_t* buffer, size_t buffer_size)
+ : buffer_head_(buffer),
+ buffer_tail_(buffer + buffer_size),
+ tail_temp_(buffer_tail_) {}
+
+PersistentArenaBufferAllocator::~PersistentArenaBufferAllocator() {}
+
+uint8_t* PersistentArenaBufferAllocator::AllocatePersistentBuffer(
+ size_t size, size_t alignment) {
+ uint8_t* const aligned_result =
+ AlignPointerDown(tail_temp_ - size, alignment);
+ if (aligned_result < buffer_head_) {
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+ const size_t missing_memory = buffer_head_ - aligned_result;
+ MicroPrintf(
+ "Failed to allocate tail memory. Requested: %u, "
+ "available %u, missing: %u",
+ size, size - missing_memory, missing_memory);
+#endif
+ return nullptr;
+ }
+ tail_temp_ = aligned_result;
+ return aligned_result;
+}
+
+size_t PersistentArenaBufferAllocator::GetPersistentUsedBytes() const {
+ return buffer_tail_ - tail_temp_;
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h
new file mode 100644
index 00000000..70de408f
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h
@@ -0,0 +1,59 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+
+#include
+#include
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
+#include "tensorflow/lite/micro/compatibility.h"
+
+namespace tflite {
+
+// PersistentArenaBufferAllocator is an implementatation of
+// IPersistentBufferAllocator interface on an arena that is dedicated for
+// persistent buffers.
+class PersistentArenaBufferAllocator : public IPersistentBufferAllocator {
+ public:
+ PersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
+ virtual ~PersistentArenaBufferAllocator();
+
+ // Allocates persistent memory. The persistent buffer is never freed.
+ // Returns nullptr if errors occured.
+ uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) override;
+
+ // Returns the size of all persistent allocations in bytes.
+ size_t GetPersistentUsedBytes() const override;
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+ private:
+ // The memory arena that this allocator manages.
+ uint8_t* const buffer_head_;
+ uint8_t* const buffer_tail_;
+
+ // The whole region is split into two parts:
+ // tail_temp_ to buffer_tail_ contains allocated buffers;
+ // buffer_head_ to tail_temp_ - 1 belongs to still available spaces.
+ // So in essence, the allocated region grows from the bottom and emulates
+ // SingleArenaBufferAllocator's persistent part.
+ uint8_t* tail_temp_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.cc
similarity index 55%
rename from code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc
rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.cc
index 6d3e72bd..0f24a0b5 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/micro/recording_simple_memory_allocator.h"
+#include "tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h"
#include
@@ -21,47 +21,49 @@ limitations under the License.
namespace tflite {
-RecordingSimpleMemoryAllocator::RecordingSimpleMemoryAllocator(
+RecordingSingleArenaBufferAllocator::RecordingSingleArenaBufferAllocator(
ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size)
- : SimpleMemoryAllocator(error_reporter, buffer_head, buffer_size),
+ : SingleArenaBufferAllocator(error_reporter, buffer_head, buffer_size),
requested_head_bytes_(0),
requested_tail_bytes_(0),
used_bytes_(0),
alloc_count_(0) {}
-RecordingSimpleMemoryAllocator::~RecordingSimpleMemoryAllocator() {}
+RecordingSingleArenaBufferAllocator::~RecordingSingleArenaBufferAllocator() {}
-RecordingSimpleMemoryAllocator* RecordingSimpleMemoryAllocator::Create(
- ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size) {
+RecordingSingleArenaBufferAllocator*
+RecordingSingleArenaBufferAllocator::Create(ErrorReporter* error_reporter,
+ uint8_t* buffer_head,
+ size_t buffer_size) {
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(buffer_head != nullptr);
- RecordingSimpleMemoryAllocator tmp =
- RecordingSimpleMemoryAllocator(error_reporter, buffer_head, buffer_size);
+ RecordingSingleArenaBufferAllocator tmp = RecordingSingleArenaBufferAllocator(
+ error_reporter, buffer_head, buffer_size);
- uint8_t* allocator_buffer =
- tmp.AllocatePersistentBuffer(sizeof(RecordingSimpleMemoryAllocator),
- alignof(RecordingSimpleMemoryAllocator));
+ uint8_t* allocator_buffer = tmp.AllocatePersistentBuffer(
+ sizeof(RecordingSingleArenaBufferAllocator),
+ alignof(RecordingSingleArenaBufferAllocator));
// Use the default copy constructor to populate internal states.
- return new (allocator_buffer) RecordingSimpleMemoryAllocator(tmp);
+ return new (allocator_buffer) RecordingSingleArenaBufferAllocator(tmp);
}
-size_t RecordingSimpleMemoryAllocator::GetRequestedBytes() const {
+size_t RecordingSingleArenaBufferAllocator::GetRequestedBytes() const {
return requested_head_bytes_ + requested_tail_bytes_;
}
-size_t RecordingSimpleMemoryAllocator::GetUsedBytes() const {
+size_t RecordingSingleArenaBufferAllocator::GetUsedBytes() const {
return used_bytes_;
}
-size_t RecordingSimpleMemoryAllocator::GetAllocatedCount() const {
+size_t RecordingSingleArenaBufferAllocator::GetAllocatedCount() const {
return alloc_count_;
}
-TfLiteStatus RecordingSimpleMemoryAllocator::ResizeBuffer(
+TfLiteStatus RecordingSingleArenaBufferAllocator::ResizeBuffer(
uint8_t* resizable_buf, size_t size, size_t alignment) {
const uint8_t* previous_head = head();
TfLiteStatus status =
- SimpleMemoryAllocator::ResizeBuffer(resizable_buf, size, alignment);
+ SingleArenaBufferAllocator::ResizeBuffer(resizable_buf, size, alignment);
if (status == kTfLiteOk) {
used_bytes_ += head() - previous_head;
requested_head_bytes_ = size;
@@ -69,11 +71,11 @@ TfLiteStatus RecordingSimpleMemoryAllocator::ResizeBuffer(
return status;
}
-uint8_t* RecordingSimpleMemoryAllocator::AllocatePersistentBuffer(
+uint8_t* RecordingSingleArenaBufferAllocator::AllocatePersistentBuffer(
size_t size, size_t alignment) {
const uint8_t* previous_tail = tail();
uint8_t* result =
- SimpleMemoryAllocator::AllocatePersistentBuffer(size, alignment);
+ SingleArenaBufferAllocator::AllocatePersistentBuffer(size, alignment);
if (result != nullptr) {
used_bytes_ += previous_tail - tail();
requested_tail_bytes_ += size;
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h
similarity index 63%
rename from code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h
rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h
index a251e940..3cec561e 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h
@@ -13,28 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
+#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
namespace tflite {
-// Utility class used to log allocations of a SimpleMemoryAllocator. Should only
-// be used in debug/evaluation settings or unit tests to evaluate allocation
-// usage.
-class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
+// Utility class used to log allocations of a SingleArenaBufferAllocator. Should
+// only be used in debug/evaluation settings or unit tests to evaluate
+// allocation usage.
+class RecordingSingleArenaBufferAllocator : public SingleArenaBufferAllocator {
public:
- RecordingSimpleMemoryAllocator(ErrorReporter* error_reporter,
- uint8_t* buffer_head, size_t buffer_size);
+ RecordingSingleArenaBufferAllocator(ErrorReporter* error_reporter,
+ uint8_t* buffer_head, size_t buffer_size);
// TODO(b/157615197): Cleanup constructors/destructor and use factory
// functions.
- ~RecordingSimpleMemoryAllocator() override;
+ ~RecordingSingleArenaBufferAllocator() override;
- static RecordingSimpleMemoryAllocator* Create(ErrorReporter* error_reporter,
- uint8_t* buffer_head,
- size_t buffer_size);
+ static RecordingSingleArenaBufferAllocator* Create(
+ ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size);
// Returns the number of bytes requested from the head or tail.
size_t GetRequestedBytes() const;
@@ -62,4 +61,4 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
} // namespace tflite
-#endif // TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.cc
similarity index 66%
rename from code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc
rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.cc
index e5d87afb..15d512bd 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
+#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include
#include
@@ -29,42 +29,45 @@ limitations under the License.
namespace tflite {
-SimpleMemoryAllocator::SimpleMemoryAllocator(ErrorReporter* error_reporter,
- uint8_t* buffer_head,
- uint8_t* buffer_tail)
- : error_reporter_(error_reporter),
+SingleArenaBufferAllocator::SingleArenaBufferAllocator(
+ ErrorReporter* error_reporter, uint8_t* buffer_head, uint8_t* buffer_tail)
+ :
+#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
+ error_reporter_(error_reporter),
+#endif
buffer_head_(buffer_head),
buffer_tail_(buffer_tail),
head_(buffer_head),
tail_(buffer_tail),
- temp_(buffer_head_) {}
+ temp_(buffer_head_) {
+}
-SimpleMemoryAllocator::SimpleMemoryAllocator(ErrorReporter* error_reporter,
- uint8_t* buffer,
- size_t buffer_size)
- : SimpleMemoryAllocator(error_reporter, buffer, buffer + buffer_size) {}
+SingleArenaBufferAllocator::SingleArenaBufferAllocator(
+ ErrorReporter* error_reporter, uint8_t* buffer, size_t buffer_size)
+ : SingleArenaBufferAllocator(error_reporter, buffer, buffer + buffer_size) {
+}
/* static */
-SimpleMemoryAllocator* SimpleMemoryAllocator::Create(
+SingleArenaBufferAllocator* SingleArenaBufferAllocator::Create(
ErrorReporter* error_reporter, uint8_t* buffer_head, size_t buffer_size) {
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(buffer_head != nullptr);
- SimpleMemoryAllocator tmp =
- SimpleMemoryAllocator(error_reporter, buffer_head, buffer_size);
+ SingleArenaBufferAllocator tmp =
+ SingleArenaBufferAllocator(error_reporter, buffer_head, buffer_size);
- // Allocate enough bytes from the buffer to create a SimpleMemoryAllocator.
- // The new instance will use the current adjusted tail buffer from the tmp
- // allocator instance.
+ // Allocate enough bytes from the buffer to create a
+ // SingleArenaBufferAllocator. The new instance will use the current adjusted
+ // tail buffer from the tmp allocator instance.
uint8_t* allocator_buffer = tmp.AllocatePersistentBuffer(
- sizeof(SimpleMemoryAllocator), alignof(SimpleMemoryAllocator));
+ sizeof(SingleArenaBufferAllocator), alignof(SingleArenaBufferAllocator));
// Use the default copy constructor to populate internal states.
- return new (allocator_buffer) SimpleMemoryAllocator(tmp);
+ return new (allocator_buffer) SingleArenaBufferAllocator(tmp);
}
-SimpleMemoryAllocator::~SimpleMemoryAllocator() {}
+SingleArenaBufferAllocator::~SingleArenaBufferAllocator() {}
-uint8_t* SimpleMemoryAllocator::AllocateResizableBuffer(size_t size,
- size_t alignment) {
+uint8_t* SingleArenaBufferAllocator::AllocateResizableBuffer(size_t size,
+ size_t alignment) {
// Only supports one resizable buffer, which starts at the buffer head.
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
if (ResizeBuffer(expect_resizable_buf, size, alignment) == kTfLiteOk) {
@@ -73,20 +76,20 @@ uint8_t* SimpleMemoryAllocator::AllocateResizableBuffer(size_t size,
return nullptr;
}
-TfLiteStatus SimpleMemoryAllocator::DeallocateResizableBuffer(
+TfLiteStatus SingleArenaBufferAllocator::DeallocateResizableBuffer(
uint8_t* resizable_buf) {
return ResizeBuffer(resizable_buf, 0, 1);
}
-TfLiteStatus SimpleMemoryAllocator::ReserveNonPersistentOverlayMemory(
+TfLiteStatus SingleArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
size_t size, size_t alignment) {
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
return ResizeBuffer(expect_resizable_buf, size, alignment);
}
-TfLiteStatus SimpleMemoryAllocator::ResizeBuffer(uint8_t* resizable_buf,
- size_t size,
- size_t alignment) {
+TfLiteStatus SingleArenaBufferAllocator::ResizeBuffer(uint8_t* resizable_buf,
+ size_t size,
+ size_t alignment) {
// Only supports one resizable buffer, which starts at the buffer head.
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
if (head_ != temp_ || resizable_buf != expect_resizable_buf) {
@@ -112,8 +115,8 @@ TfLiteStatus SimpleMemoryAllocator::ResizeBuffer(uint8_t* resizable_buf,
return kTfLiteOk;
}
-uint8_t* SimpleMemoryAllocator::AllocatePersistentBuffer(size_t size,
- size_t alignment) {
+uint8_t* SingleArenaBufferAllocator::AllocatePersistentBuffer(
+ size_t size, size_t alignment) {
uint8_t* const aligned_result = AlignPointerDown(tail_ - size, alignment);
if (aligned_result < head_) {
#ifndef TF_LITE_STRIP_ERROR_STRINGS
@@ -129,7 +132,8 @@ uint8_t* SimpleMemoryAllocator::AllocatePersistentBuffer(size_t size,
return aligned_result;
}
-uint8_t* SimpleMemoryAllocator::AllocateTemp(size_t size, size_t alignment) {
+uint8_t* SingleArenaBufferAllocator::AllocateTemp(size_t size,
+ size_t alignment) {
uint8_t* const aligned_result = AlignPointerUp(temp_, alignment);
const size_t available_memory = tail_ - aligned_result;
if (available_memory < size) {
@@ -145,12 +149,12 @@ uint8_t* SimpleMemoryAllocator::AllocateTemp(size_t size, size_t alignment) {
return aligned_result;
}
-void SimpleMemoryAllocator::DeallocateTemp(uint8_t* temp_buf) {
+void SingleArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
temp_buffer_ptr_check_sum_ ^= (reinterpret_cast(temp_buf));
temp_buffer_count_--;
}
-bool SimpleMemoryAllocator::IsAllTempDeallocated() {
+bool SingleArenaBufferAllocator::IsAllTempDeallocated() {
if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
MicroPrintf(
"Number of allocated temp buffers: %d. Checksum passing status: %d",
@@ -160,7 +164,7 @@ bool SimpleMemoryAllocator::IsAllTempDeallocated() {
return true;
}
-TfLiteStatus SimpleMemoryAllocator::ResetTempAllocations() {
+TfLiteStatus SingleArenaBufferAllocator::ResetTempAllocations() {
// TODO(b/209453859): enable error check based on IsAllTempDeallocated after
// all AllocateTemp have been paird with DeallocateTemp
if (!IsAllTempDeallocated()) {
@@ -172,34 +176,34 @@ TfLiteStatus SimpleMemoryAllocator::ResetTempAllocations() {
return kTfLiteOk;
}
-uint8_t* SimpleMemoryAllocator::GetOverlayMemoryAddress() const {
+uint8_t* SingleArenaBufferAllocator::GetOverlayMemoryAddress() const {
return buffer_head_;
}
-size_t SimpleMemoryAllocator::GetNonPersistentUsedBytes() const {
+size_t SingleArenaBufferAllocator::GetNonPersistentUsedBytes() const {
return std::max(head_ - buffer_head_, temp_ - buffer_head_);
}
-size_t SimpleMemoryAllocator::GetPersistentUsedBytes() const {
+size_t SingleArenaBufferAllocator::GetPersistentUsedBytes() const {
return buffer_tail_ - tail_;
}
-size_t SimpleMemoryAllocator::GetAvailableMemory(size_t alignment) const {
+size_t SingleArenaBufferAllocator::GetAvailableMemory(size_t alignment) const {
uint8_t* const aligned_temp = AlignPointerUp(temp_, alignment);
uint8_t* const aligned_tail = AlignPointerDown(tail_, alignment);
return aligned_tail - aligned_temp;
}
-size_t SimpleMemoryAllocator::GetUsedBytes() const {
+size_t SingleArenaBufferAllocator::GetUsedBytes() const {
return GetPersistentUsedBytes() + GetNonPersistentUsedBytes();
}
-size_t SimpleMemoryAllocator::GetBufferSize() const {
+size_t SingleArenaBufferAllocator::GetBufferSize() const {
return buffer_tail_ - buffer_head_;
}
-uint8_t* SimpleMemoryAllocator::head() const { return head_; }
+uint8_t* SingleArenaBufferAllocator::head() const { return head_; }
-uint8_t* SimpleMemoryAllocator::tail() const { return tail_; }
+uint8_t* SingleArenaBufferAllocator::tail() const { return tail_; }
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h
similarity index 83%
rename from code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h
rename to code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h
index d88c4a3d..d3be1f23 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h
@@ -13,37 +13,37 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
#include
#include
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
#include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/ibuffer_allocator.h"
namespace tflite {
// TODO(petewarden): This allocator never frees up or reuses any memory, even
// though we have enough information about lifetimes of the tensors to do so.
// This makes it pretty wasteful, so we should use a more intelligent method.
-class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
- public IPersistentBufferAllocator {
+class SingleArenaBufferAllocator : public INonPersistentBufferAllocator,
+ public IPersistentBufferAllocator {
public:
// TODO(b/157615197): Cleanup constructors/destructor and use factory
// functions.
- SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer_head,
- uint8_t* buffer_tail);
- SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer,
- size_t buffer_size);
- virtual ~SimpleMemoryAllocator();
+ SingleArenaBufferAllocator(ErrorReporter* error_reporter,
+ uint8_t* buffer_head, uint8_t* buffer_tail);
+ SingleArenaBufferAllocator(ErrorReporter* error_reporter, uint8_t* buffer,
+ size_t buffer_size);
+ virtual ~SingleArenaBufferAllocator();
- // Creates a new SimpleMemoryAllocator from a given buffer head and size.
- static SimpleMemoryAllocator* Create(ErrorReporter* error_reporter,
- uint8_t* buffer_head,
- size_t buffer_size);
+ // Creates a new SingleArenaBufferAllocator from a given buffer head and size.
+ static SingleArenaBufferAllocator* Create(ErrorReporter* error_reporter,
+ uint8_t* buffer_head,
+ size_t buffer_size);
// Resizes a buffer that is previously returned by the
// AllocateResizableBuffer. In current implementation, it Adjusts the head
@@ -126,7 +126,9 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
private:
size_t GetBufferSize() const;
+#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
ErrorReporter* error_reporter_;
+#endif
uint8_t* buffer_head_;
uint8_t* buffer_tail_;
uint8_t* head_;
@@ -147,4 +149,4 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
} // namespace tflite
-#endif // TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
+#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SINGLE_ARENA_BUFFER_ALLOCATOR_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
index 5a5ba9ab..2403c6b1 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
@@ -16,10 +16,10 @@ limitations under the License.
#include "tensorflow/lite/micro/fake_micro_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
namespace tflite {
namespace {
@@ -30,7 +30,7 @@ static uint8_t dummy_tensor_arena[KDummyTensorArenaSize];
} // namespace
FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
- SimpleMemoryAllocator* allocator,
+ SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph)
: MicroContext(
MicroAllocator::Create(dummy_tensor_arena, KDummyTensorArenaSize,
@@ -67,10 +67,10 @@ TfLiteEvalTensor* FakeMicroContext::GetEvalTensor(int tensor_index) {
}
void* FakeMicroContext::AllocatePersistentBuffer(size_t bytes) {
- // FakeMicroContext use SimpleMemoryAllocator, which does not automatically
- // apply the buffer alignment like MicroAllocator.
- // The buffer alignment is potentially wasteful but allows the
- // fake_micro_context to work correctly with optimized kernels.
+ // FakeMicroContext use SingleArenaBufferAllocator, which does not
+ // automatically apply the buffer alignment like MicroAllocator. The buffer
+ // alignment is potentially wasteful but allows the fake_micro_context to work
+ // correctly with optimized kernels.
return allocator_->AllocatePersistentBuffer(bytes,
MicroArenaBufferAlignment());
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h
index 99933c19..31b39d38 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h
@@ -23,7 +23,7 @@ namespace tflite {
// A fake of MicroContext for kernel util tests.
class FakeMicroContext : public MicroContext {
public:
- FakeMicroContext(TfLiteTensor* tensors, SimpleMemoryAllocator* allocator,
+ FakeMicroContext(TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph);
void* AllocatePersistentBuffer(size_t bytes) override;
@@ -46,7 +46,7 @@ class FakeMicroContext : public MicroContext {
TfLiteTensor* tensors_;
int allocated_tensor_count_ = 0;
- SimpleMemoryAllocator* allocator_;
+ SingleArenaBufferAllocator* allocator_;
TF_LITE_REMOVE_VIRTUAL_DELETE
};
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc
index c556ac64..e0b79631 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default: {
- TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
- TfLiteTypeGetName(input->type));
+ MicroPrintf("Only float32 is supported currently, got %s",
+ TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default: {
- TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
- TfLiteTypeGetName(input->type));
+ MicroPrintf("Only float32 is supported currently, got %s",
+ TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_RELU() {
- return {/*init=*/ReluInit,
- /*free=*/nullptr,
- /*prepare=*/ReluPrepare,
- /*invoke=*/ReluEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
}
TfLiteRegistration Register_RELU6() {
- return {/*init=*/Relu6Init,
- /*free=*/nullptr,
- /*prepare=*/Relu6Prepare,
- /*invoke=*/Relu6Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc
index 75523d14..f75db4e5 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc
@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_ADD() {
- return {/*init=*/AddInit,
- /*free=*/nullptr,
- /*prepare=*/AddPrepare,
- /*invoke=*/AddEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
index 5d0ab724..ce064687 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_ADD_N() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc
index 8217a4a0..a8aa5a48 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc
@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace arg_min_max
TfLiteRegistration Register_ARG_MAX() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/arg_min_max::ArgMaxEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
}
TfLiteRegistration Register_ARG_MIN() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/arg_min_max::ArgMinEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
}
} // namespace micro
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
index e28ebebb..a770d0aa 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_ASSIGN_VARIABLE() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
index 07b680df..be82d942 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
index fa333249..be2672ec 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_BROADCAST_ARGS() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/BroadcastArgsPrepare,
- /*invoke=*/BroadcastArgsEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
+ BroadcastArgsEval);
}
-} // namespace tflite
\ No newline at end of file
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
index 5302faf1..63a14db2 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_BROADCAST_TO() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/BroadcastToPrepare,
- /*invoke=*/BroadcastToEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
+ BroadcastToEval);
}
-} // namespace tflite
\ No newline at end of file
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
index 4db39f7d..200242b2 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.
TfLiteRegistration Register_CALL_ONCE() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
index dc651a24..a1f4516b 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CAST() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
index d0a48f91..a390a735 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace ceil
TfLiteRegistration Register_CEIL() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/ceil::Prepare,
- /*invoke=*/ceil::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
}
} // namespace micro
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc
index bda3e66a..399d1648 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc
@@ -108,14 +108,8 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
- static TfLiteRegistration r = {/*init=*/CircularBufferInit,
- /*free=*/nullptr,
- /*prepare=*/CircularBufferPrepare,
- /*invoke=*/CircularBufferEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ static TfLiteRegistration r = tflite::micro::RegisterOp(
+ CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
return &r;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc
index 682efb43..81db6e65 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc
@@ -39,13 +39,12 @@ const int kCircularBufferCyclesMaxIndex = 0; // 'cycles_max'
const TfLiteStatus kTfLiteAbort = static_cast(-9);
TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
+ MicroContext* micro_context = GetMicroContext(context);
- MicroContext * micro_context = GetMicroContext(context);
-
- TfLiteTensor* input =
- micro_context-> AllocateTempInputTensor(node, kCircularBufferInputTensor);
- TfLiteTensor* output =
- micro_context-> AllocateTempOutputTensor(node, kCircularBufferOutputTensor);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kCircularBufferInputTensor);
+ TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
+ node, kCircularBufferOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
OpDataCircularBuffer* op_data =
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
index 925c3fb5..cff15e4d 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} // namespace comparisons
TfLiteRegistration Register_EQUAL() {
- return {/*init=*/comparisons::Init,
- /*free=*/nullptr,
- /*prepare=*/comparisons::Prepare,
- /*invoke=*/comparisons::EqualEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+ comparisons::EqualEval);
}
TfLiteRegistration Register_NOT_EQUAL() {
- return {/*init=*/comparisons::Init,
- /*free=*/nullptr,
- /*prepare=*/comparisons::Prepare,
- /*invoke=*/comparisons::NotEqualEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+ comparisons::NotEqualEval);
}
TfLiteRegistration Register_GREATER() {
- return {/*init=*/comparisons::Init,
- /*free=*/nullptr,
- /*prepare=*/comparisons::Prepare,
- /*invoke=*/comparisons::GreaterEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+ comparisons::GreaterEval);
}
TfLiteRegistration Register_GREATER_EQUAL() {
- return {/*init=*/comparisons::Init,
- /*free=*/nullptr,
- /*prepare=*/comparisons::Prepare,
- /*invoke=*/comparisons::GreaterEqualEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+ comparisons::GreaterEqualEval);
}
TfLiteRegistration Register_LESS() {
- return {/*init=*/comparisons::Init,
- /*free=*/nullptr,
- /*prepare=*/comparisons::Prepare,
- /*invoke=*/comparisons::LessEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+ comparisons::LessEval);
}
TfLiteRegistration Register_LESS_EQUAL() {
- return {/*init=*/comparisons::Init,
- /*free=*/nullptr,
- /*prepare=*/comparisons::Prepare,
- /*invoke=*/comparisons::LessEqualEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+ comparisons::LessEqualEval);
}
} // namespace micro
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
index d727a0d5..34622c22 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input != nullptr);
int num_dimensions = NumDimensions(input);
- if (num_dimensions > 4) {
+ if (num_dimensions > RuntimeShape::kMaxSmallSize) {
TF_LITE_KERNEL_LOG(
context,
- "Op Concatenation does not currently support num dimensions >4 "
+ "Op Concatenation does not currently support num dimensions > %d "
"Tensor has %d dimensions.",
- num_dimensions);
+ RuntimeShape::kMaxSmallSize, num_dimensions);
return kTfLiteError;
}
micro_context->DeallocateTempTfLiteTensor(input);
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace concatenation
TfLiteRegistration Register_CONCATENATION() {
- return {/*init=*/concatenation::Init,
- /*free=*/nullptr,
- /*prepare=*/concatenation::Prepare,
- /*invoke=*/concatenation::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
+ concatenation::Eval);
}
} // namespace micro
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc
index 0fed1223..87ea92e6 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
namespace tflite {
namespace {
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData(filter),
tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData(bias),
+ tflite::micro::GetOptionalTensorData(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
break;
}
case kTfLiteInt16: {
- reference_integer_ops::ConvPerChannel(
- ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
- data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData(output));
+ switch (bias->type) {
+ case kTfLiteInt32: {
+ reference_integer_ops::ConvPerChannel(
+ ConvParamsQuantized(params, data),
+ data.per_channel_output_multiplier, data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetOptionalTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ break;
+ }
+ case kTfLiteInt64: {
+ reference_integer_ops::ConvPerChannel(
+ ConvParamsQuantized(params, data),
+ data.per_channel_output_multiplier, data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetOptionalTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ break;
+ }
+ default:
+ MicroPrintf("Bias type %s (%d) not supported.",
+ TfLiteTypeGetName(bias->type), bias->type);
+ return kTfLiteError;
+ }
break;
}
case kTfLiteInt8: {
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData(filter),
tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData(bias),
+ tflite::micro::GetOptionalTensorData(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData(output));
break;
}
default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input->type), input->type);
+ MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CONV_2D() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/ConvPrepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h
index 38b69525..47ba8ac4 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h
@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
TfLiteRegistration registration, int16_t* output_data);
+TfLiteStatus TestConvQuantizedPerChannel(
+ int* input_dims_data, const float* input_data, int16_t* input_quantized,
+ float input_scale, int input_zero_point, int* filter_dims_data,
+ const float* filter_data, int8_t* filter_data_quantized,
+ int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
+ float* bias_scales, int* bias_zero_points, int* output_dims_data,
+ const float* expected_output_data, int16_t* expected_output_data_quantized,
+ float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
+ TfLiteRegistration registration, int16_t* output_data);
+
} // namespace testing
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
index 61f7af23..eedc61fd 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_CUMSUM() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
index cce93c9c..ec000540 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_DEPTH_TO_SPACE() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc
index 8a58433a..d2468ff9 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc
@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData(filter),
tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData(bias),
+ tflite::micro::GetOptionalTensorData(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData(output));
break;
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData(filter),
tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData(bias),
+ tflite::micro::GetOptionalTensorData(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData(output));
break;
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/DepthwiseConvPrepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h
index 7a7eb0ba..562438d7 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
+// This is the most generic TfLiteRegistration. The actual supported types may
+// still be target dependent. The only requirement is that every implementation
+// (reference or optimized) must define this function.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D();
+
+#if defined(CMSIS_NN)
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int8 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
+
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int16 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
+
+#else
+inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
+ return Register_DEPTHWISE_CONV_2D();
+}
+
+inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
+ return Register_DEPTHWISE_CONV_2D();
+}
+#endif
+
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc
index 4438ea33..1cf7f133 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc
@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData(output));
break;
+ case kTfLiteUInt8:
+ reference_ops::Dequantize(data->quantization_params,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_DEQUANTIZE() {
- return {/*init=*/DequantizeInit,
- /*free=*/nullptr,
- /*prepare=*/DequantizePrepare,
- /*invoke=*/DequantizeEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
+ DequantizeEval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
index 4be5ad89..438f9cda 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
- TF_LITE_ENSURE(context,
- input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
+ TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
+ input->type == kTfLiteInt16 ||
+ input->type == kTfLiteUInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
if (output->type == kTfLiteInt32) {
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
index efe57e2f..326d87b5 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return op_data;
}
-void Free(TfLiteContext* context, void* buffer) {}
-
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = static_cast(node->user_data);
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
- static TfLiteRegistration r = {/*init=*/Init,
- /*free=*/Free,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/div.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/div.cc
new file mode 100644
index 00000000..099c0225
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/div.cc
@@ -0,0 +1,208 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/internal/reference/div.h"
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+namespace tflite {
+namespace {
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+struct OpDataDiv {
+ // Parameters used in the quantized paths where the output is 8bit
+ int32_t input1_zero_point;
+ int32_t input2_zero_point;
+ int32_t output_zero_point;
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+
+ // Parameters used in all quantized paths
+ int32_t output_multiplier;
+ int output_shift;
+};
+
+TfLiteStatus CalculateOpDataDiv(TfLiteContext* context, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output,
+ TfLiteDivParams* params, OpDataDiv* data) {
+ TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
+ TF_LITE_ENSURE_TYPES_EQ(context, input1->type, output->type);
+
+ if (output->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
+ context, params->activation, output, &data->output_activation_min,
+ &data->output_activation_max));
+ const double real_multiplier = static_cast(
+ input1->params.scale / (input2->params.scale * output->params.scale));
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ data->input1_zero_point = input1->params.zero_point;
+ data->input2_zero_point = input2->params.zero_point;
+ data->output_zero_point = output->params.zero_point;
+ }
+
+ return kTfLiteOk;
+}
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(OpDataDiv));
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* input1 =
+ micro_context->AllocateTempInputTensor(node, kInputTensor1);
+ TF_LITE_ENSURE(context, input1 != nullptr);
+ TfLiteTensor* input2 =
+ micro_context->AllocateTempInputTensor(node, kInputTensor2);
+ TF_LITE_ENSURE(context, input2 != nullptr);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+
+ OpDataDiv* data = static_cast(node->user_data);
+ auto* params = reinterpret_cast(node->builtin_data);
+
+ TF_LITE_ENSURE_STATUS(
+ CalculateOpDataDiv(context, input1, input2, output, params, data));
+
+ micro_context->DeallocateTempTfLiteTensor(input1);
+ micro_context->DeallocateTempTfLiteTensor(input2);
+ micro_context->DeallocateTempTfLiteTensor(output);
+ return kTfLiteOk;
+}
+
+void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
+ const OpDataDiv* data, const TfLiteEvalTensor* input1,
+ const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
+ tflite::ArithmeticParams op_params = {};
+
+#define TF_LITE_DIV(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, tflite::micro::GetTensorShape(input1), \
+ tflite::micro::GetTensorData(input1), \
+ tflite::micro::GetTensorShape(input2), \
+ tflite::micro::GetTensorData(input2), \
+ tflite::micro::GetTensorShape(output), \
+ tflite::micro::GetTensorData(output))
+
+ bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
+ tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorShape(input2), &op_params);
+
+ if (requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDivSlow, float);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, float);
+ }
+#undef TF_LITE_DIV
+}
+
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDivParams* params, const OpDataDiv* data,
+ const TfLiteEvalTensor* input1,
+ const TfLiteEvalTensor* input2,
+ TfLiteEvalTensor* output) {
+ tflite::ArithmeticParams op_params = {};
+
+#define TF_LITE_DIV(type, opname, dtype) \
+ type::opname(op_params, tflite::micro::GetTensorShape(input1), \
+ tflite::micro::GetTensorData(input1), \
+ tflite::micro::GetTensorShape(input2), \
+ tflite::micro::GetTensorData(input2), \
+ tflite::micro::GetTensorShape(output), \
+ tflite::micro::GetTensorData(output))
+
+ if (input1->type == kTfLiteInt8 && input2->type == kTfLiteInt8 &&
+ output->type == kTfLiteInt8) {
+ SetActivationParams(data->output_activation_min,
+ data->output_activation_max, &op_params);
+ op_params.input1_offset = -data->input1_zero_point;
+ op_params.input2_offset = -data->input2_zero_point;
+ op_params.output_offset = data->output_zero_point;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = data->output_shift;
+
+ bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
+ tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorShape(input2), &op_params);
+
+ if (requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDivSlow, int8_t);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, int8_t);
+ }
+#undef TF_LITE_DIV
+ } else {
+ TF_LITE_KERNEL_LOG(
+ context, "Unsupported combination of input and output types in DIV.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+ auto* params = static_cast(node->builtin_data);
+ TFLITE_DCHECK(node->user_data != nullptr);
+ auto* data = static_cast(node->user_data);
+
+ const TfLiteEvalTensor* input1 =
+ tflite::micro::GetEvalInput(context, node, kInputTensor1);
+ const TfLiteEvalTensor* input2 =
+ tflite::micro::GetEvalInput(context, node, kInputTensor2);
+ TfLiteEvalTensor* output =
+ tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+ EvalDiv(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
+ input1, input2, output));
+ } else {
+ TF_LITE_KERNEL_LOG(context,
+ "DIV only supports FLOAT32, quantized INT8 "
+ "now, got type %s (%d).",
+ TfLiteTypeGetName(output->type), output->type);
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteRegistration Register_DIV() {
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
index 366dd610..b1cb1dcb 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@ limitations under the License.
#include
#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -27,6 +29,22 @@ namespace micro {
namespace elementwise {
namespace {
+constexpr int kAbsNameId = 0;
+constexpr int kRsrqtNameId = 1;
+
+const int kElementwiseInputTensor = 0;
+const int kElementwiseOutputTensor = 0;
+
+struct OpDataAbsRsqrt {
+ int32_t multiplier;
+ int shift;
+ int input_offset;
+ int output_offset;
+ bool needs_rescale;
+ TfLiteQuantizationType input_quantization_type;
+ TfLiteType input_type;
+};
+
bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32;
}
@@ -35,16 +53,40 @@ bool IsLogicalSupportedType(const TfLiteType type) {
return type == kTfLiteBool;
}
+bool IsAbsSupportedType(const TfLiteType type) {
+ return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
+}
+
+bool IsRsqrtSupportedType(const TfLiteType type) {
+ return type == kTfLiteFloat32 || type == kTfLiteInt8;
+}
+
+inline void SetAbsOutputMultiplier(const float input_scale,
+ const float output_scale,
+ int32_t* multiplier, int* shift) {
+ QuantizeMultiplier(static_cast(input_scale / output_scale),
+ multiplier, shift);
+}
+
+inline void SetRsqrtOutputMultiplier(const float input_scale,
+ const float output_scale,
+ int32_t* multiplier, int* shift) {
+ const double scale =
+ 1. / static_cast((std::sqrt(input_scale) * output_scale));
+ QuantizeMultiplier(scale, multiplier, shift);
+}
+
typedef bool (*IsSupportedType)(TfLiteType);
template
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
-
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
@@ -58,9 +100,79 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+typedef bool (*IsSupportedType)(TfLiteType);
+template
+TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
+ MicroContext* micro_context = GetMicroContext(context);
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
+ TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+ if (!IsSupportedType(input->type)) {
+ TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
+ TfLiteTypeGetName(input->type), input->type);
+ return kTfLiteError;
+ }
+
+ auto* op_data = static_cast(node->user_data);
+ op_data->input_type = input->type;
+
+ // For int16 type input, we support both quantized and non-quantized
+ // evaluation.
+ if (op_nameid == kAbsNameId) {
+ op_data->input_quantization_type = input->quantization.type;
+ }
+
+ if (input->type == kTfLiteInt8 ||
+ (input->type == kTfLiteInt16 &&
+ input->quantization.type != kTfLiteNoQuantization)) {
+ TF_LITE_ENSURE_EQ(context, input->quantization.type,
+ kTfLiteAffineQuantization);
+ TF_LITE_ENSURE_EQ(context, output->quantization.type,
+ kTfLiteAffineQuantization);
+ const auto* input_params =
+ reinterpret_cast(input->quantization.params);
+ const auto* output_params = reinterpret_cast(
+ output->quantization.params);
+ TF_LITE_ENSURE(context, input_params != nullptr);
+ TF_LITE_ENSURE(context, input_params->scale != nullptr);
+ TF_LITE_ENSURE(context, input_params->scale->size > 0);
+ TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
+ TF_LITE_ENSURE(context, output_params != nullptr);
+ TF_LITE_ENSURE(context, output_params->scale != nullptr);
+ TF_LITE_ENSURE(context, output_params->scale->size > 0);
+ TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
+ op_data->input_offset = input_params->zero_point->data[0];
+ op_data->output_offset = output_params->zero_point->data[0];
+ if (input->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
+ TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
+ }
+ const float input_scale = input_params->scale->data[0];
+ const float output_scale = output_params->scale->data[0];
+ op_data->needs_rescale = input_scale != output_scale;
+ if (op_nameid == kAbsNameId && op_data->needs_rescale) {
+ SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
+ &op_data->shift);
+ } else if (op_nameid == kRsrqtNameId) {
+ SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
+ &op_data->shift);
+ }
+ }
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+ return kTfLiteOk;
+}
+
template
-inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
- T func(T), TfLiteType expected_type) {
+inline TfLiteStatus EvalImplQuantized(
+ TfLiteContext* context, TfLiteNode* node,
+ T func(TfLiteContext*, TfLiteNode*, T),
+ TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
+ TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
@@ -68,6 +180,34 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
const T* in_data = tflite::micro::GetTensorData(input);
T* out_data = tflite::micro::GetTensorData(output);
for (size_t i = 0; i < num_elements; ++i) {
+ if (validate_input_func) {
+ TF_LITE_ENSURE_OK(context,
+ validate_input_func(context, node, in_data[i]));
+ }
+ out_data[i] = func(context, node, in_data[i]);
+ }
+ return kTfLiteOk;
+}
+
+template
+inline T AbsHelper(T i) {
+ return std::abs(i);
+}
+
+template
+inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
+ T func(T), TfLiteStatus validate_input_func(T),
+ TfLiteType expected_type) {
+ const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+ TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+ TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
+ const size_t num_elements = ElementCount(*input->dims);
+ const T* in_data = tflite::micro::GetTensorData(input);
+ T* out_data = tflite::micro::GetTensorData(output);
+ for (size_t i = 0; i < num_elements; ++i) {
+ if (validate_input_func) {
+ TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
+ }
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
- return EvalImpl(context, node, float_func, kTfLiteFloat32);
+ return EvalImpl(context, node, float_func,
+ /*validate_input_func=*/nullptr, kTfLiteFloat32);
}
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
+
bool bool_func(bool)) {
- return EvalImpl(context, node, bool_func, kTfLiteBool);
+ return EvalImpl(context, node, bool_func,
+ /*validate_input_func=*/nullptr, kTfLiteBool);
+}
+
+void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
+ size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
+}
+
+template
+inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
+ const auto* op_data = static_cast(node->user_data);
+ const int kMin = std::numeric_limits::min();
+ const int kMax = std::numeric_limits::max();
+
+ const int32_t value = std::abs(i - op_data->input_offset);
+ if (!op_data->needs_rescale) {
+ return static_cast(
+ std::min(std::max(static_cast(value + op_data->output_offset),
+ static_cast(kMin)),
+ static_cast(kMax)));
+ }
+
+ const int32_t output = tflite::MultiplyByQuantizedMultiplier(
+ value, op_data->multiplier, op_data->shift) +
+ op_data->output_offset;
+ return static_cast(std::min(
+ std::max(static_cast(output), static_cast(kMin)),
+ static_cast(kMax)));
+}
+
+template
+inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
+ const auto* op_data = static_cast(node->user_data);
+ const int kMin = std::numeric_limits::min();
+ const int kMax = std::numeric_limits::max();
+
+ const int32_t value = (i - op_data->input_offset);
+ const int32_t kShift = 20; // Shift to keep value integer.
+ if (value == 0) {
+ // Assume that any value close to 0 represents the max output value.
+ return static_cast(kMax);
+ }
+ int32_t inv_sqrt_multiplier;
+ int inv_sqrt_shift;
+ GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
+ &inv_sqrt_shift);
+ const int32_t data = tflite::MultiplyByQuantizedMultiplier(
+ static_cast(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
+ const int32_t output =
+ tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
+ op_data->shift - kShift) +
+ op_data->output_offset;
+ return static_cast(std::min(
+ std::max(static_cast(output), static_cast(kMin)),
+ static_cast(kMax)));
+}
+
+template
+TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
+ T i) {
+ const auto* op_data = static_cast(node->user_data);
+
+ TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
+ "Rsqrt is only defined for positive values");
+ return static_cast(kTfLiteOk);
}
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, std::abs);
+ OpDataAbsRsqrt* op_data = reinterpret_cast(node->user_data);
+ TfLiteType type = op_data->input_type;
+ TfLiteQuantizationType input_quantization_type =
+ op_data->input_quantization_type;
+ TfLiteStatus eval_result;
+
+ switch (type) {
+ case kTfLiteFloat32:
+ eval_result = EvalNumeric(context, node, std::abs);
+ break;
+ case kTfLiteInt8:
+ eval_result =
+ EvalImplQuantized(context, node, AbsEvalQuantized,
+ /*validate_input_func=*/nullptr, type);
+ break;
+ case kTfLiteInt16:
+ eval_result =
+ input_quantization_type == kTfLiteNoQuantization
+ ? EvalImpl(context, node, AbsHelper,
+ /*validate_input_func=*/nullptr, type)
+ : EvalImplQuantized(context, node, AbsEvalQuantized,
+ /*validate_input_func=*/nullptr,
+ type);
+ break;
+ default:
+ TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
+ TfLiteTypeGetName(type));
+ return kTfLiteError;
+ break;
+ }
+ return eval_result;
}
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
+ const auto* op_data = static_cast(node->user_data);
+ TfLiteType type = op_data->input_type;
+ switch (type) {
+ case kTfLiteFloat32:
+ return EvalImpl(
+ context, node, [](float f) { return 1.f / std::sqrt(f); },
+ /*validate_input_func=*/nullptr, type);
+ case kTfLiteInt8:
+ return EvalImplQuantized(context, node,
+ elementwise::RsqrtEvalQuantized,
+ elementwise::validate_input_func, type);
+
+ default:
+ TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
+ TfLiteTypeGetName(type));
+ return kTfLiteError;
+ }
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
@@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace elementwise
TfLiteRegistration Register_ABS() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::AbsEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ elementwise::ElementWiseAbsRsqrtInit,
+ elementwise::PrepareAbsRsqrt,
+ elementwise::AbsEval);
}
TfLiteRegistration Register_SIN() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::SinEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ nullptr, elementwise::GenericPrepare,
+ elementwise::SinEval);
}
TfLiteRegistration Register_COS() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::CosEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ nullptr, elementwise::GenericPrepare,
+ elementwise::CosEval);
}
TfLiteRegistration Register_LOG() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::LogEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ nullptr, elementwise::GenericPrepare,
+ elementwise::LogEval);
}
TfLiteRegistration Register_SQRT() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::SqrtEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ nullptr, elementwise::GenericPrepare,
+ elementwise::SqrtEval);
}
TfLiteRegistration Register_RSQRT() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::RsqrtEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ elementwise::ElementWiseAbsRsqrtInit,
+ elementwise::PrepareAbsRsqrt,
+ elementwise::RsqrtEval);
}
TfLiteRegistration Register_SQUARE() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::SquareEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ nullptr, elementwise::GenericPrepare,
+ elementwise::SquareEval);
}
TfLiteRegistration Register_LOGICAL_NOT() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare,
- /*invoke=*/elementwise::LogicalNotEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(
+ nullptr, elementwise::GenericPrepare,
+ elementwise::LogicalNotEval);
}
} // namespace micro
} // namespace ops
-} // namespace tflite
+} // namespace tflite
\ No newline at end of file
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
index b2cd19cc..0b64e89d 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_ELU() {
- return {/*init=*/EluInit,
- /*free=*/nullptr,
- /*prepare=*/EluPrepare,
- /*invoke=*/EluEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
index 47a17d9f..2f1ac58d 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_ADD() {
- return {/*init=*/AddInit,
- /*free=*/nullptr,
- /*prepare=*/AddPrepare,
- /*invoke=*/AddEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
index 09260482..919dd006 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if ESP_NN
if (input->type == kTfLiteInt8) {
+ data_dims_t input_dims = {
+ .width = input_width, .height = input_height,
+ .channels = input->dims->data[3], 1
+ };
+ data_dims_t output_dims = {
+ .width = output_width, .height = output_height,
+ .channels = output->dims->data[3], 1
+ };
+ data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+ conv_params_t conv_params = {
+ .in_offset = 0, .out_offset = 0,
+ .stride = {params.stride_width, params.stride_height},
+ .padding = {data->op_data.padding.width, data->op_data.padding.height},
+ .dilation = {0, 0}, .activation = {-128, 127}
+ };
+
int scratch_buf_size = esp_nn_get_conv_scratch_size(
- input_width, input_height, input->dims->data[3],
- output->dims->data[3], filter_width, filter_height);
+ &input_dims, &filter_dims, &output_dims, &conv_params);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
const int input_size = input_width * input_height * input_depth;
const int output_size = output_width * output_height * output_depth;
+ data_dims_t input_dims = {
+ .width = input_width, .height = input_height,
+ .channels = input_depth, 1
+ };
+ data_dims_t output_dims = {
+ .width = output_width, .height = output_height,
+ .channels = output_depth, 1
+ };
+ data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+ conv_params_t conv_params = {
+ .in_offset = input_offset, .out_offset = output_offset,
+ .stride = {stride_width, stride_height},
+ .padding = {pad_width, pad_height},
+ .dilation = {0, 0},
+ .activation = {activation_min, activation_max}
+ };
+ quant_data_t quant_data = {
+ .shift = data.op_data.per_channel_output_shift,
+ .mult = data.op_data.per_channel_output_multiplier
+ };
+
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
- esp_nn_conv_s8(input_data + i_batch * input_size,
- input_width, input_height, input_depth, input_offset,
- pad_width, pad_height, stride_width, stride_height,
- tflite::micro::GetTensorData(filter),
- filter_width, filter_height,
+ esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
+ &filter_dims, tflite::micro::GetTensorData(filter),
tflite::micro::GetTensorData(bias),
- output_data + i_batch * output_size,
- output_width, output_height, output_depth, output_offset,
- data.op_data.per_channel_output_shift,
- data.op_data.per_channel_output_multiplier,
- activation_min, activation_max);
+ &output_dims, output_data + i_batch * output_size,
+ &conv_params, &quant_data);
}
} else {
reference_integer_ops::ConvPerChannel(
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
- conv_total_time += esp_timer_get_time() - start_time;
+ long long time_this_instance = esp_timer_get_time() - start_time;
+ conv_total_time += time_this_instance;
+ //printf("time this instance: %llu\n", time_this_instance / 1000);
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_CONV_2D() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
index 5f2d9d50..a2460248 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
if (data.buffer_idx > -1) {
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
}
+
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
+ data_dims_t input_dims = {
+ .width = input_width, .height = input_height,
+ .channels = input_depth, 1
+ };
+ data_dims_t output_dims = {
+ .width = output_width, .height = output_height,
+ .channels = output_depth, 1
+ };
+ data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+ dw_conv_params_t conv_params = {
+ .in_offset = input_offset, .out_offset = output_offset,
+ .ch_mult = depth_multiplier,
+ .stride = {stride_width, stride_height},
+ .padding = {pad_width, pad_height}, .dilation = {0, 0},
+ .activation = {activation_min, activation_max}
+ };
+ quant_data_t quant_data = {
+ .shift = data.op_data.per_channel_output_shift,
+ .mult = data.op_data.per_channel_output_multiplier
+ };
+
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
- esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
- input_height, input_depth, input_offset,
- pad_width, pad_height,
- stride_width, stride_height, depth_multiplier,
- tflite::micro::GetTensorData(filter),
- filter_width, filter_height,
+ esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
+ &filter_dims, tflite::micro::GetTensorData(filter),
tflite::micro::GetTensorData(bias),
- output_data + i_batch * output_size,
- output_width, output_height, output_offset,
- data.op_data.per_channel_output_shift,
- data.op_data.per_channel_output_multiplier,
- activation_min, activation_max);
+ &output_dims, output_data + i_batch * output_size,
+ &conv_params, &quant_data);
}
} else {
reference_integer_ops::DepthwiseConvPerChannel(
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if ESP_NN
if (input->type == kTfLiteInt8) {
+ data_dims_t input_dims = {
+ .width = input_width, .height = input_height,
+ .channels = input->dims->data[3], 1
+ };
+ data_dims_t output_dims = {
+ .width = output_width, .height = output_height,
+ .channels = output->dims->data[3], 1
+ };
+ data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+ dw_conv_params_t conv_params = {
+ .in_offset = 0, .out_offset = 0,
+ .ch_mult = params.depth_multiplier,
+ .stride = {params.stride_width, params.stride_height},
+ .padding = {data->op_data.padding.width, data->op_data.padding.height},
+ .dilation = {0, 0}, .activation = {-128, 127}
+ };
+
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
- input_width, input_height, input->dims->data[3],
- params.depth_multiplier, filter_width, filter_height);
+ &input_dims, &filter_dims, &output_dims, &conv_params);
if (scratch_buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, scratch_buf_size, &data->buffer_idx));
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
- dc_total_time += esp_timer_get_time() - start_time;
+ long long time_this_instance = esp_timer_get_time() - start_time;
+ dc_total_time += time_this_instance;
+ // printf("time this instance: %llu\n", time_this_instance / 1000);
+
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
index 5e1705da..484cffb6 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc
index 0e8a82f4..02413f5c 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc
@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteRegistration Register_MUL() {
- return {/*init=*/MulInit,
- /*free=*/nullptr,
- /*prepare=*/MulPrepare,
- /*invoke=*/MulEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc
index d55bab82..b450929e 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc
@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} // namespace
TfLiteRegistration Register_AVERAGE_POOL_2D() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/PoolingPrepare,
- /*invoke=*/AverageEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
}
TfLiteRegistration Register_MAX_POOL_2D() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/PoolingPrepare,
- /*invoke=*/MaxEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc
new file mode 100644
index 00000000..9a967839
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc
@@ -0,0 +1,208 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/kernels/softmax.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/softmax.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+#include "freertos/FreeRTOS.h"
+#include
+
+#if ESP_NN
+#include
+#endif
+
+long long softmax_total_time = 0;
+
+namespace tflite {
+namespace {
+// Softmax parameter data that persists in user_data
+const int kInt16LUTArraySize = 513;
+
+struct NodeData {
+ SoftmaxParams op_data;
+#if ESP_NN
+ int buffer_idx;
+#endif
+};
+
+static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(NodeData));
+}
+
+void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
+ TfLiteEvalTensor* output, const NodeData* data) {
+ if (input->type == kTfLiteInt8) {
+ if (output->type == kTfLiteInt16) {
+ tflite::reference_ops::Softmax(
+ data->op_data, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ } else {
+#if ESP_NN
+ const int32_t input_beta_multiplier = data->op_data.input_multiplier;
+ const int32_t input_beta_left_shift = data->op_data.input_left_shift;
+ const int diff_min = data->op_data.diff_min;
+ const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+ const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int8_t *in_ptr = tflite::micro::GetTensorData(input);
+ int8_t *out_ptr = tflite::micro::GetTensorData(output);
+ void *scratch_buf = NULL;
+ if (data->buffer_idx > -1) {
+ scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
+ }
+ esp_nn_set_softmax_scratch_buf(scratch_buf);
+ esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
+ input_beta_left_shift, diff_min, out_ptr);
+#else
+ tflite::reference_ops::Softmax(
+ data->op_data, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+#endif
+ }
+ } else {
+ tflite::reference_ops::SoftmaxInt16(
+ data->op_data, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ }
+}
+
+static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+ TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+
+ TFLITE_DCHECK(node->user_data != nullptr);
+ NodeData data = *static_cast(node->user_data);
+
+ long long start_time = esp_timer_get_time();
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ tflite::reference_ops::Softmax(
+ data.op_data, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ }
+ break;
+ case kTfLiteInt8:
+ case kTfLiteInt16: {
+ SoftmaxQuantized(context, input, output, &data);
+ }
+ break;
+ default:
+ TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+ TfLiteTypeGetName(input->type), input->type);
+ return kTfLiteError;
+ }
+ softmax_total_time += esp_timer_get_time() - start_time;
+ return kTfLiteOk;
+}
+
+static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
+ TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
+ TF_LITE_ENSURE(context, output != nullptr);
+
+ TF_LITE_ENSURE(context, node->user_data != nullptr);
+ NodeData* data = static_cast(node->user_data);
+ // Only allocate LUTs for KTfLiteInt16 data type
+ if (input->type == kTfLiteInt16) {
+ void* raw_exp_lut = context->AllocatePersistentBuffer(
+ context, sizeof(int16_t) * kInt16LUTArraySize);
+ TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
+ data->op_data.exp_lut = reinterpret_cast(raw_exp_lut);
+ void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
+ context, sizeof(int16_t) * kInt16LUTArraySize);
+ TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
+ data->op_data.one_over_one_plus_x_lut =
+ reinterpret_cast(one_over_one_plus_x_lut);
+ }
+
+ if (output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE(context,
+ input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ }
+
+ // Populate LUT if required
+ if (input->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ // exp LUT only used on negative values
+ // we consider exp(-10.0) is insignificant to accumulation
+ gen_lut(
+ [](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
+ data->op_data.exp_lut);
+ gen_lut(
+ [](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
+ 1.0f, data->op_data.one_over_one_plus_x_lut);
+ data->op_data.zero_point = output->params.zero_point;
+ data->op_data.scale = output->params.scale;
+ }
+
+ auto* params = static_cast(node->builtin_data);
+ auto ret_val =
+ CalculateSoftmaxParams(context, input, output, params, &data->op_data);
+
+#if ESP_NN
+ if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
+ const int32_t input_width = input->dims->data[1];
+ const int32_t input_height = input->dims->data[2];
+ int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
+ input_height);
+ if (scratch_buf_size > 0) {
+ TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
+ context, scratch_buf_size, &data->buffer_idx));
+ }
+ }
+#endif
+
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+ return ret_val;
+}
+
+} // namespace
+
+TfLiteRegistration Register_SOFTMAX() {
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc
index d1b0f6cb..ae26f636 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc
@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_EXP() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc
index 6dcba4d5..4b105bf6 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc
@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_EXPAND_DIMS() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc
index d8a2b09d..9f438b89 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc
@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FILL() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc
index b8be1cf0..6b2a4cc2 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc
@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace floor
TfLiteRegistration Register_FLOOR() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/floor::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
}
} // namespace micro
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc
index d11e4969..333a1eba 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc
@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FLOOR_DIV() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc
index 083bd5cb..9bb49497 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc
@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_FLOOR_MOD() {
- return {/*init=*/Init,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
+ return tflite::micro::RegisterOp(Init, Prepare, Eval);
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc
index c0be3814..a083edd7 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
-
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
- TF_LITE_ENSURE_MSG(context, input->type == filter->type,
- "Hybrid models are not supported on TFLite Micro.");
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
context, params->activation, input->type,
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
+ case kTfLiteInt16: {
+ const int64_t* bias_data =
+ nullptr != bias ? tflite::micro::GetTensorData