diff --git a/FeatureRequest.md b/FeatureRequest.md
index e2e97fd1..9f7f29c4 100644
--- a/FeatureRequest.md
+++ b/FeatureRequest.md
@@ -11,6 +11,10 @@
____
+#### #26 Changes behaviour for "N" replacement
+
+* in case the higher digits has already increased by minium 1 - don't set the "N" to the last value, but to "0"
+* https://github.com/jomjol/AI-on-the-edge-device/issues/792
#### #25 Trigger Measurement via MQTT
diff --git a/README.md b/README.md
index fc27496a..a36a7b8d 100644
--- a/README.md
+++ b/README.md
@@ -52,7 +52,15 @@ In other cases you can contact the developer via email:
~/.ssh/id_rsa_base64
+ - base64 --decode --ignore-garbage ~/.ssh/id_rsa_base64 > ~/.ssh/id_rsa
+ - chmod 600 ~/.ssh/id_rsa
+ - echo -e "Host gitlab.espressif.cn\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config
+
+before_script:
+ # Add gitlab ssh key
+ - *add_ssh_key
+ # Set git config
+ - *set_git_config
+
+.build_esp32s3: &build_esp32s3
+ - idf.py set-target esp32s3 build
+
+.build_esp32: &build_esp32
+ - idf.py set-target esp32 build
+
+build_demo:
+ stage: build
+ image: $CI_DOCKER_REGISTRY/esp32-ci-env:esp-nn
+ tags:
+ - build
+ script:
+ # Clone IDF
+ - git clone --recursive --single-branch -b release/v4.4 --reference-if-able /local_references/gitlab/ https://gitlab-ci-token:${BOT_TOKEN}@gitlab.espressif.cn:6688/espressif/esp-idf.git
+ - cd esp-idf
+ - ./install.sh
+ - . ./export.sh
+ - cd ..
+ # Build examples now
+ - cd test_app
+ # Build esp32s3
+ - *build_esp32s3
+ # Build esp32
+ - *build_esp32
+ - cd -
diff --git a/code/components/esp-nn/CMakeLists.txt b/code/components/esp-nn/CMakeLists.txt
new file mode 100644
index 00000000..da463779
--- /dev/null
+++ b/code/components/esp-nn/CMakeLists.txt
@@ -0,0 +1,48 @@
+idf_build_get_property(idf_target IDF_TARGET)
+
+set(c_srcs
+ "src/activation_functions/esp_nn_relu_ansi.c"
+ "src/basic_math/esp_nn_add_ansi.c"
+ "src/basic_math/esp_nn_mul_ansi.c"
+ "src/convolution/esp_nn_conv_ansi.c"
+ "src/convolution/esp_nn_depthwise_conv_ansi.c"
+ "src/fully_connected/esp_nn_fully_connected_ansi.c"
+ "src/softmax/esp_nn_softmax_ansi.c"
+ "src/softmax/esp_nn_softmax_opt.c"
+ "src/pooling/esp_nn_avg_pool_ansi.c"
+ "src/pooling/esp_nn_max_pool_ansi.c")
+
+if(CONFIG_IDF_TARGET_ESP32S3)
+ set(s3_srcs
+ "src/common/esp_nn_common_functions_esp32s3.S"
+ "src/common/esp_nn_multiply_by_quantized_mult_esp32s3.S"
+ "src/common/esp_nn_multiply_by_quantized_mult_ver1_esp32s3.S"
+ "src/activation_functions/esp_nn_relu_s8_esp32s3.S"
+ "src/basic_math/esp_nn_add_s8_esp32s3.S"
+ "src/basic_math/esp_nn_mul_s8_esp32s3.S"
+ "src/convolution/esp_nn_conv_esp32s3.c"
+ "src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c"
+ "src/convolution/esp_nn_conv_s16_mult8_esp32s3.S"
+ "src/convolution/esp_nn_conv_s16_mult8_1x1_esp32s3.S"
+ "src/convolution/esp_nn_conv_s16_mult4_1x1_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s16_mult1_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s16_mult1_3x3_no_pad_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s16_mult8_3x3_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s16_mult4_esp32s3.S"
+ "src/convolution/esp_nn_depthwise_conv_s16_mult8_esp32s3.S"
+ "src/fully_connected/esp_nn_fully_connected_s8_esp32s3.S"
+ "src/pooling/esp_nn_max_pool_s8_esp32s3.S"
+ "src/pooling/esp_nn_avg_pool_s8_esp32s3.S")
+endif()
+
+idf_component_register(SRCS "${c_srcs}"
+ "${s3_srcs}"
+ INCLUDE_DIRS "include" "src/common")
+
+if(CONFIG_IDF_TARGET_ESP32S3)
+ target_compile_options(${COMPONENT_LIB} PRIVATE -mlongcalls -fno-unroll-loops -O2 -Wno-unused-function)
+else()
+ target_compile_options(${COMPONENT_LIB} PRIVATE -Wno-unused-function)
+endif()
\ No newline at end of file
diff --git a/code/components/esp-nn/Kconfig.projbuild b/code/components/esp-nn/Kconfig.projbuild
new file mode 100644
index 00000000..3bd683fc
--- /dev/null
+++ b/code/components/esp-nn/Kconfig.projbuild
@@ -0,0 +1,29 @@
+menu "ESP-NN"
+
+choice NN_OPTIMIZATIONS
+ bool "Optimization for nn functions"
+ default NN_OPTIMIZED
+ help
+ Use ANSI-C versions for verification and debug purpose.
+ Optimisations are automatically picked up for a chipset.
+ For ESP32-S3, assembly Optimisations are selected.
+ For ESP32, just the ANSI C versions are selected for now.
+
+config NN_ANSI_C
+ bool "ANSI C"
+ help
+ ANSI C versions for verification and debug purposes.
+config NN_OPTIMIZED
+ bool "Optimized versions"
+ help
+ Optimisations are automatically picked up for a chipset.
+ For ESP32-S3, assembly Optimisations are selected.
+ For ESP32, just the ANSI C versions are selected for now.
+endchoice
+
+config NN_OPTIMIZATIONS
+ int
+ default 0 if NN_ANSI_C
+ default 1 if NN_OPTIMIZED
+
+endmenu
diff --git a/code/components/esp-nn/LICENSE b/code/components/esp-nn/LICENSE
new file mode 100644
index 00000000..d6456956
--- /dev/null
+++ b/code/components/esp-nn/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/code/components/esp-nn/README.md b/code/components/esp-nn/README.md
new file mode 100644
index 00000000..0d988c55
--- /dev/null
+++ b/code/components/esp-nn/README.md
@@ -0,0 +1,54 @@
+# ESP-NN
+
+The library contains optimised NN (Neural Network) functions for various Espressif chipsets.
+
+* Supported platforms:
+ * TensorFlow Lite Micro (TFLite Micro). Repo can be found [here](https://github.com/espressif/tflite-micro-esp-examples)
+
+* Supported ESP chipsets include:
+ * ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
+ * ESP32 (ANSI C versions)
+
+## Performance
+
+### Kernelwise performance for s8 versions:
+
+ * Kernelwise performance on ESP32-S3 chip
+ * Numbers are ticks taken for kernel to execute
+ * Chip config: 240MHz, SPI: QPI 80MHz, Data cache: 64KB
+
+ | Function | ANSI C | ESP32-S3 Opt | Opt Ratio | Data info | Memory |
+ | ----------------| --------|---------|---------|-------------|-----------|
+ | elementwise_add | 320397 | 87119 | 3.68 | size = 1615 | External |
+ | elementwise_mul | 125958 | 44239 | 2.85 | size = 1615 | External |
+ | convolution | 4663012 | 428675 | 10.88 | input(10,10), filter(64x1x1x64) | External |
+ | convolution | 301014 | 32433 | 9.28 | input(8,8), filter(16x1x1x16) | External |
+ | convolution | 2115418 | 1020923 | 2.07 | input(10,10), filter(64x3x3x3) | External |
+ | depthwise conv | 1190062 | 203278 | 5.85 | input (18, 18), pad(0,0), stride(1,1) filter: 1x3x3x16 | External |
+ | depthwise conv | 837072 | 182335 | 4.59 | input (12, 12), pad(1,1), stride(1,1) filter: 8x5x5x4 | External |
+ | max pool | 485714 | 76747 | 6.33 | input(16,16), filter (1x3x3x16) | Internal |
+ | avg pool | 541462 | 160580 | 3.37 | input(16,16), filter (1x3x3x16) | Internal |
+ | fully connected | 15853 | 9547 | 1.66 | len: 265, ch = 3 | Internal |
+ | prelu (relu6) | 19472 | 2734 | 7.12 | size, 1615 | Internal |
+
+
+## Configuration
+
+ * To configure, please use `idf.py menuconfig` and under `ESP-NN` select `NN_OPTIMIZATIONS`
+ * There are two options presented:
+ * Optimized versions
+ * ANSI C
+
+ * Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for ESP32, ANSI-C versions are selected by default.
+ * For debugging purposes, you may want to select `ANSI C`
+
+
+## Contributing
+
+If you encounter an issue with ESP-NN, or wish to submit a feature request, please use the Issues section on the Github.
+
+For general questions related to this library, please use the esp32.com forum.
+
+## Copyrights and License
+
+All original source code in this repository is Copyright (C) 2020-2021 Espressif Systems. This source code is licensed under the Apache License 2.0 as described in the file LICENSE.
diff --git a/code/components/esp-nn/include/esp_nn.h b/code/components/esp-nn/include/esp_nn.h
new file mode 100644
index 00000000..a4081871
--- /dev/null
+++ b/code/components/esp-nn/include/esp_nn.h
@@ -0,0 +1,46 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#if defined(CONFIG_NN_OPTIMIZED)
+#ifdef CONFIG_IDF_TARGET_ESP32S3
+#define ARCH_ESP32_S3 1
+#endif
+#ifdef CONFIG_IDF_TARGET_ESP32
+#define ARCH_ESP32 1
+#endif
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* reference kernels included by default */
+#include "esp_nn_ansi_headers.h"
+
+#if defined(CONFIG_NN_OPTIMIZED)
+#ifdef ARCH_ESP32_S3
+#include "esp_nn_esp32s3.h"
+#endif
+#ifdef ARCH_ESP32
+#include "esp_nn_esp32.h"
+#endif
+#else
+#include "esp_nn_ansi_c.h"
+#endif
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/code/components/esp-nn/include/esp_nn_ansi_c.h b/code/components/esp-nn/include/esp_nn_ansi_c.h
new file mode 100644
index 00000000..1612228c
--- /dev/null
+++ b/code/components/esp-nn/include/esp_nn_ansi_c.h
@@ -0,0 +1,46 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/**
+ * @file Header definitions to include for ANSI C versions.
+ * These are just typedefs to pick up ANSI versions.
+ */
+
+#pragma once
+
+#include "esp_nn_ansi_headers.h"
+
+#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
+#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
+
+#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
+
+#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
+
+#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
+#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
+
+#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
+#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
+
+#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
+
+#define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_ansi
+#define esp_nn_max_pool_s8 esp_nn_max_pool_s8_ansi
+
+#define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_ansi
+
+#define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_ansi
+#define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_ansi
+#define esp_nn_softmax_s8 esp_nn_softmax_s8_ansi
diff --git a/code/components/esp-nn/include/esp_nn_ansi_headers.h b/code/components/esp-nn/include/esp_nn_ansi_headers.h
new file mode 100644
index 00000000..f871537d
--- /dev/null
+++ b/code/components/esp-nn/include/esp_nn_ansi_headers.h
@@ -0,0 +1,283 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+/**
+ * @file Header definitions to include for esp_nn reference functions
+ */
+
+#include
+
+/************************** Basic math functions ****************************/
+
+/**
+ * @brief elementwise addition
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ *
+ * shift values are expected to be <= 0
+ */
+void esp_nn_add_elementwise_s8_ansi(const int8_t *input1_data,
+ const int8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ const int32_t input1_mult,
+ const int32_t input2_mult,
+ const int32_t input1_shift,
+ const int32_t input2_shift,
+ const int32_t left_shift,
+ int8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size);
+/**
+ * @brief elementwise multiplication
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ *
+ * output shift is expected to be <= 0
+ */
+void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
+ const int8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ int8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size);
+
+
+/************************** Convolution functions *****************************/
+
+/**
+ * @brief depthwise convolution per channel
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * Version used in tflite is per channel.
+ * This version follows the same footsprints.
+ * Meaning, it has per out_channel shift and multiplier for
+ * requantization
+ *
+ * optimization notes: Though input_offset is int32 type,
+ * offset values are contained in 8 bits [-128, 127]
+ */
+void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+/**
+ * @brief 2d-convolution channelwise
+ *
+ * @note operation: result += (input + offset) * filter
+ *
+ * inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_conv_s8_ansi(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_ch,
+ const uint16_t out_ch,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht);
+void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
+
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t ch_mult,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht);
+void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
+
+/************************** Activation functions *****************************/
+
+/**
+ * @brief relu6
+ *
+ * @note inout: int8_t
+ */
+void esp_nn_relu6_s8_ansi(int8_t *data, uint16_t size);
+
+/************************** Pooling functions *****************************/
+
+
+/**
+ * @brief max_pool
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_max_pool_s8_ansi(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ int8_t *output,
+ const uint16_t output_wd,
+ const uint16_t output_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const uint16_t channels);
+
+/**
+ * @brief avg_pool
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_avg_pool_s8_ansi(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ int8_t *output,
+ const uint16_t output_wd,
+ const uint16_t output_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const uint16_t channels);
+
+
+/************************** Fully connected functions ***********************/
+
+/**
+ * @brief fully connected
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_fully_connected_s8_ansi(const int8_t *input_data,
+ const int32_t input_offset,
+ const uint16_t row_len,
+ const int8_t *filter_data,
+ const int32_t filter_offset,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t out_shift,
+ const int32_t out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+/**
+ * @brief Get scratch buffer size needed by softmax function
+ *
+ * @param width
+ * @param height
+ * @return size in bytes
+ *
+ * @note buffer must be 4 byte aligned
+ */
+int32_t esp_nn_get_softmax_scratch_size_ansi(const int32_t width, const int32_t height);
+
+/* ANSI C function to be hooked up when optimised version needed */
+int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t height);
+
+/**
+ * @brief Set scratch buffer to be used by softmax function
+ *
+ * @param buffer this can be NULL if one needs to unset it
+ * must be aligned to 4 bytes
+ */
+void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
+
+/* ANSI C function to be hooked up when optimised version needed */
+void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
+
+/**
+ * @brief reference softmax function
+ *
+ * @note inputs type: int8_t, output: int8_t
+ */
+void esp_nn_softmax_s8_ansi(const int8_t *input_data,
+ const int32_t height,
+ const int32_t width,
+ const int32_t mult,
+ const int32_t shift,
+ const int32_t diff_min,
+ int8_t *output_data);
+
+/**
+ * @brief optimised version of softmax function
+ *
+ * @note the function uses extra buffer (4 * width bytes)
+ * hence, scratch buffers must be set before calling this.
+ */
+void esp_nn_softmax_s8_opt(const int8_t *input_data,
+ const int32_t height,
+ const int32_t width,
+ const int32_t mult,
+ const int32_t shift,
+ const int32_t diff_min,
+ int8_t *output_data);
diff --git a/code/components/esp-nn/include/esp_nn_esp32.h b/code/components/esp-nn/include/esp_nn_esp32.h
new file mode 100644
index 00000000..03fd8216
--- /dev/null
+++ b/code/components/esp-nn/include/esp_nn_esp32.h
@@ -0,0 +1,48 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/**
+ * @file Header definitions to include for esp_nn optimized functions for
+ * the ESP32 platform.
+ * We are hooking up just the C versions for now.
+ * The file hence is exactly same as `esp_nn_ansi_c.h`
+ */
+
+#pragma once
+
+#include "esp_nn_ansi_headers.h"
+
+#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
+#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
+
+#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
+
+#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
+
+#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
+#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
+
+#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
+#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
+
+#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
+
+#define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_ansi
+#define esp_nn_max_pool_s8 esp_nn_max_pool_s8_ansi
+
+#define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_ansi
+
+#define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_opt
+#define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_opt
+#define esp_nn_softmax_s8 esp_nn_softmax_s8_opt
diff --git a/code/components/esp-nn/include/esp_nn_esp32s3.h b/code/components/esp-nn/include/esp_nn_esp32s3.h
new file mode 100644
index 00000000..58b544e4
--- /dev/null
+++ b/code/components/esp-nn/include/esp_nn_esp32s3.h
@@ -0,0 +1,261 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/**
+ * @file Header definitions to include for esp_nn optimized functions for
+ * the ESP32-S3 platform
+ */
+
+#pragma once
+
+#include
+#include "esp_nn_ansi_headers.h"
+
+/************************** Basic math functions *****************************/
+
+
+/**
+ * @brief elementwise addition
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ *
+ * shift values are expected to be <= 0
+ */
+void esp_nn_add_elementwise_s8_esp32s3(const int8_t *input1_data,
+ const int8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ const int32_t input1_mult,
+ const int32_t input2_mult,
+ const int32_t input1_shift,
+ const int32_t input2_shift,
+ const int32_t left_shift,
+ int8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size);
+
+/**
+ * @brief elementwise multiplication
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ *
+ * output shift is expected to be <= 0
+ */
+void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
+ const int8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ int8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size);
+
+
+/************************** Convolution functions *****************************/
+
+/**
+ * @brief depthwise convolution per channel
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * Version used in tflite is per channel.
+ * This version follows the same footsprints.
+ * Meaning, it has per out_channel shift and multiplier for
+ * requantization
+ *
+ * optimization notes: Though input_offset is int32 type,
+ * offset values are contained in 8 bits [-128, 127]
+ */
+void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+/**
+ * @brief 2d - convolution channelwise
+ *
+ * @note operation: result += (input + offset) * filter
+ *
+ * inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_conv_s8_esp32s3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_ch,
+ const uint16_t out_ch,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht);
+void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
+
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t ch_mult,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht);
+void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
+
+/************************** Pooling functions *****************************/
+
+/**
+ * @brief max_pool
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_max_pool_s8_esp32s3(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ int8_t *output,
+ const uint16_t output_wd,
+ const uint16_t output_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const uint16_t channels);
+
+/**
+ * @brief avg_pool
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_avg_pool_s8_esp32s3(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ int8_t *output,
+ const uint16_t output_wd,
+ const uint16_t output_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const uint16_t channels);
+
+
+/************************** Fully connected functions *****************************/
+
+/**
+ * @brief fully connected
+ *
+ * @note inputs type: int8_t, output: int8_t
+ * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ *
+ * Current version works only on aligned input.
+ * row_len and channels should both be multiple of 8.
+ */
+void esp_nn_fully_connected_s8_esp32s3(const int8_t *input_data,
+ const int32_t input_offset,
+ const uint16_t row_len,
+ const int8_t *filter_data,
+ const int32_t filter_offset,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t out_shift,
+ const int32_t out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+/**
+ * @brief relu6
+ *
+ * @note inout: int8_t
+ */
+void esp_nn_relu6_s8_esp32s3(int8_t *data, uint16_t size);
+
+/********************** function defines ***************************/
+
+#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_esp32s3
+#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_esp32s3
+
+#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_esp32s3
+
+#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_esp32s3
+#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_esp32s3
+
+#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_esp32s3
+#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_esp32s3
+
+#define esp_nn_conv_s8 esp_nn_conv_s8_esp32s3
+
+#define esp_nn_relu6_s8 esp_nn_relu6_s8_esp32s3
+
+#define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_esp32s3
+#define esp_nn_max_pool_s8 esp_nn_max_pool_s8_esp32s3
+
+#define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_esp32s3
+
+#define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_opt
+#define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_opt
+#define esp_nn_softmax_s8 esp_nn_softmax_s8_opt
diff --git a/code/components/esp-nn/src/activation_functions/esp_nn_relu_ansi.c b/code/components/esp-nn/src/activation_functions/esp_nn_relu_ansi.c
new file mode 100644
index 00000000..1d4c3d11
--- /dev/null
+++ b/code/components/esp-nn/src/activation_functions/esp_nn_relu_ansi.c
@@ -0,0 +1,30 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+
+#include
+
+void esp_nn_relu6_s8_ansi(int8_t *data, uint16_t size)
+{
+ int32_t i;
+
+ for (i = 0; i < size; i++) {
+ int32_t ip = data[i];
+
+ ip = max(ip, 0);
+ data[i] = min(ip, 6);
+ }
+}
diff --git a/code/components/esp-nn/src/basic_math/esp_nn_add_ansi.c b/code/components/esp-nn/src/basic_math/esp_nn_add_ansi.c
new file mode 100644
index 00000000..617386cf
--- /dev/null
+++ b/code/components/esp-nn/src/basic_math/esp_nn_add_ansi.c
@@ -0,0 +1,97 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+void esp_nn_add_elementwise_u8_ansi(const uint8_t *input1_data,
+ const uint8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ const int32_t input1_mult,
+ const int32_t input2_mult,
+ const int32_t input1_shift,
+ const int32_t input2_shift,
+ const int32_t left_shift,
+ uint8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size)
+{
+ for (int i = 0; i < size; i++) {
+ int32_t tmp1 = input1_data[i] + input1_offset;
+ int32_t tmp2 = input2_data[i] + input2_offset;
+
+ tmp1 <<= left_shift;
+ tmp2 <<= left_shift;
+
+ tmp1 = esp_nn_sat_round_doubling_high_mul(tmp1, input1_mult);
+ tmp2 = esp_nn_sat_round_doubling_high_mul(tmp2, input2_mult);
+
+ tmp1 = esp_nn_div_by_power_of_two(tmp1, -input1_shift);
+ tmp2 = esp_nn_div_by_power_of_two(tmp2, -input2_shift);
+
+ int32_t out = tmp1 + tmp2;
+ out = esp_nn_sat_round_doubling_high_mul(out, out_mult);
+ out = esp_nn_div_by_power_of_two(out, -out_shift);
+ out = out + out_offset;
+
+ out = max(activation_min, min(out, activation_max));
+ output[i] = (uint8_t) out;
+ }
+}
+
+void esp_nn_add_elementwise_s8_ansi(const int8_t *input1_data,
+ const int8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ const int32_t input1_mult,
+ const int32_t input2_mult,
+ const int32_t input1_shift,
+ const int32_t input2_shift,
+ const int32_t left_shift,
+ int8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size)
+{
+ for (int i = 0; i < size; i++) {
+ int32_t tmp1 = input1_data[i] + input1_offset;
+ int32_t tmp2 = input2_data[i] + input2_offset;
+
+ tmp1 <<= left_shift;
+ tmp2 <<= left_shift;
+
+ tmp1 = esp_nn_sat_round_doubling_high_mul(tmp1, input1_mult);
+ tmp2 = esp_nn_sat_round_doubling_high_mul(tmp2, input2_mult);
+
+ tmp1 = esp_nn_div_by_power_of_two(tmp1, -input1_shift);
+ tmp2 = esp_nn_div_by_power_of_two(tmp2, -input2_shift);
+
+ int32_t out = tmp1 + tmp2;
+ out = esp_nn_sat_round_doubling_high_mul(out, out_mult);
+ out = esp_nn_div_by_power_of_two(out, -out_shift);
+ out = out + out_offset;
+
+ out = max(activation_min, min(out, activation_max));
+ output[i] = (int8_t) out;
+ }
+}
diff --git a/code/components/esp-nn/src/basic_math/esp_nn_mul_ansi.c b/code/components/esp-nn/src/basic_math/esp_nn_mul_ansi.c
new file mode 100644
index 00000000..db8e8cc0
--- /dev/null
+++ b/code/components/esp-nn/src/basic_math/esp_nn_mul_ansi.c
@@ -0,0 +1,42 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
+ const int8_t *input2_data,
+ const int32_t input1_offset,
+ const int32_t input2_offset,
+ int8_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const int32_t size)
+{
+ for (int i = 0; i < size; i++) {
+ int32_t tmp1 = input1_data[i] + input1_offset;
+ int32_t tmp2 = input2_data[i] + input2_offset;
+
+ int32_t out = tmp1 * tmp2;
+ out = esp_nn_multiply_by_quantized_mult(out, out_mult, out_shift);
+ out = out + out_offset;
+
+ out = max(activation_min, min(out, activation_max));
+ output[i] = (int8_t) out;
+ }
+}
diff --git a/code/components/esp-nn/src/common/common_functions.h b/code/components/esp-nn/src/common/common_functions.h
new file mode 100644
index 00000000..9a5f0dcc
--- /dev/null
+++ b/code/components/esp-nn/src/common/common_functions.h
@@ -0,0 +1,218 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+#include
+#include
+
+/**
+ * c99 standard still doesn't strictly inline functions
+ * We need to use attribute as well to do this.
+ */
+#define __NN_FORCE_INLINE__ __attribute((always_inline)) static inline
+
+/* min/max macros */
+#ifndef max
+#define max(a, b) ({ \
+ __typeof__ (a) _a = (a); \
+ __typeof__ (b) _b = (b); \
+ _a > _b ? _a : _b; \
+})
+
+#define min(a, b) ({ \
+ __typeof__ (a) _a = (a); \
+ __typeof__ (b) _b = (b); \
+ _a < _b ? _a : _b; \
+})
+#endif
+
+__NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
+{
+ __asm__ volatile("nsau %0, %0" : "+r" (in));
+ return in;
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
+{
+ int32_t sign = (int32_t) (val64 >> 63);
+ int32_t to_add = sign & ((1ul << 31) - 1);
+ return (int32_t) ((int64_t) (val64 + to_add) >> 31);
+}
+
+/**
+ * Signed saturate a 32 bit value to 8 bits keeping output in 32 bit variable.
+ */
+__NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
+{
+ __asm__ volatile("clamps %0, %0, 7" : "+a"(in));
+ return in;
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
+{
+ int32_t result;
+ int64_t in0_64 = (int64_t) in0;
+ bool overflow = (in0 == in1) && (in0 == (int32_t) INT32_MIN);
+
+ /* Nudge value */
+ int64_t nudge_val = 1 << 30;
+ if ((in0 < 0) ^ (in1 < 0)) {
+ nudge_val = 1 - nudge_val;
+ }
+
+ /* Multiply and add nudge */
+ int64_t mult = in0_64 * in1 + nudge_val;
+
+ /* Round and pickup 32 bits */
+ result = esp_nn_pick_sat_high32_of64(mult);
+
+ return overflow ? INT32_MAX : result;
+}
+
+/**
+ * fast version
+ * this will fail for values closer to INT32_MAX and INT32_MIN by `1 << (exponent - 1)`.
+ * We can afford to do this because we are at the very last stage of filter.
+ * Also it is pretty rare condition as our output is going to be 8 bit.
+ */
+__NN_FORCE_INLINE__ int32_t esp_nn_div_by_power_of_two_fast(int32_t val, int32_t exponent)
+{
+ int32_t to_add = (1 << (exponent - 1)) - (val < 0);
+ return (int32_t) ((val + to_add) >> exponent);
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_div_by_power_of_two(int32_t val, int32_t exponent)
+{
+ int32_t result;
+
+ const int32_t mask = (1 << exponent) - 1;
+ const int32_t remainder = val & mask;
+
+ result = val >> exponent;
+ int32_t threshold = (mask >> 1) + (result < 0);
+
+ if (remainder > threshold) {
+ result += 1;
+ }
+ return result;
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_multiply_by_quantized_mult(int32_t x, int32_t mult, int32_t shift)
+{
+ int32_t left_shift = shift > 0 ? shift : 0;
+ int32_t right_shift = shift > 0 ? 0 : -shift;
+ int32_t result = esp_nn_sat_round_doubling_high_mul(x * (1 << left_shift), mult);
+ return esp_nn_div_by_power_of_two(result, right_shift);
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_multiply_by_quantized_mult_fast(int32_t x, int32_t mult, int32_t shift)
+{
+ int32_t left_shift = max(shift, 0);
+ int32_t right_shift = left_shift - shift;
+
+ int64_t nudge_val = 1 << 30;
+ int64_t in0_64 = (int64_t) (x << left_shift);
+
+ /* Multiply and add nudge */
+ int64_t mult_64 = in0_64 * mult + nudge_val;
+ int32_t result = (int32_t) (mult_64 >> 31);
+ if (right_shift) {
+ result = esp_nn_div_by_power_of_two_fast(result, right_shift);
+ }
+ return result;
+}
+
+static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t pad_val,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht)
+{
+ /* memset with pad_val */
+ memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels * 2);
+ dst += (pad_wd + input_wd + pad_wd) * channels;
+
+ for (int i = 0; i < input_ht; i++) {
+ dst += pad_wd * channels;
+ for (int j = 0; j < input_wd * channels; j++) {
+ *dst++ = *src++;
+ }
+ dst += pad_wd * channels;
+ }
+}
+
+#if 0
+static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t pad_val,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht)
+{
+ for (int i = 0; i < input_ht; i++) {
+ for (int j = 0; j < input_wd * channels; j++) {
+ *dst++ = *src++;
+ }
+ memset(dst, pad_val, pad_wd * channels);
+ dst += pad_wd * channels;
+ }
+ /* pad end `pad_ht` lines at end */
+ memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+}
+#endif
+
+/**
+ * @brief convert 8 bit input data to 16 bit
+ *
+ * @param src int8_t source data
+ * @param dst int16_t dst data
+ * @param size length of data
+ * @param offset offset to be added to src data. Range: [-128, 127]
+ */
+__NN_FORCE_INLINE__ void esp_nn_s8_to_s16_with_offset(const int8_t *src, int16_t *dst,
+ const int size, const int32_t offset)
+{
+ int i = 0;
+ for (; i < size; i += 2) {
+ dst[i + 0] = src[i + 0] + offset;
+ dst[i + 1] = src[i + 1] + offset;
+ }
+ if(i < size) {
+ dst[i] = src[i] + offset;
+ }
+}
+
+/**
+ * @brief convert 8 bit input data to 16 bit
+ *
+ * @param src int8_t source data
+ * @param dst int16_t dst data
+ * @param size length of data
+ */
+__NN_FORCE_INLINE__ void esp_nn_s8_to_s16(const int8_t *src, int16_t *dst, const int size)
+{
+ int i = 0;
+ for (; i < size; i += 2) {
+ dst[i + 0] = src[i + 0];
+ dst[i + 1] = src[i + 1];
+ }
+ if(i < size) {
+ dst[i] = src[i];
+ }
+}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c b/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
new file mode 100644
index 00000000..d04f78e1
--- /dev/null
+++ b/code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
@@ -0,0 +1,175 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_ch,
+ const uint16_t out_ch,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht)
+{
+ return 0;
+}
+
+void esp_nn_set_conv_scratch_buf_ansi(const void *buf)
+{
+
+}
+
+/**
+ * Assumption 1: i/p channels == o/p channels
+ * Assumption 2: Pointers are valid
+ * Assumption 3: dialation width = 1
+ */
+void esp_nn_conv_u8_ansi(const uint8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t filter_offset,
+ const int32_t *bias,
+ uint8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t out_shift,
+ const int32_t out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+ const int16_t base_y = (out_y * stride_ht) - pad_ht;
+ for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+ const int16_t base_x = (out_x * stride_wd) - pad_wd;
+ for (int out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {//channel_loop
+ int32_t result = 0;
+
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ for (int in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+ int32_t input_index = (idx_y * input_wd + idx_x) * in_channels + in_ch_idx;
+ int32_t filter_index = ((out_ch_idx * filter_ht + filter_y_idx)
+ * filter_wd + filter_x_idx) * in_channels
+ + in_ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val = filter_data[filter_index] + filter_offset;
+ result += input_val * filter_val;
+ }
+ }
+ }
+ if (bias) {
+ result += bias[out_ch_idx];
+ }
+ result = esp_nn_multiply_by_quantized_mult(result, out_mult, out_shift);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ int out_index = (out_y * out_wd + out_x) * out_channels + out_ch_idx;
+ out_data[out_index] = (uint8_t) result;
+ }
+ }
+ }
+}
+
+/**
+ * Assumption 1: i/p channels == o/p channels
+ * Assumption 2: Pointers are valid
+ * Assumption 3: dialation width = 1
+ */
+void esp_nn_conv_s8_ansi(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
+
+ for (out_y = 0; out_y < out_ht; out_y++) {
+ for (out_x = 0; out_x < out_wd; out_x++) {
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+
+ const int32_t base_y = stride_ht * out_y - pad_ht;
+ const int32_t base_x = stride_wd * out_x - pad_wd;
+
+ const int32_t filter_y_start = max(0, -base_y);
+ const int32_t filter_x_start = max(0, -base_x);
+
+ const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+ const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t in_row = base_y + filter_y_idx;
+ const int32_t in_col = base_x + filter_x_idx;
+ int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
+ int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
+ (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
+ for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+ conv_out +=
+ (input_data[input_base_offset + in_ch_idx] + input_offset) *
+ filter_data[filter_base_offset + in_ch_idx];
+ }
+ }
+ }
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c b/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
new file mode 100644
index 00000000..ea8fdfa5
--- /dev/null
+++ b/code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
@@ -0,0 +1,436 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+
+#include
+
+static int16_t *scratch_buffer = NULL;
+
+extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const int16_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ void *buffer /* scratch buffer */);
+
+extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int16_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ void *buffer /* scratch buffer */);
+
+extern void esp_nn_conv_s16_mult8_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int16_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int16_t *dst,
+ const int size, const int32_t offset);
+
+extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
+
+static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
+
+ for (out_y = 0; out_y < out_ht; out_y++) {
+ for (out_x = 0; out_x < out_wd; out_x++) {
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+
+ const int32_t base_y = stride_ht * out_y - pad_ht;
+ const int32_t base_x = stride_wd * out_x - pad_wd;
+
+ const int32_t filter_y_start = max(0, -base_y);
+ const int32_t filter_x_start = max(0, -base_x);
+
+ const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+ const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t in_row = base_y + filter_y_idx;
+ const int32_t in_col = base_x + filter_x_idx;
+ int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
+ int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
+ (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
+ for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+ conv_out +=
+ (input_data[input_base_offset + in_ch_idx] + input_offset) *
+ filter_data[filter_base_offset + in_ch_idx];
+ }
+ }
+ }
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
+
+static void esp_nn_conv_s8_pad_valid(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
+
+ for (out_y = 0; out_y < out_ht; out_y++) {
+ for (out_x = 0; out_x < out_wd; out_x++) {
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+
+ const int32_t base_y = stride_ht * out_y;
+ const int32_t base_x = stride_wd * out_x;
+
+ for (filter_y_idx = 0; filter_y_idx < filter_ht; filter_y_idx++) {
+ for (filter_x_idx = 0; filter_x_idx < filter_wd; filter_x_idx++) {
+ const int32_t in_row = base_y + filter_y_idx;
+ const int32_t in_col = base_x + filter_x_idx;
+ int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
+ int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
+ (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
+ const int8_t *input_data_ptr = input_data + input_base_offset;
+ const int8_t *filter_data_ptr = filter_data + filter_base_offset;
+ for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+ conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
+ }
+ }
+ }
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
+
+static void esp_nn_conv_s8_pad_valid_3x3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_channels,
+ const int32_t input_offset,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
+
+ for (out_y = 0; out_y < out_ht; out_y++) {
+ for (out_x = 0; out_x < out_wd; out_x++) {
+ const int32_t base_y = stride_ht * out_y;
+ const int32_t base_x = stride_wd * out_x;
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+ for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
+ for (filter_x_idx = 0; filter_x_idx < 3; filter_x_idx++) {
+ const int32_t in_row = base_y + filter_y_idx;
+ const int32_t in_col = base_x + filter_x_idx;
+ int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
+ int32_t filter_base_offset = out_ch_idx * in_channels * 3 * 3 +
+ (filter_y_idx * 3 + filter_x_idx) * in_channels;
+ const int8_t *input_data_ptr = input_data + input_base_offset;
+ const int8_t *filter_data_ptr = filter_data + filter_base_offset;
+ for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+ conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
+ }
+ }
+ }
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
+
+static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const int32_t input_offset,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int32_t out_ch_idx, out_y, out_x, filter_y_idx;
+
+ /* use scratch_buffer to pre-compute offset factor */
+ int16_t *filter_sum = (int16_t *) scratch_buffer;
+ const int8_t *filter_ptr = filter_data;
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int16_t sum_val = 0;
+ for (int i = 0; i < 9; i++) {
+ sum_val += *filter_ptr++;
+ sum_val += *filter_ptr++;
+ sum_val += *filter_ptr++;
+ }
+ *filter_sum++ = sum_val;
+ }
+
+ for (out_y = 0; out_y < out_ht; out_y++) {
+ for (out_x = 0; out_x < out_wd; out_x++) {
+ const int8_t *filter_data_ptr = filter_data;
+ const int32_t base_y = stride_ht * out_y;
+ const int32_t base_x = stride_wd * out_x;
+ const int8_t *input_base_ptr = input_data + (base_y * input_wd + base_x) * 3;
+ int16_t *filter_sum = (int16_t *) scratch_buffer;
+ for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+ int32_t conv_out = 0;
+
+ for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
+ const int8_t *input_data_ptr = input_base_ptr + (filter_y_idx * input_wd) * 3;
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
+ }
+
+ conv_out += *filter_sum++ * input_offset;
+
+ if (bias) {
+ conv_out += bias[out_ch_idx];
+ }
+ conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ conv_out += out_offset;
+ conv_out = max(conv_out, activation_min);
+ conv_out = min(conv_out, activation_max);
+ *out_data++ = (int8_t) conv_out;
+ }
+ }
+ }
+}
+
+int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t in_ch,
+ const uint16_t out_ch,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht)
+{
+ int filter_size = filter_wd * filter_ht * in_ch * out_ch;
+ int input_size = input_wd * input_ht * in_ch;
+ int transpose_buf_size = 8 * in_ch; /* to store intermediate data */
+ int align_buf_size = 32; /* extra buffer for alignment */
+ return 2 * (filter_size + input_size + transpose_buf_size) + align_buf_size;
+}
+
+void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
+{
+ scratch_buffer = (int16_t *) buf;
+}
+
+void esp_nn_conv_s8_esp32s3(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int filter_size = filter_wd * filter_ht * channels * out_channels;
+ int input_size = input_wd * input_ht * channels;
+ int align_len = 16 - (filter_size & 15);
+ int16_t *filter_data16 = scratch_buffer;
+ int16_t *input_data16 = scratch_buffer + filter_size + align_len;
+
+ if (scratch_buffer == NULL) {
+ printf("esp_nn_conv error! scratch_buffer not set!\n");
+ return;
+ }
+
+ if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
+ pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
+ int scratch_offset = (int) (filter_data16 + filter_size);
+ void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
+ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
+ esp_nn_conv_s16_mult8_1x1_esp32s3(
+ input, input_wd, input_ht, channels, input_offset, filter_data16,
+ bias, out_data, out_wd, out_ht, out_channels, out_offset,
+ out_shift, out_mult, activation_min, activation_max, scratch_buf);
+ } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
+ (input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */
+ pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
+ int scratch_offset = (int) (input_data16 + input_size);
+ void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
+ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
+ esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
+ esp_nn_conv_s16_mult4_1x1_esp32s3(
+ input_data16, input_wd, input_ht, channels, filter_data16,
+ bias, out_data, out_wd, out_ht, out_channels, out_offset,
+ out_shift, out_mult, activation_min, activation_max, scratch_buf);
+ } else if (channels % 8 == 0) {
+ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
+ esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
+ esp_nn_conv_s16_mult8_esp32s3(
+ input_data16, input_wd, input_ht, channels, pad_wd, pad_ht,
+ stride_wd, stride_ht, filter_data16, filter_wd, filter_ht, bias,
+ out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ } else if (pad_wd == 0 && pad_ht == 0) {
+ if (filter_wd == 3 && filter_ht == 3 && channels == 3) {
+ esp_nn_conv_s8_pad_valid_ch3_3x3(input, input_wd, input_ht, input_offset,
+ stride_wd, stride_ht, filter_data, bias,
+ out_data, out_wd, out_ht, out_channels, out_offset,
+ out_shift, out_mult, activation_min, activation_max);
+ } else {
+ esp_nn_conv_s8_pad_valid(input, input_wd, input_ht, channels, input_offset,
+ stride_wd, stride_ht, filter_data, filter_wd, filter_ht, bias,
+ out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ }
+ } else {
+ /* Basic unrolled version */
+ esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht,
+ filter_data, filter_wd, filter_ht, bias,
+ out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ }
+}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
new file mode 100644
index 00000000..9cac6cef
--- /dev/null
+++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
@@ -0,0 +1,97 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t ch_mult,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht)
+{
+ return 0;
+}
+
+void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf)
+{
+
+}
+
+void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int out_idx = 0;
+ for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+ const int16_t base_y = (out_y * stride_ht) - pad_ht;
+ for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+ const int16_t base_x = (out_x * stride_wd) - pad_wd;
+ for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
+ for (int ch_mult_idx = 0; ch_mult_idx < ch_mult; ch_mult_idx++) {
+ int32_t result = 0;
+ const int out_ch_idx = ch_mult_idx + ch_idx * ch_mult;
+
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val = filter_data[filter_index];
+ result += input_val * filter_val;
+ }
+ }
+ if (bias) {
+ result += bias[out_ch_idx];
+ }
+ result = esp_nn_multiply_by_quantized_mult(result, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ out_data[out_idx++] = result;
+ }
+ }
+ }
+ }
+}
diff --git a/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
new file mode 100644
index 00000000..c588c48f
--- /dev/null
+++ b/code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
@@ -0,0 +1,483 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+
+#include
+
+static int16_t *scratch_buffer = NULL;
+
+extern void esp_nn_depthwise_conv_s16_mult8_3x3_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int16_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_depthwise_conv_s16_mult1_3x3_no_pad_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int16_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_depthwise_conv_s16_mult8_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int16_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_depthwise_conv_s16_mult4_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int16_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int16_t *filter_data,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_depthwise_conv_s16_mult1_esp32s3(const int16_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int16_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max);
+
+extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
+
+extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int16_t *dst,
+ const int size, const int32_t offset);
+
+static void esp_nn_depthwise_conv_s8_unrolled(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int out_idx = 0;
+ for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+ const int16_t base_y = (out_y * stride_ht) - pad_ht;
+ for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+ const int16_t base_x = (out_x * stride_wd) - pad_wd;
+ for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
+ int ch_mult_idx = 0;
+ for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
+ int32_t result0 = 0, result1 = 0, result2 = 0, result3 = 0;
+ const int out_ch_idx = ch_mult_idx + ch_idx * ch_mult;
+
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val0 = filter_data[filter_index + 0];
+ int32_t filter_val1 = filter_data[filter_index + 1];
+ int32_t filter_val2 = filter_data[filter_index + 2];
+ int32_t filter_val3 = filter_data[filter_index + 3];
+ result0 += input_val * filter_val0;
+ result1 += input_val * filter_val1;
+ result2 += input_val * filter_val2;
+ result3 += input_val * filter_val3;
+ }
+ }
+ if (bias) {
+ result0 += bias[out_ch_idx + 0];
+ result1 += bias[out_ch_idx + 1];
+ result2 += bias[out_ch_idx + 2];
+ result3 += bias[out_ch_idx + 3];
+ }
+ result0 = esp_nn_multiply_by_quantized_mult(result0,
+ out_mult[out_ch_idx + 0], out_shift[out_ch_idx + 0]);
+ result1 = esp_nn_multiply_by_quantized_mult(result1,
+ out_mult[out_ch_idx + 1], out_shift[out_ch_idx + 1]);
+ result2 = esp_nn_multiply_by_quantized_mult(result2,
+ out_mult[out_ch_idx + 2], out_shift[out_ch_idx + 2]);
+ result3 = esp_nn_multiply_by_quantized_mult(result3,
+ out_mult[out_ch_idx + 3], out_shift[out_ch_idx + 3]);
+
+ result0 += out_offset;
+ result1 += out_offset;
+ result2 += out_offset;
+ result3 += out_offset;
+
+ result0 = max(result0, activation_min);
+ result1 = max(result1, activation_min);
+ result2 = max(result2, activation_min);
+ result3 = max(result3, activation_min);
+
+ result0 = min(result0, activation_max);
+ result1 = min(result1, activation_max);
+ result2 = min(result2, activation_max);
+ result3 = min(result3, activation_max);
+
+ out_data[out_idx++] = result0;
+ out_data[out_idx++] = result1;
+ out_data[out_idx++] = result2;
+ out_data[out_idx++] = result3;
+ }
+
+ /* left-over */
+ for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
+ int32_t result = 0;
+ const int out_ch_idx = ch_mult_idx + ch_idx * ch_mult;
+
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val = filter_data[filter_index];
+ result += input_val * filter_val;
+ }
+ }
+ if (bias) {
+ result += bias[out_ch_idx];
+ }
+ result = esp_nn_multiply_by_quantized_mult(result, out_mult[out_ch_idx], out_shift[out_ch_idx]);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ out_data[out_idx++] = result;
+ }
+ }
+ }
+ }
+}
+
+void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int out_idx = 0;
+ for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+ const int16_t base_y = (out_y * stride_ht) - pad_ht;
+ for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+ const int16_t base_x = (out_x * stride_wd) - pad_wd;
+ for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
+ int32_t result = 0;
+ /* Select filter so as the point doesn't lie outside block */
+ int filter_y_start = max(0, -base_y);
+ int filter_x_start = max(0, -base_x);
+ int filter_y_end = min(filter_ht, input_ht - base_y);
+ int filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+ const int32_t idx_y = base_y + filter_y_idx;
+ for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+ const int32_t idx_x = base_x + filter_x_idx;
+ int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+ int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * channels + ch_idx;
+ int32_t input_val = input_data[input_index] + input_offset;
+ int32_t filter_val = filter_data[filter_index];
+ result += input_val * filter_val;
+ }
+ }
+ if (bias) {
+ result += bias[ch_idx];
+ }
+ result = esp_nn_multiply_by_quantized_mult(result, out_mult[ch_idx], out_shift[ch_idx]);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ out_data[out_idx++] = result;
+ }
+ }
+ }
+}
+
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const uint16_t ch_mult,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht)
+{
+ int filter_size = filter_wd * filter_ht * channels * ch_mult;
+ int padding_used = ((filter_wd == 3) && (filter_ht == 3)) * 2;
+ int input_size = (input_wd + padding_used) * (input_ht + padding_used) * channels;
+ return 2 * (filter_size + input_size) + 16; //16 for alignment
+}
+
+void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
+{
+ scratch_buffer = (int16_t *) buf;
+}
+
+/**
+ * Assumption 1: i/p channels == o/p channels
+ * Assumption 2: Pointers are valid
+ * Assumption 3: dialation width = 1
+ */
+void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ const uint16_t channels,
+ const int32_t input_offset,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t ch_mult,
+ const int8_t *filter_data,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_wd,
+ const uint16_t out_ht,
+ const int32_t out_offset,
+ const int32_t *out_shift,
+ const int32_t *out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ int filter_size = filter_wd * filter_ht * channels * ch_mult;
+ int align_len = 16 - (filter_size & 15);
+ int input_size = input_wd * input_ht * channels;
+ int16_t *filter_data16 = scratch_buffer;
+ int16_t *input_data16 = scratch_buffer + filter_size + align_len;
+ if (scratch_buffer == NULL) {
+ printf("esp_nn_depthwise_conv error! scratch_buffer not set!\n");
+ return;
+ }
+
+ if ((ch_mult == 1) && (channels % 8 == 0)) {
+ if ((filter_wd == 3) && (filter_ht == 3)) {
+ if ((channels % 16 == 0) && (pad_wd == 1) && (pad_ht == 1)) {
+ /* process in 8 bits */
+ int8_t *filter_aligned = (int8_t *) scratch_buffer;
+ int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
+ memcpy(filter_aligned, filter_data, filter_size);
+ esp_nn_aligned_s8_pad_with_value(input_data, input_padded, input_wd, input_ht, channels,
+ -input_offset, pad_wd, pad_ht);
+ esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + 2 * pad_wd,
+ input_ht + 2 * pad_ht, channels, input_offset,
+ stride_wd, stride_ht, filter_aligned, bias,
+ out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ } else if ((pad_wd == 0) && (pad_ht == 0) &&
+ // because this does not handle padding offset cases yet, run just for stride (1, 1).
+ // end padding of input with `-input_offset` should solve this
+ (stride_wd == 1) && (stride_ht == 1)) {
+ /* process in 8 bits */
+ int8_t *filter_aligned = (int8_t *) scratch_buffer;
+ memcpy(filter_aligned, filter_data, filter_size);
+ esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_data, input_wd, input_ht, channels, input_offset,
+ stride_wd, stride_ht, filter_aligned,
+ bias, out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ } else { /* (channels % 8) == 0 && pad_wd == 1 && pad_ht == 1 */
+ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
+ esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
+ esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
+ pad_wd, pad_ht, stride_wd, stride_ht, filter_data16,
+ bias, out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ }
+ } else { // all other ch_mult == 1, `channels % 8 == 0`
+ esp_nn_depthwise_conv_s8_ch_mult1(input_data, input_wd, input_ht, channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht,
+ filter_data, filter_wd, filter_ht,
+ bias, out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ }
+ } else if (ch_mult % 8 == 0) {
+ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
+ esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
+ if (filter_wd == 3 && filter_ht == 3) {
+ esp_nn_depthwise_conv_s16_mult8_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
+ pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
+ filter_data16, bias,
+ out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ } else {
+ esp_nn_depthwise_conv_s16_mult8_esp32s3(input_data16, input_wd, input_ht, channels,
+ pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
+ filter_data16, filter_wd, filter_ht, bias,
+ out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ }
+ } else if (ch_mult % 4 == 0) {
+ esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
+ esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
+ esp_nn_depthwise_conv_s16_mult4_esp32s3(input_data16, input_wd, input_ht, channels,
+ pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
+ filter_data16, filter_wd, filter_ht, bias,
+ out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ } else {
+ esp_nn_depthwise_conv_s8_unrolled(input_data, input_wd, input_ht, channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
+ filter_data, filter_wd, filter_ht,
+ bias, out_data, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+ }
+}
diff --git a/code/components/esp-nn/src/fully_connected/esp_nn_fully_connected_ansi.c b/code/components/esp-nn/src/fully_connected/esp_nn_fully_connected_ansi.c
new file mode 100644
index 00000000..6d800bc5
--- /dev/null
+++ b/code/components/esp-nn/src/fully_connected/esp_nn_fully_connected_ansi.c
@@ -0,0 +1,50 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+void esp_nn_fully_connected_s8_ansi(const int8_t *input_data,
+ const int32_t input_offset,
+ const uint16_t row_len,
+ const int8_t *filter_data,
+ const int32_t filter_offset,
+ const int32_t *bias,
+ int8_t *out_data,
+ const uint16_t out_channels,
+ const int32_t out_offset,
+ const int32_t out_shift,
+ const int32_t out_mult,
+ const int32_t activation_min,
+ const int32_t activation_max)
+{
+ for (int32_t out_c = 0; out_c < out_channels; ++out_c) {
+ int32_t result = 0;
+ for (int32_t data_idx = 0; data_idx < row_len; data_idx++) {
+ int32_t filter_index = row_len * out_c + data_idx;
+ int32_t input_val = input_data[data_idx];
+ int32_t filter_val = filter_data[filter_index];
+ result += (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ if (bias) {
+ result += bias[out_c];
+ }
+ result = esp_nn_multiply_by_quantized_mult(result, out_mult, out_shift);
+ result += out_offset;
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+ out_data[out_c] = (int8_t) result;
+ }
+}
diff --git a/code/components/esp-nn/src/pooling/esp_nn_avg_pool_ansi.c b/code/components/esp-nn/src/pooling/esp_nn_avg_pool_ansi.c
new file mode 100644
index 00000000..03846aa0
--- /dev/null
+++ b/code/components/esp-nn/src/pooling/esp_nn_avg_pool_ansi.c
@@ -0,0 +1,72 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+void esp_nn_avg_pool_s8_ansi(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ int8_t *output,
+ const uint16_t output_wd,
+ const uint16_t output_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const uint16_t channels)
+{
+ int32_t base_y = -pad_ht;
+ for (int32_t out_y = 0; out_y < output_ht; out_y++, base_y += stride_ht) {
+ int32_t base_x = -pad_wd;
+ for (int32_t out_x = 0; out_x < output_wd; out_x++, base_x += stride_wd) {
+ for (int32_t ch_idx = 0; ch_idx < channels; ch_idx++) {
+ int32_t result = 0;
+ int32_t filter_cnt = 0;
+ /* Make sure filter does not cross the input box */
+ int32_t filter_y_start = max(0, -base_y);
+ int32_t filter_x_start = max(0, -base_x);
+
+ int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+ int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int32_t filter_y = filter_y_start; filter_y < filter_y_end; filter_y++) {
+ for (int32_t filter_x = filter_x_start; filter_x < filter_x_end; filter_x++) {
+ int32_t in_x_idx = base_x + filter_x;
+ int32_t in_y_idx = base_y + filter_y;
+ int32_t input_index = (in_y_idx * input_wd + in_x_idx) * channels + ch_idx;
+ result += input[input_index];
+ filter_cnt++;
+ }
+ }
+
+ /* Rounded average */
+ result = result > 0 ? (result + filter_cnt / 2) / filter_cnt
+ : (result - filter_cnt / 2) / filter_cnt;
+
+ /* Activation function */
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ int32_t output_index = (out_y * output_wd + out_x) * channels + ch_idx;
+ output[output_index] = (int8_t) result;
+ }
+ }
+ }
+}
diff --git a/code/components/esp-nn/src/pooling/esp_nn_max_pool_ansi.c b/code/components/esp-nn/src/pooling/esp_nn_max_pool_ansi.c
new file mode 100644
index 00000000..4ca5c42d
--- /dev/null
+++ b/code/components/esp-nn/src/pooling/esp_nn_max_pool_ansi.c
@@ -0,0 +1,66 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include
+
+void esp_nn_max_pool_s8_ansi(const int8_t *input,
+ const uint16_t input_wd,
+ const uint16_t input_ht,
+ int8_t *output,
+ const uint16_t output_wd,
+ const uint16_t output_ht,
+ const uint16_t stride_wd,
+ const uint16_t stride_ht,
+ const uint16_t filter_wd,
+ const uint16_t filter_ht,
+ const uint16_t pad_wd,
+ const uint16_t pad_ht,
+ const int32_t activation_min,
+ const int32_t activation_max,
+ const uint16_t channels)
+{
+ int32_t base_y = -pad_ht;
+ for (int32_t out_y = 0; out_y < output_ht; out_y++, base_y += stride_ht) {
+ int32_t base_x = -pad_wd;
+ for (int32_t out_x = 0; out_x < output_wd; out_x++, base_x += stride_wd) {
+ /* Make sure filter does not cross the input box */
+ int32_t filter_y_start = max(0, -base_y);
+ int32_t filter_x_start = max(0, -base_x);
+ int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+ int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+ for (int32_t ch_idx = 0; ch_idx < channels; ch_idx++) {
+ int8_t result = INT8_MIN;
+
+ for (int32_t filter_y = filter_y_start; filter_y < filter_y_end; filter_y++) {
+ for (int32_t filter_x = filter_x_start; filter_x < filter_x_end; filter_x++) {
+ int32_t in_x_idx = base_x + filter_x;
+ int32_t in_y_idx = base_y + filter_y;
+ int32_t input_index = (in_y_idx * input_wd + in_x_idx) * channels + ch_idx;
+ result = max(input[input_index], result);
+ }
+ }
+
+ /* Activation function */
+ result = max(result, activation_min);
+ result = min(result, activation_max);
+
+ int32_t output_index = (out_y * output_wd + out_x) * channels + ch_idx;
+ output[output_index] = result;
+ }
+ }
+ }
+}
diff --git a/code/components/esp-nn/src/softmax/esp_nn_softmax_ansi.c b/code/components/esp-nn/src/softmax/esp_nn_softmax_ansi.c
new file mode 100644
index 00000000..d71a8616
--- /dev/null
+++ b/code/components/esp-nn/src/softmax/esp_nn_softmax_ansi.c
@@ -0,0 +1,88 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "softmax_common.h"
+
+int32_t esp_nn_get_softmax_scratch_size_ansi(const int32_t width, const int32_t height)
+{
+ (void) width;
+ (void) height;
+ return 0;
+}
+
+void esp_nn_set_softmax_scratch_buf_ansi(void *buffer)
+{
+ (void) buffer;
+ return;
+}
+
+void esp_nn_softmax_s8_ansi(const int8_t *input_data,
+ const int32_t height,
+ const int32_t width,
+ const int32_t mult,
+ const int32_t shift,
+ const int32_t diff_min,
+ int8_t *output_data)
+{
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input mult, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+#define ACCUM_BITS 12
+#define DIFF_BITS 5
+
+ const int32_t mask = (1 << shift);
+ int32_t col = 0;
+ const int8_t *in_ptr = input_data;
+ int8_t *out_ptr = output_data;
+
+ for (int row_idx = 0; row_idx < height; row_idx++) {
+ int8_t max_in_row = in_ptr[0];
+ for (col = 1; col < width; col++) {
+ max_in_row = max(max_in_row, in_ptr[col]);
+ }
+
+ int32_t input_diff = 0;
+ int32_t sum_of_exps = 0;
+
+ for (col = 0; col < width; col++) {
+ input_diff = in_ptr[col] - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32_t input_diff_rescaled = SAT_HIGH_MUL(input_diff * mask, mult);
+ const int32_t exp_raw = esp_nn_exp_on_negative_values(input_diff_rescaled);
+ sum_of_exps += DIV_POW2(exp_raw, ACCUM_BITS);
+ }
+ }
+
+ const int32_t headroom_plus1 = esp_nn_clz32((uint32_t) sum_of_exps);
+ const int32_t shifted_scale = ONE_OVER_ONE_X((sum_of_exps << headroom_plus1) - (1 << 31));
+ const int32_t bits_over_unit = ACCUM_BITS - headroom_plus1 + 31 - sizeof(int8_t) * 8;
+
+ for (col = 0; col < width; col++) {
+ input_diff = in_ptr[col] - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32_t input_diff_rescaled = SAT_HIGH_MUL(input_diff * mask, mult);
+ const int32_t exp_raw = esp_nn_exp_on_negative_values(input_diff_rescaled);
+ const int32_t shifted_output = SAT_HIGH_MUL(shifted_scale, exp_raw);
+ const int32_t result = DIV_POW2(shifted_output, bits_over_unit) - 128;
+ out_ptr[col] = (int8_t) esp_nn_saturate8(result);
+ } else {
+ out_ptr[col] = -128;
+ }
+ }
+ in_ptr += width;
+ out_ptr += width;
+ }
+}
diff --git a/code/components/esp-nn/src/softmax/esp_nn_softmax_opt.c b/code/components/esp-nn/src/softmax/esp_nn_softmax_opt.c
new file mode 100644
index 00000000..93337d32
--- /dev/null
+++ b/code/components/esp-nn/src/softmax/esp_nn_softmax_opt.c
@@ -0,0 +1,108 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "softmax_common.h"
+#include
+
+static int32_t *scratch_buf = NULL;
+
+/**
+ * @brief Get scratch buffer size needed by softmax function
+ *
+ * @param width
+ * @param height
+ * @return size in bytes
+ *
+ * @note buffer must be 4 byte aligned
+ */
+int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t height)
+{
+ (void) height;
+ return width * 4;
+}
+
+/**
+ * @brief Set scratch buffer to be used by softmax function
+ *
+ * @param buffer this can be NULL if one needs to unset it
+ * must be aligned to 4 bytes
+ */
+void esp_nn_set_softmax_scratch_buf_opt(void *buffer)
+{
+ scratch_buf = (int32_t *) buffer;
+}
+
+void esp_nn_softmax_s8_opt(const int8_t *input_data,
+ const int32_t height,
+ const int32_t width,
+ const int32_t mult,
+ const int32_t shift,
+ const int32_t diff_min,
+ int8_t *output_data)
+{
+ if (scratch_buf == NULL) {
+ printf("%s error! scratch buffer not set\n", __FUNCTION__);
+ return;
+ }
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input mult, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+#define ACCUM_BITS 12
+#define DIFF_BITS 5
+
+ const int32_t mask = (1 << shift);
+ int32_t col = 0;
+ const int8_t *in_ptr = input_data;
+ int8_t *out_ptr = output_data;
+
+ for (int row_idx = 0; row_idx < height; row_idx++) {
+ int8_t max_in_row = in_ptr[0];
+ for (col = 1; col < width; col++) {
+ max_in_row = max(max_in_row, in_ptr[col]);
+ }
+
+ int32_t input_diff = 0;
+ int32_t sum_of_exps = 0;
+
+ for (col = 0; col < width; col++) {
+ input_diff = in_ptr[col] - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32_t input_diff_rescaled = SAT_HIGH_MUL(input_diff * mask, mult);
+ const int32_t exp_raw = esp_nn_exp_on_negative_values(input_diff_rescaled);
+ scratch_buf[col] = exp_raw; // store to avoid duplicate calculation later
+ sum_of_exps += DIV_POW2(exp_raw, ACCUM_BITS);
+ }
+ }
+
+ const int32_t headroom_plus1 = esp_nn_clz32((uint32_t) sum_of_exps);
+ const int32_t shifted_scale = ONE_OVER_ONE_X((sum_of_exps << headroom_plus1) - (1 << 31));
+ const int32_t bits_over_unit = ACCUM_BITS - headroom_plus1 + 31 - sizeof(int8_t) * 8;
+
+ for (col = 0; col < width; col++) {
+ input_diff = in_ptr[col] - max_in_row;
+ if (input_diff >= diff_min) {
+ int32_t exp_raw = scratch_buf[col];
+ const int32_t shifted_output = SAT_HIGH_MUL(shifted_scale, exp_raw);
+ const int32_t result = DIV_POW2(shifted_output, bits_over_unit) - 128;
+ out_ptr[col] = (int8_t) esp_nn_saturate8(result);
+ } else {
+ out_ptr[col] = -128;
+ }
+ }
+ in_ptr += width;
+ out_ptr += width;
+ }
+}
diff --git a/code/components/esp-nn/src/softmax/softmax_common.h b/code/components/esp-nn/src/softmax/softmax_common.h
new file mode 100644
index 00000000..254d6ace
--- /dev/null
+++ b/code/components/esp-nn/src/softmax/softmax_common.h
@@ -0,0 +1,104 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+
+#define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
+#define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
+#define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
+#define SAT_HIGH_MUL(x, y) esp_nn_sat_round_doubling_high_mul((x), (y))
+#define DIV_POW2(x,y) esp_nn_div_by_power_of_two((x), (y))
+
+__NN_FORCE_INLINE__ int32_t mul_power_of_2(int val, int exp)
+{
+ const int32_t thresh = ((1 << (31 - exp)) - 1);
+ int32_t result = val << exp;
+ result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), INT32_MAX, result);
+ result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), INT32_MIN, result);
+ return result;
+}
+
+/**
+ * @brief Calculate `1 / (1 + x)` for x in [0, 1]
+ *
+ * @param val input value to calculate `1/(1+x)` for
+ * @return `int32_t` result
+ * @note Newton-Raphson division
+ *
+ * https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
+ * Refer to that page for the logic behind the 48/17 and 32/17 constants.
+ * Pseudocode: https://en.wikipedia.org/wiki/Division_algorithm#Pseudocode
+ */
+__NN_FORCE_INLINE__ int32_t esp_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
+{
+ const int64_t sum = (int64_t) val + INT32_MAX;
+ const int32_t half_denominator = (int32_t) ((sum + (sum >= 0 ? 1 : -1)) / 2L);
+ int32_t constant_48_over_17 = 1515870810;
+ int32_t constant_neg_32_over_17 = -1010580540;
+ int32_t x = constant_48_over_17 + SAT_HIGH_MUL(half_denominator, constant_neg_32_over_17);
+ const int32_t fixed_2_one = (1 << 29);
+
+ x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
+ x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
+ x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
+
+ return mul_power_of_2(x, 1);
+}
+
+#define ONE_OVER_ONE_X(x) esp_nn_one_over_one_plus_x_for_x_in_0_1((x))
+
+/**
+ * @brief Return exp(x) for x < 0.
+ *
+ */
+__NN_FORCE_INLINE__ int32_t esp_nn_exp_on_negative_values(int32_t val)
+{
+ int32_t shift = 24;
+
+ const int32_t one_quarter = (1 << shift);
+ int32_t mask = one_quarter - 1;
+ const int32_t val_mod_minus_quarter = (val & mask) - one_quarter;
+ const int32_t remainder = val_mod_minus_quarter - val;
+
+ // calculate exponent for x in [-1/4, 0) in `result`
+ const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
+ const int32_t x2 = SAT_HIGH_MUL(x, x);
+ const int32_t x3 = SAT_HIGH_MUL(x2, x);
+ const int32_t x4 = SAT_HIGH_MUL(x2, x2);
+ const int32_t one_over_3 = 715827883;
+ const int32_t one_over_8 = 1895147668;
+
+ const int32_t x4_over_4 = DIV_POW2(x4, 2);
+ const int32_t x4_over_4_plus_x3_over_6_plus_x2_over_2 = DIV_POW2(SAT_HIGH_MUL(x4_over_4 + x3, one_over_3) + x2, 1);
+ int32_t result = one_over_8 + SAT_HIGH_MUL(one_over_8, x + x4_over_4_plus_x3_over_6_plus_x2_over_2);
+
+#define SELECT_IF_NON_ZERO(x) { \
+ mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
+ result = SELECT_USING_MASK(mask, SAT_HIGH_MUL(result, x), result); \
+}
+
+ SELECT_IF_NON_ZERO(1672461947)
+ SELECT_IF_NON_ZERO(1302514674)
+ SELECT_IF_NON_ZERO(790015084)
+ SELECT_IF_NON_ZERO(290630308)
+ SELECT_IF_NON_ZERO(39332535)
+ SELECT_IF_NON_ZERO(720401)
+ SELECT_IF_NON_ZERO(242)
+
+#undef SELECT_IF_NON_ZERO
+
+ mask = MASK_IF_ZERO(val);
+ return SELECT_USING_MASK(mask, INT32_MAX, result);
+}
\ No newline at end of file
diff --git a/code/components/esp-nn/test_app/CMakeLists.txt b/code/components/esp-nn/test_app/CMakeLists.txt
new file mode 100644
index 00000000..8d332768
--- /dev/null
+++ b/code/components/esp-nn/test_app/CMakeLists.txt
@@ -0,0 +1,9 @@
+# The following lines of boilerplate have to be in your project's
+# CMakeLists in this exact order for cmake to work correctly
+cmake_minimum_required(VERSION 3.5)
+
+set(EXTRA_COMPONENT_DIRS "../" "../tests/")
+set(IDF_EXCLUDE_COMPONENTS test test_app)
+
+include($ENV{IDF_PATH}/tools/cmake/project.cmake)
+project(test_app)
diff --git a/code/components/esp-nn/test_app/main/CMakeLists.txt b/code/components/esp-nn/test_app/main/CMakeLists.txt
new file mode 100644
index 00000000..04161254
--- /dev/null
+++ b/code/components/esp-nn/test_app/main/CMakeLists.txt
@@ -0,0 +1,7 @@
+
+set(COMPONENT_SRCS "main.c")
+set(COMPONENT_ADD_INCLUDEDIRS "")
+
+set(COMPONENT_PRIV_REQUIRES tests)
+
+register_component()
diff --git a/code/components/esp-nn/test_app/main/component.mk b/code/components/esp-nn/test_app/main/component.mk
new file mode 100644
index 00000000..5d85ad38
--- /dev/null
+++ b/code/components/esp-nn/test_app/main/component.mk
@@ -0,0 +1,8 @@
+#
+# Main component makefile.
+#
+# This Makefile can be left empty. By default, it will take the sources in the
+# src/ directory, compile them and link them into lib(subdirectory_name).a
+# in the build directory. This behaviour is entirely configurable,
+# please read the ESP-IDF documents if you need to do this.
+#
diff --git a/code/components/esp-nn/test_app/main/main.c b/code/components/esp-nn/test_app/main/main.c
new file mode 100644
index 00000000..267e35f2
--- /dev/null
+++ b/code/components/esp-nn/test_app/main/main.c
@@ -0,0 +1,87 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+static const char *TAG = "test_app";
+static uint32_t start_c, start_opt, total_c, total_opt;
+
+void profile_c_start()
+{
+ /* initiate profiling */
+ start_c = esp_cpu_get_ccount();
+}
+
+void profile_c_end()
+{
+ /* record profile number */
+ total_c = esp_cpu_get_ccount() - start_c;
+}
+
+void profile_opt_start()
+{
+ /* initiate profiling */
+ start_opt = esp_cpu_get_ccount();
+}
+
+void profile_opt_end()
+{
+ /* record profile number */
+ total_opt = esp_cpu_get_ccount() - start_opt;
+}
+
+void app_main()
+{
+ /* s8 tests */
+ ESP_LOGI(TAG, "Running s8 tests...");
+ esp_nn_add_elementwise_s8_test();
+ printf("add, c %u opt %u\n", total_c, total_opt);
+ esp_nn_mul_elementwise_s8_test();
+ printf("mul, c %u opt %u\n", total_c, total_opt);
+ esp_nn_depthwise_conv_s8_test();
+ printf("depthwise, c %u opt %u\n", total_c, total_opt);
+ esp_nn_conv_s8_test();
+ printf("conv2d, c %u opt %u\n", total_c, total_opt);
+
+ esp_nn_relu6_s8_test();
+ printf("relu, c %u opt %u\n", total_c, total_opt);
+ esp_nn_avg_pool_s8_test();
+ printf("avg_pool, c %u opt %u\n", total_c, total_opt);
+ esp_nn_max_pool_s8_test();
+ printf("max_pool, c %u opt %u\n", total_c, total_opt);
+ esp_nn_fully_connected_s8_test();
+ printf("fully_connected, c %u opt %u\n", total_c, total_opt);
+ esp_nn_softmax_s8_test();
+ printf("softmax, c %u opt %u\n", total_c, total_opt);
+ ESP_LOGI(TAG, "s8 tests done!\n");
+
+ /* u8 tests */
+ //ESP_LOGI(TAG, "Running u8 tests...");
+ //esp_nn_add_elementwise_u8_test();
+ //esp_nn_depthwise_conv_u8_test();
+ //esp_nn_conv_u8_test();
+ //esp_nn_avg_pool_u8_test();
+ //esp_nn_max_pool_u8_test();
+ //esp_nn_fully_connected_u8_test();
+ //ESP_LOGI(TAG, "u8 tests done!\n");
+}
diff --git a/code/components/esp-nn/test_app/sdkconfig.defaults b/code/components/esp-nn/test_app/sdkconfig.defaults
new file mode 100644
index 00000000..bb37aac5
--- /dev/null
+++ b/code/components/esp-nn/test_app/sdkconfig.defaults
@@ -0,0 +1,5 @@
+
+#
+# esp-nn
+#
+CONFIG_NN_ESP32=y
diff --git a/code/components/esp-nn/tests/CMakeLists.txt b/code/components/esp-nn/tests/CMakeLists.txt
new file mode 100644
index 00000000..97ec946f
--- /dev/null
+++ b/code/components/esp-nn/tests/CMakeLists.txt
@@ -0,0 +1,15 @@
+
+set(COMPONENT_ADD_INCLUDEDIRS ./include/)
+set(COMPONENT_SRCS "src/basic_math_test.c"
+ "src/convolution_test.c"
+ "src/fully_connected_test.c"
+ "src/pooling_test.c"
+ "src/relu_test.c"
+ "src/softmax_test.c")
+
+set(COMPONENT_REQUIRES )
+set(COMPONENT_PRIV_REQUIRES esp-nn)
+
+register_component()
+
+target_compile_options(${COMPONENT_LIB} PRIVATE -Wno-unused-function)
diff --git a/code/components/esp-nn/tests/README.md b/code/components/esp-nn/tests/README.md
new file mode 100644
index 00000000..41c94235
--- /dev/null
+++ b/code/components/esp-nn/tests/README.md
@@ -0,0 +1,4 @@
+# Tests for esp_nn library
+
+- Include these in your test framework and run the framework.
+- For IDF test please refer `test_app`
diff --git a/code/components/esp-nn/tests/component.mk b/code/components/esp-nn/tests/component.mk
new file mode 100644
index 00000000..2860f3ff
--- /dev/null
+++ b/code/components/esp-nn/tests/component.mk
@@ -0,0 +1,5 @@
+#FIXME
+
+COMPONENT_ADD_INCLUDEDIRS := include/
+
+COMPONENT_SRCDIRS := src/
diff --git a/code/components/esp-nn/tests/include/test_functions.h b/code/components/esp-nn/tests/include/test_functions.h
new file mode 100644
index 00000000..3e882efa
--- /dev/null
+++ b/code/components/esp-nn/tests/include/test_functions.h
@@ -0,0 +1,48 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+/* int8_t ops tests */
+void esp_nn_add_elementwise_s8_test();
+void esp_nn_mul_elementwise_s8_test();
+
+void esp_nn_depthwise_conv_s8_test();
+void esp_nn_conv_s8_test();
+
+void esp_nn_avg_pool_s8_test();
+void esp_nn_max_pool_s8_test();
+
+void esp_nn_fully_connected_s8_test();
+
+void esp_nn_relu6_s8_test();
+
+void esp_nn_softmax_s8_test();
+
+/* uint8_t ops tests */
+void esp_nn_add_elementwise_u8_test();
+
+void esp_nn_depthwise_conv_u8_test();
+void esp_nn_conv_u8_test();
+
+void esp_nn_avg_pool_u8_test();
+void esp_nn_max_pool_u8_test();
+
+void esp_nn_fully_connected_u8_test();
+
+/* instructions test functions */
+void compare_instructions_test();
+void arith_instructions_test();
+void min_max_instructions_test();
+void bitwise_instructions_test();
+void load_store_instructions_test();
diff --git a/code/components/esp-nn/tests/include/test_utils.h b/code/components/esp-nn/tests/include/test_utils.h
new file mode 100644
index 00000000..a152549b
--- /dev/null
+++ b/code/components/esp-nn/tests/include/test_utils.h
@@ -0,0 +1,87 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+
+/* mult value range */
+#define MULT_MAX INT32_MAX
+#define MULT_MIN 0
+
+/* shift value range */
+#define SHIFT_MIN -31
+#define SHIFT_MAX 30
+
+/**
+ * @brief callback function to run before C function
+ */
+void profile_c_start();
+
+/**
+ * @brief callback function to run after C function
+ */
+void profile_c_end();
+
+/**
+ * @brief callback function to run before optimized function
+ */
+void profile_opt_start();
+
+/**
+ * @brief callback function to run after optimized function
+ */
+void profile_opt_end();
+
+#define ANSI_COLOR_RED "\x1b[31m"
+#define ANSI_COLOR_GREEN "\x1b[32m"
+#define ANSI_COLOR_YELLOW "\x1b[33m"
+#define ANSI_COLOR_BLUE "\x1b[34m"
+#define ANSI_COLOR_MAGENTA "\x1b[35m"
+#define ANSI_COLOR_CYAN "\x1b[36m"
+#define ANSI_COLOR_RESET "\x1b[0m"
+
+#define CHECK_EQUAL(ARRAY1, ARRAY2, size) ({ \
+ bool res = true; \
+ for (int _i = 0; _i < size; _i++) { \
+ if (ARRAY1[_i] != ARRAY2[_i]) { \
+ res = false; \
+ break; \
+ } \
+ } \
+ res; \
+})
+
+#define PRINT_ARRAY_INT(ARRAY, width, height) ({ \
+ int *_array = (int *) ARRAY; \
+ for (int _j = 0; _j < height; _j++) { \
+ for (int _i = 0; _i < width; _i++) { \
+ printf("%d\t", _array[width * _j + _i]); \
+ } \
+ printf("\n"); \
+ } \
+ printf("\n"); \
+})
+
+#define PRINT_ARRAY_HEX(ARRAY, width, height) ({ \
+ uint8_t *_array = (uint8_t *) ARRAY; \
+ for (int _j = 0; _j < height; _j++) { \
+ for (int _i = 0; _i < width; _i++) { \
+ printf("%02x\t", _array[width * _j + _i]); \
+ } \
+ printf("\n"); \
+ } \
+ printf("\n"); \
+})
diff --git a/code/components/esp-nn/tests/src/basic_math_test.c b/code/components/esp-nn/tests/src/basic_math_test.c
new file mode 100644
index 00000000..5b96b990
--- /dev/null
+++ b/code/components/esp-nn/tests/src/basic_math_test.c
@@ -0,0 +1,343 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include "test_utils.h"
+
+#if CONFIG_IDF_CMAKE
+#define IDF_HEAP_CAPS 1
+
+#if IDF_HEAP_CAPS
+#include "esp_heap_caps.h"
+#endif
+#endif
+
+void esp_nn_add_elementwise_s8_test()
+{
+ /* prepare data */
+ const int size = 1600 + 8 + 7; /* odd len to test leftover */
+ int8_t *input1;
+ int8_t *input2;
+ int8_t *out_data_c;
+ int8_t *out_data_opt;
+ int8_t *input1_orig = NULL;
+ int8_t *input2_orig = NULL;
+ int8_t *out_c_orig = NULL;
+ int8_t *out_opt_orig = NULL;
+ int32_t input1_offset = 34;
+ int32_t input2_offset = 35;
+ int32_t output_offset = 36;
+ int32_t input1_shift = -8; // right_shift amt always <= 0
+ int32_t input2_shift = -8; // right_shift amt always <= 0
+ int32_t output_shift = -9; // right_shift amt always <= 0
+ int32_t left_shift = 15; // always +ve
+ int32_t input1_mult = INT32_MAX;
+ int32_t input2_mult = INT32_MAX;
+ int32_t output_mult = INT32_MAX;
+ int32_t activation_min = -128;
+ int32_t activation_max = 127;
+
+ for (int itr = 0; itr < 10; itr++) {
+ switch (itr) {
+ case 0: // all zeros
+ input1_offset = 0;
+ input2_offset = 0;
+ output_offset = 0;
+ input1_mult = 0;
+ input2_mult = 0;
+ output_mult = 0;
+ input1_shift = 0;
+ input2_shift = 0;
+ output_shift = 0;
+ left_shift = 0;
+ break;
+ case 1: // hit min
+ input1_offset = -127;
+ input2_offset = -127;
+ output_offset = -128;
+ input1_mult = MULT_MIN;
+ input2_mult = MULT_MIN;
+ output_mult = MULT_MIN;
+ input1_shift = 0;
+ input2_shift = 0;
+ output_shift = 0;
+ left_shift = 0;
+ break;
+ case 2: // hit max
+ input1_offset = 128;
+ input2_offset = 128;
+ output_offset = -127;
+ input1_mult = MULT_MAX;
+ input2_mult = MULT_MAX;
+ output_mult = MULT_MAX;
+ input1_shift = SHIFT_MIN;
+ input2_shift = SHIFT_MIN;
+ output_shift = SHIFT_MIN;
+ left_shift = 30 - 8; // since input is 8 bits
+ break;
+ case 3: // hit extreme max
+ input1_offset = 128;
+ input2_offset = 128;
+ output_offset = -127;
+ input1_mult = MULT_MAX;
+ input2_mult = MULT_MAX;
+ output_mult = MULT_MAX;
+ input1_shift = 0;
+ input2_shift = 0;
+ output_shift = 0;
+ left_shift = 30 - 8; // -8 since input is 8 bit
+ break;
+ default: // practical random input
+ input1_offset = rand() % 256 - 127; // range [-127, 128]
+ input2_offset = rand() % 256 - 127; // range [-127, 128]
+ output_offset = rand() % 256 - 128; // range [-128, 127]
+ input1_mult = MULT_MAX / 2 + rand() % INT16_MAX;
+ input2_mult = MULT_MAX / 2 + rand() % INT16_MAX;
+ output_mult = MULT_MAX / 2 + rand() % INT16_MAX;
+ input1_shift = -8 + rand() % 4;
+ input2_shift = -8 + rand() % 4;
+ output_shift = -8 + rand() % 4;
+ left_shift = rand() % 15;
+ }
+#if IDF_HEAP_CAPS
+ input1_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ input2_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ out_c_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ out_opt_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+
+ input1 = 16 + input1_orig - ((uint32_t) input1_orig & 0xf);
+ input2 = 16 + input2_orig - ((uint32_t) input2_orig & 0xf);
+ out_data_c = 16 + out_c_orig - ((uint32_t) out_c_orig & 0xf);
+ out_data_opt = 16 + out_opt_orig - ((uint32_t) out_opt_orig & 0xf);
+#else
+ input1 = memalign(16, size);
+ input2 = memalign(16, size);
+ out_data_c = memalign(16, size);
+ out_data_opt = memalign(16, size);
+
+ input1_orig = input1;
+ input2_orig = input2;
+ out_c_orig = out_data_c;
+ out_opt_orig = out_data_opt;
+#endif
+
+ for (int i = 0; i < size; ++i) {
+ input1[i] = rand() % 256 - 128;
+ input2[i] = rand() % 256 - 128;
+ }
+
+ if (itr == 0) {
+ /* enable profiler */
+ profile_c_start();
+ }
+ /* C function */
+ esp_nn_add_elementwise_s8_ansi(input1, input2, input1_offset, input2_offset,
+ input1_mult, input2_mult, input1_shift, input2_shift,
+ left_shift, out_data_c, output_offset, output_mult,
+ output_shift, activation_min, activation_max, size);
+
+ if (itr == 0) {
+ profile_c_end();
+ profile_opt_start();
+ }
+
+ /* Optimized function */
+ esp_nn_add_elementwise_s8(input1, input2, input1_offset, input2_offset,
+ input1_mult, input2_mult, input1_shift, input2_shift,
+ left_shift, out_data_opt, output_offset, output_mult,
+ output_shift, activation_min, activation_max, size);
+ if (itr == 0) {
+ /* disable profiler */
+ profile_opt_end();
+ }
+
+ bool ret = CHECK_EQUAL(out_data_c, out_data_opt, size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(out_data_opt, size, 1);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(out_data_c, size, 1);
+ printf("Input1:\n");
+ PRINT_ARRAY_HEX(input1, size, 1);
+ printf("Input2:\n");
+ PRINT_ARRAY_HEX(input2, size, 1);
+ printf("in1_shift %d, in2_shift %d, left_shift %d, out_shift %d\n",
+ input1_shift, input2_shift, left_shift, output_shift);
+ printf("in1_mult %d, in2_mult %d, out_mult %d\n", input1_mult, input2_mult, output_mult);
+ goto elementwise_add_test_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+
+elementwise_add_test_cleanup:
+ if (input1_orig) {
+ free(input1_orig);
+ }
+ if (input2_orig) {
+ free(input2_orig);
+ }
+ if (out_data_c) {
+ free(out_c_orig);
+ }
+ if (out_data_opt) {
+ free(out_opt_orig);
+ }
+ }
+}
+
+void esp_nn_mul_elementwise_s8_test()
+{
+ /* prepare data */
+ const int size = 1600 + 8 + 7; /* odd len to test leftover */
+ int8_t *input1;
+ int8_t *input2;
+ int8_t *out_data_c;
+ int8_t *out_data_opt;
+ int32_t input1_offset = 34;
+ int32_t input2_offset = 35;
+ int32_t output_offset = 36;
+ int32_t output_shift = -7;
+ int32_t output_mult = MULT_MAX; // max out_mult
+ int32_t activation_min = -128;
+ int32_t activation_max = 127;
+ int8_t *input1_orig = NULL;
+ int8_t *input2_orig = NULL;
+ int8_t *out_c_orig = NULL;
+ int8_t *out_opt_orig = NULL;
+
+ for (int itr = 0; itr < 10; itr++) {
+ switch (itr) {
+ case 0: // all zeros
+ input1_offset = 0;
+ input2_offset = 0;
+ output_offset = 0;
+ output_mult = 0;
+ output_shift = 0;
+ break;
+ case 1: // hit min
+ input1_offset = -127;
+ input2_offset = -127;
+ output_offset = -128;
+ output_mult = MULT_MIN;
+ output_shift = 0;
+ break;
+ case 2: // hit max
+ input1_offset = 128;
+ input2_offset = 128;
+ output_offset = -127;
+ output_mult = MULT_MAX;
+ output_shift = SHIFT_MIN;
+ break;
+ case 3: // hit extreme max
+ input1_offset = 128;
+ input2_offset = 128;
+ output_offset = -127;
+ output_mult = MULT_MAX;
+ output_shift = 0;
+ break;
+ default: // practical random input
+ input1_offset = rand() % 256 - 127; // range [-127, 128]
+ input2_offset = rand() % 256 - 127; // range [-127, 128]
+ output_offset = rand() % 256 - 128; // range [-128, 127]
+ output_mult = MULT_MAX / 2 + rand() % INT16_MAX;
+ output_shift = -8 + rand() % 4;
+ }
+
+#if IDF_HEAP_CAPS
+ input1_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ input2_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ out_c_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ out_opt_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+
+ input1 = 16 + input1_orig - ((uint32_t) input1_orig & 0xf);
+ input2 = 16 + input2_orig - ((uint32_t) input2_orig & 0xf);
+ out_data_c = 16 + out_c_orig - ((uint32_t) out_c_orig & 0xf);
+ out_data_opt = 16 + out_opt_orig - ((uint32_t) out_opt_orig & 0xf);
+#else
+ input1 = memalign(16, size);
+ input2 = memalign(16, size);
+ out_data_c = memalign(16, size);
+ out_data_opt = memalign(16, size);
+
+ input1_orig = input1;
+ input2_orig = input2;
+ out_c_orig = out_data_c;
+ out_opt_orig = out_data_opt;
+#endif
+
+ for (int i = 0; i < size; ++i) {
+ input1[i] = rand() % 256 - 128;
+ input2[i] = rand() % 256 - 128;
+ }
+
+ if (itr == 0) {
+ /* enable profiler */
+ profile_c_start();
+ }
+ /* C function */
+ esp_nn_mul_elementwise_s8_ansi(input1, input2, input1_offset, input2_offset,
+ out_data_c, output_offset, output_mult, output_shift,
+ activation_min, activation_max, size);
+
+ if (itr == 0) {
+ profile_c_end();
+ profile_opt_start();
+ }
+ /* Optimized function */
+ esp_nn_mul_elementwise_s8(input1, input2, input1_offset, input2_offset,
+ out_data_opt, output_offset, output_mult, output_shift,
+ activation_min, activation_max, size);
+
+ if (itr == 0) {
+ /* disable profiler */
+ profile_opt_end();
+ }
+
+ bool ret = CHECK_EQUAL(out_data_c, out_data_opt, size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(out_data_opt, size, 1);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(out_data_c, size, 1);
+ printf("Input1:\n");
+ PRINT_ARRAY_HEX(input1, size, 1);
+ printf("Input2:\n");
+ PRINT_ARRAY_HEX(input2, size, 1);
+ goto elementwise_mult_test_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+
+elementwise_mult_test_cleanup:
+ if (input1_orig) {
+ free(input1_orig);
+ }
+ if (input2_orig) {
+ free(input2_orig);
+ }
+ if (out_data_c) {
+ free(out_c_orig);
+ }
+ if (out_data_opt) {
+ free(out_opt_orig);
+ }
+ }
+}
diff --git a/code/components/esp-nn/tests/src/convolution_test.c b/code/components/esp-nn/tests/src/convolution_test.c
new file mode 100644
index 00000000..f3802257
--- /dev/null
+++ b/code/components/esp-nn/tests/src/convolution_test.c
@@ -0,0 +1,571 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include "test_utils.h"
+
+#if CONFIG_IDF_CMAKE
+#define IDF_HEAP_CAPS 1
+
+#if IDF_HEAP_CAPS
+#include "esp_heap_caps.h"
+#endif
+#endif
+
+void esp_nn_depthwise_conv_s8_test()
+{
+ int8_t *input = NULL, *filter_data = NULL, *out_data_c = NULL, *out_data_opt = NULL;
+ int32_t *bias = NULL;
+ int32_t input_offset = 5; /* some number in [-128, 127] */
+ int32_t out_offset = 7;
+ int32_t activation_min = -125;
+ int32_t activation_max = 120;
+ void *scratch_buf = NULL;
+
+ /* independent variables */
+ int input_wd, input_ht, channels;
+ uint16_t filter_ht, filter_wd, ch_mult;
+ uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
+
+ // run for 10 iterations
+ for (int itr = 0; itr < 10; itr++) {
+ /* prepare data */
+ switch (itr) {
+ case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
+ input_wd = 18;
+ input_ht = 18;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 16;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 1: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (1,1)
+ input_wd = 10;
+ input_ht = 10;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 16;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 2: // (ch_mult 1, (channels % 8) = 0), filter (3,3), pad (1,1)
+ input_wd = 10;
+ input_ht = 10;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 24;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 3: // other filter sizes (ch_mult 1, (channels % 8) = 0)
+ input_wd = 10;
+ input_ht = 10;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 24;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 4: // other filter sizes (ch_mult 8 = 0)
+ input_wd = 6;
+ input_ht = 6;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 8;
+ channels = 4;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 5: // other filter sizes (ch_mult 8 = 0)
+ input_wd = 12;
+ input_ht = 12;
+ filter_ht = 5;
+ filter_wd = 5;
+ ch_mult = 8;
+ channels = 4;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 6: // other filter sizes (ch_mult 4 = 0)
+ input_wd = 6;
+ input_ht = 6;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 4;
+ channels = 4;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 7: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0) stride (2,2)
+ input_wd = 6;
+ input_ht = 6;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 1;
+ channels = 16;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 2;
+ stride_ht = 2;
+ break;
+ default:
+ input_wd = 4;
+ input_ht = 4;
+ filter_ht = 3;
+ filter_wd = 3;
+ ch_mult = 4;
+ channels = 4;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ }
+
+ uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
+ uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
+ int in_size = input_wd * input_ht * channels;
+ int out_size = out_wd * out_ht * channels * ch_mult;
+ int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
+ int bias_size = channels * ch_mult + 1;
+ int32_t out_shift[channels * ch_mult];
+ int32_t out_mult[channels * ch_mult];
+
+#if IDF_HEAP_CAPS
+ int8_t *input_orig = (int8_t *) heap_caps_malloc(in_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ int8_t *out_c_orig = (int8_t *) heap_caps_malloc(out_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ int8_t *out_opt_orig = (int8_t *) heap_caps_malloc(out_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ filter_data = (int8_t *) heap_caps_malloc(filter_size, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ bias = (int32_t *) heap_caps_malloc(bias_size * 4, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+
+ input = 16 + input_orig - ((uint32_t) input_orig & 0xf);
+ out_data_c = 16 + out_c_orig - ((uint32_t) out_c_orig & 0xf);
+ out_data_opt = 16 + out_opt_orig - ((uint32_t) out_opt_orig & 0xf);
+#else
+ input = memalign(16, in_size + 16);
+ filter_data = memalign(16, filter_size);
+ out_data_c = memalign(16, out_size + 16);
+ out_data_opt = memalign(16, out_size + 16);
+ bias = memalign(16, bias_size * 4);
+ int8_t *input_orig = input;
+ int8_t *out_c_orig = out_data_c;
+ int8_t *out_opt_orig = out_data_opt;
+#endif
+ if (bias == NULL || input == NULL || filter_data == NULL ||
+ out_data_c == NULL || out_data_opt == NULL || bias == NULL) {
+ printf(ANSI_COLOR_RED"%s[%d] allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ goto dc_s8_cleanup;
+ }
+
+ /* Generate input data */
+ for (int i = 0; i < in_size; ++i) {
+ input[i] = rand() % 128;
+ }
+
+ /* Generate filter data */
+ for (int i = 0; i < filter_size; ++i) {
+ filter_data[i] = rand() % 256 - 128;
+ }
+
+ /* Generate bias data */
+ for (int i = 0; i < channels * ch_mult; ++i) {
+ bias[i + 1] = rand() % INT16_MAX; //0th index left for unalignment
+ out_shift[i] = -8 + rand() % 3;
+ out_mult[i] = 0x7eb0e200 + rand() % 50;
+ }
+
+ int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(input_wd, input_ht,
+ channels, ch_mult,
+ filter_wd, filter_ht);
+ if (scratch_buf_size > 0) {
+#if IDF_HEAP_CAPS
+ scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ int align_sz = 16 - (((int32_t) scratch_buf) & 0xf);
+#else
+ scratch_buf = memalign(16, scratch_buf_size);
+ int align_sz = 0;
+#endif
+ if (scratch_buf == NULL) {
+ printf(ANSI_COLOR_RED"%s[%d] scratch_buf alloc failed size %d\n"ANSI_COLOR_RESET,
+ __FUNCTION__, itr, scratch_buf_size);
+ goto dc_s8_cleanup;
+ }
+ esp_nn_set_depthwise_conv_scratch_buf(scratch_buf + align_sz);
+ }
+ if (itr == 0) {
+ /* enable profiler */
+ profile_c_start();
+ }
+
+ /* C function */
+ esp_nn_depthwise_conv_s8_ansi(input, input_wd, input_ht, channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
+ filter_data + 4, filter_wd, filter_ht,
+ bias + 1, out_data_c, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+
+ if (itr == 0) {
+ profile_c_end();
+ profile_opt_start();
+ }
+
+ /* Optimized function */
+ esp_nn_depthwise_conv_s8(input, input_wd, input_ht, channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
+ filter_data + 4, filter_wd, filter_ht,
+ bias + 1, out_data_opt, out_wd, out_ht, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+
+ if (itr == 0) {
+ /* disable profiler */
+ profile_opt_end();
+ }
+
+ bool ret = CHECK_EQUAL(out_data_c, out_data_opt, out_size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(out_data_opt, out_size / out_ht, out_ht);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(out_data_c, out_size / out_ht, out_ht);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, in_size / input_ht, input_ht);
+ printf("Filter data:\n");
+ PRINT_ARRAY_HEX(filter_data + 4, (filter_size - 4) / filter_ht, filter_ht);
+ printf("bias data:\n");
+ PRINT_ARRAY_INT(bias + 1, ch_mult * channels, 1);
+ goto dc_s8_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+
+ dc_s8_cleanup:
+ if (input) {
+ free(input_orig);
+ }
+ if (filter_data) {
+ free(filter_data);
+ }
+ if (out_data_c) {
+ free(out_c_orig);
+ }
+ if (out_data_opt) {
+ free(out_opt_orig);
+ }
+ if (bias) {
+ free(bias);
+ }
+ if (scratch_buf) {
+ free(scratch_buf);
+ }
+ }
+}
+
+void esp_nn_conv_s8_test()
+{
+ const int32_t input_offset = 5; /* some number in [-128, 127] */
+ const int32_t activation_min = -125;
+ const int32_t activation_max = 122;
+ const int32_t out_offset = 3;
+
+ void *scratch_buf = NULL;
+ int8_t *input_orig;
+ int8_t *out_c_orig;
+ int8_t *out_opt_orig;
+ int8_t *filter_data;
+ int32_t *bias;
+
+ /* independent variable */
+ int in_wd, in_ht, in_channels, out_channels;
+ uint16_t filter_ht, filter_wd;
+ uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
+
+ // run for 10 iterations
+ for (int itr = 0; itr < 10; itr++) {
+ switch (itr) {
+ case 0: // ch % 8 == 0 && filter (1,1), padding (0,0)
+ in_wd = 10;
+ in_ht = 10;
+ in_channels = 64;
+ out_channels = 64;
+ filter_ht = 1;
+ filter_wd = 1;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 1: // ch % 4 == 0 && (in_wd * in_ht) % 16 == 0
+ in_wd = 4;
+ in_ht = 4;
+ in_channels = 20;
+ out_channels = 8;
+ filter_ht = 1;
+ filter_wd = 1;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 2: // ch, filter (3x3x3)
+ in_wd = 10;
+ in_ht = 10;
+ in_channels = 3;
+ out_channels = 64;
+ filter_ht = 3;
+ filter_wd = 3;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 3: // remaining pad (0, 0)
+ in_wd = 10;
+ in_ht = 10;
+ in_channels = 3;
+ out_channels = 64;
+ filter_ht = 1;
+ filter_wd = 1;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 4: // unopt case
+ in_wd = 10;
+ in_ht = 10;
+ in_channels = 12;
+ out_channels = 64;
+ filter_ht = 3;
+ filter_wd = 3;
+ pad_wd = 1;
+ pad_ht = 1;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ case 5: // ch % 8 == 0 & stride (2,2)
+ in_wd = 16;
+ in_ht = 16;
+ in_channels = 16;
+ out_channels = 16;
+ filter_ht = 1;
+ filter_wd = 1;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 2;
+ stride_ht = 2;
+ break;
+ case 6: // ch % 8 == 0 && filter (1,1), padding (0,0)
+ in_wd = 2;
+ in_ht = 2;
+ in_channels = 8;
+ out_channels = 8;
+ filter_ht = 1;
+ filter_wd = 1;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ default: // ch % 8 == 0
+ in_wd = 8;
+ in_ht = 8;
+ in_channels = 16;
+ out_channels = 16;
+ filter_ht = 1;
+ filter_wd = 1;
+ pad_wd = 0;
+ pad_ht = 0;
+ stride_wd = 1;
+ stride_ht = 1;
+ break;
+ }
+
+ /* prepare data */
+ uint16_t out_wd = (in_wd - filter_wd + 1) / stride_wd;
+ uint16_t out_ht = (in_ht - filter_ht + 1) / stride_ht;
+
+ int in_size = in_wd * in_ht * in_channels;
+ int filter_size = filter_wd * filter_ht * in_channels * out_channels + 2;
+ int out_size = out_wd * out_ht * out_channels;
+
+#if IDF_HEAP_CAPS
+ input_orig = (int8_t *) heap_caps_malloc(in_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ out_c_orig = (int8_t *) heap_caps_malloc(out_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ out_opt_orig = (int8_t *) heap_caps_malloc(out_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ filter_data = (int8_t *) heap_caps_malloc(filter_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ bias = (int32_t *) heap_caps_malloc(128 + sizeof (int32_t) * out_channels, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+
+ int8_t *input = 16 + input_orig - ((uint32_t) input_orig & 0xf);
+ int8_t *out_data_c = 16 + out_c_orig - ((uint32_t) out_c_orig & 0xf);
+ int8_t *out_data_opt = 16 + out_opt_orig - ((uint32_t) out_opt_orig & 0xf);
+#else
+ int8_t *input = memalign(16, in_size);
+ int8_t *out_data_c = memalign(16, out_size);
+ int8_t *out_data_opt = memalign(16, out_size);
+ filter_data = memalign(16, filter_size);
+ bias = calloc(1, 128 + sizeof (int32_t) * out_channels);
+ input_orig = input;
+ out_c_orig = out_data_c;
+ out_opt_orig = out_data_opt;
+#endif
+ int32_t *out_shift = calloc(1, 128 + sizeof (int32_t) * out_channels);
+ int32_t *out_mult = calloc(1, 128 + sizeof (int32_t) * out_channels);
+
+ if (input == NULL || filter_data == NULL ||
+ out_data_c == NULL || out_data_opt == NULL) {
+ printf(ANSI_COLOR_RED"%s allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto conv_s8_cleanup;
+ }
+
+ if (bias == NULL || out_shift == NULL || out_mult == NULL) {
+ printf(ANSI_COLOR_RED"%s allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto conv_s8_cleanup;
+ }
+
+ /* Generate input data between -128 -> +127 */
+ for (int i = 0; i < in_size; ++i) {
+ input[i] = rand() % 255 - 128;
+ }
+
+ /* Generate filter data between -128 -> +127 */
+ for (int i = 0; i < filter_size; ++i) {
+ filter_data[i] = rand() % 256 - 128;
+ }
+
+ /* Generate bias data */
+ for (int i = 0; i < out_channels; ++i) {
+ bias[i] = (int32_t)rand() % UINT16_MAX + UINT8_MAX;
+ }
+
+ /* Shift and multiplier */
+ for (int i = 0; i < out_channels; ++i) {
+ out_shift[i] = -10 + rand() % 2;
+ out_mult[i] = 0x7f67f4f8 + rand() % 50;
+ }
+
+ int scratch_buf_size = esp_nn_get_conv_scratch_size(in_wd, in_ht, in_channels,
+ out_channels, filter_wd, filter_ht);
+ if (scratch_buf_size > 0) {
+#if IDF_HEAP_CAPS
+ void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
+ int align_sz = 16 - (((int32_t) scratch_buf) & 0xf);
+#else
+ void *scratch_buf = memalign(16, scratch_buf_size);
+ int align_sz = 0;
+#endif
+ if (scratch_buf == NULL) {
+ printf(ANSI_COLOR_RED"%s scratch_buf alloc failed size %d\n"ANSI_COLOR_RESET, __FUNCTION__, scratch_buf_size);
+ goto conv_s8_cleanup;
+ }
+ esp_nn_set_conv_scratch_buf(scratch_buf + align_sz);
+ }
+
+ if (itr == 0) {
+ /* enable profiler */
+ profile_c_start();
+ }
+
+ /* C function */
+ esp_nn_conv_s8_ansi(input, in_wd, in_ht, in_channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht,
+ filter_data + 2, filter_wd, filter_ht, bias,
+ out_data_c, out_wd, out_ht, out_channels, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+
+ if (itr == 0) {
+ profile_c_end();
+ profile_opt_start();
+ }
+
+ /* Optimized function */
+ esp_nn_conv_s8(input, in_wd, in_ht, in_channels, input_offset,
+ pad_wd, pad_ht, stride_wd, stride_ht,
+ filter_data + 2, filter_wd, filter_ht, bias,
+ out_data_opt, out_wd, out_ht, out_channels, out_offset, out_shift,
+ out_mult, activation_min, activation_max);
+
+ if (itr == 0) {
+ /* disable profiler */
+ profile_opt_end();
+ }
+
+ bool ret = CHECK_EQUAL(out_data_c, out_data_opt, out_size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(out_data_opt, out_size / out_ht, out_ht);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(out_data_c, out_size / out_ht, out_ht);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, in_size / in_ht, in_ht);
+ printf("Filter data:\n");
+ PRINT_ARRAY_HEX(filter_data + 2, (filter_size - 2) / filter_ht, filter_ht);
+ printf("bias data:\n");
+ PRINT_ARRAY_INT(bias, out_channels, 1);
+ goto conv_s8_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+
+ conv_s8_cleanup:
+ if (input) {
+ free(input_orig);
+ }
+ if (filter_data) {
+ free(filter_data);
+ }
+ if (out_data_c) {
+ free(out_c_orig);
+ }
+ if (out_data_opt) {
+ free(out_opt_orig);
+ }
+ if (bias) {
+ free(bias);
+ }
+ if (out_shift) {
+ free(out_shift);
+ }
+ if (out_mult) {
+ free(out_mult);
+ }
+ if (scratch_buf) {
+ free(scratch_buf);
+ }
+ }
+}
diff --git a/code/components/esp-nn/tests/src/fully_connected_test.c b/code/components/esp-nn/tests/src/fully_connected_test.c
new file mode 100644
index 00000000..d0210b46
--- /dev/null
+++ b/code/components/esp-nn/tests/src/fully_connected_test.c
@@ -0,0 +1,111 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+
+#include
+#include "test_utils.h"
+
+
+void esp_nn_fully_connected_s8_test()
+{
+ /* prepare data */
+ static uint16_t row_len = 256 + 8 + 7; /* odd len to test unaligned+left-over */
+ static uint16_t out_channels = 3;
+ int8_t input[row_len];
+ int8_t filter_data[row_len * out_channels];
+ int8_t output_c[out_channels], output_opt[out_channels];
+ static int32_t activation_min = -128;
+ static int32_t activation_max = 127;
+ static int32_t input_offset = 0;
+ static int32_t filter_offset = 0;
+ int32_t out_shift = -10;
+ static int32_t out_offset = 127;
+ int32_t out_mult = 0x59e492c4;
+ for (int itr = 0; itr < 5; itr++) {
+ out_mult = INT32_MAX / row_len + rand() % INT16_MAX;
+ switch (itr) {
+ case 0:
+ out_shift = -10;
+ break;
+ case 1:
+ out_shift = SHIFT_MIN;
+ break;
+ case 2:
+ out_shift = SHIFT_MAX;
+ break;
+ case 3:
+ out_shift = 0;
+ break;
+ default:
+ out_shift = -10 + rand() % 5;
+ break;
+ }
+ if (itr == 0) {
+ out_shift = SHIFT_MAX;
+ }
+ /* Generate input and filter data */
+ for (int i = 0; i < row_len; ++i) {
+ input[i] = rand() % 256 - 128;
+ }
+ for (int i = 0; i < row_len * out_channels; ++i) {
+ filter_data[i] = rand() % 256 - 128;
+ }
+
+ if (itr == 0) {
+ /* enable profiler */
+ profile_c_start();
+ }
+
+ /* C function */
+ esp_nn_fully_connected_s8_ansi(input, input_offset, row_len, filter_data, filter_offset,
+ NULL, output_c, out_channels, out_offset, out_shift, out_mult,
+ activation_min, activation_max);
+
+ if (itr == 0) {
+ profile_c_end();
+ profile_opt_start();
+ }
+
+ /* Optimized function */
+ esp_nn_fully_connected_s8(input, input_offset, row_len, filter_data, filter_offset,
+ NULL, output_opt, out_channels, out_offset, out_shift, out_mult,
+ activation_min, activation_max);
+
+ if (itr == 0) {
+ /* disable profiler */
+ profile_opt_end();
+ }
+
+ bool ret = CHECK_EQUAL(output_c, output_opt, out_channels);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(output_opt, out_channels, 1);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(output_c, out_channels, 1);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, row_len, 1);
+ printf("Filter data:\n");
+ PRINT_ARRAY_HEX(filter_data, row_len, out_channels);
+ printf("Out shift: %d\n", out_shift);
+ printf("Out mult: %x\n", out_mult);
+ return;
+ }
+ printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
+ }
+}
diff --git a/code/components/esp-nn/tests/src/pooling_test.c b/code/components/esp-nn/tests/src/pooling_test.c
new file mode 100644
index 00000000..c1c889e1
--- /dev/null
+++ b/code/components/esp-nn/tests/src/pooling_test.c
@@ -0,0 +1,184 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include "test_utils.h"
+
+
+void esp_nn_avg_pool_s8_test()
+{
+ /* prepare data */
+ const uint16_t input_wd = 16;
+ const uint16_t input_ht = 16;
+ const uint16_t channels = 16; /* With TFLite example, I have seen it 256 */
+ const int size = input_wd * input_ht * channels;
+ int8_t *input, *output_c, *output_opt;
+ const int32_t activation_min = -128;
+ const int32_t activation_max = 127;
+ const uint16_t pad_wd = 1;
+ const uint16_t pad_ht = 1;
+ const uint16_t stride_wd = 1;
+ const uint16_t stride_ht = 1;
+ const uint16_t filter_ht = 3;
+ const uint16_t filter_wd = 3;
+ const uint16_t out_wd = input_wd / stride_wd;
+ const uint16_t out_ht = input_ht / stride_ht;
+ const int out_size = out_wd * out_ht * channels;
+
+ input = memalign(16, size);
+ output_c = memalign(16, out_size);
+ output_opt = memalign(16, out_size);
+
+ if (input == NULL || output_c == NULL || output_opt == NULL) {
+ printf(ANSI_COLOR_RED"%s allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto avg_pool_s8_cleanup;
+ }
+ /**
+ * width/height, channels etc look suspicious but it it true.
+ * It actually depends upon where in model this is actually placed.
+ * If at the end wd/ht tends to be smaller and depth larger.
+ */
+
+ for (int i = 0; i < size; ++i) {
+ input[i] = rand() % 256 - 128;
+ }
+
+ /* enable profiler */
+ profile_c_start();
+
+ /* C function */
+ esp_nn_avg_pool_s8_ansi(input, input_wd, input_ht, output_c, out_wd, out_ht,
+ stride_wd, stride_ht, filter_wd, filter_ht, pad_wd, pad_ht,
+ activation_min, activation_max, channels);
+
+ profile_c_end();
+ profile_opt_start();
+
+ /* Optimized function */
+ esp_nn_avg_pool_s8(input, input_wd, input_ht, output_opt, out_wd, out_ht,
+ stride_wd, stride_ht, filter_wd, filter_ht, pad_wd, pad_ht,
+ activation_min, activation_max, channels);
+
+ /* disable profiler */
+ profile_opt_end();
+
+
+ bool ret = CHECK_EQUAL(output_c, output_opt, out_size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(output_opt, out_wd * channels, out_ht);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(output_c, out_wd * channels, out_ht);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, input_wd * channels, input_ht);
+ goto avg_pool_s8_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s passed\n"ANSI_COLOR_RESET, __FUNCTION__);
+
+avg_pool_s8_cleanup:
+ if (input) {
+ free(input);
+ }
+ if (output_c) {
+ free(output_c);
+ }
+ if (output_opt) {
+ free(output_opt);
+ }
+}
+
+void esp_nn_max_pool_s8_test()
+{
+ /* prepare data */
+ const uint16_t input_wd = 16;
+ const uint16_t input_ht = 16;
+ const uint16_t channels = 16; /* With TFLite example, I have seen it 256 */
+ int8_t *input, *output_c, *output_opt;
+ const int size = input_wd * input_ht * channels;
+ const int32_t activation_min = -128;
+ const int32_t activation_max = 127;
+ const uint16_t pad_wd = 1;
+ const uint16_t pad_ht = 1;
+ const uint16_t stride_wd = 1;
+ const uint16_t stride_ht = 1;
+ const uint16_t filter_ht = 3;
+ const uint16_t filter_wd = 3;
+ const uint16_t out_wd = input_wd / stride_wd;
+ const uint16_t out_ht = input_ht / stride_ht;
+ const int out_size = out_wd * out_ht * channels;
+
+ input = memalign(16, size);
+ output_c = memalign(16, out_size);
+ output_opt = memalign(16, out_size);
+
+ if (input == NULL || output_c == NULL || output_opt == NULL) {
+ printf(ANSI_COLOR_RED"%s allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto max_pool_s8_cleanup;
+ }
+
+ for (int i = 0; i < size; ++i) {
+ input[i] = rand() % 256 - 128;
+ }
+
+ /* enable profiler */
+ profile_c_start();
+
+ /* C function */
+ esp_nn_max_pool_s8_ansi(input, input_wd, input_ht, output_c, out_wd, out_ht,
+ stride_wd, stride_ht, filter_wd, filter_ht, pad_wd, pad_ht,
+ activation_min, activation_max, channels);
+
+ profile_c_end();
+ profile_opt_start();
+
+ /* Optimized function */
+ esp_nn_max_pool_s8(input, input_wd, input_ht, output_opt, out_wd, out_ht,
+ stride_wd, stride_ht, filter_wd, filter_ht, pad_wd, pad_ht,
+ activation_min, activation_max, channels);
+
+ /* disable profiler */
+ profile_opt_end();
+
+
+ bool ret = CHECK_EQUAL(output_c, output_opt, out_wd * out_ht * channels);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(output_opt, out_wd * out_ht * channels, 1);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(output_c, out_wd * out_ht * channels, 1);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, 8, size / 8);
+ goto max_pool_s8_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s passed\n"ANSI_COLOR_RESET, __FUNCTION__);
+
+max_pool_s8_cleanup:
+ if (input) {
+ free(input);
+ }
+ if (output_c) {
+ free(output_c);
+ }
+ if (output_opt) {
+ free(output_opt);
+ }
+}
diff --git a/code/components/esp-nn/tests/src/relu_test.c b/code/components/esp-nn/tests/src/relu_test.c
new file mode 100644
index 00000000..ce6f13f1
--- /dev/null
+++ b/code/components/esp-nn/tests/src/relu_test.c
@@ -0,0 +1,83 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include "test_utils.h"
+
+void esp_nn_relu6_s8_test()
+{
+ const int size = 1600 + 8 + 7;
+ int8_t *input, *inout_ansi, *inout_opt;
+
+ input = memalign(16, size);
+ inout_ansi = memalign(16, size);
+ inout_opt = memalign(16, size);
+
+ if (input == NULL || inout_ansi == NULL || inout_opt == NULL) {
+ printf(ANSI_COLOR_RED"%s allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto relu6_s8_cleanup;
+ }
+ /* Generate filter data between -128 -> +127 */
+ for (int i = 0; i < size; ++i) {
+ input[i] = rand() % 255 - 128;
+ inout_ansi[i] = input[i];
+ inout_opt[i] = input[i];
+ }
+
+ /* enable profiler */
+ profile_c_start();
+
+ /* C function */
+ esp_nn_relu6_s8_ansi(inout_ansi, size);
+
+ profile_c_end();
+ profile_opt_start();
+
+ /* Optimized function */
+ esp_nn_relu6_s8(inout_opt, size);
+
+ /* disable profiler */
+ profile_opt_end();
+
+ bool ret = CHECK_EQUAL(inout_ansi, inout_opt, size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(inout_opt, size, 1);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(inout_ansi, size, 1);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, size, 1);
+ goto relu6_s8_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s passed\n"ANSI_COLOR_RESET, __FUNCTION__);
+
+relu6_s8_cleanup:
+ if (input) {
+ free (input);
+ }
+ if (inout_ansi) {
+ free (inout_ansi);
+ }
+ if (inout_opt) {
+ free (inout_opt);
+ }
+
+}
diff --git a/code/components/esp-nn/tests/src/softmax_test.c b/code/components/esp-nn/tests/src/softmax_test.c
new file mode 100644
index 00000000..f7c734cd
--- /dev/null
+++ b/code/components/esp-nn/tests/src/softmax_test.c
@@ -0,0 +1,101 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include "test_utils.h"
+
+void esp_nn_softmax_s8_test()
+{
+ const int32_t height = 8;
+ const int32_t width = 32;
+ const int32_t diff_min = -128;
+ const int32_t mult = INT32_MAX / 2;
+ const int32_t shift = 7;
+ void *scratch_buf = NULL;
+ const int size = width * height;
+ int8_t *input, *out_ansi, *out_opt;
+
+ input = memalign(16, size);
+ out_ansi = memalign(16, size);
+ out_opt = memalign(16, size);
+
+ if (input == NULL || out_ansi == NULL || out_opt == NULL) {
+ printf(ANSI_COLOR_RED"%s buffer allocations failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ goto softmax_s8_cleanup;
+ }
+
+ /* Generate input data between -128 -> +127 */
+ for (int i = 0; i < size; ++i) {
+ input[i] = rand() % 255 - 128;
+ }
+
+ /* enable profiler */
+ profile_c_start();
+
+ /* C function */
+ esp_nn_softmax_s8_ansi(input, height, width, mult, shift, diff_min, out_ansi);
+
+ profile_c_end();
+
+ int32_t scratch_buf_size = esp_nn_get_softmax_scratch_size(width, height);
+ if (scratch_buf_size) {
+ scratch_buf = memalign(4, scratch_buf_size);
+ if (scratch_buf == NULL) {
+ printf(ANSI_COLOR_RED"%s scratch_buf alloc failed size %d\n"ANSI_COLOR_RESET, __FUNCTION__, scratch_buf_size);
+ goto softmax_s8_cleanup;
+ }
+ esp_nn_set_softmax_scratch_buf(scratch_buf);
+ }
+
+ profile_opt_start();
+
+ /* Optimized function */
+ esp_nn_softmax_s8(input, height, width, mult, shift, diff_min, out_opt);
+
+ /* disable profiler */
+ profile_opt_end();
+
+ bool ret = CHECK_EQUAL(out_ansi, out_opt, size);
+ if (ret == false) {
+ printf(ANSI_COLOR_RED"%s failed\n"ANSI_COLOR_RESET, __FUNCTION__);
+ printf("Output: \n");
+ PRINT_ARRAY_HEX(out_opt, width, height);
+ printf("Expected: \n");
+ PRINT_ARRAY_HEX(out_ansi, width, height);
+ printf("Input:\n");
+ PRINT_ARRAY_HEX(input, width, height);
+ goto softmax_s8_cleanup;
+ }
+ printf(ANSI_COLOR_GREEN"%s passed\n"ANSI_COLOR_RESET, __FUNCTION__);
+
+softmax_s8_cleanup:
+ if (input) {
+ free (input);
+ }
+ if (out_ansi) {
+ free (out_ansi);
+ }
+ if (out_opt) {
+ free (out_opt);
+ }
+ if (scratch_buf) {
+ free (scratch_buf);
+ }
+}
diff --git a/code/components/esp32-camera-master_neu_20220121.zip b/code/components/esp32-camera-master_neu_20220121.zip
deleted file mode 100644
index 3acbcf1a..00000000
Binary files a/code/components/esp32-camera-master_neu_20220121.zip and /dev/null differ
diff --git a/code/components/esp32-camera-master_old_version.zip b/code/components/esp32-camera-master_old_version.zip
deleted file mode 100644
index c0c60f8f..00000000
Binary files a/code/components/esp32-camera-master_old_version.zip and /dev/null differ
diff --git a/code/components/jomjol_flowcontroll/ClassFlow.cpp b/code/components/jomjol_flowcontroll/ClassFlow.cpp
index ff14c1b2..f15844d5 100644
--- a/code/components/jomjol_flowcontroll/ClassFlow.cpp
+++ b/code/components/jomjol_flowcontroll/ClassFlow.cpp
@@ -19,7 +19,6 @@ void ClassFlow::SetInitialParameter(void)
std::vector ClassFlow::ZerlegeZeile(std::string input, std::string delimiter)
{
std::vector Output;
-// std::string delimiter = " =,";
input = trim(input, delimiter);
size_t pos = findDelimiterPos(input, delimiter);
diff --git a/code/components/jomjol_flowcontroll/ClassFlow.h b/code/components/jomjol_flowcontroll/ClassFlow.h
index 4df4777c..92184d32 100644
--- a/code/components/jomjol_flowcontroll/ClassFlow.h
+++ b/code/components/jomjol_flowcontroll/ClassFlow.h
@@ -26,7 +26,6 @@ struct HTMLInfo
class ClassFlow
{
protected:
-// std::vector ZerlegeZeile(string input);
std::vector ZerlegeZeile(string input, string delimiter = " =, \t");
bool isNewParagraph(string input);
bool GetNextParagraph(FILE* pfile, string& aktparamgraph);
diff --git a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
index e3092308..9644b265 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
+++ b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
@@ -197,33 +197,6 @@ int ClassFlowCNNGeneral::ZeigerEvalHybrid(float zahl, float zahl_vorgaenger, int
return ((int) trunc(zahl) + 10) % 10;
}
-/*
-int ClassFlowCNNGeneral::ZeigerEvalHybrid_NEU(float zahl, float zahl_vorgaenger)
-{
- int ergebnis_nachkomma = ((int) floor(zahl * 10) + 10) % 10;
- int ergebnis_vorkomma = ((int) floor(zahl) + 10) % 10;
- int ergebnis, ergebnis_rating;
-
-
- if (zahl_vorgaenger < 0)
- return ergebnis_vorkomma % 10;
-
- ergebnis_rating = ergebnis_nachkomma - zahl_vorgaenger;
- if (ergebnis_nachkomma >= 5)
- ergebnis_rating-=5;
- else
- ergebnis_rating+=5;
- ergebnis = (int) round(zahl);
- if (ergebnis_rating < 0)
- ergebnis-=1;
- if (ergebnis == -1)
- ergebnis+=10;
-
- ergebnis = (ergebnis + 10) % 10;
- return ergebnis;
-
-}
-*/
int ClassFlowCNNGeneral::ZeigerEval(float zahl, int ziffer_vorgaenger)
@@ -309,11 +282,12 @@ bool ClassFlowCNNGeneral::ReadParameter(FILE* pfile, string& aktparamgraph)
{
CNNGoodThreshold = std::stof(zerlegt[1]);
}
- if ((toUpper(zerlegt[0]) == "MODELINPUTSIZE") && (zerlegt.size() > 2))
+/* if ((toUpper(zerlegt[0]) == "MODELINPUTSIZE") && (zerlegt.size() > 2))
{
this->modelxsize = std::stoi(zerlegt[1]);
this->modelysize = std::stoi(zerlegt[2]);
}
+*/
if (zerlegt.size() >= 5)
{
general* _analog = GetGENERAL(zerlegt[0], true);
@@ -334,11 +308,14 @@ bool ClassFlowCNNGeneral::ReadParameter(FILE* pfile, string& aktparamgraph)
}
}
+ if (!getNetworkParameter())
+ return false;
- for (int _ana = 0; _ana < GENERAL.size(); ++_ana)
+
+ for (int _ana = 0; _ana < GENERAL.size(); ++_ana)
for (int i = 0; i < GENERAL[_ana]->ROI.size(); ++i)
{
- GENERAL[_ana]->ROI[i]->image = new CImageBasis(modelxsize, modelysize, 3);
+ GENERAL[_ana]->ROI[i]->image = new CImageBasis(modelxsize, modelysize, modelchannel);
GENERAL[_ana]->ROI[i]->image_org = new CImageBasis(GENERAL[_ana]->ROI[i]->deltax, GENERAL[_ana]->ROI[i]->deltay, 3);
}
@@ -499,13 +476,11 @@ void ClassFlowCNNGeneral::DrawROI(CImageBasis *_zw)
}
}
-bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
+bool ClassFlowCNNGeneral::getNetworkParameter()
{
if (disabled)
return true;
- string logPath = CreateLogFolder(time);
-
CTfLiteClass *tflite = new CTfLiteClass;
string zwcnn = "/sdcard" + cnnmodelfile;
zwcnn = FormatFileName(zwcnn);
@@ -513,7 +488,6 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
if (!tflite->LoadModel(zwcnn)) {
printf("Can't read model file /sdcard%s\n", cnnmodelfile.c_str());
LogFile.WriteToFile("Cannot load model");
-
delete tflite;
return false;
}
@@ -521,6 +495,11 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
if (CNNType == AutoDetect)
{
+ tflite->GetInputDimension(false);
+ modelxsize = tflite->ReadInputDimenstion(0);
+ modelysize = tflite->ReadInputDimenstion(1);
+ modelchannel = tflite->ReadInputDimenstion(2);
+
int _anzoutputdimensions = tflite->GetAnzOutPut();
switch (_anzoutputdimensions)
{
@@ -549,6 +528,30 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
}
}
+ delete tflite;
+ return true;
+}
+
+bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
+{
+ if (disabled)
+ return true;
+
+ string logPath = CreateLogFolder(time);
+
+ CTfLiteClass *tflite = new CTfLiteClass;
+ string zwcnn = "/sdcard" + cnnmodelfile;
+ zwcnn = FormatFileName(zwcnn);
+ printf(zwcnn.c_str());printf("\n");
+ if (!tflite->LoadModel(zwcnn)) {
+ printf("Can't read model file /sdcard%s\n", cnnmodelfile.c_str());
+ LogFile.WriteToFile("Cannot load model");
+
+ delete tflite;
+ return false;
+ }
+ tflite->MakeAllocate();
+
for (int _ana = 0; _ana < GENERAL.size(); ++_ana)
{
for (int i = 0; i < GENERAL[_ana]->ROI.size(); ++i)
@@ -581,14 +584,15 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
if (isLogImage)
{
+ string _imagename = GENERAL[_ana]->name + "_" + GENERAL[_ana]->ROI[i]->name;
if (isLogImageSelect)
{
if (LogImageSelect.find(GENERAL[_ana]->ROI[i]->name) != std::string::npos)
- LogImage(logPath, GENERAL[_ana]->ROI[i]->name, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
+ LogImage(logPath, _imagename, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
}
else
{
- LogImage(logPath, GENERAL[_ana]->ROI[i]->name, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
+ LogImage(logPath, _imagename, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
}
}
} break;
@@ -617,7 +621,18 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
if (debugdetailgeneral) LogFile.WriteToFile(_zwres);
if (isLogImage)
- LogImage(logPath, GENERAL[_ana]->ROI[i]->name, &GENERAL[_ana]->ROI[i]->result_float, NULL, time, GENERAL[_ana]->ROI[i]->image_org);
+ {
+ string _imagename = GENERAL[_ana]->name + "_" + GENERAL[_ana]->ROI[i]->name;
+ if (isLogImageSelect)
+ {
+ if (LogImageSelect.find(GENERAL[_ana]->ROI[i]->name) != std::string::npos)
+ LogImage(logPath, _imagename, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
+ }
+ else
+ {
+ LogImage(logPath, _imagename, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
+ }
+ }
} break;
case DigitalHyprid10:
{
@@ -641,7 +656,18 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
if (debugdetailgeneral) LogFile.WriteToFile(_zwres);
if (isLogImage)
- LogImage(logPath, GENERAL[_ana]->ROI[i]->name, &GENERAL[_ana]->ROI[i]->result_float, NULL, time, GENERAL[_ana]->ROI[i]->image_org);
+ {
+ string _imagename = GENERAL[_ana]->name + "_" + GENERAL[_ana]->ROI[i]->name;
+ if (isLogImageSelect)
+ {
+ if (LogImageSelect.find(GENERAL[_ana]->ROI[i]->name) != std::string::npos)
+ LogImage(logPath, _imagename, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
+ }
+ else
+ {
+ LogImage(logPath, _imagename, NULL, &GENERAL[_ana]->ROI[i]->result_klasse, time, GENERAL[_ana]->ROI[i]->image_org);
+ }
+ }
} break;
case DoubleHyprid10:
@@ -649,6 +675,7 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
int _num, _numplus, _numminus;
float _val, _valplus, _valminus;
float _fit;
+ float _result_save_file;
tflite->LoadInputImageBasis(GENERAL[_ana]->ROI[i]->image);
tflite->Invoke();
@@ -680,10 +707,13 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
if (result < 0)
result = result + 10;
+ _result_save_file = result;
+
if (_fit < CNNGoodThreshold)
{
GENERAL[_ana]->ROI[i]->isReject = true;
result = -1;
+ _result_save_file+= 100; // Für den Fall, dass fit nicht ausreichend, soll trotzdem das Ergebnis mit "-10x.y" abgespeichert werden.
string zw = "Value Rejected due to Threshold (Fit: " + to_string(_fit) + "Threshold: " + to_string(CNNGoodThreshold);
printf("Value Rejected due to Threshold (Fit: %f, Threshold: %f\n", _fit, CNNGoodThreshold);
LogFile.WriteToFile(zw);
@@ -693,9 +723,23 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
GENERAL[_ana]->ROI[i]->isReject = false;
}
+
GENERAL[_ana]->ROI[i]->result_float = result;
printf("Result General(Analog)%i: %f\n", i, GENERAL[_ana]->ROI[i]->result_float);
+ if (isLogImage)
+ {
+ string _imagename = GENERAL[_ana]->name + "_" + GENERAL[_ana]->ROI[i]->name;
+ if (isLogImageSelect)
+ {
+ if (LogImageSelect.find(GENERAL[_ana]->ROI[i]->name) != std::string::npos)
+ LogImage(logPath, _imagename, &_result_save_file, NULL, time, GENERAL[_ana]->ROI[i]->image_org);
+ }
+ else
+ {
+ LogImage(logPath, _imagename, &_result_save_file, NULL, time, GENERAL[_ana]->ROI[i]->image_org);
+ }
+ }
}
break;
diff --git a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h
index e9c5c3ce..32fcf9bd 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h
+++ b/code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.h
@@ -24,7 +24,7 @@ protected:
float CNNGoodThreshold;
string cnnmodelfile;
- int modelxsize, modelysize;
+ int modelxsize, modelysize, modelchannel;
bool isLogImageSelect;
string LogImageSelect;
ClassFlowAlignment* flowpostalignment;
@@ -39,6 +39,8 @@ protected:
bool doNeuralNetwork(string time);
bool doAlignAndCut(string time);
+ bool getNetworkParameter();
+
public:
ClassFlowCNNGeneral(ClassFlowAlignment *_flowalign, t_CNNType _cnntype = AutoDetect);
diff --git a/code/components/jomjol_flowcontroll/ClassFlowDefineTypes.h b/code/components/jomjol_flowcontroll/ClassFlowDefineTypes.h
index 181332d0..98432886 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowDefineTypes.h
+++ b/code/components/jomjol_flowcontroll/ClassFlowDefineTypes.h
@@ -37,6 +37,7 @@ struct NumberPost {
float PreValue; // letzter Wert, der gut ausgelesen wurde
float Value; // letzer ausgelesener Wert, inkl. Korrekturen
string ReturnRateValue; // RückgabewertRate
+ string ReturnChangeAbsolute; // RückgabewertRate
string ReturnRawValue; // Rohwert (mit N & führenden 0)
string ReturnValue; // korrigierter Rückgabewert, ggf. mit Fehlermeldung
string ReturnPreValue; // korrigierter Rückgabewert ohne Fehlermeldung
diff --git a/code/components/jomjol_flowcontroll/ClassFlowMQTT.cpp b/code/components/jomjol_flowcontroll/ClassFlowMQTT.cpp
index ceaf9671..f4e014e9 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowMQTT.cpp
+++ b/code/components/jomjol_flowcontroll/ClassFlowMQTT.cpp
@@ -149,6 +149,7 @@ bool ClassFlowMQTT::doFlow(string zwtime)
std::string resultraw = "";
std::string resultrate = "";
std::string resulttimestamp = "";
+ std::string resultchangabs = "";
string zw = "";
string namenumber = "";
@@ -180,6 +181,7 @@ bool ClassFlowMQTT::doFlow(string zwtime)
resultraw = (*NUMBERS)[i]->ReturnRawValue;
resulterror = (*NUMBERS)[i]->ErrorMessageText;
resultrate = (*NUMBERS)[i]->ReturnRateValue;
+ resultchangabs = (*NUMBERS)[i]->ReturnChangeAbsolute;
resulttimestamp = (*NUMBERS)[i]->timeStamp;
namenumber = (*NUMBERS)[i]->name;
@@ -200,6 +202,10 @@ bool ClassFlowMQTT::doFlow(string zwtime)
if (resultrate.length() > 0)
MQTTPublish(zw, resultrate, SetRetainFlag);
+ zw = namenumber + "changeabsolut";
+ if (resultchangabs.length() > 0)
+ MQTTPublish(zw, resultchangabs, SetRetainFlag);
+
zw = namenumber + "raw";
if (resultraw.length() > 0)
MQTTPublish(zw, resultraw, SetRetainFlag);
diff --git a/code/components/jomjol_flowcontroll/ClassFlowPostProcessing.cpp b/code/components/jomjol_flowcontroll/ClassFlowPostProcessing.cpp
index f61e0140..b3c60fad 100644
--- a/code/components/jomjol_flowcontroll/ClassFlowPostProcessing.cpp
+++ b/code/components/jomjol_flowcontroll/ClassFlowPostProcessing.cpp
@@ -77,6 +77,8 @@ void ClassFlowPostProcessing::SetPreValue(float zw, string _numbers, bool _exter
if (NUMBERS[j]->name == _numbers)
{
NUMBERS[j]->PreValue = zw;
+ NUMBERS[j]->ReturnPreValue = std::to_string(zw);
+ NUMBERS[j]->PreValueOkay = true;
if (_extern)
{
time(&(NUMBERS[j]->lastvalue));
@@ -541,7 +543,6 @@ void ClassFlowPostProcessing::InitNUMBERS()
_number->ReturnRawValue = ""; // Rohwert (mit N & führenden 0)
_number->ReturnValue = ""; // korrigierter Rückgabewert, ggf. mit Fehlermeldung
-// _number->ReturnValueNoError = ""; // korrigierter Rückgabewert ohne Fehlermeldung
_number->ErrorMessageText = ""; // Fehlermeldung bei Consistency Check
_number->ReturnPreValue = "";
_number->PreValueOkay = false;
@@ -560,7 +561,6 @@ void ClassFlowPostProcessing::InitNUMBERS()
_number->Value = 0; // letzer ausgelesener Wert, inkl. Korrekturen
_number->ReturnRawValue = ""; // Rohwert (mit N & führenden 0)
_number->ReturnValue = ""; // korrigierter Rückgabewert, ggf. mit Fehlermeldung
-// _number->ReturnValueNoError = ""; // korrigierter Rückgabewert ohne Fehlermeldung
_number->ErrorMessageText = ""; // Fehlermeldung bei Consistency Check
_number->Nachkomma = _number->AnzahlAnalog;
@@ -722,7 +722,7 @@ bool ClassFlowPostProcessing::doFlow(string zwtime)
if (NUMBERS[j]->useMaxRateValue && PreValueUse && NUMBERS[j]->PreValueOkay)
{
- float _ratedifference;
+ float _ratedifference;
if (NUMBERS[j]->RateType == RateChange)
_ratedifference = NUMBERS[j]->FlowRateAct;
else
@@ -745,6 +745,7 @@ bool ClassFlowPostProcessing::doFlow(string zwtime)
NUMBERS[j]->ReturnValue = RundeOutput(NUMBERS[j]->Value, NUMBERS[j]->Nachkomma);
NUMBERS[j]->ReturnPreValue = RundeOutput(NUMBERS[j]->PreValue, NUMBERS[j]->Nachkomma);
+ NUMBERS[j]->ReturnChangeAbsolute = RundeOutput(NUMBERS[j]->Value - NUMBERS[j]->PreValue, NUMBERS[j]->Nachkomma);
NUMBERS[j]->ErrorMessageText = "no error";
UpdatePreValueINI = true;
diff --git a/code/components/jomjol_tfliteclass/CTfLiteClass.cpp b/code/components/jomjol_tfliteclass/CTfLiteClass.cpp
index df008a1b..15affbc0 100644
--- a/code/components/jomjol_tfliteclass/CTfLiteClass.cpp
+++ b/code/components/jomjol_tfliteclass/CTfLiteClass.cpp
@@ -87,6 +87,19 @@ void CTfLiteClass::GetInputDimension(bool silent = false)
}
}
+int CTfLiteClass::ReadInputDimenstion(int _dim)
+{
+ if (_dim == 0)
+ return im_width;
+ if (_dim == 1)
+ return im_height;
+ if (_dim == 2)
+ return im_channel;
+
+ return -1;
+}
+
+
int CTfLiteClass::GetAnzOutPut(bool silent)
{
diff --git a/code/components/jomjol_tfliteclass/CTfLiteClass.h b/code/components/jomjol_tfliteclass/CTfLiteClass.h
index cab5a0e3..ef98c1fa 100644
--- a/code/components/jomjol_tfliteclass/CTfLiteClass.h
+++ b/code/components/jomjol_tfliteclass/CTfLiteClass.h
@@ -71,5 +71,6 @@ class CTfLiteClass
float GetOutputValue(int nr);
void GetInputDimension(bool silent);
+ int ReadInputDimenstion(int _dim);
};
diff --git a/code/components/tflite-lib/CMakeLists.txt b/code/components/tflite-lib/CMakeLists.txt
index fab7027a..eed31a57 100644
--- a/code/components/tflite-lib/CMakeLists.txt
+++ b/code/components/tflite-lib/CMakeLists.txt
@@ -1,3 +1,5 @@
+## TODO: GLOB is not a good way to collect files. Use explicit file list instead
+
cmake_minimum_required(VERSION 3.5)
set(tflite_dir "${CMAKE_CURRENT_SOURCE_DIR}/tensorflow/lite")
@@ -16,14 +18,27 @@ file(GLOB srcs_kernels
"${tfmicro_kernels_dir}/*.c"
"${tfmicro_kernels_dir}/*.cc")
+# remove sources which will be provided by esp_nn
+list(REMOVE_ITEM srcs_kernels
+ "${tfmicro_kernels_dir}/add.cc"
+ "${tfmicro_kernels_dir}/conv.cc"
+ "${tfmicro_kernels_dir}/depthwise_conv.cc"
+ "${tfmicro_kernels_dir}/fully_connected.cc"
+ "${tfmicro_kernels_dir}/mul.cc"
+ "${tfmicro_kernels_dir}/pooling.cc")
+
+FILE(GLOB esp_nn_kernels
+ "${tfmicro_kernels_dir}/esp_nn/*.cc")
+
set(lib_srcs
"${srcs_micro}"
"${srcs_kernels}"
+ "${esp_nn_kernels}"
"${src_micro_frontend}"
"${tflite_dir}/kernels/kernel_util.cc"
"${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
"${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
- "${tflite_dir}/c/common.c"
+ "${tflite_dir}/c/common.cc"
"${tflite_dir}/core/api/error_reporter.cc"
"${tflite_dir}/core/api/flatbuffer_conversions.cc"
"${tflite_dir}/core/api/op_resolver.cc"
@@ -36,15 +51,17 @@ idf_component_register(
INCLUDE_DIRS "." "third_party/gemmlowp"
"third_party/flatbuffers/include"
"third_party/ruy"
- "third_party/kissfft")
+ "third_party/kissfft"
+ REQUIRES "esp-nn")
# Reduce the level of paranoia to be able to compile TF sources
target_compile_options(${COMPONENT_LIB} PRIVATE
-Wno-maybe-uninitialized
-Wno-missing-field-initializers
+ -DESP_NN # enables ESP-NN optimizations by Espressif
-Wno-type-limits)
-target_compile_options(${COMPONENT_LIB} PRIVATE -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DESP -DESP_NN -Wno-nonnull -Wno-nonnull -Wno-nonnull)
-target_compile_options(${COMPONENT_LIB} PRIVATE $<$: -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DESP -DESP_NN -Wno-return-type -Wno-strict-aliasing -std=gnu++14 -Wno-return-type -Wno-strict-aliasing -std=gnu++14 -Wno-return-type -Wno-strict-aliasing -std=gnu++14 >)
+target_compile_options(${COMPONENT_LIB} PRIVATE -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -Wno-nonnull)
+target_compile_options(${COMPONENT_LIB} PRIVATE $<$: -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -Wno-return-type -Wno-strict-aliasing -std=gnu++14 >)
target_compile_options(${COMPONENT_LIB} INTERFACE $<$>:-DTF_LITE_STATIC_MEMORY>)
target_link_libraries(${COMPONENT_LIB} PRIVATE -lm)
diff --git a/code/components/tflite-lib/tensorflow/lite/builtin_op_data.h b/code/components/tflite-lib/tensorflow/lite/builtin_op_data.h
new file mode 100644
index 00000000..b9d42845
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/builtin_op_data.h
@@ -0,0 +1,22 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Compatibility shim for new location of interface definitions.
+
+#ifndef TENSORFLOW_LITE_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_LITE_BUILTIN_OP_DATA_H_
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+
+#endif // TENSORFLOW_LITE_BUILTIN_OP_DATA_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/builtin_ops.h b/code/components/tflite-lib/tensorflow/lite/builtin_ops.h
new file mode 100644
index 00000000..19ce3e2c
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/builtin_ops.h
@@ -0,0 +1,187 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_BUILTIN_OPS_H_
+#define TENSORFLOW_LITE_BUILTIN_OPS_H_
+
+// DO NOT EDIT MANUALLY: This file is automatically generated by
+// `schema/builtin_ops_header/generator.cc`.
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// The enum for builtin operators.
+// Note: CUSTOM, DELEGATE, and PLACEHOLDER_FOR_GREATER_OP_CODES are 3 special
+// ops which are not real built-in ops.
+typedef enum {
+ kTfLiteBuiltinAdd = 0,
+ kTfLiteBuiltinAveragePool2d = 1,
+ kTfLiteBuiltinConcatenation = 2,
+ kTfLiteBuiltinConv2d = 3,
+ kTfLiteBuiltinDepthwiseConv2d = 4,
+ kTfLiteBuiltinDepthToSpace = 5,
+ kTfLiteBuiltinDequantize = 6,
+ kTfLiteBuiltinEmbeddingLookup = 7,
+ kTfLiteBuiltinFloor = 8,
+ kTfLiteBuiltinFullyConnected = 9,
+ kTfLiteBuiltinHashtableLookup = 10,
+ kTfLiteBuiltinL2Normalization = 11,
+ kTfLiteBuiltinL2Pool2d = 12,
+ kTfLiteBuiltinLocalResponseNormalization = 13,
+ kTfLiteBuiltinLogistic = 14,
+ kTfLiteBuiltinLshProjection = 15,
+ kTfLiteBuiltinLstm = 16,
+ kTfLiteBuiltinMaxPool2d = 17,
+ kTfLiteBuiltinMul = 18,
+ kTfLiteBuiltinRelu = 19,
+ kTfLiteBuiltinReluN1To1 = 20,
+ kTfLiteBuiltinRelu6 = 21,
+ kTfLiteBuiltinReshape = 22,
+ kTfLiteBuiltinResizeBilinear = 23,
+ kTfLiteBuiltinRnn = 24,
+ kTfLiteBuiltinSoftmax = 25,
+ kTfLiteBuiltinSpaceToDepth = 26,
+ kTfLiteBuiltinSvdf = 27,
+ kTfLiteBuiltinTanh = 28,
+ kTfLiteBuiltinConcatEmbeddings = 29,
+ kTfLiteBuiltinSkipGram = 30,
+ kTfLiteBuiltinCall = 31,
+ kTfLiteBuiltinCustom = 32,
+ kTfLiteBuiltinEmbeddingLookupSparse = 33,
+ kTfLiteBuiltinPad = 34,
+ kTfLiteBuiltinUnidirectionalSequenceRnn = 35,
+ kTfLiteBuiltinGather = 36,
+ kTfLiteBuiltinBatchToSpaceNd = 37,
+ kTfLiteBuiltinSpaceToBatchNd = 38,
+ kTfLiteBuiltinTranspose = 39,
+ kTfLiteBuiltinMean = 40,
+ kTfLiteBuiltinSub = 41,
+ kTfLiteBuiltinDiv = 42,
+ kTfLiteBuiltinSqueeze = 43,
+ kTfLiteBuiltinUnidirectionalSequenceLstm = 44,
+ kTfLiteBuiltinStridedSlice = 45,
+ kTfLiteBuiltinBidirectionalSequenceRnn = 46,
+ kTfLiteBuiltinExp = 47,
+ kTfLiteBuiltinTopkV2 = 48,
+ kTfLiteBuiltinSplit = 49,
+ kTfLiteBuiltinLogSoftmax = 50,
+ kTfLiteBuiltinDelegate = 51,
+ kTfLiteBuiltinBidirectionalSequenceLstm = 52,
+ kTfLiteBuiltinCast = 53,
+ kTfLiteBuiltinPrelu = 54,
+ kTfLiteBuiltinMaximum = 55,
+ kTfLiteBuiltinArgMax = 56,
+ kTfLiteBuiltinMinimum = 57,
+ kTfLiteBuiltinLess = 58,
+ kTfLiteBuiltinNeg = 59,
+ kTfLiteBuiltinPadv2 = 60,
+ kTfLiteBuiltinGreater = 61,
+ kTfLiteBuiltinGreaterEqual = 62,
+ kTfLiteBuiltinLessEqual = 63,
+ kTfLiteBuiltinSelect = 64,
+ kTfLiteBuiltinSlice = 65,
+ kTfLiteBuiltinSin = 66,
+ kTfLiteBuiltinTransposeConv = 67,
+ kTfLiteBuiltinSparseToDense = 68,
+ kTfLiteBuiltinTile = 69,
+ kTfLiteBuiltinExpandDims = 70,
+ kTfLiteBuiltinEqual = 71,
+ kTfLiteBuiltinNotEqual = 72,
+ kTfLiteBuiltinLog = 73,
+ kTfLiteBuiltinSum = 74,
+ kTfLiteBuiltinSqrt = 75,
+ kTfLiteBuiltinRsqrt = 76,
+ kTfLiteBuiltinShape = 77,
+ kTfLiteBuiltinPow = 78,
+ kTfLiteBuiltinArgMin = 79,
+ kTfLiteBuiltinFakeQuant = 80,
+ kTfLiteBuiltinReduceProd = 81,
+ kTfLiteBuiltinReduceMax = 82,
+ kTfLiteBuiltinPack = 83,
+ kTfLiteBuiltinLogicalOr = 84,
+ kTfLiteBuiltinOneHot = 85,
+ kTfLiteBuiltinLogicalAnd = 86,
+ kTfLiteBuiltinLogicalNot = 87,
+ kTfLiteBuiltinUnpack = 88,
+ kTfLiteBuiltinReduceMin = 89,
+ kTfLiteBuiltinFloorDiv = 90,
+ kTfLiteBuiltinReduceAny = 91,
+ kTfLiteBuiltinSquare = 92,
+ kTfLiteBuiltinZerosLike = 93,
+ kTfLiteBuiltinFill = 94,
+ kTfLiteBuiltinFloorMod = 95,
+ kTfLiteBuiltinRange = 96,
+ kTfLiteBuiltinResizeNearestNeighbor = 97,
+ kTfLiteBuiltinLeakyRelu = 98,
+ kTfLiteBuiltinSquaredDifference = 99,
+ kTfLiteBuiltinMirrorPad = 100,
+ kTfLiteBuiltinAbs = 101,
+ kTfLiteBuiltinSplitV = 102,
+ kTfLiteBuiltinUnique = 103,
+ kTfLiteBuiltinCeil = 104,
+ kTfLiteBuiltinReverseV2 = 105,
+ kTfLiteBuiltinAddN = 106,
+ kTfLiteBuiltinGatherNd = 107,
+ kTfLiteBuiltinCos = 108,
+ kTfLiteBuiltinWhere = 109,
+ kTfLiteBuiltinRank = 110,
+ kTfLiteBuiltinElu = 111,
+ kTfLiteBuiltinReverseSequence = 112,
+ kTfLiteBuiltinMatrixDiag = 113,
+ kTfLiteBuiltinQuantize = 114,
+ kTfLiteBuiltinMatrixSetDiag = 115,
+ kTfLiteBuiltinRound = 116,
+ kTfLiteBuiltinHardSwish = 117,
+ kTfLiteBuiltinIf = 118,
+ kTfLiteBuiltinWhile = 119,
+ kTfLiteBuiltinNonMaxSuppressionV4 = 120,
+ kTfLiteBuiltinNonMaxSuppressionV5 = 121,
+ kTfLiteBuiltinScatterNd = 122,
+ kTfLiteBuiltinSelectV2 = 123,
+ kTfLiteBuiltinDensify = 124,
+ kTfLiteBuiltinSegmentSum = 125,
+ kTfLiteBuiltinBatchMatmul = 126,
+ kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127,
+ kTfLiteBuiltinCumsum = 128,
+ kTfLiteBuiltinCallOnce = 129,
+ kTfLiteBuiltinBroadcastTo = 130,
+ kTfLiteBuiltinRfft2d = 131,
+ kTfLiteBuiltinConv3d = 132,
+ kTfLiteBuiltinImag = 133,
+ kTfLiteBuiltinReal = 134,
+ kTfLiteBuiltinComplexAbs = 135,
+ kTfLiteBuiltinHashtable = 136,
+ kTfLiteBuiltinHashtableFind = 137,
+ kTfLiteBuiltinHashtableImport = 138,
+ kTfLiteBuiltinHashtableSize = 139,
+ kTfLiteBuiltinReduceAll = 140,
+ kTfLiteBuiltinConv3dTranspose = 141,
+ kTfLiteBuiltinVarHandle = 142,
+ kTfLiteBuiltinReadVariable = 143,
+ kTfLiteBuiltinAssignVariable = 144,
+ kTfLiteBuiltinBroadcastArgs = 145,
+ kTfLiteBuiltinRandomStandardNormal = 146,
+ kTfLiteBuiltinBucketize = 147,
+ kTfLiteBuiltinRandomUniform = 148,
+ kTfLiteBuiltinMultinomial = 149,
+ kTfLiteBuiltinGelu = 150,
+ kTfLiteBuiltinDynamicUpdateSlice = 151,
+} TfLiteBuiltinOperator;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_LITE_BUILTIN_OPS_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h b/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
index 678dfae6..d2524969 100644
--- a/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
+++ b/code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
@@ -98,6 +98,7 @@ typedef enum {
kTfLiteResource = 14,
kTfLiteVariant = 15,
kTfLiteUInt32 = 16,
+ kTfLiteUInt16 = 17,
} TfLiteType;
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
@@ -111,6 +112,12 @@ typedef struct TfLiteQuantizationParams {
int32_t zero_point;
} TfLiteQuantizationParams;
+// --------------------------------------------------------------------------
+// Opaque types used by c_api_opaque.h.
+
+// TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
+typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;
+
#ifdef __cplusplus
} // extern C
#endif
diff --git a/code/components/tflite-lib/tensorflow/lite/c/common.c b/code/components/tflite-lib/tensorflow/lite/c/common.cc
similarity index 89%
rename from code/components/tflite-lib/tensorflow/lite/c/common.c
rename to code/components/tflite-lib/tensorflow/lite/c/common.cc
index d149d22c..956e9d69 100644
--- a/code/components/tflite-lib/tensorflow/lite/c/common.c
+++ b/code/components/tflite-lib/tensorflow/lite/c/common.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include
#endif // TF_LITE_STATIC_MEMORY
+extern "C" {
+
size_t TfLiteIntArrayGetSizeInBytes(int size) {
static TfLiteIntArray dummy;
@@ -34,13 +36,13 @@ size_t TfLiteIntArrayGetSizeInBytes(int size) {
int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b) {
if (a == b) return 1;
- if (a == NULL || b == NULL) return 0;
+ if (a == nullptr || b == nullptr) return 0;
return TfLiteIntArrayEqualsArray(a, b->size, b->data);
}
int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
const int b_data[]) {
- if (a == NULL) return (b_size == 0);
+ if (a == nullptr) return (b_size == 0);
if (a->size != b_size) return 0;
int i = 0;
for (; i < a->size; i++)
@@ -52,7 +54,7 @@ int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
TfLiteIntArray* TfLiteIntArrayCreate(int size) {
size_t alloc_size = TfLiteIntArrayGetSizeInBytes(size);
- if (alloc_size <= 0) return NULL;
+ if (alloc_size <= 0) return nullptr;
TfLiteIntArray* ret = (TfLiteIntArray*)malloc(alloc_size);
if (!ret) return ret;
ret->size = size;
@@ -60,7 +62,7 @@ TfLiteIntArray* TfLiteIntArrayCreate(int size) {
}
TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src) {
- if (!src) return NULL;
+ if (!src) return nullptr;
TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
if (ret) {
memcpy(ret->data, src->data, src->size * sizeof(int));
@@ -99,7 +101,7 @@ void TfLiteTensorDataFree(TfLiteTensor* t) {
t->allocation_type == kTfLitePersistentRo) {
free(t->data.raw);
}
- t->data.raw = NULL;
+ t->data.raw = nullptr;
}
void TfLiteQuantizationFree(TfLiteQuantization* quantization) {
@@ -108,31 +110,31 @@ void TfLiteQuantizationFree(TfLiteQuantization* quantization) {
(TfLiteAffineQuantization*)(quantization->params);
if (q_params->scale) {
TfLiteFloatArrayFree(q_params->scale);
- q_params->scale = NULL;
+ q_params->scale = nullptr;
}
if (q_params->zero_point) {
TfLiteIntArrayFree(q_params->zero_point);
- q_params->zero_point = NULL;
+ q_params->zero_point = nullptr;
}
free(q_params);
}
- quantization->params = NULL;
+ quantization->params = nullptr;
quantization->type = kTfLiteNoQuantization;
}
void TfLiteSparsityFree(TfLiteSparsity* sparsity) {
- if (sparsity == NULL) {
+ if (sparsity == nullptr) {
return;
}
if (sparsity->traversal_order) {
TfLiteIntArrayFree(sparsity->traversal_order);
- sparsity->traversal_order = NULL;
+ sparsity->traversal_order = nullptr;
}
if (sparsity->block_map) {
TfLiteIntArrayFree(sparsity->block_map);
- sparsity->block_map = NULL;
+ sparsity->block_map = nullptr;
}
if (sparsity->dim_metadata) {
@@ -141,13 +143,13 @@ void TfLiteSparsityFree(TfLiteSparsity* sparsity) {
TfLiteDimensionMetadata metadata = sparsity->dim_metadata[i];
if (metadata.format == kTfLiteDimSparseCSR) {
TfLiteIntArrayFree(metadata.array_segments);
- metadata.array_segments = NULL;
+ metadata.array_segments = nullptr;
TfLiteIntArrayFree(metadata.array_indices);
- metadata.array_indices = NULL;
+ metadata.array_indices = nullptr;
}
}
free(sparsity->dim_metadata);
- sparsity->dim_metadata = NULL;
+ sparsity->dim_metadata = nullptr;
}
free(sparsity);
@@ -156,16 +158,16 @@ void TfLiteSparsityFree(TfLiteSparsity* sparsity) {
void TfLiteTensorFree(TfLiteTensor* t) {
TfLiteTensorDataFree(t);
if (t->dims) TfLiteIntArrayFree(t->dims);
- t->dims = NULL;
+ t->dims = nullptr;
if (t->dims_signature) {
TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
}
- t->dims_signature = NULL;
+ t->dims_signature = nullptr;
TfLiteQuantizationFree(&t->quantization);
TfLiteSparsityFree(t->sparsity);
- t->sparsity = NULL;
+ t->sparsity = nullptr;
}
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
@@ -185,7 +187,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
tensor->is_variable = is_variable;
tensor->quantization.type = kTfLiteNoQuantization;
- tensor->quantization.params = NULL;
+ tensor->quantization.params = nullptr;
}
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
@@ -229,6 +231,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "NOTYPE";
case kTfLiteFloat32:
return "FLOAT32";
+ case kTfLiteUInt16:
+ return "UINT16";
case kTfLiteInt16:
return "INT16";
case kTfLiteInt32:
@@ -263,14 +267,6 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "Unknown type";
}
-TfLiteDelegate TfLiteDelegateCreate(void) {
- TfLiteDelegate d = {
- .data_ = NULL,
- .Prepare = NULL,
- .CopyFromBufferHandle = NULL,
- .CopyToBufferHandle = NULL,
- .FreeBufferHandle = NULL,
- .flags = kTfLiteDelegateFlagsNone,
- };
- return d;
-}
+TfLiteDelegate TfLiteDelegateCreate() { return TfLiteDelegate{}; }
+
+} // extern "C"
diff --git a/code/components/tflite-lib/tensorflow/lite/c/common.h b/code/components/tflite-lib/tensorflow/lite/c/common.h
index 7056d1e2..6a109e1e 100644
--- a/code/components/tflite-lib/tensorflow/lite/c/common.h
+++ b/code/components/tflite-lib/tensorflow/lite/c/common.h
@@ -173,8 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
} \
} while (false)
#else // TF_LITE_STRIP_ERROR_STRINGS
-#define TF_LITE_KERNEL_LOG(context, ...)
-#define TF_LITE_MAYBE_KERNEL_LOG(context, ...)
+#define UNUSED(...) (void)sizeof(#__VA_ARGS__)
+#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
+#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
#endif // TF_LITE_STRIP_ERROR_STRINGS
// Check whether value is true, and if not return kTfLiteError from
@@ -316,6 +317,7 @@ typedef union TfLitePtrUnion {
uint8_t* uint8;
bool* b;
int16_t* i16;
+ uint16_t* ui16;
TfLiteComplex64* c64;
TfLiteComplex128* c128;
int8_t* int8;
@@ -459,7 +461,8 @@ typedef struct TfLiteTensor {
// Optional. Encodes shapes with unknown dimensions with -1. This field is
// only populated when unknown dimensions exist in a read-write tensor (i.e.
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
- // `dims_signature` contains [1, -1, -1, 3]).
+ // `dims_signature` contains [1, -1, -1, 3]). Note that this field only
+ // exists when TF_LITE_STATIC_MEMORY is not defined.
const TfLiteIntArray* dims_signature;
} TfLiteTensor;
diff --git a/code/components/tflite-lib/tensorflow/lite/context_util.h b/code/components/tflite-lib/tensorflow/lite/context_util.h
new file mode 100644
index 00000000..7c8a5abd
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/context_util.h
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This provides a few C++ helpers that are useful for manipulating C structures
+// in C++.
+#ifndef TENSORFLOW_LITE_CONTEXT_UTIL_H_
+#define TENSORFLOW_LITE_CONTEXT_UTIL_H_
+
+#include
+
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite
+// C api uses. Can't use the google array_view, since we can't depend on even
+// absl for embedded device reasons.
+class TfLiteIntArrayView {
+ public:
+ // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null
+ // and this view does not take ownership of it.
+ explicit TfLiteIntArrayView(const TfLiteIntArray* int_array)
+ : int_array_(int_array) {}
+
+ TfLiteIntArrayView(const TfLiteIntArrayView&) = default;
+ TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default;
+
+ typedef const int* const_iterator;
+ const_iterator begin() const { return int_array_->data; }
+ const_iterator end() const { return &int_array_->data[int_array_->size]; }
+ size_t size() const { return end() - begin(); }
+ int operator[](size_t pos) const { return int_array_->data[pos]; }
+
+ private:
+ const TfLiteIntArray* int_array_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_CONTEXT_UTIL_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
index dfa0ccfd..e92d754f 100644
--- a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -208,6 +208,14 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseBatchToSpaceNd(op, error_reporter, allocator, builtin_data);
}
+ case BuiltinOperator_BROADCAST_ARGS: {
+ return ParseBroadcastArgs(op, error_reporter, allocator, builtin_data);
+ }
+
+ case BuiltinOperator_BROADCAST_TO: {
+ return ParseBroadcastTo(op, error_reporter, allocator, builtin_data);
+ }
+
case BuiltinOperator_CALL_ONCE: {
return ParseCallOnce(op, error_reporter, allocator, builtin_data);
}
@@ -336,6 +344,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseLogSoftmax(op, error_reporter, allocator, builtin_data);
}
+ case BuiltinOperator_LSTM: {
+ return ParseLSTM(op, error_reporter, allocator, builtin_data);
+ }
+
case BuiltinOperator_MAXIMUM: {
return ParseMaximum(op, error_reporter, allocator, builtin_data);
}
@@ -605,37 +617,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
- case BuiltinOperator_LSTM: {
- auto params = safe_allocator.Allocate();
- TF_LITE_ENSURE(error_reporter, params != nullptr);
- if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
- params->activation =
- ConvertActivation(lstm_params->fused_activation_function());
- params->cell_clip = lstm_params->cell_clip();
- params->proj_clip = lstm_params->proj_clip();
- switch (lstm_params->kernel_type()) {
- case LSTMKernelType_FULL:
- params->kernel_type = kTfLiteLSTMFullKernel;
- break;
- case LSTMKernelType_BASIC:
- params->kernel_type = kTfLiteLSTMBasicKernel;
- break;
- default:
- TF_LITE_REPORT_ERROR(error_reporter,
- "Unhandled LSTM kernel type: %d",
- lstm_params->kernel_type());
- return kTfLiteError;
- }
- params->asymmetric_quantize_inputs =
- lstm_params->asymmetric_quantize_inputs();
- } else {
- TF_LITE_REPORT_ERROR(error_reporter,
- "No valid LSTM builtin options exist");
- return kTfLiteError;
- }
- *builtin_data = params.release();
- return kTfLiteOk;
- }
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
return ParseUnidirectionalSequenceLSTM(op, error_reporter, allocator,
builtin_data);
@@ -883,7 +864,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
- case BuiltinOperator_BROADCAST_TO:
case BuiltinOperator_RFFT2D:
case BuiltinOperator_IMAG:
case BuiltinOperator_REAL:
@@ -891,7 +871,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_HASHTABLE_FIND:
case BuiltinOperator_HASHTABLE_IMPORT:
case BuiltinOperator_HASHTABLE_SIZE:
- case BuiltinOperator_BROADCAST_ARGS:
+ case BuiltinOperator_DYNAMIC_UPDATE_SLICE:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;
@@ -916,6 +896,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_INT16:
*type = kTfLiteInt16;
return kTfLiteOk;
+ case TensorType_UINT16:
+ *type = kTfLiteUInt16;
+ return kTfLiteOk;
case TensorType_INT32:
*type = kTfLiteInt32;
return kTfLiteOk;
@@ -1085,6 +1068,22 @@ TfLiteStatus ParseBatchToSpaceNd(const Operator*, ErrorReporter*,
return kTfLiteOk;
}
+// We have this parse function instead of directly returning kTfLiteOk from the
+// switch-case in ParseOpData because this function is used as part of the
+// selective registration for the OpResolver implementation in micro.
+TfLiteStatus ParseBroadcastArgs(const Operator*, ErrorReporter*,
+ BuiltinDataAllocator*, void**) {
+ return kTfLiteOk;
+}
+
+// We have this parse function instead of directly returning kTfLiteOk from the
+// switch-case in ParseOpData because this function is used as part of the
+// selective registration for the OpResolver implementation in micro.
+TfLiteStatus ParseBroadcastTo(const Operator*, ErrorReporter*,
+ BuiltinDataAllocator*, void**) {
+ return kTfLiteOk;
+}
+
TfLiteStatus ParseCallOnce(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
@@ -1605,6 +1604,40 @@ TfLiteStatus ParseLogSoftmax(const Operator*, ErrorReporter*,
return kTfLiteOk;
}
+TfLiteStatus ParseLSTM(const Operator* op, ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
+ CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+ SafeBuiltinDataAllocator safe_allocator(allocator);
+ auto params = safe_allocator.Allocate();
+ TF_LITE_ENSURE(error_reporter, params != nullptr);
+ if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ ConvertActivation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ default:
+ TF_LITE_REPORT_ERROR(error_reporter, "Unhandled LSTM kernel type: %d",
+ lstm_params->kernel_type());
+ return kTfLiteError;
+ }
+ params->asymmetric_quantize_inputs =
+ lstm_params->asymmetric_quantize_inputs();
+ } else {
+ TF_LITE_REPORT_ERROR(error_reporter, "No valid LSTM builtin options exist");
+ return kTfLiteError;
+ }
+ *builtin_data = params.release();
+ return kTfLiteOk;
+}
+
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
@@ -2337,6 +2370,31 @@ TfLiteStatus ParseVarHandle(const Operator* op, ErrorReporter* error_reporter,
return kTfLiteOk;
}
+TfLiteStatus ParseWhile(const Operator* op, ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
+ CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+ SafeBuiltinDataAllocator safe_allocator(allocator);
+ std::unique_ptr
+ params = safe_allocator.Allocate();
+ TF_LITE_ENSURE(error_reporter, params != nullptr);
+
+ const WhileOptions* schema_params = op->builtin_options_as_WhileOptions();
+
+ if (schema_params != nullptr) {
+ params->cond_subgraph_index = schema_params->cond_subgraph_index();
+ params->body_subgraph_index = schema_params->body_subgraph_index();
+ } else {
+ // TODO(b/157480169): We should either return kTfLiteError or fill in some
+ // reasonable defaults in the params struct. We are not doing so until we
+ // better undertand the ramifications of changing the legacy behavior.
+ }
+
+ *builtin_data = params.release();
+ return kTfLiteOk;
+}
+
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
diff --git a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h
index 8cf889d8..cd6637bc 100644
--- a/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -98,6 +98,15 @@ TfLiteStatus ParseBatchToSpaceNd(const Operator* op,
BuiltinDataAllocator* allocator,
void** builtin_data);
+TfLiteStatus ParseBroadcastArgs(const Operator* op,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator,
+ void** builtin_data);
+
+TfLiteStatus ParseBroadcastTo(const Operator* op, ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator,
+ void** builtin_data);
+
TfLiteStatus ParseCallOnce(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
@@ -232,6 +241,9 @@ TfLiteStatus ParseLogSoftmax(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
+TfLiteStatus ParseLSTM(const Operator* op, ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
+
TfLiteStatus ParseMaximum(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
@@ -379,6 +391,9 @@ TfLiteStatus ParseVarHandle(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
+TfLiteStatus ParseWhile(const Operator* op, ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
+
TfLiteStatus ParseZerosLike(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h
index 0671ce73..ab0c8f96 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/portable_tensor_utils.h
@@ -60,9 +60,8 @@ void VectorBatchVectorAdd(const T* vector, int v_size, int n_batch,
// Cwise product of two vectors.
template
-inline void VectorVectorCwiseProduct(const T* __restrict__ vector1,
- const T* __restrict__ vector2, int v_size,
- T* __restrict__ result) {
+inline void VectorVectorCwiseProduct(const T* vector1, const T* vector2,
+ int v_size, T* result) {
for (int v = 0; v < v_size; v++) {
*result++ = *vector1++ * *vector2++;
}
@@ -117,6 +116,367 @@ void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch,
}
}
+// Checks if all entries of vector are zero for float.
+bool IsZeroVector(const float* vector, int v_size);
+
+// Checks if all entries of vector are zero for int8.
+bool IsZeroVector(const int8_t* vector, int v_size);
+
+// Quantizes a buffer of floating point values using a symmetric quantization
+// (i.e. linear quantization without an offset) to 8-bit signed integers.
+// It also outputs the range (min, max) of the floating point buffer, and the
+// scaling factor used to quantize the values.
+void SymmetricQuantizeFloats(const float* values, const int size,
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor);
+
+// Quantizes a buffer of floating point values using a symmetric quantization
+// (i.e. linear quantization without an offset) to 8-bit signed integers.
+// It uses the range (min, max) provided to the function to calculate the
+// appropriate scaling factor to quantize the values.
+void SymmetricQuantizeFloats(const float* values, const int size,
+ int8_t* quantized_values, float min_value,
+ float max_value, float* scaling_factor);
+
+void AsymmetricQuantizeFloats(const float* values, const int size,
+ int8_t* quantized_values, float* scaling_factor,
+ int32_t* offset);
+
+// Helper function to quantize floats.
+// float_data_ptr input float vectors
+// n_batch number of input vectors
+// n_data size of a single input vector
+// quantized_data_ptr (out) vector with quantized data
+// scaling_factors (out) scaling factors (one per vector)
+// zero_points (out) zero points (one per vector)
+// do_asymmetric controls if the quantization should be asymmetric.
+inline void BatchQuantizeFloats(const float* float_data_ptr, int n_batch,
+ int n_data, int8_t* quantized_data_ptr,
+ float* scaling_factors, int32_t* zero_points,
+ bool do_asymmetric) {
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_data;
+ if (do_asymmetric) {
+ tensor_utils::AsymmetricQuantizeFloats(
+ float_data_ptr + offset, n_data, quantized_data_ptr + offset,
+ &scaling_factors[b], &zero_points[b]);
+ } else {
+ float unused_min, unused_max;
+ tensor_utils::SymmetricQuantizeFloats(
+ float_data_ptr + offset, n_data, quantized_data_ptr + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ }
+}
+
+// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
+// dimension composed by input vectors independent from each other). The result
+// of the multiplication is accumulated to the passed result buffer.
+// More specifically, for a matrix M of shape [n, i] and a batched-vector
+// of shape [i, batch] it will first compute the product of shape [n, batch].
+// This product will be accumulated to the result buffer.
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result);
+
+// Same as the function above, but the matrix is a sparse tensor with block
+// pattern 1x4.
+// This function assumes that m_cols is a multiple of the block size (4 in this
+// case) so that there's no incomplete block.
+void SparseMatrixBatchVectorMultiplyAccumulate1x4(
+ const float* __restrict__ matrix, const int32_t* __restrict__ segments,
+ const int32_t* __restrict__ indices, int m_rows, int m_cols,
+ const float* __restrict__ vector, int n_batch, float* __restrict__ result);
+
+// Same as the function above, but the matrix is stored in block compressed
+// sparse row format with block pattern 1x16 which consists of two arrays:
+// 1. A matrix array stores non-zero blocks of the matrix in row major.
+// 2. A ledger array stores nrows groups, one group per row. Each group starts
+// with an integer representing the number of non-zero blocks for the
+// corresponding row and follows with column indexes of the first element
+// of each non-zero block.
+// This function assumes that
+// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
+// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
+void SparseMatrixBatchVectorMultiplyAccumulate(
+ const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
+ int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
+ float* __restrict__ result);
+
+// Same as the function above, but for values quantized using symmetric
+// quantization (e.g. by calling SymmetricQuantizeFloats).
+// The passed scaling factors is a buffer of the quantization scaling factors
+// that will be used to dequentize the products into the final result buffer.
+// These scaling factors are the multiplication of the matrix scaling factor
+// by the vector's scaling factor, one per batch (i.e. this allows quantizing
+// each batch in the batch-vector matrix independently).
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors,
+ const float* __restrict__ scaling_factors, int n_batch,
+ float* __restrict__ result);
+
+// Same as the function above except that vector values
+// are quantized with asymmetric quantization per-batch and the matrix
+// is quantized per row.
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors,
+ const float* __restrict__ scaling_factors, int n_batch,
+ float* __restrict__ result, const float* __restrict__ per_channel_scale,
+ const int32_t* __restrict__ input_offset);
+
+// Same as the function above, but the matrix is a sparse tensor with block
+// pattern 1x16.
+// This function assumes that m_cols is a multiple of the block size (16 in this
+// case) so that there's no incomplete block. Also, it assumes all offsets of
+// input, output and filter are zero.
+void SparseMatrixBatchVectorMultiplyAccumulate1x16(
+ const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
+ const int32_t* __restrict__ indices, int m_rows, int m_cols,
+ const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
+ int n_batch, const int32_t input_offset, const int32_t output_multiplier,
+ const int32_t output_shift, const int32_t output_offset,
+ const int32_t output_activation_min, const int32_t output_activation_max,
+ int8_t* __restrict__ result);
+
+// Same as the function above, but the matrix is stored in block compressed
+// sparse row format with block pattern 1x16 which consists of two arrays:
+// 1. A matrix array stores non-zero blocks of the matrix in row major.
+// 2. A ledger array stores nrows groups, one group per row. Each group starts
+// with an integer representing the number of non-zero blocks for the
+// corresponding row followed by column index of the first element of
+// each non-zero block.
+// This function assumes that
+// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
+// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
+void SparseMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
+ const int m_rows, const int m_cols, const int8_t* __restrict__ vectors,
+ const float* __restrict__ scaling_factors, int n_batch,
+ float* __restrict__ result);
+
+// Same as the above 8, 8, 8 integer matmul except for the presence of zero
+// point and non-accumulative.
+// TODO(b/148688698): remove this function by folding zero point calculation in
+// prepare() function.
+void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
+ const int8_t* input_to_gate_weights,
+ int32_t input_to_gate_effective_scale_a,
+ int32_t input_to_gate_effective_scale_b,
+ int32_t n_batch, int32_t n_input, int32_t n_cell,
+ int8_t* gate_output, int8_t gate_output_zp);
+
+// Same as above but has 16 bit and 8 bit input and 8 bit output.
+// Used in projection when hidden is 16bit.
+void MatrixBatchVectorMultiply(const int16_t* hidden,
+ const int8_t* hidden_to_output_weights,
+ int32_t proj_effective_scale_a,
+ int32_t proj_effective_scale_b,
+ const int32_t* gate_bias, int32_t n_batch,
+ int32_t n_hidden, int32_t n_output,
+ int32_t output_zp, int8_t* proj_output);
+
+// Apply Layer Normalization (https://arxiv.org/abs/1607.06450) to a Quantized
+// vector.
+// Parameters:
+// - input: batch vector of size n_batch * n_input; 16 bit.
+// - layer_norm_weights: the quantized layer normalization weights.
+// - bias: the bias for the layer normalization.
+// - layer_norm_scale_a: multiplier for scale factor.
+// - layer_norm_scale_b: shift for scale factor.
+// - variance_limit: the guard to make sure the inverse does not overflow.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - output: the 16 bit output
+void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
+ const int32_t* bias, int32_t layer_norm_scale_a,
+ int32_t layer_norm_scale_b, int32_t variance_limit,
+ int n_batch, int n_input, int16_t* output);
+
+// Same as above but the internal calculation is done in float.
+void ApplyLayerNormFloat(const int16_t* input,
+ const int16_t* layer_norm_weights,
+ int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
+ const int32_t* bias, int n_batch, int n_input,
+ int16_t* output);
+
+// Apply Sigmoid to a quantized vector.
+// Parameters:
+// - input: batch vector of size n_batch * n_input; 16 bit.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - output: the 16 bit output
+// The input is in Q3.12 format and the output is in Q0.15 format.
+void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
+ int16_t* output);
+
+// Same as above but the internal calcualtion is float.
+void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
+ int16_t* output);
+
+// Apply Tanh to a quantized vector.
+// Parameters:
+// - integer_bits: the integer bits of the input.
+// Currently supports 0, 1, 2, 3, 4, 5, 6.
+// - input: batch vector of size n_batch * n_input; 16 bit.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - output: the 16 bit output
+// The input is in Qm.15-m format and the output is in Q0.15 format.
+void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
+ int32_t n_input, int16_t* output);
+
+// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
+// - Input has 2^(integer_bits) as scale.
+// - Output has Q0.15 as scale.
+void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
+ int32_t integer_bits, int16_t* output);
+
+// Element-wise multiplication of two quantized vectors.
+// Parameters:
+// - input_1: batch vector of size n_batch * n_input; 16 bit.
+// - input_2: batch vector of size n_batch * n_input; 16 bit.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - shift: the shift needed to produce the output.
+// - output: the 16 bit output of size n_batch * n_input.
+// Output does not need to be initialized.
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int shift, int16_t* output);
+
+// Element-wise multiplication of two quantized vectors.
+// Parameters:
+// - input_1: batch vector of size n_batch * n_input; 16 bit.
+// - input_2: batch vector of size n_batch * n_input; 16 bit.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - shift: the shift needed to produce the output.
+// - output: the 8 bit output of size n_batch * n_input.
+// Output does not need to be initialized.
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int shift, int8_t* output);
+
+// Element-wise multiplication of two quantized vectors with rescaling.
+// Parameters:
+// - input_1: batch vector of size n_batch * n_input; 16 bit.
+// - input_2: batch vector of size n_batch * n_input; 16 bit.
+// - multiplier: the multiplier part of scale.
+// - shift: the shift part of scale.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - output: the 8 bit output of size n_batch * n_input.
+// - output_zp: the zero point of output.
+// Output does not need to be initialized.
+// Multiplier ("m") and shift ("s") are connected to scale ("s") with s = m *
+// 2^(s - 31).
+void CwiseMul(const int16_t* input_1, const int16_t* input_2,
+ int32_t multiplier, int32_t shift, int32_t n_batch,
+ int32_t n_input, int32_t output_zp, int8_t* output);
+
+// Element-wise saturating addition of two quantized vectors without rescaling.
+// Parameters:
+// - input_1: batch vector of size n_batch * n_input; 16 bit.
+// - input_2: batch vector of size n_batch * n_input; 16 bit.
+// - n_batch: the number of batches.
+// - n_input: the size for input and output.
+// - output: the 8 bit output of size n_batch * n_input.
+// Output does not need to be initialized.
+void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int16_t* output);
+
+// Element-wise in-place clipping of a vector. Overloaded for float, int16_t,
+// int8_t. Parameters:
+// - vector: vector of size v_size.
+// - v_size: the size of the vector.
+// - clipping_value: the value used for clipping.
+void CwiseClipping(float* vector, const int v_size, const float clipping_value);
+void CwiseClipping(int16_t* vector, const int v_size,
+ const int16_t clipping_value);
+void CwiseClipping(int8_t* vector, const int v_size,
+ const int8_t clipping_value);
+
+// Dot product of two vectors.
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors of size n_batch * v_size:
+// vector1 = [x_1_1, x_1_2, ..., x_1_vsize,
+// x_2_1, x_2_2, ..., x_2_vsize,
+// ...
+// x_nbatch_1,..., x_nbatch_vsize]
+// vector2 = [y_1_1, y_1_2, ..., y_1_vsize,
+// y_2_1, y_2_2, ..., y_2_vsize,
+// ...
+// y_nbatch_1,..., y_nbatch_vsize]
+// Then result will be a vector of n_batch size starting from 'result':
+// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize,
+// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
+// ...
+// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
+template
+inline void BatchVectorBatchVectorDotProduct(const T* vector1, const T* vector2,
+ int v_size, int n_batch,
+ T* result) {
+ for (int b = 0; b < n_batch; b++) {
+ result[b] = VectorVectorDotProduct(vector1, vector2, v_size);
+ vector1 += v_size;
+ vector2 += v_size;
+ }
+}
+
+// Same as above but input is 16bit and output is 32bit.
+void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
+ const int16_t* vector2, int v_size,
+ int n_batch, int32_t* result);
+
+// Same as above, but inputs are 16bit integer and output is 16bit integer.
+void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
+ const int16_t* batch_vector,
+ int n_batch, int32_t multiplier,
+ int shift, int16_t* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void Sub1Vector(const float* vector, int v_size, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG) for int16 input.
+// "vector" has range [0, 32767] because it is the output of sigmoid function.
+void Sub1Vector(const int16_t* vector, int v_size, int16_t* result);
+
+// Multiply all elements of vector with a scalar.
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+
+// Same as above but input/output is 32 bit integer.
+void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
+ int output_size, int reduction_size);
+
+// Same as above but input is 8 bit integer.
+void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
+ int output_size, int reduction_size);
+
+// Layer norm for each batch.
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch);
+
+// Saturate Add with rescale on both inputs.
+void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
+ const int8_t* recurrent, int8_t recurrent_zp,
+ int32_t input_effective_scale_a,
+ int32_t input_effective_scale_b,
+ int32_t recurrent_effective_scale_a,
+ int32_t recurrent_effective_scale_b, int32_t n_batch,
+ int32_t n_cell, int16_t* output);
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/batch_matmul.h
index 5fe01da2..767ad6ab 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/batch_matmul.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/batch_matmul.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
-#include "tensorflow/lite/kernels/internal/tensor_utils_common.h"
+#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/broadcast_args.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/broadcast_args.h
new file mode 100644
index 00000000..d93c316d
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/broadcast_args.h
@@ -0,0 +1,56 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_ARGS_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_ARGS_H_
+
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+template
+void BroadcastArgs(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Gets data at the backward index i of the shape tensor. Returns 1 if the
+ // index is out of range.
+ auto get_shape_data = [](const RuntimeShape& shape, const T* data,
+ int backward_idx) -> T {
+ int forward_idx = shape.FlatSize() - 1 - backward_idx;
+ if (forward_idx < 0) return 1;
+ return data[forward_idx];
+ };
+
+ int output_num_elements = output_shape.FlatSize();
+ for (int i = 0; i < output_num_elements; ++i) {
+ int backward_i = output_num_elements - 1 - i;
+ int shape1_i = get_shape_data(input1_shape, input1_data, i);
+ int shape2_i = get_shape_data(input2_shape, input2_data, i);
+ if (shape1_i == 1) {
+ output_data[backward_i] = shape2_i;
+ } else if (shape2_i == 1) {
+ output_data[backward_i] = shape1_i;
+ } else {
+ TFLITE_CHECK_EQ(shape1_i, shape2_i);
+ output_data[backward_i] = shape1_i;
+ }
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_ARGS_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/broadcast_to.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/broadcast_to.h
new file mode 100644
index 00000000..f106b2b5
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/broadcast_to.h
@@ -0,0 +1,97 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
+
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace reference_ops {
+template
+void BroadcastImpl(const NdArrayDesc& input_desc, const char* input_data,
+ const NdArrayDesc& output_desc, char* output_data,
+ int indexes[N], int dim, const int last_broadcasting_dim,
+ const int type_size) {
+ // Copy data from input to output.
+ if (dim == last_broadcasting_dim) {
+ int copy_size = output_desc.strides[dim] * type_size;
+ const char* data_src =
+ input_data + SubscriptToIndex(input_desc, indexes) * type_size;
+ char* data_dst =
+ output_data + SubscriptToIndex(output_desc, indexes) * type_size;
+ for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
+ memcpy(data_dst, data_src, copy_size);
+ }
+ return;
+ }
+
+ // Recursive call to find the next broadcasting.
+ for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim];
+ ++indexes[dim]) {
+ BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes,
+ dim + 1, last_broadcasting_dim, type_size);
+ }
+
+ // Duplicate data in output tensor.
+ indexes[dim] = 0;
+ if (input_desc.extents[dim] != output_desc.extents[dim]) {
+ int copy_size = output_desc.strides[dim] * type_size;
+ char* data_src =
+ output_data + SubscriptToIndex(output_desc, indexes) * type_size;
+ char* data_dst = data_src + copy_size;
+ for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
+ memcpy(data_dst, data_src, copy_size);
+ }
+ }
+}
+
+template
+inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
+ const char* input_data,
+ const RuntimeShape& unextended_output_shape,
+ char* output_data, TfLiteType data_type) {
+ NdArrayDesc input_desc;
+ NdArrayDesc output_desc;
+ CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
+ &input_desc);
+ CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
+ &output_desc);
+
+ // Get the last dimension has broadcasting. At this dimension, the data is
+ // copied from input tensor to output tensor.
+ int last_broadcast_dim = -1;
+ for (int i = N - 1; i >= 0; --i) {
+ if (input_desc.extents[i] != output_desc.extents[i]) {
+ last_broadcast_dim = i;
+ break;
+ }
+ }
+
+ // If non-broadcasting, just copy data from input to output tensor.
+ if (last_broadcast_dim == -1) {
+ memcpy(output_data, input_data,
+ unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
+ return;
+ }
+
+ // Broadcasting using memcpy.
+ int indexes[N] = {0};
+ BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes, 0,
+ last_broadcast_dim, TfLiteTypeGetSize(data_type));
+}
+} // namespace reference_ops
+} // namespace tflite
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h
index 5a6369d8..ac5f04f6 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/conv.h
@@ -43,7 +43,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
(void)im2col_data; // only used in optimized code.
(void)im2col_shape; // only used in optimized code.
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int input_depth = input_shape.Dims(3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
@@ -52,14 +52,20 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
for (int out_x = 0; out_x < output_width; ++out_x) {
const int in_x_origin = (out_x * stride_width) - pad_width;
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ auto group = out_channel / filters_per_group;
float total = 0.f;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y = in_y_origin + dilation_height_factor * filter_y;
@@ -74,10 +80,11 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
if (!is_point_inside_image) {
continue;
}
-
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- float input_value = input_data[Offset(input_shape, batch, in_y,
- in_x, in_channel)];
+ for (int in_channel = 0; in_channel < filter_input_depth;
+ ++in_channel) {
+ float input_value =
+ input_data[Offset(input_shape, batch, in_y, in_x,
+ in_channel + group * filter_input_depth)];
float filter_value = filter_data[Offset(
filter_shape, out_channel, filter_y, filter_x, in_channel)];
total += (input_value * filter_value);
@@ -126,7 +133,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int input_depth = input_shape.Dims(3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
@@ -135,6 +142,10 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
@@ -143,6 +154,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
for (int out_x = 0; out_x < output_width; ++out_x) {
const int in_x_origin = (out_x * stride_width) - pad_width;
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ auto group = out_channel / filters_per_group;
int32_t acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y = in_y_origin + dilation_height_factor * filter_y;
@@ -158,9 +170,11 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
continue;
}
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- int32_t input_val = input_data[Offset(input_shape, batch, in_y,
- in_x, in_channel)];
+ for (int in_channel = 0; in_channel < filter_input_depth;
+ ++in_channel) {
+ int32_t input_val =
+ input_data[Offset(input_shape, batch, in_y, in_x,
+ in_channel + group * filter_input_depth)];
int32_t filter_val = filter_data[Offset(
filter_shape, out_channel, filter_y, filter_x, in_channel)];
acc +=
@@ -206,7 +220,7 @@ inline void HybridConvPerChannel(
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int input_depth = input_shape.Dims(3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
@@ -215,18 +229,24 @@ inline void HybridConvPerChannel(
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ auto group = out_channel / filters_per_group;
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
int32_t acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ for (int in_channel = 0; in_channel < filter_input_depth;
+ ++in_channel) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
const int in_y =
in_y_origin + dilation_height_factor * filter_y;
@@ -235,7 +255,8 @@ inline void HybridConvPerChannel(
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
int32_t input_val = input_data[Offset(
- input_shape, batch, in_y, in_x, in_channel)];
+ input_shape, batch, in_y, in_x,
+ in_channel + group * filter_input_depth)];
int32_t filter_val =
filter_data[Offset(filter_shape, out_channel, filter_y,
filter_x, in_channel)];
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
index 3a4164d3..3f869a3a 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
@@ -48,7 +48,7 @@ inline void ConvPerChannel(
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int input_depth = input_shape.Dims(3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
@@ -59,6 +59,10 @@ inline void ConvPerChannel(
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
@@ -67,6 +71,7 @@ inline void ConvPerChannel(
for (int out_x = 0; out_x < output_width; ++out_x) {
const int in_x_origin = (out_x * stride_width) - pad_width;
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ auto group = out_channel / filters_per_group;
int32_t acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y = in_y_origin + dilation_height_factor * filter_y;
@@ -82,9 +87,11 @@ inline void ConvPerChannel(
continue;
}
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- int32_t input_val = input_data[Offset(input_shape, batch, in_y,
- in_x, in_channel)];
+ for (int in_channel = 0; in_channel < filter_input_depth;
+ ++in_channel) {
+ int32_t input_val =
+ input_data[Offset(input_shape, batch, in_y, in_x,
+ in_channel + group * filter_input_depth)];
int32_t filter_val = filter_data[Offset(
filter_shape, out_channel, filter_y, filter_x, in_channel)];
// Accumulate with 32 bits accumulator.
@@ -126,12 +133,13 @@ inline void ConvPerChannel(
// Fixed-point per-channel-quantization convolution reference kernel.
// 16-bit data and 8-bit filter
+template
inline void ConvPerChannel(
const ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift, const RuntimeShape& input_shape,
const int16_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
- const std::int64_t* bias_data, const RuntimeShape& output_shape,
+ const AccumScalar* bias_data, const RuntimeShape& output_shape,
int16_t* output_data) {
// Get parameters.
const int stride_width = params.stride_width;
@@ -151,7 +159,7 @@ inline void ConvPerChannel(
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int input_depth = input_shape.Dims(3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
@@ -162,6 +170,10 @@ inline void ConvPerChannel(
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
@@ -170,7 +182,8 @@ inline void ConvPerChannel(
for (int out_x = 0; out_x < output_width; ++out_x) {
const int in_x_origin = (out_x * stride_width) - pad_width;
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- std::int64_t acc = 0;
+ auto group = out_channel / filters_per_group;
+ AccumScalar acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y = in_y_origin + dilation_height_factor * filter_y;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
@@ -185,9 +198,11 @@ inline void ConvPerChannel(
continue;
}
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- int32_t input_val = input_data[Offset(input_shape, batch, in_y,
- in_x, in_channel)];
+ for (int in_channel = 0; in_channel < filter_input_depth;
+ ++in_channel) {
+ int32_t input_val =
+ input_data[Offset(input_shape, batch, in_y, in_x,
+ in_channel + group * filter_input_depth)];
int32_t filter_val = filter_data[Offset(
filter_shape, out_channel, filter_y, filter_x, in_channel)];
// Accumulate with 64 bits accumulator.
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h
index 1a469fa9..42920d16 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h
@@ -34,12 +34,13 @@ inline void FullyConnected(
const int32_t output_activation_min = params.quantized_activation_min;
const int32_t output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int filter_dim_count = filter_shape.DimensionsCount();
- const int batches = output_shape.Dims(0);
- const int output_depth = output_shape.Dims(1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = output_shape.Dims(output_dim_count - 1);
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
for (int b = 0; b < batches; ++b) {
@@ -62,11 +63,12 @@ inline void FullyConnected(
}
}
+template
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const int16_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
- const int64_t* bias_data, const RuntimeShape& output_shape,
+ const AccumScalar* bias_data, const RuntimeShape& output_shape,
int16_t* output_data) {
const int32_t filter_offset = params.weights_offset;
const int32_t output_multiplier = params.output_multiplier;
@@ -85,7 +87,7 @@ inline void FullyConnected(
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
for (int b = 0; b < batches; ++b) {
for (int out_c = 0; out_c < output_depth; ++out_c) {
- int64_t acc = 0;
+ AccumScalar acc = 0;
for (int d = 0; d < accum_depth; ++d) {
int32_t input_val = input_data[b * accum_depth + d];
int32_t filter_val = filter_data[out_c * accum_depth + d];
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h
index 284c0f21..3397f869 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h
@@ -119,15 +119,16 @@ inline void TransposeConv(
}
}
-// int16_t input (zero_point=0), int8_t filter, int64 accumulator
+// int16_t input (zero_point=0), int8_t filter, int32 or int64 accumulator
+template
inline void TransposeConv(
const ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift, const RuntimeShape& input_shape,
const int16_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
- const std::int64_t* bias_data, const RuntimeShape& output_shape,
+ const Scalar* bias_data, const RuntimeShape& output_shape,
int16_t* output_data, const RuntimeShape& im2col_shape, int8_t* im2col_data,
- std::int64_t* scratch_buffer) {
+ Scalar* scratch_buffer) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int pad_width = params.padding_values.width;
@@ -157,7 +158,7 @@ inline void TransposeConv(
const int num_elements = output_shape.FlatSize();
// We need to initialize scratch_buffer to all 0s, as we apply the same
// 'scatter' based trick as in float version.
- memset(scratch_buffer, 0, num_elements * sizeof(std::int64_t));
+ memset(scratch_buffer, 0, num_elements * sizeof(Scalar));
// Loop through input elements one at a time.
for (int batch = 0; batch < batches; ++batch) {
@@ -198,8 +199,8 @@ inline void TransposeConv(
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- std::int64_t acc = scratch_buffer[Offset(output_shape, batch, out_y,
- out_x, out_channel)];
+ Scalar acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
+ out_channel)];
if (bias_data) {
acc += bias_data[out_channel];
}
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/lstm_cell.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/lstm_cell.h
new file mode 100644
index 00000000..17b113eb
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/lstm_cell.h
@@ -0,0 +1,422 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
+
+#include
+#include
+#include
+
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/reference/concatenation.h"
+#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches =
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
+ const int height =
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
+ const int width =
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ const int intern_activ_depth =
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+
+ // Concatenate prev_activ and input data together
+ float const* concat_input_arrays_data[2] = {input_data, prev_activ_data};
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape, concat_temp_data);
+
+ // Fully connected
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits::lowest();
+ fc_params.float_activation_max = std::numeric_limits::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
+
+ // Memory state update (the LSTM "guts")
+ for (int b = 0; b < batches; ++b) {
+ for (int w = 0; w < width; ++w) {
+ for (int h = 0; h < height; ++h) {
+ for (int c = 0; c < output_depth; ++c) {
+ const float input_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 0 * output_depth + c)]));
+ const float new_input = std::tanh(activ_temp_data[Offset(
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
+ const float forget_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 2 * output_depth + c)]));
+ const float output_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 3 * output_depth + c)]));
+ const float new_state =
+ input_gate * new_input +
+ forget_gate *
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
+ output_gate * std::tanh(new_state);
+ }
+ }
+ }
+ }
+}
+
+// Quantized LSTM cell implementation.
+// The quantization of the input, output arrays is as follows:
+// - The input activations are quantized as uint8 on the interval
+// [-1, 127/128].
+// The rationale for that is that is the natural interval for output
+// activations (see next point) and these need to be concatenated together.
+// We could accommodate different ranges by re-scaling, but we empirically
+// found that setting the input activations range to be [-1, 127/128] in the
+// first place, removing the need for re-scaling, greatly improves accuracy.
+// - The output activations are quantized as uint8 on the interval
+// [-1, 127/128].
+// The rationale for that is that the definition of a LSTM cell makes them
+// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
+// makes for simpler, more accurate fixed-point arithmetic.
+// - The output-at-previous-timestep state array is obviously quantized as
+// the output activations.
+// - The internal LSTM memory (not the output-at-previous-timestep, the other
+// internal state array) is int16-quantized and may use any power-of-two,
+// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
+// StateIntegerBits below, see the below discussion of that template
+// parameter ("The StateIntegerBits template parameter").
+// - The output of the internal fully-connected node is int16-quantized
+// on the interval [-8, 8 * 32767/32768], the rationale for which is
+// explained just below ("Why [-8, 8] for fully-connected output?").
+//
+//
+// === The StateIntegerBits template parameter ===
+//
+// The StateIntegerBits template parameter controls the fixed-point format used
+// to represent the internal memory of the LSTM cell (not the
+// output-at-previous-timestep, the other internal state array). It's currently
+// a template parameter so that the model can control that. The most typical
+// value for StateIntegerBits is 4. Other plausible values are anywhere between
+// 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
+// and drop that template parameter. The reason why it can't be a runtime
+// parameter is that this controls the fixed-point format used, i.e. we need to
+// generate actually different code based on it. In particular, we generate code
+// for a fixed-point tanh() implementation for that format, which internally
+// uses a fixed-point exp() implementation, which internally uses a
+// barrel-shifter with a number of steps that depends on StateIntegerBits.
+// Another consequence of that is that a higher value of StateIntegerBits
+// results in a more expensive implementation (more barrel shifter steps
+// needed).
+//
+//
+// === Why [-8, 8] for fully-connected output? ===
+//
+// This array is only fed to Logistic and Tanh functions, for which
+// the quantized implementation will want to use fixed-point arithmetic,
+// requiring a power-of-two representation interval. Thus, we should right
+// away quantize this array to a power-of-two interval; otherwise,
+// implementation will need to rescale that, losing any benefit that a tighter
+// representation interval might otherwise yield, while introducing some
+// numerical error and computational overhead.
+//
+// Now, Logistic and Tanh
+// are nearly constant (nearly equal to their horizontal asymptotes)
+// outside of a small bounded interval around 0:
+//
+// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
+// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
+// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
+//
+// From this, we see that clamping to [-4, 4] would be too inaccurate
+// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
+// while clamping to [-16, 16] would make no difference even in float32.
+// However, for a fixed-point implementation in 16-bit integers, using 5
+// integer bits to represent the [-16, 16] range would leave only 11
+// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
+// representable values. Notice that is higher than the
+// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
+// Using [-8, 8] thus seems like the better compromise overall, enjoying
+// an increment of 2.4e-4 between representable values and a worst-case
+// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
+// [-16, 16].
+//
+// Moreover, all other things being equal, it is nice to choose the narrower
+// representation range, as that makes the implementation of fixed-point
+// math functions a little cheaper (each integer bit requires an additional
+// barrel-shifter atep in the implementation of exp(-x)). That is further
+// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
+// sense for 32-bit float or 32-bit fixed-point quantization, but we are
+// aiming for 16-bit fixed-point quantization of these internal nodes here.
+//
+template
+inline void LstmCell(const LstmCellParams& params,
+ const RuntimeShape& unextended_input_shape,
+ const uint8_t* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8_t* prev_activ_data_uint8,
+ const RuntimeShape& weights_shape,
+ const uint8_t* weights_data_uint8,
+ const RuntimeShape& unextended_bias_shape,
+ const int32_t* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16_t* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16_t* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8_t* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8_t* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16_t* activ_temp_data_int16, void* gemmlowp_context) {
+ (void)gemmlowp_context; // only used in optimized code.
+ int32_t weights_zero_point = params.weights_zero_point;
+ int32_t accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ // Gather dimensions information, and perform consistency checks.
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ const int intern_activ_depth =
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
+ const int fc_output_depth =
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
+
+ // Depth-concatenate prev_activ and input data together.
+ uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
+ prev_activ_data_uint8};
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
+
+ // Implementation of the fully connected node inside the LSTM cell.
+ // The operands are 8-bit integers, the accumulators are internally 32bit
+ // integers, and the output is 16-bit fixed-point with 3 integer bits so
+ // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
+ // is explained in the function comment above.
+ for (int b = 0; b < fc_batches; ++b) {
+ for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32_t accum = bias_data_int32[out_c];
+ // Accumulation loop.
+ for (int d = 0; d < fc_accum_depth; ++d) {
+ int16_t input_val =
+ concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
+ int16_t weights_val =
+ weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
+ accum += input_val * weights_val;
+ }
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, using 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ accum =
+ MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
+ // Saturate, cast to int16, and store to the temporary activations array.
+ accum = std::max(-32768, std::min(32767, accum));
+ activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
+ }
+ }
+
+ // Rest of the LSTM cell: tanh and logistic math functions, and some adds
+ // and muls, all done in 16-bit fixed-point.
+ for (int b = 0; b < outer_size; ++b) {
+ for (int c = 0; c < output_depth; ++c) {
+ // Define the fixed-point data types that we will use here. All use
+ // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
+ // They only differ by the number of integral vs. fractional bits,
+ // determining the range of values that they can represent.
+ //
+ // F0 uses 0 integer bits, range [-1, 1].
+ // This is the return type of math functions such as tanh, logistic,
+ // whose range is in [-1, 1].
+ using F0 = gemmlowp::FixedPoint;
+ // F3 uses 3 integer bits, range [-8, 8].
+ // This is the range of the previous fully-connected node's output,
+ // which is our input here.
+ using F3 = gemmlowp::FixedPoint;
+ // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
+ // 2^StateIntegerBits]. It's used to represent the internal state, whose
+ // number of integer bits is currently dictated by the model. See comment
+ // on the StateIntegerBits template parameter above.
+ using FS = gemmlowp::FixedPoint;
+ // Implementation of input gate, using fixed-point logistic function.
+ F3 input_gate_input = F3::FromRaw(
+ activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
+ F0 input_gate_output = gemmlowp::logistic(input_gate_input);
+ // Implementation of input modulation gate, using fixed-point tanh
+ // function.
+ F3 input_modulation_gate_input = F3::FromRaw(
+ activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
+ F0 input_modulation_gate_output =
+ gemmlowp::tanh(input_modulation_gate_input);
+ // Implementation of forget gate, using fixed-point logistic function.
+ F3 forget_gate_input = F3::FromRaw(
+ activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
+ F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
+ // Implementation of output gate, using fixed-point logistic function.
+ F3 output_gate_input = F3::FromRaw(
+ activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
+ F0 output_gate_output = gemmlowp::logistic(output_gate_input);
+ // Implementation of internal multiplication nodes, still in fixed-point.
+ F0 input_times_input_modulation =
+ input_gate_output * input_modulation_gate_output;
+ FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
+ FS prev_state_times_forget_state = forget_gate_output * prev_state;
+ // Implementation of internal addition node, saturating.
+ FS new_state = gemmlowp::SaturatingAdd(
+ gemmlowp::Rescale(input_times_input_modulation),
+ prev_state_times_forget_state);
+ // Implementation of last internal Tanh node, still in fixed-point.
+ // Since a Tanh fixed-point implementation is specialized for a given
+ // number or integer bits, and each specialization can have a substantial
+ // code size, and we already used above a Tanh on an input with 3 integer
+ // bits, and per the table in the above function comment there is no
+ // significant accuracy to be lost by clamping to [-8, +8] for a
+ // 3-integer-bits representation, let us just do that. This helps people
+ // porting this to targets where code footprint must be minimized.
+ F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
+ F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
+ // Store the new internal state back to memory, as 16-bit integers.
+ // Note: here we store the original value with StateIntegerBits, not
+ // the rescaled 3-integer-bits value fed to tanh.
+ output_state_data_int16[b * output_depth + c] = new_state.raw();
+ // Down-scale the output activations to 8-bit integers, saturating,
+ // and store back to memory.
+ int16_t rescaled_output_activ =
+ gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
+ int16_t clamped_output_activ = std::max(
+ -128, std::min(127, rescaled_output_activ));
+ output_activ_data_uint8[b * output_depth + c] =
+ 128 + clamped_output_activ;
+ }
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
index 4cc51cb4..4684be64 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -227,6 +227,41 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
}
}
+void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
+ const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
+ const int32_t* __restrict__ indices, int m_rows, int m_cols,
+ const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
+ int n_batch, const int32_t input_offset, const int32_t output_multiplier,
+ const int32_t output_shift, const int32_t output_offset,
+ const int32_t output_activation_min, const int32_t output_activation_max,
+ int8_t* __restrict__ result) {
+ const int kBlockSize = 16;
+ TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
+ for (int batch = 0; batch < n_batch; ++batch) {
+ const int8_t* matrix_ptr = matrix;
+ for (int row = 0; row < m_rows; ++row) {
+ int32_t dot_prod = 0;
+ const int8_t* vector_in_batch = vector + batch * m_cols;
+ for (int i = segments[row]; i < segments[row + 1]; ++i) {
+ const int block_start_index = indices[i] * kBlockSize;
+ const int8_t* vector_block_in_batch_ptr =
+ vector_in_batch + block_start_index;
+ for (int c = 0; c < kBlockSize; c++) {
+ dot_prod += *matrix_ptr * *vector_block_in_batch_ptr++;
+ dot_prod += *matrix_ptr++ * input_offset;
+ }
+ }
+ const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0;
+ dot_prod = MultiplyByQuantizedMultiplier(dot_prod + bias_value,
+ output_multiplier, output_shift);
+ dot_prod += output_offset;
+ result[batch * m_rows + row] =
+ static_cast(ActivationFunctionWithMinMax(
+ dot_prod, output_activation_min, output_activation_max));
+ }
+ }
+}
+
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
new file mode 100644
index 00000000..0416db09
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -0,0 +1,333 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+
+#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
+
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
+
+namespace tflite {
+namespace tensor_utils {
+
+// Check if all entries of a vector are zero for float.
+bool IsZeroVector(const float* vector, int v_size) {
+ return PortableIsZeroVector(vector, v_size);
+}
+
+// Check if all entries of a vector are zero for int8_t.
+bool IsZeroVector(const int8_t* vector, int v_size) {
+ return PortableIsZeroVector(vector, v_size);
+}
+
+void SymmetricQuantizeFloats(const float* values, const int size,
+ int8_t* quantized_values, float* min, float* max,
+ float* scaling_factor) {
+ PortableSymmetricQuantizeFloats(values, size, quantized_values, min, max,
+ scaling_factor);
+}
+
+void SymmetricQuantizeFloats(const float* values, const int size,
+ int8_t* quantized_values, float min_value,
+ float max_value, float* scaling_factor) {
+ PortableSymmetricQuantizeFloats(values, size, quantized_values, min_value,
+ max_value, scaling_factor);
+}
+
+void AsymmetricQuantizeFloats(const float* values, const int size,
+ int8_t* quantized_values, float* scaling_factor,
+ int32_t* offset) {
+ PortableAsymmetricQuantizeFloats(values, size, quantized_values,
+ scaling_factor, offset);
+}
+
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result) {
+ PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
+ n_batch, result);
+}
+
+void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
+ const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vector,
+ const float* scaling_factors,
+ int n_batch,
+ float* __restrict__ result) {
+ PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
+ scaling_factors, n_batch, result);
+}
+
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, const float* per_channel_scale,
+ const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
+ bool* compute_row_sums, CpuBackendContext* context) {
+ PortableMatrixBatchVectorMultiplyAccumulate(
+ matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
+ per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
+ context);
+}
+
+void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
+ const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vector,
+ const float* scaling_factors,
+ int n_batch, int32_t* scratch,
+ float* __restrict__ result,
+ CpuBackendContext* context) {
+ PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
+ scaling_factors, n_batch, result);
+}
+
+void SparseMatrixBatchVectorMultiplyAccumulate1x4(
+ const float* __restrict__ matrix, const int32_t* __restrict__ segments,
+ const int32_t* __restrict__ indices, int m_rows, int m_cols,
+ const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
+ PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
+ matrix, segments, indices, m_rows, m_cols, vector, n_batch, result);
+}
+
+void SparseMatrixBatchVectorMultiplyAccumulate(
+ const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
+ int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
+ float* __restrict__ result) {
+ PortableSparseMatrixBatchVectorMultiplyAccumulate(
+ matrix, ledger, m_rows, m_cols, vector, n_batch, result);
+}
+
+void SparseMatrixBatchVectorMultiplyAccumulate1x16(
+ const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
+ const int32_t* __restrict__ indices, int m_rows, int m_cols,
+ const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
+ int n_batch, const int32_t input_offset, const int32_t output_multiplier,
+ const int32_t output_shift, const int32_t output_offset,
+ const int32_t output_activation_min, const int32_t output_activation_max,
+
+ int8_t* __restrict__ result) {
+ PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
+ matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch,
+ input_offset, output_multiplier, output_shift, output_offset,
+ output_activation_min, output_activation_max, result);
+}
+
+void SparseMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
+ const int m_cols, const int8_t* __restrict__ vectors,
+ const float* scaling_factors, int n_batch, float* __restrict__ result) {
+ PortableSparseMatrixBatchVectorMultiplyAccumulate(
+ matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch,
+ result);
+}
+
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* input, const int32_t* bias,
+ const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
+ int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
+ int32_t* scratch, int16_t* output, CpuBackendContext* context) {
+ PortableMatrixBatchVectorMultiplyAccumulate(
+ input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
+ n_output, output_zp, scratch, output, context);
+}
+
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* input, const int32_t* bias,
+ const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
+ int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
+ int32_t* scratch, int8_t* output, CpuBackendContext* context) {
+ PortableMatrixBatchVectorMultiplyAccumulate(
+ input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
+ n_output, output_zp, scratch, output, context);
+}
+
+void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
+ int32_t n_row, int32_t n_col,
+ int32_t* output) {
+ PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output);
+}
+
+void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
+ const int8_t* input_to_gate_weights,
+ int32_t input_to_gate_effective_scale_a,
+ int32_t input_to_gate_effective_scale_b,
+ int32_t n_batch, int32_t n_input, int32_t n_cell,
+ int8_t* gate_output, int8_t gate_output_zp) {
+ PortableMatrixBatchVectorMultiply(
+ input, input_zeropoint, input_to_gate_weights,
+ input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
+ n_input, n_cell, gate_output, gate_output_zp);
+}
+
+void MatrixBatchVectorMultiply(const int16_t* hidden,
+ const int8_t* hidden_to_output_weights,
+ int32_t proj_effective_scale_a,
+ int32_t proj_effective_scale_b,
+ const int32_t* gate_bias, int32_t n_batch,
+ int32_t n_hidden, int32_t n_output,
+ int32_t output_zp, int8_t* proj_output) {
+ PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
+ proj_effective_scale_a,
+ proj_effective_scale_b, gate_bias, n_batch,
+ n_hidden, n_output, output_zp, proj_output);
+}
+
+void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
+ const int32_t* bias, int32_t layer_norm_scale_a,
+ int32_t layer_norm_scale_b, int32_t variance_limit,
+ int n_batch, int n_input, int16_t* output) {
+ PortableApplyLayerNorm(input, layer_norm_weights, bias, layer_norm_scale_a,
+ layer_norm_scale_b, variance_limit, n_batch, n_input,
+ output);
+}
+
+void ApplyLayerNormFloat(const int16_t* input,
+ const int16_t* layer_norm_weights,
+ int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
+ const int32_t* bias, int n_batch, int n_input,
+ int16_t* output) {
+ PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
+ layer_norm_scale_b, bias, n_batch, n_input,
+ output);
+}
+
+void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
+ int16_t* output) {
+ PortableApplySigmoid(input, n_batch, n_input, output);
+}
+
+void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
+ int16_t* output) {
+ PortableApplySigmoidFloat(input, n_batch, n_input, output);
+}
+
+void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
+ int32_t n_input, int16_t* output) {
+ PortableApplyTanh(integer_bits, input, n_batch, n_input, output);
+}
+
+void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
+ int32_t integer_bits, int16_t* output) {
+ PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
+}
+
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int shift, int16_t* output) {
+ PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
+}
+
+void CwiseMul(const int16_t* input_1, const int16_t* input_2,
+ int32_t multiplier, int32_t shift, int32_t n_batch,
+ int32_t n_input, int32_t output_zp, int8_t* output) {
+ PortableCwiseMul(input_1, input_2, multiplier, shift, n_batch, n_input,
+ output_zp, output);
+}
+
+void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int16_t* output) {
+ PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
+}
+
+void CwiseClipping(float* vector, const int v_size,
+ const float clipping_value) {
+ PortableCwiseClipping(vector, v_size, clipping_value);
+}
+
+void CwiseClipping(int16_t* vector, const int v_size,
+ const int16_t clipping_value) {
+ PortableCwiseClipping(vector, v_size, clipping_value);
+}
+
+void CwiseClipping(int8_t* vector, const int v_size,
+ const int8_t clipping_value) {
+ PortableCwiseClipping(vector, v_size, clipping_value);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
+ const int16_t* batch_vector,
+ int n_batch, int32_t multiplier,
+ int shift, int16_t* result) {
+ PortableVectorBatchVectorCwiseProductAccumulate(
+ vector, v_size, batch_vector, n_batch, multiplier, shift, result);
+}
+
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ return PortableVectorVectorDotProduct(vector1, vector2, v_size);
+}
+
+void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
+ const int16_t* vector2, int v_size,
+ int n_batch, int32_t* result) {
+ PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
+ result);
+}
+
+void Sub1Vector(const float* vector, int v_size, float* result) {
+ PortableSub1Vector(vector, v_size, result);
+}
+
+void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
+ PortableSub1Vector(vector, v_size, result);
+}
+
+// Multiply all elements of vector with a scalar.
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result) {
+ PortableVectorScalarMultiply(vector, v_size, scale, result);
+}
+
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ PortableReductionSumVector(input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
+ int output_size, int reduction_size) {
+ PortableReductionSumVector(input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
+ int output_size, int reduction_size) {
+ PortableReductionSumVector(input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
+}
+
+void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
+ const int8_t* recurrent, int8_t recurrent_zp,
+ int32_t input_effective_scale_a,
+ int32_t input_effective_scale_b,
+ int32_t recurrent_effective_scale_a,
+ int32_t recurrent_effective_scale_b, int32_t n_batch,
+ int32_t n_cell, int16_t* output) {
+ PortableTwoGateSaturatingAdd(
+ input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
+ input_effective_scale_b, recurrent_effective_scale_a,
+ recurrent_effective_scale_b, n_batch, n_cell, output);
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
index 1e411e16..6c404d5e 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
@@ -87,6 +87,15 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate(
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
float* __restrict__ result);
+void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
+ const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
+ const int32_t* __restrict__ indices, int m_rows, int m_cols,
+ const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
+ int n_batch, const int32_t input_offset, const int32_t output_multiplier,
+ const int32_t output_shift, const int32_t output_offset,
+ const int32_t output_activation_min, const int32_t output_activation_max,
+ int8_t* __restrict__ result);
+
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/sub.h b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/sub.h
index 3fa43ce9..d0ebc95a 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/sub.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/sub.h
@@ -273,6 +273,9 @@ void BroadcastQuantSubSlow(const ArithmeticParams& params,
const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
ruy::profiler::ScopeLabel label("BroadcastQuantSubSlow/T");
+ TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N);
+ TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N);
+ TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
NdArrayDesc desc1;
NdArrayDesc desc2;
NdArrayDesc output_desc;
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.cc b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.cc
index 75529296..10b37ed3 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.cc
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
@@ -466,10 +467,10 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
const int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
if (!(d1 == d2 || d1 == 1 || d2 == 1)) {
- context->ReportError(context,
- "Given shapes, %s and %s, are not broadcastable.",
- GetShapeDebugString(input1->dims).c_str(),
- GetShapeDebugString(input2->dims).c_str());
+ TF_LITE_KERNEL_LOG(context,
+ "Given shapes, %s and %s, are not broadcastable.",
+ GetShapeDebugString(input1->dims).c_str(),
+ GetShapeDebugString(input2->dims).c_str());
return kTfLiteError;
}
@@ -504,11 +505,11 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
if (min_value == 0) max_value = 0;
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
!(d3 == 1 || d3 == max_value)) {
- context->ReportError(
- context, "Given shapes, %s, %s and %s, are not broadcastable.",
- GetShapeDebugString(input1->dims).c_str(),
- GetShapeDebugString(input2->dims).c_str(),
- GetShapeDebugString(input3->dims).c_str());
+ TF_LITE_KERNEL_LOG(context,
+ "Given shapes, %s, %s and %s, are not broadcastable.",
+ GetShapeDebugString(input1->dims).c_str(),
+ GetShapeDebugString(input2->dims).c_str(),
+ GetShapeDebugString(input3->dims).c_str());
return kTfLiteError;
}
shape->data[out_dims - i - 1] = max_value;
@@ -529,6 +530,9 @@ int TfLiteTypeGetSize(TfLiteType type) {
return 1;
case kTfLiteBool:
return sizeof(bool);
+ case kTfLiteUInt16:
+ static_assert(sizeof(uint16_t) == 2, "");
+ return 2;
case kTfLiteInt16:
static_assert(sizeof(int16_t) == 2, "");
return 2;
@@ -575,4 +579,15 @@ bool IsMobilePlatform() {
return false;
}
+bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
+#ifndef TF_LITE_STATIC_MEMORY
+ if (tensor->dims_signature) {
+ for (int i : TfLiteIntArrayView(tensor->dims_signature)) {
+ if (i == -1) return true;
+ }
+ }
+#endif // TF_LITE_STATIC_MEMORY
+ return false;
+}
+
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
index d082e7b0..22689436 100644
--- a/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
+++ b/code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
@@ -314,6 +314,9 @@ int TfLiteTypeGetSize(TfLiteType type);
// Whether the current platform is mobile (Android or iOS).
bool IsMobilePlatform();
+// Returns whether there is unspecified dimension in the tensor's dim signature.
+bool HasUnspecifiedDimension(const TfLiteTensor* tensor);
+
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc b/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc
index 8777cd28..6fa1b31b 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/all_ops_resolver.cc
@@ -29,8 +29,12 @@ AllOpsResolver::AllOpsResolver() {
AddAssignVariable();
AddAveragePool2D();
AddBatchToSpaceNd();
+ AddBroadcastArgs();
+ AddBroadcastTo();
AddCallOnce();
+ AddCast();
AddCeil();
+ AddCircularBuffer();
AddConcatenation();
AddConv2D();
AddCos();
@@ -49,9 +53,12 @@ AllOpsResolver::AllOpsResolver() {
AddFloorDiv();
AddFloorMod();
AddFullyConnected();
+ AddGather();
+ AddGatherNd();
AddGreater();
AddGreaterEqual();
AddHardSwish();
+ AddIf();
AddL2Normalization();
AddL2Pool2D();
AddLeakyRelu();
@@ -66,6 +73,7 @@ AllOpsResolver::AllOpsResolver() {
AddMaximum();
AddMean();
AddMinimum();
+ AddMirrorPad();
AddMul();
AddNeg();
AddNotEqual();
@@ -85,6 +93,7 @@ AllOpsResolver::AllOpsResolver() {
AddRsqrt();
AddShape();
AddSin();
+ AddSlice();
AddSoftmax();
AddSpaceToBatchNd();
AddSpaceToDepth();
@@ -101,6 +110,8 @@ AllOpsResolver::AllOpsResolver() {
AddTransposeConv();
AddUnpack();
AddVarHandle();
+ AddWhile();
+ AddZerosLike();
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
new file mode 100644
index 00000000..5a5ba9ab
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
@@ -0,0 +1,107 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/fake_micro_context.h"
+
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/micro_allocator.h"
+#include "tensorflow/lite/micro/micro_arena_constants.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+#include "tensorflow/lite/micro/simple_memory_allocator.h"
+
+namespace tflite {
+namespace {
+// Dummy static variables to allow creation of dummy MicroAllocator.
+// All tests are guarateed to run serially.
+static constexpr int KDummyTensorArenaSize = 256;
+static uint8_t dummy_tensor_arena[KDummyTensorArenaSize];
+} // namespace
+
+FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
+ SimpleMemoryAllocator* allocator,
+ MicroGraph* micro_graph)
+ : MicroContext(
+ MicroAllocator::Create(dummy_tensor_arena, KDummyTensorArenaSize,
+ GetMicroErrorReporter()),
+ nullptr, micro_graph),
+ tensors_(tensors),
+ allocator_(allocator) {}
+
+TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
+ allocated_tensor_count_++;
+ return &tensors_[tensor_index];
+}
+
+void FakeMicroContext::DeallocateTempTfLiteTensor(TfLiteTensor* tensor) {
+ allocated_tensor_count_--;
+}
+
+bool FakeMicroContext::IsAllTempTfLiteTensorDeallocated() {
+ return !allocated_tensor_count_;
+}
+
+TfLiteEvalTensor* FakeMicroContext::GetEvalTensor(int tensor_index) {
+ TfLiteEvalTensor* eval_tensor =
+ reinterpret_cast(allocator_->AllocateTemp(
+ sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
+ TFLITE_DCHECK(eval_tensor != nullptr);
+
+ // In unit tests, the TfLiteTensor pointer contains the source of truth for
+ // buffers and values:
+ eval_tensor->data = tensors_[tensor_index].data;
+ eval_tensor->dims = tensors_[tensor_index].dims;
+ eval_tensor->type = tensors_[tensor_index].type;
+ return eval_tensor;
+}
+
+void* FakeMicroContext::AllocatePersistentBuffer(size_t bytes) {
+ // FakeMicroContext use SimpleMemoryAllocator, which does not automatically
+ // apply the buffer alignment like MicroAllocator.
+ // The buffer alignment is potentially wasteful but allows the
+ // fake_micro_context to work correctly with optimized kernels.
+ return allocator_->AllocatePersistentBuffer(bytes,
+ MicroArenaBufferAlignment());
+}
+
+TfLiteStatus FakeMicroContext::RequestScratchBufferInArena(size_t bytes,
+ int* buffer_index) {
+ TFLITE_DCHECK(buffer_index != nullptr);
+
+ if (scratch_buffer_count_ == kNumScratchBuffers_) {
+ MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
+ kNumScratchBuffers_);
+ return kTfLiteError;
+ }
+
+ // For tests, we allocate scratch buffers from the tail and keep them around
+ // for the lifetime of model. This means that the arena size in the tests will
+ // be more than what we would have if the scratch buffers could share memory.
+ scratch_buffers_[scratch_buffer_count_] =
+ allocator_->AllocatePersistentBuffer(bytes, MicroArenaBufferAlignment());
+ TFLITE_DCHECK(scratch_buffers_[scratch_buffer_count_] != nullptr);
+
+ *buffer_index = scratch_buffer_count_++;
+ return kTfLiteOk;
+}
+
+void* FakeMicroContext::GetScratchBuffer(int buffer_index) {
+ TFLITE_DCHECK(scratch_buffer_count_ <= kNumScratchBuffers_);
+ if (buffer_index >= scratch_buffer_count_) {
+ return nullptr;
+ }
+ return scratch_buffers_[buffer_index];
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h
new file mode 100644
index 00000000..99933c19
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.h
@@ -0,0 +1,56 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_MICRO_FAKE_MICRO_CONTEXT_H_
+#define TENSORFLOW_LITE_MICRO_FAKE_MICRO_CONTEXT_H_
+
+#include "tensorflow/lite/micro/micro_context.h"
+#include "tensorflow/lite/micro/micro_graph.h"
+
+namespace tflite {
+// A fake of MicroContext for kernel util tests.
+class FakeMicroContext : public MicroContext {
+ public:
+ FakeMicroContext(TfLiteTensor* tensors, SimpleMemoryAllocator* allocator,
+ MicroGraph* micro_graph);
+
+ void* AllocatePersistentBuffer(size_t bytes) override;
+ TfLiteStatus RequestScratchBufferInArena(size_t bytes,
+ int* buffer_index) override;
+ void* GetScratchBuffer(int buffer_index) override;
+
+ TfLiteTensor* AllocateTempTfLiteTensor(int tensor_index) override;
+ void DeallocateTempTfLiteTensor(TfLiteTensor* tensor) override;
+ bool IsAllTempTfLiteTensorDeallocated();
+
+ TfLiteEvalTensor* GetEvalTensor(int tensor_index) override;
+
+ private:
+ static constexpr int kNumScratchBuffers_ = 12;
+
+ int scratch_buffer_count_ = 0;
+ uint8_t* scratch_buffers_[kNumScratchBuffers_];
+
+ TfLiteTensor* tensors_;
+ int allocated_tensor_count_ = 0;
+
+ SimpleMemoryAllocator* allocator_;
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_FAKE_MICRO_CONTEXT_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h b/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h
new file mode 100644
index 00000000..3767cb9f
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h
@@ -0,0 +1,100 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+
+#include
+#include
+
+#include "tensorflow/lite/c/c_api_types.h"
+
+namespace tflite {
+// Interface classes that the TFLM framework relies on to get buffers it needs.
+// There are two types of buffers that the TFLM framework requires: persistent
+// and non-persistent. Persistent buffers, once allocated, are never freed by
+// the TFLM framework. Non-persist buffers can be allocated and deallocated by
+// the TFLM framework. This file defines two interfaces classes that TFLM
+// framework will rely on to manage these buffers.
+
+// Interface class for managing persistent buffers.
+class IPersistentBufferAllocator {
+ public:
+ IPersistentBufferAllocator() {}
+ virtual ~IPersistentBufferAllocator() {}
+
+ // Allocates persistent memory. The persistent buffer is never freed.
+ virtual uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) = 0;
+
+ // Returns the size of all persistent allocations in bytes.
+ virtual size_t GetPersistentUsedBytes() const = 0;
+};
+
+// Interface class for managing non-persistent buffers.
+// The default non-persistent buffers are temp buffers that are not resizable.
+// Support of at least one resizable buffer is required.
+class INonPersistentBufferAllocator {
+ public:
+ INonPersistentBufferAllocator() {}
+ virtual ~INonPersistentBufferAllocator() {}
+
+ // Allocates a temporary buffer. This buffer is not resizable.
+ virtual uint8_t* AllocateTemp(size_t size, size_t alignment) = 0;
+
+ // Signals that a temporary buffer is no longer needed.
+ virtual void DeallocateTemp(uint8_t* buf) = 0;
+
+ // Returns true if all temporary buffers are already deallocated.
+ virtual bool IsAllTempDeallocated() = 0;
+
+ // Signals that all temporary allocations can be reclaimed. TFLM calls this
+ // API when it knows that all temporary buffers that it requested has been
+ // deallocated. The goal of API is to facilitate implementations of
+ // INonPersistentBufferAllocator can reuse buffer with some reasonable
+ // complexity.
+ virtual TfLiteStatus ResetTempAllocations() = 0;
+
+ // Returns a buffer that is resizable viable ResizeBuffer().
+ virtual uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) = 0;
+
+ // Resizes a buffer that is previously returned by the
+ // AllocateResizableBuffer.
+ virtual TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size,
+ size_t alignment) = 0;
+
+ // Frees up the memory occupied by the resizable buffer.
+ virtual TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) = 0;
+
+ // Returns a pointer pointing to the start of the overlay memory, which is
+ // used for activation tensors and scratch buffers by kernels at Invoke stage.
+ virtual uint8_t* GetOverlayMemoryAddress() const = 0;
+
+ // Reserves the size of the overlay memory. This overlay is reserved for the
+ // kernels at Invoke stage. This is referred to as the overlay because before
+ // Invoket state, the same memory can be used for temp buffers. The layout of
+ // the memory is planned by the memory planner separately at Invoke stage.
+ virtual TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size,
+ size_t alignment) = 0;
+
+ // Returns the size of non-persistent buffer in use.
+ virtual size_t GetNonPersistentUsedBytes() const = 0;
+
+ // Returns the number of bytes available with a given alignment. This number
+ // takes in account any temporary allocations.
+ virtual size_t GetAvailableMemory(size_t alignment) const = 0;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations_common.cc
index 90afe832..2ec3a1bf 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/activations_common.cc
@@ -117,15 +117,21 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
ReluOpData* data = static_cast(node->user_data);
- const TfLiteTensor* input = GetInput(context, node, kActivationsInputTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kActivationsInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* output = GetOutput(context, node, kActivationsOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kActivationsOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input->type == kTfLiteInt8) {
CalculateReluOpData(input, output, data);
}
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
@@ -133,7 +139,9 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
Relu6OpData* data = static_cast(node->user_data);
- const TfLiteTensor* input = GetInput(context, node, kActivationsInputTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kActivationsInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
if (input->type == kTfLiteInt8) {
@@ -142,6 +150,8 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
data->zero_int8 = input->params.zero_point;
}
+ micro_context->DeallocateTempTfLiteTensor(input);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_common.cc
index 3d0c841e..b285b800 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_common.cc
@@ -80,11 +80,15 @@ TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
- const TfLiteTensor* input1 = GetInput(context, node, kAddInputTensor1);
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* input1 =
+ micro_context->AllocateTempInputTensor(node, kAddInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
- const TfLiteTensor* input2 = GetInput(context, node, kAddInputTensor2);
+ TfLiteTensor* input2 =
+ micro_context->AllocateTempInputTensor(node, kAddInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
- TfLiteTensor* output = GetOutput(context, node, kAddOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kAddOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
OpDataAdd* data = static_cast(node->user_data);
@@ -93,6 +97,9 @@ TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(
CalculateOpDataAdd(context, params, input1, input2, output, data));
+ micro_context->DeallocateTempTfLiteTensor(input1);
+ micro_context->DeallocateTempTfLiteTensor(input2);
+ micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
index b57a2ae6..5d0ab724 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
@@ -50,18 +50,19 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, num_inputs >= 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input_tensor_first;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node, kInputTensor0, &input_tensor_first));
- TfLiteTensor* output;
- TF_LITE_ENSURE_OK(context,
- GetOutputSafe(context, node, kOutputTensor, &output));
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* input_tensor_first =
+ micro_context->AllocateTempInputTensor(node, kInputTensor0);
+ TF_LITE_ENSURE(context, input_tensor_first != nullptr);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
// Check that all tensors have the same shape and type.
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_tensor_first->type);
for (int i = kInputTensor0 + 1; i < num_inputs; ++i) {
- const TfLiteTensor* input;
- TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
+ TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
+ TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, HaveSameShapes(input_tensor_first, input));
TF_LITE_ENSURE_TYPES_EQ(context, input_tensor_first->type, input->type);
@@ -72,6 +73,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context,
input_tensor_first->params.scale == input->params.scale);
}
+
+ micro_context->DeallocateTempTfLiteTensor(input);
}
if (output->type == kTfLiteFloat32) {
@@ -123,6 +126,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
+ micro_context->DeallocateTempTfLiteTensor(input_tensor_first);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
index a583a067..e28ebebb 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
@@ -52,21 +52,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_resource_id_tensor->type == kTfLiteInt32));
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor->dims), 1);
- const TfLiteTensor* input_value = GetInput(context, node, kInputValue);
+ tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
+ TfLiteTensor* input_value =
+ micro_context->AllocateTempInputTensor(node, kInputValue);
TFLITE_DCHECK(input_value != nullptr);
- // Casting to TfliteIntArray is required since we are re-using
- // GetExecutionPlan from TfLiteContext. On TFLM this method returns a
- // MicroGraph.
- // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
- MicroGraph* graph_info;
- context->GetExecutionPlan(context,
- reinterpret_cast(&graph_info));
- MicroResourceVariables* resources = graph_info->GetResourceVariables();
+ MicroGraph& graph_info = micro_context->graph();
+
+ MicroResourceVariables* resources = graph_info.GetResourceVariables();
TF_LITE_ENSURE_OK(context,
resources->Allocate(input_resource_id_tensor->data.i32[0],
context, input_value));
+ micro_context->DeallocateTempTfLiteTensor(input_value);
return kTfLiteOk;
}
@@ -79,14 +77,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetEvalInput(context, node, kInputValue);
TFLITE_DCHECK(input_value != nullptr);
- // Casting to TfliteIntArray is required since we are re-using
- // GetExecutionPlan from TfLiteContext. On TFLM this method returns a
- // MicroGraph.
- // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
- MicroGraph* graph_info;
- context->GetExecutionPlan(context,
- reinterpret_cast(&graph_info));
- MicroResourceVariables* resources = graph_info->GetResourceVariables();
+ tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
+ MicroGraph& graph_info = micro_context->graph();
+
+ MicroResourceVariables* resources = graph_info.GetResourceVariables();
if (resources == nullptr) {
MicroPrintf(
"ASSIGN_VARIABLE requires resource variables. Please create "
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
index a6fa0462..07b680df 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
@@ -41,8 +41,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
@@ -51,6 +55,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
new file mode 100644
index 00000000..fa333249
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
@@ -0,0 +1,97 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/kernels/internal/reference/broadcast_args.h"
+
+#include
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_context.h"
+
+namespace tflite {
+namespace {
+constexpr int kShape1Tensor = 0;
+constexpr int kShape2Tensor = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus BroadcastArgsPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* shape1 =
+ micro_context->AllocateTempInputTensor(node, kShape1Tensor);
+ TfLiteTensor* shape2 =
+ micro_context->AllocateTempInputTensor(node, kShape2Tensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
+
+ TF_LITE_ENSURE(context,
+ shape1->type == kTfLiteInt32 || shape1->type == kTfLiteInt64);
+ TF_LITE_ENSURE_EQ(context, shape1->type, shape2->type);
+ TF_LITE_ENSURE_EQ(context, shape1->type, output->type);
+
+ // Ensures the shapes are 1D tensor.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(shape1), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(shape2), 1);
+
+ // Ensure the shape of the output tensor is compatible
+ TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
+
+ micro_context->DeallocateTempTfLiteTensor(shape1);
+ micro_context->DeallocateTempTfLiteTensor(shape2);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteEvalTensor* shape1 =
+ micro::GetEvalInput(context, node, kShape1Tensor);
+ const TfLiteEvalTensor* shape2 =
+ micro::GetEvalInput(context, node, kShape2Tensor);
+ TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteInt32) {
+ reference_ops::BroadcastArgs(
+ micro::GetTensorShape(shape1), micro::GetTensorData(shape1),
+ micro::GetTensorShape(shape2), micro::GetTensorData(shape2),
+ micro::GetTensorShape(output), micro::GetTensorData(output));
+ } else {
+ reference_ops::BroadcastArgs(
+ micro::GetTensorShape(shape1), micro::GetTensorData(shape1),
+ micro::GetTensorShape(shape2), micro::GetTensorData(shape2),
+ micro::GetTensorShape(output), micro::GetTensorData(output));
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteRegistration Register_BROADCAST_ARGS() {
+ return {/*init=*/nullptr,
+ /*free=*/nullptr,
+ /*prepare=*/BroadcastArgsPrepare,
+ /*invoke=*/BroadcastArgsEval,
+ /*profiling_string=*/nullptr,
+ /*builtin_code=*/0,
+ /*custom_name=*/nullptr,
+ /*version=*/0};
+}
+
+} // namespace tflite
\ No newline at end of file
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
new file mode 100644
index 00000000..5302faf1
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
@@ -0,0 +1,129 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
+
+#include
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_context.h"
+
+namespace tflite {
+
+namespace {
+constexpr int kInputTensor = 0;
+constexpr int kShapeTensor = 1;
+constexpr int kOutputTensor = 0;
+// Support a maximum of 5 dimensions in TFLM.
+constexpr int kMaxDims = 5;
+
+TfLiteStatus ValidateOutputTensor(TfLiteContext* context, TfLiteTensor* input,
+ TfLiteTensor* shape, TfLiteTensor* output) {
+ // Ensures the shape is 1D tensor.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
+
+ // Ensure output dims is not less than input dims.
+ int input_num_dims = NumDimensions(input);
+ int output_num_dims = NumDimensions(output);
+ int shape_num_dims = SizeOfDimension(shape, 0);
+ TF_LITE_ENSURE_MSG(context, output_num_dims == shape_num_dims,
+ "Output must match with the expected shape dimension.");
+ TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
+ "Output shape must be broadcastable from input shape.");
+ TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
+ "BroadcastTo only supports 1-5D tensor.");
+
+ // Check if output shape is broadcastable from input shape.
+ auto get_shape_data = [shape](int i) -> int32_t {
+ if (shape->type == kTfLiteInt32) {
+ return GetTensorData(shape)[i];
+ } else {
+ return GetTensorData(shape)[i];
+ }
+ };
+
+ int extending_dims = output_num_dims - input_num_dims;
+ for (int idx = 0; idx < input_num_dims; ++idx) {
+ TF_LITE_ENSURE_MSG(
+ context,
+ (SizeOfDimension(input, idx) == 1 ||
+ SizeOfDimension(input, idx) == get_shape_data(extending_dims + idx)),
+ "Output shape must be broadcastable from input shape.");
+ }
+
+ // Validating the shape of the output tensor.
+ tflite::RuntimeShape output_shape = tflite::GetTensorShape(output);
+ for (int idx = 0; idx < output_num_dims; ++idx) {
+ TF_LITE_ENSURE(context, output_shape.Dims(idx) == get_shape_data(idx));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus BroadcastToPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ MicroContext* micro_context = GetMicroContext(context);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
+ TfLiteTensor* shape =
+ micro_context->AllocateTempInputTensor(node, kShapeTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
+
+ TF_LITE_ENSURE_MSG(context, (NumDimensions(input) <= kMaxDims),
+ "BroadcastTo only supports 1-5D tensor.");
+
+ TF_LITE_ENSURE(context,
+ shape->type == kTfLiteInt32 || shape->type == kTfLiteInt64);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ // Does not support String type due to its variable size. This limitation is
+ // the same as TFLite.
+ TF_LITE_ENSURE(context, input->type != kTfLiteString);
+
+ TF_LITE_ENSURE_STATUS(ValidateOutputTensor(context, input, shape, output));
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(shape);
+ micro_context->DeallocateTempTfLiteTensor(output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteEvalTensor* input =
+ micro::GetEvalInput(context, node, kInputTensor);
+ TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
+
+ // BroadcastTo op support upto 5 dims, different from 8 dims in TFLite.
+ reference_ops::BroadcastTo(
+ micro::GetTensorShape(input), input->data.raw,
+ micro::GetTensorShape(output), output->data.raw, input->type);
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteRegistration Register_BROADCAST_TO() {
+ return {/*init=*/nullptr,
+ /*free=*/nullptr,
+ /*prepare=*/BroadcastToPrepare,
+ /*invoke=*/BroadcastToEval,
+ /*profiling_string=*/nullptr,
+ /*builtin_code=*/0,
+ /*custom_name=*/nullptr,
+ /*version=*/0};
+}
+
+} // namespace tflite
\ No newline at end of file
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
index 97fded0c..4db39f7d 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"
#include "tensorflow/lite/schema/schema_generated.h"
@@ -50,16 +51,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 0);
TF_LITE_ENSURE(context, NumOutputs(node) == 0);
- // Casting to TfliteIntArray is required since we are re-using
- // GetExecutionPlan from TfLiteContext. On TFLM this method returns a
- // MicroGraph.
- // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
- MicroGraph* graph_info;
- context->GetExecutionPlan(context,
- reinterpret_cast(&graph_info));
+ tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
+ MicroGraph& graph_info = micro_context->graph();
TF_LITE_ENSURE(context,
- op_data->init_subgraph_index < graph_info->NumSubgraphs());
+ op_data->init_subgraph_index < graph_info.NumSubgraphs());
return kTfLiteOk;
}
@@ -72,16 +68,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
- // Casting to TfliteIntArray is required since we are re-using
- // GetExecutionPlan from TfLiteContext. On TFLM this method returns a
- // MicroGraph.
- // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
- MicroGraph* graph_info;
- context->GetExecutionPlan(context,
- reinterpret_cast(&graph_info));
+ tflite::MicroContext* micro_context = tflite::GetMicroContext(context);
+ MicroGraph& graph_info = micro_context->graph();
TF_LITE_ENSURE_OK(context,
- graph_info->InvokeSubgraph(op_data->init_subgraph_index));
+ graph_info.InvokeSubgraph(op_data->init_subgraph_index));
op_data->has_run = true;
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
index 0314e523..dc651a24 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
@@ -28,11 +28,19 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
@@ -83,6 +91,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32:
return copyToTensor(context, tflite::micro::GetTensorData(input),
output, num_elements);
+ case kTfLiteUInt32:
+ return copyToTensor(context,
+ tflite::micro::GetTensorData(input), output,
+ num_elements);
case kTfLiteFloat32:
return copyToTensor(context, tflite::micro::GetTensorData(input),
output, num_elements);
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
index f929ce62..d0a48f91 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
@@ -29,9 +29,13 @@ constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -42,6 +46,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < output->dims->size; ++i) {
TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
}
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc
index 0bb4d476..682efb43 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer_common.cc
@@ -39,9 +39,13 @@ const int kCircularBufferCyclesMaxIndex = 0; // 'cycles_max'
const TfLiteStatus kTfLiteAbort = static_cast(-9);
TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input =
- GetInput(context, node, kCircularBufferInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kCircularBufferOutputTensor);
+
+ MicroContext * micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context-> AllocateTempInputTensor(node, kCircularBufferInputTensor);
+ TfLiteTensor* output =
+ micro_context-> AllocateTempOutputTensor(node, kCircularBufferOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
OpDataCircularBuffer* op_data =
@@ -85,6 +89,9 @@ TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) {
op_data->cycles_until_run = op_data->cycles_max;
node->user_data = op_data;
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
index eb39d9ea..925c3fb5 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
@@ -540,9 +540,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast(node->user_data);
- const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input1 =
+ micro_context->AllocateTempInputTensor(node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
- const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* input2 =
+ micro_context->AllocateTempInputTensor(node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
if (input1->type == kTfLiteInt8) {
@@ -570,6 +574,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->params.input2_shift = input2_shift;
}
+ micro_context->DeallocateTempTfLiteTensor(input1);
+ micro_context->DeallocateTempTfLiteTensor(input2);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
index 8f45ac6a..d727a0d5 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
@@ -115,13 +115,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteConcatenationParams* params =
reinterpret_cast(node->builtin_data);
- const TfLiteTensor* input_tensor = GetInput(context, node, 0);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input_tensor = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input_tensor != nullptr);
TfLiteType input_type = input_tensor->type;
- const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output_tensor =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr);
TfLiteType output_type = output_tensor->type;
+ micro_context->DeallocateTempTfLiteTensor(input_tensor);
+ micro_context->DeallocateTempTfLiteTensor(output_tensor);
+
// Check activation and input type
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
TF_LITE_ENSURE(context,
@@ -138,7 +144,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Shapes with dimensions >4 are not yet supported with static allocation.
for (int i = 0; i < num_inputs; ++i) {
- const TfLiteTensor* input = GetInput(context, node, i);
+ TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, input != nullptr);
int num_dimensions = NumDimensions(input);
@@ -150,13 +156,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
num_dimensions);
return kTfLiteError;
}
+ micro_context->DeallocateTempTfLiteTensor(input);
}
// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast(node->user_data);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
switch (output_type) { // Already know in/outtypes are same.
@@ -183,10 +191,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate persistent scale and zeropoint buffers.
// Store input scale and zero point values in OpParams:
for (int i = 0; i < node->inputs->size; ++i) {
- const TfLiteTensor* t = GetInput(context, node, i);
+ TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, t != nullptr);
input_scales[i] = t->params.scale;
input_zero_points[i] = t->params.zero_point;
+ micro_context->DeallocateTempTfLiteTensor(t);
}
data->params.input_scale = input_scales;
@@ -202,6 +211,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.h b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.h
index 4089a965..06b35e1e 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.h
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -79,7 +79,8 @@ TfLiteRegistration Register_CONV_2D();
#if defined(XTENSA)
// Returns a TfLiteRegistration struct for kernel variant that only supports
-// int8 inputs and outputs.
+// int8 activations and int8 weights and always calls the reference
+// implementation.
TfLiteRegistration Register_CONV_2D_INT8REF();
#else
inline TfLiteRegistration Register_CONV_2D_INT8REF() {
@@ -87,6 +88,25 @@ inline TfLiteRegistration Register_CONV_2D_INT8REF() {
}
#endif
+#if defined(CMSIS_NN)
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int8 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_CONV_2D_INT8();
+
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int16 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_CONV_2D_INT16();
+
+#else
+inline TfLiteRegistration Register_CONV_2D_INT8() { return Register_CONV_2D(); }
+
+inline TfLiteRegistration Register_CONV_2D_INT16() {
+ return Register_CONV_2D();
+}
+#endif
+
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_common.cc
index 6887e423..7115f7ba 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_common.cc
@@ -93,13 +93,18 @@ TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
params.dilation_width_factor, height, width, filter_height, filter_width,
padding, &out_height, &out_width);
- const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
+ TfLiteTensor* filter =
+ micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kConvBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
+ TfLiteTensor* bias =
+ micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
// Note that quantized inference requires that all tensors have their
@@ -119,6 +124,11 @@ TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(filter);
+ micro_context->DeallocateTempTfLiteTensor(output);
+ micro_context->DeallocateTempTfLiteTensor(bias);
+
return kTfLiteOk;
}
@@ -129,12 +139,16 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
OpDataConv* data = static_cast(node->user_data);
const auto& params =
*(static_cast(node->builtin_data));
+ MicroContext* micro_context = GetMicroContext(context);
- TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
- const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
+ TfLiteTensor* filter =
+ micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const int input_width = input->dims->data[2];
@@ -174,6 +188,10 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));
+ micro_context->DeallocateTempTfLiteTensor(filter);
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
index 2dc9f98f..61f7af23 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
@@ -47,8 +47,12 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
+ TfLiteTensor* axis =
+ micro_context->AllocateTempInputTensor(node, kAxisTensor);
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
@@ -58,7 +62,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE(context, HaveSameShapes(input, output));
@@ -91,6 +96,10 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
&data->output_activation_max));
}
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(axis);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
index ae42ee1b..cce93c9c 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
@@ -40,11 +40,14 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input;
- TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
- TfLiteTensor* output;
- TF_LITE_ENSURE_OK(context,
- GetOutputSafe(context, node, kOutputTensor, &output));
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -83,6 +86,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
output->dims->data[kWidthRank] = output_width;
output->dims->data[kDepthRank] = output_channels;
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
index 49167f38..3bf07274 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
@@ -94,13 +94,18 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
params.dilation_width_factor, height, width, filter_height, filter_width,
padding, &out_height, &out_width);
- const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
+ TfLiteTensor* filter =
+ micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kConvBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
+ TfLiteTensor* bias =
+ micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
// Note that quantized inference requires that all tensors have their
@@ -120,6 +125,11 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(filter);
+ micro_context->DeallocateTempTfLiteTensor(bias);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
@@ -130,14 +140,16 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
OpDataConv* data = static_cast(node->user_data);
const auto& params =
*(static_cast(node->builtin_data));
+ MicroContext* micro_context = GetMicroContext(context);
- TfLiteTensor* output = GetOutput(context, node, kDepthwiseConvOutputTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
- const TfLiteTensor* input =
- GetInput(context, node, kDepthwiseConvInputTensor);
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kDepthwiseConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* filter =
- GetInput(context, node, kDepthwiseConvWeightsTensor);
+ TfLiteTensor* filter =
+ micro_context->AllocateTempInputTensor(node, kDepthwiseConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const int input_width = input->dims->data[2];
@@ -180,6 +192,10 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));
+ micro_context->DeallocateTempTfLiteTensor(output);
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(filter);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
index 00b47f57..4be5ad89 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
@@ -33,10 +33,12 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ MicroContext* micro_context = GetMicroContext(context);
+
// TODO(b/140515557): Add cached dequant to improve hybrid model performance.
- const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context,
@@ -54,6 +56,10 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
data->quantization_params.zero_point = input->params.zero_point;
data->quantization_params.scale = static_cast(input->params.scale);
data->output_zero_point = output->params.zero_point;
+
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
index 5ac343cf..efe57e2f 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include
#include
+#include
#include "flatbuffers/flexbuffers.h"
#include "tensorflow/lite/c/builtin_op_data.h"
@@ -152,14 +154,17 @@ void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = static_cast(node->user_data);
+ MicroContext* micro_context = GetMicroContext(context);
+
// Inputs: box_encodings, scores, anchors
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
- const TfLiteTensor* input_box_encodings =
- GetInput(context, node, kInputTensorBoxEncodings);
- const TfLiteTensor* input_class_predictions =
- GetInput(context, node, kInputTensorClassPredictions);
- const TfLiteTensor* input_anchors =
- GetInput(context, node, kInputTensorAnchors);
+ TfLiteTensor* input_box_encodings =
+ micro_context->AllocateTempInputTensor(node, kInputTensorBoxEncodings);
+ TfLiteTensor* input_class_predictions =
+ micro_context->AllocateTempInputTensor(node,
+ kInputTensorClassPredictions);
+ TfLiteTensor* input_anchors =
+ micro_context->AllocateTempInputTensor(node, kInputTensorAnchors);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
@@ -217,6 +222,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// num_detections
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
+ micro_context->DeallocateTempTfLiteTensor(input_box_encodings);
+ micro_context->DeallocateTempTfLiteTensor(input_class_predictions);
+ micro_context->DeallocateTempTfLiteTensor(input_anchors);
+
return kTfLiteOk;
}
@@ -313,9 +322,10 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
void DecreasingPartialArgSort(const float* values, int num_values,
int num_to_sort, int* indices) {
std::iota(indices, indices + num_values, 0);
- std::partial_sort(
- indices, indices + num_to_sort, indices + num_values,
- [&values](const int i, const int j) { return values[i] > values[j]; });
+ std::partial_sort(indices, indices + num_to_sort, indices + num_values,
+ [&values](const int i, const int j) {
+ return std::tie(values[i], j) > std::tie(values[j], i);
+ });
}
template
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
index 581e532b..366dd610 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
@@ -38,11 +38,13 @@ bool IsLogicalSupportedType(const TfLiteType type) {
typedef bool (*IsSupportedType)(TfLiteType);
template
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
+ MicroContext* micro_context = GetMicroContext(context);
+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
@@ -50,6 +52,9 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
+
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
index 7e785f2f..b2cd19cc 100644
--- a/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
@@ -80,13 +80,16 @@ void EvalUsingLookupTable(const OpData* data, const TfLiteEvalTensor* input,
}
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
+ MicroContext* micro_context = GetMicroContext(context);
+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input;
- TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
- TfLiteTensor* output;
- TF_LITE_ENSURE_OK(context,
- GetOutputSafe(context, node, kOutputTensor, &output));
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
// Use LUT to handle quantized elu path.
@@ -97,7 +100,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
};
PopulateLookupTable(input, output, transform, data);
}
-
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/README.md b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/README.md
new file mode 100644
index 00000000..b0c215fb
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/README.md
@@ -0,0 +1,11 @@
+# Info
+
+These are the Espressif chipset specific replacement kernels.
+The kernels call optimized routines or reference routines depending upon optimization option selected.
+
+By default optimizations are selected if available.
+To change this behaviour, please make the appropriate `ESP-NN` menu selection after running:
+
+```
+idf.py menuconfig
+```
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
new file mode 100644
index 00000000..47a17d9f
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
@@ -0,0 +1,209 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/internal/reference/add.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
+#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/add.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+#include
+
+#if ESP_NN
+#include
+#endif
+
+long long add_total_time = 0;
+
+namespace tflite {
+
+void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
+ const OpDataAdd* data, const TfLiteEvalTensor* input1,
+ const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(data->output_activation_min_f32,
+ data->output_activation_max_f32, &op_params);
+ if (data->requires_broadcast) {
+ reference_ops::BroadcastAdd4DSlow(
+ op_params, tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorData(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ } else {
+ reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorData(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ }
+}
+
+TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, const OpDataAdd* data,
+ const TfLiteEvalTensor* input1,
+ const TfLiteEvalTensor* input2,
+ TfLiteEvalTensor* output) {
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = data->left_shift;
+ op_params.input1_offset = data->input1_offset;
+ op_params.input1_multiplier = data->input1_multiplier;
+ op_params.input1_shift = data->input1_shift;
+ op_params.input2_offset = data->input2_offset;
+ op_params.input2_multiplier = data->input2_multiplier;
+ op_params.input2_shift = data->input2_shift;
+ op_params.output_offset = data->output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = data->output_shift;
+ SetActivationParams(data->output_activation_min, data->output_activation_max,
+ &op_params);
+ bool need_broadcast = reference_ops::ProcessBroadcastShapes(
+ tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorShape(input2), &op_params);
+
+ switch (output->type) {
+ case kTfLiteInt8: {
+ if (need_broadcast) {
+ reference_integer_ops::BroadcastAdd4DSlow(
+ op_params, tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorData(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ } else {
+#if ESP_NN
+ const int8_t *input1_data = tflite::micro::GetTensorData(input1);
+ const int8_t *input2_data = tflite::micro::GetTensorData(input2);
+ int8_t *out_data = tflite::micro::GetTensorData(output);
+
+ esp_nn_add_elementwise_s8(input1_data,
+ input2_data,
+ data->input1_offset,
+ data->input2_offset,
+ data->input1_multiplier,
+ data->input2_multiplier,
+ data->input1_shift,
+ data->input2_shift,
+ data->left_shift,
+ out_data,
+ data->output_offset,
+ data->output_multiplier,
+ data->output_shift,
+ data->output_activation_min,
+ data->output_activation_max,
+ MatchingElementsSize(tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorShape(output))
+ );
+#else
+ reference_integer_ops::Add(
+ op_params, tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorData(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+#endif
+ }
+ break;
+ }
+ case kTfLiteInt16: {
+ if (need_broadcast) {
+ reference_ops::BroadcastAdd4DSlow(
+ op_params, tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorData(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ } else {
+ reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
+ tflite::micro::GetTensorData(input1),
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output),
+ false);
+ }
+ break;
+ }
+ default:
+ MicroPrintf("Type %s (%d) not supported.",
+ TfLiteTypeGetName(output->type), output->type);
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+void* AddInit(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(OpDataAdd));
+}
+
+TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast(node->builtin_data);
+
+ TFLITE_DCHECK(node->user_data != nullptr);
+ const OpDataAdd* data = static_cast(node->user_data);
+
+ const TfLiteEvalTensor* input1 =
+ tflite::micro::GetEvalInput(context, node, kAddInputTensor1);
+ const TfLiteEvalTensor* input2 =
+ tflite::micro::GetEvalInput(context, node, kAddInputTensor2);
+ TfLiteEvalTensor* output =
+ tflite::micro::GetEvalOutput(context, node, kAddOutputTensor);
+
+ long long start_time = esp_timer_get_time();
+
+ if (output->type == kTfLiteFloat32) {
+ EvalAdd(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, data,
+ input1, input2, output));
+ } else {
+ MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(output->type),
+ output->type);
+ return kTfLiteError;
+ }
+ add_total_time += esp_timer_get_time() - start_time;
+
+ return kTfLiteOk;
+}
+
+TfLiteRegistration Register_ADD() {
+ return {/*init=*/AddInit,
+ /*free=*/nullptr,
+ /*prepare=*/AddPrepare,
+ /*invoke=*/AddEval,
+ /*profiling_string=*/nullptr,
+ /*builtin_code=*/0,
+ /*custom_name=*/nullptr,
+ /*version=*/0};
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
new file mode 100644
index 00000000..09260482
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
@@ -0,0 +1,319 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/kernels/conv.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/conv.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/padding.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+#include "freertos/FreeRTOS.h"
+#include
+
+#if ESP_NN
+#include
+#endif
+
+
+long long conv_total_time = 0;
+
+namespace tflite {
+namespace {
+
+struct NodeData {
+ OpDataConv op_data;
+#if ESP_NN
+ int buffer_idx;
+#endif
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(NodeData));
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ NodeData* data = static_cast(node->user_data);
+ const auto& params =
+ *(static_cast(node->builtin_data));
+
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kConvInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TfLiteTensor* filter =
+ micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
+ TF_LITE_ENSURE(context, filter != nullptr);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+
+ const int input_width = input->dims->data[2];
+ const int input_height = input->dims->data[1];
+ const int filter_width = filter->dims->data[2];
+ const int filter_height = filter->dims->data[1];
+ const int output_width = output->dims->data[2];
+ const int output_height = output->dims->data[1];
+
+ // Dynamically allocate per-channel quantization parameters.
+ const int num_channels = filter->dims->data[kConvQuantizedDimension];
+ data->op_data.per_channel_output_multiplier =
+ static_cast(context->AllocatePersistentBuffer(
+ context, num_channels * sizeof(int32_t)));
+ data->op_data.per_channel_output_shift =
+ static_cast(context->AllocatePersistentBuffer(
+ context, num_channels * sizeof(int32_t)));
+
+ // All per-channel quantized tensors need valid zero point and scale arrays.
+ if (input->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_EQ(context, filter->quantization.type,
+ kTfLiteAffineQuantization);
+
+ const auto* affine_quantization =
+ static_cast(filter->quantization.params);
+ TFLITE_DCHECK(affine_quantization != nullptr);
+ TFLITE_DCHECK(affine_quantization->scale != nullptr);
+ TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
+
+ TF_LITE_ENSURE(context,
+ affine_quantization->scale->size == 1 ||
+ affine_quantization->scale->size ==
+ filter->dims->data[kConvQuantizedDimension]);
+ TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
+ affine_quantization->zero_point->size);
+ }
+
+ TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
+ context, node, params, input_width, input_height, filter_width,
+ filter_height, output_width, output_height, input->type, &data->op_data));
+
+#if ESP_NN
+ if (input->type == kTfLiteInt8) {
+ int scratch_buf_size = esp_nn_get_conv_scratch_size(
+ input_width, input_height, input->dims->data[3],
+ output->dims->data[3], filter_width, filter_height);
+ if (scratch_buf_size > 0) {
+ TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
+ context, scratch_buf_size, &data->buffer_idx));
+ } else {
+ data->buffer_idx = -1;
+ }
+ }
+#endif
+
+ micro_context->DeallocateTempTfLiteTensor(output);
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(filter);
+
+ return kTfLiteOk;
+}
+
+#if ESP_NN
+// Fixed-point per-channel-quantization convolution Int8 function wrapper.
+inline void EvalQuantizedPerChannel(
+ TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params,
+ const NodeData& data, const TfLiteEvalTensor* input,
+ const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
+ TfLiteEvalTensor* output) {
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+
+ if (dilation_width_factor == 1 && dilation_height_factor == 1) {
+ // Get parameters.
+ RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
+ RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+ RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
+ RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);
+
+ const int8_t *input_data = tflite::micro::GetTensorData(input);
+ int8_t *output_data = tflite::micro::GetTensorData(output);
+
+ const int32_t input_offset = -data.op_data.input_zero_point;
+ const int32_t output_offset = data.op_data.output_zero_point;
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = data.op_data.padding.width;
+ const int pad_height = data.op_data.padding.height;
+
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+
+ // Set min and max value of the output.
+ const int32_t activation_min = data.op_data.output_activation_min;
+ const int32_t activation_max = data.op_data.output_activation_max;
+
+ // Consistency check.
+ TFLITE_DCHECK_LE(activation_min, activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+
+ if (tflite::micro::GetTensorData(bias)) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+
+ void *scratch_buf = NULL;
+ if (data.buffer_idx > -1) {
+ scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
+ }
+ esp_nn_set_conv_scratch_buf(scratch_buf);
+
+ const int input_size = input_width * input_height * input_depth;
+ const int output_size = output_width * output_height * output_depth;
+
+ for (int i_batch = 0; i_batch < batch_size; i_batch++) {
+ esp_nn_conv_s8(input_data + i_batch * input_size,
+ input_width, input_height, input_depth, input_offset,
+ pad_width, pad_height, stride_width, stride_height,
+ tflite::micro::GetTensorData(filter),
+ filter_width, filter_height,
+ tflite::micro::GetTensorData(bias),
+ output_data + i_batch * output_size,
+ output_width, output_height, output_depth, output_offset,
+ data.op_data.per_channel_output_shift,
+ data.op_data.per_channel_output_multiplier,
+ activation_min, activation_max);
+ }
+ } else {
+ reference_integer_ops::ConvPerChannel(
+ ConvParamsQuantized(params, data.op_data),
+ data.op_data.per_channel_output_multiplier,
+ data.op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ }
+}
+#endif
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteEvalTensor* input =
+ tflite::micro::GetEvalInput(context, node, kConvInputTensor);
+ const TfLiteEvalTensor* filter =
+ tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
+ const TfLiteEvalTensor* bias =
+ (NumInputs(node) == 3)
+ ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
+ : nullptr;
+ TfLiteEvalTensor* output =
+ tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
+
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+ const auto& params =
+ *(reinterpret_cast(node->builtin_data));
+ TFLITE_DCHECK(node->user_data != nullptr);
+ const auto& data = *(static_cast(node->user_data));
+
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_MSG(context, input->type == filter->type,
+ "Hybrid models are not supported on TFLite Micro.");
+
+ long long start_time = esp_timer_get_time();
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32: {
+ tflite::reference_ops::Conv(
+ ConvParamsFloat(params, data.op_data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr);
+ break;
+ }
+ case kTfLiteInt8: {
+#if ESP_NN
+ EvalQuantizedPerChannel(context, node, params, data, input, filter,
+ bias, output);
+#else
+ reference_integer_ops::ConvPerChannel(
+ ConvParamsQuantized(params, data.op_data),
+ data.op_data.per_channel_output_multiplier,
+ data.op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+#endif
+ break;
+ }
+ case kTfLiteUInt8: {
+ //EvalQuantized
+ reference_ops::Conv(ConvParamsQuantized(params, data.op_data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr,
+ nullptr);
+ break;
+ }
+ default:
+ TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+ TfLiteTypeGetName(input->type), input->type);
+ return kTfLiteError;
+ }
+ conv_total_time += esp_timer_get_time() - start_time;
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteRegistration Register_CONV_2D() {
+ return {/*init=*/Init,
+ /*free=*/nullptr,
+ /*prepare=*/Prepare,
+ /*invoke=*/Eval,
+ /*profiling_string=*/nullptr,
+ /*builtin_code=*/0,
+ /*custom_name=*/nullptr,
+ /*version=*/0};
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
new file mode 100644
index 00000000..5f2d9d50
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
@@ -0,0 +1,319 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/padding.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+#include "freertos/FreeRTOS.h"
+#include
+
+#if ESP_NN
+#include
+#endif
+
+long long dc_total_time = 0;
+
+namespace tflite {
+namespace {
+
+struct NodeData {
+ OpDataConv op_data;
+#if ESP_NN
+ int buffer_idx;
+#endif
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(NodeData));
+}
+
+#if ESP_NN
+inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteDepthwiseConvParams& params,
+ const NodeData& data,
+ const TfLiteEvalTensor* input,
+ const TfLiteEvalTensor* filter,
+ const TfLiteEvalTensor* bias,
+ TfLiteEvalTensor* output) {
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+
+ if (dilation_width_factor == 1 && dilation_height_factor == 1) {
+ // Get parameters.
+ RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+ RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
+ RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
+ RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);
+
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int8_t *input_data = tflite::micro::GetTensorData(input);
+ int8_t *output_data = tflite::micro::GetTensorData(output);
+
+ const int depth_multiplier = params.depth_multiplier;
+ const int32_t input_offset = -data.op_data.input_zero_point;
+ const int32_t output_offset = data.op_data.output_zero_point;
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = data.op_data.padding.width;
+ const int pad_height = data.op_data.padding.height;
+
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+
+ // Set min and max value of the output.
+ const int32_t activation_min = data.op_data.output_activation_min;
+ const int32_t activation_max = data.op_data.output_activation_max;
+
+ // Consistency check.
+ TFLITE_DCHECK_LE(activation_min, activation_max);
+ const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ if (tflite::micro::GetTensorData(bias)) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+
+ const int input_size = input_width * input_height * input_depth;
+ const int output_size = output_width * output_height * output_depth;
+ void *scratch_buf = NULL;
+ if (data.buffer_idx > -1) {
+ scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
+ }
+ esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
+
+ for (int i_batch = 0; i_batch < batch_size; i_batch++) {
+ esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
+ input_height, input_depth, input_offset,
+ pad_width, pad_height,
+ stride_width, stride_height, depth_multiplier,
+ tflite::micro::GetTensorData(filter),
+ filter_width, filter_height,
+ tflite::micro::GetTensorData(bias),
+ output_data + i_batch * output_size,
+ output_width, output_height, output_offset,
+ data.op_data.per_channel_output_shift,
+ data.op_data.per_channel_output_multiplier,
+ activation_min, activation_max);
+ }
+ } else {
+ reference_integer_ops::DepthwiseConvPerChannel(
+ DepthwiseConvParamsQuantized(params, data.op_data),
+ data.op_data.per_channel_output_multiplier,
+ data.op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ }
+}
+#endif
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ NodeData* data = static_cast(node->user_data);
+ const TfLiteDepthwiseConvParams& params =
+ *(static_cast(node->builtin_data));
+
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kConvInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TfLiteTensor* filter =
+ micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
+ TF_LITE_ENSURE(context, filter != nullptr);
+ TfLiteTensor* bias =
+ micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
+ TfLiteTensor* output =
+ micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+
+ const int input_width = input->dims->data[2];
+ const int input_height = input->dims->data[1];
+ const int filter_width = filter->dims->data[2];
+ const int filter_height = filter->dims->data[1];
+ const int output_width = output->dims->data[2];
+ const int output_height = output->dims->data[1];
+
+ // Dynamically allocate per-channel quantization parameters.
+ const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
+ data->op_data.per_channel_output_multiplier =
+ static_cast(context->AllocatePersistentBuffer(
+ context, num_channels * sizeof(int32_t)));
+ data->op_data.per_channel_output_shift =
+ static_cast(context->AllocatePersistentBuffer(
+ context, num_channels * sizeof(int32_t)));
+
+ // All per-channel quantized tensors need valid zero point and scale arrays.
+ if (input->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_EQ(context, filter->quantization.type,
+ kTfLiteAffineQuantization);
+
+ const auto* affine_quantization =
+ static_cast(filter->quantization.params);
+ TFLITE_DCHECK(affine_quantization != nullptr);
+ TFLITE_DCHECK(affine_quantization->scale != nullptr);
+ TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
+
+ TF_LITE_ENSURE(
+ context, affine_quantization->scale->size == 1 ||
+ affine_quantization->scale->size ==
+ filter->dims->data[kDepthwiseConvQuantizedDimension]);
+
+ TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
+ affine_quantization->zero_point->size);
+ }
+
+ TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
+ context, node, params, input_width, input_height, filter_width,
+ filter_height, output_width, output_height, input->type, &data->op_data));
+
+#if ESP_NN
+ if (input->type == kTfLiteInt8) {
+ int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
+ input_width, input_height, input->dims->data[3],
+ params.depth_multiplier, filter_width, filter_height);
+ if (scratch_buf_size > 0) {
+ TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
+ context, scratch_buf_size, &data->buffer_idx));
+ } else {
+ data->buffer_idx = -1;
+ }
+ }
+#endif
+
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(filter);
+ micro_context->DeallocateTempTfLiteTensor(bias);
+ micro_context->DeallocateTempTfLiteTensor(output);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ auto& params =
+ *(reinterpret_cast(node->builtin_data));
+ const NodeData& data = *(static_cast(node->user_data));
+
+ TfLiteEvalTensor* output =
+ tflite::micro::GetEvalOutput(context, node, kDepthwiseConvOutputTensor);
+ const TfLiteEvalTensor* input =
+ tflite::micro::GetEvalInput(context, node, kDepthwiseConvInputTensor);
+ const TfLiteEvalTensor* filter =
+ tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor);
+ const TfLiteEvalTensor* bias =
+ (NumInputs(node) == 3)
+ ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
+ : nullptr;
+
+ long long start_time = esp_timer_get_time();
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ tflite::reference_ops::DepthwiseConv(
+ DepthwiseConvParamsFloat(params, data.op_data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ break;
+ case kTfLiteInt8:
+#if ESP_NN
+ EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
+ output);
+#else
+ reference_integer_ops::DepthwiseConvPerChannel(
+ DepthwiseConvParamsQuantized(params, data.op_data),
+ data.op_data.per_channel_output_multiplier,
+ data.op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+#endif
+ break;
+ case kTfLiteUInt8:
+ //EvalQuantized(context, node, params, &data, input, filter, bias, output);
+ reference_ops::DepthwiseConv(
+ DepthwiseConvParamsQuantized(params, data.op_data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData(output));
+ break;
+ default:
+ TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+ TfLiteTypeGetName(input->type), input->type);
+ return kTfLiteError;
+ }
+ dc_total_time += esp_timer_get_time() - start_time;
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
+ return {/*init=*/Init,
+ /*free=*/nullptr,
+ /*prepare=*/Prepare,
+ /*invoke=*/Eval,
+ /*profiling_string=*/nullptr,
+ /*builtin_code=*/0,
+ /*custom_name=*/nullptr,
+ /*version=*/0};
+}
+
+} // namespace tflite
diff --git a/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
new file mode 100644
index 00000000..5e1705da
--- /dev/null
+++ b/code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
@@ -0,0 +1,198 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/kernels/fully_connected.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+#if ESP_NN
+#include
+#endif
+
+#include
+
+long long fc_total_time = 0;
+
+namespace tflite {
+namespace {
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context,
+ sizeof(OpDataFullyConnected));
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ MicroContext* micro_context = GetMicroContext(context);
+
+ TFLITE_DCHECK(node->user_data != nullptr);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ auto* data = static_cast(node->user_data);
+ const auto params =
+ static_cast(node->builtin_data);
+
+ TfLiteTensor* input =
+ micro_context->AllocateTempInputTensor(node, kFullyConnectedInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ TfLiteTensor* filter = micro_context->AllocateTempInputTensor(
+ node, kFullyConnectedWeightsTensor);
+ TF_LITE_ENSURE(context, filter != nullptr);
+ TfLiteTensor* bias =
+ micro_context->AllocateTempInputTensor(node, kFullyConnectedBiasTensor);
+ TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
+ node, kFullyConnectedOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+
+ TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_MSG(context, input->type == filter->type,
+ "Hybrid models are not supported on TFLite Micro.");
+
+ TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
+ context, params->activation, input->type,
+ input, filter, bias, output, data));
+
+ micro_context->DeallocateTempTfLiteTensor(input);
+ micro_context->DeallocateTempTfLiteTensor(filter);
+ if (bias != nullptr) {
+ micro_context->DeallocateTempTfLiteTensor(bias);
+ }
+ micro_context->DeallocateTempTfLiteTensor(output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+ const auto* params =
+ static_cast(node->builtin_data);
+
+ const TfLiteEvalTensor* input =
+ tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
+ const TfLiteEvalTensor* filter =
+ tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
+ const TfLiteEvalTensor* bias =
+ tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
+ TfLiteEvalTensor* output =
+ tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
+
+ TFLITE_DCHECK(node->user_data != nullptr);
+ const auto& data =
+ *(static_cast(node->user_data));
+
+ long long start_time = esp_timer_get_time();
+ // Checks in Prepare ensure input, output and filter types are all the same.
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ tflite::reference_ops::FullyConnected(
+ FullyConnectedParamsFloat(params->activation),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData