mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-08 04:26:58 +03:00
Rolling 20220716_2
This commit is contained in:
@@ -54,8 +54,10 @@ In other cases you can contact the developer via email: <img src="https://raw.gi
|
|||||||
|
|
||||||
##### Rolling (2022-07-16)
|
##### Rolling (2022-07-16)
|
||||||
|
|
||||||
- Updated esp32cam
|
- TFMicro/Lite: Update (espressif Version 20220716)
|
||||||
|
- Updated esp32cam (v20220716)
|
||||||
- Integrated new analog classificational CNN (from @haverland)
|
- Integrated new analog classificational CNN (from @haverland)
|
||||||
|
- Bugfix: Postprocessing
|
||||||
|
|
||||||
##### Rolling (2022-07-01)
|
##### Rolling (2022-07-01)
|
||||||
|
|
||||||
@@ -79,7 +81,7 @@ Rolling (2022-04-26)
|
|||||||
- Extended MQTT with absolute Change (in addition to rate)
|
- Extended MQTT with absolute Change (in addition to rate)
|
||||||
- Internal optimization, removal of modelfile from `config.ini` (is now read out of the cnn file directly)
|
- Internal optimization, removal of modelfile from `config.ini` (is now read out of the cnn file directly)
|
||||||
|
|
||||||
- TFMicro/Lite: Update (espressif Verision 20220417)
|
- TFMicro/Lite: Update (espressif Version 20220417)
|
||||||
- ESP-IDF: Update to 4.3.0
|
- ESP-IDF: Update to 4.3.0
|
||||||
|
|
||||||
Rolling (2022-04-17)
|
Rolling (2022-04-17)
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ set(c_srcs
|
|||||||
"src/basic_math/esp_nn_add_ansi.c"
|
"src/basic_math/esp_nn_add_ansi.c"
|
||||||
"src/basic_math/esp_nn_mul_ansi.c"
|
"src/basic_math/esp_nn_mul_ansi.c"
|
||||||
"src/convolution/esp_nn_conv_ansi.c"
|
"src/convolution/esp_nn_conv_ansi.c"
|
||||||
|
"src/convolution/esp_nn_conv_opt.c"
|
||||||
"src/convolution/esp_nn_depthwise_conv_ansi.c"
|
"src/convolution/esp_nn_depthwise_conv_ansi.c"
|
||||||
|
"src/convolution/esp_nn_depthwise_conv_opt.c"
|
||||||
"src/fully_connected/esp_nn_fully_connected_ansi.c"
|
"src/fully_connected/esp_nn_fully_connected_ansi.c"
|
||||||
"src/softmax/esp_nn_softmax_ansi.c"
|
"src/softmax/esp_nn_softmax_ansi.c"
|
||||||
"src/softmax/esp_nn_softmax_opt.c"
|
"src/softmax/esp_nn_softmax_opt.c"
|
||||||
@@ -23,7 +25,7 @@ if(CONFIG_IDF_TARGET_ESP32S3)
|
|||||||
"src/convolution/esp_nn_conv_esp32s3.c"
|
"src/convolution/esp_nn_conv_esp32s3.c"
|
||||||
"src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c"
|
"src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c"
|
||||||
"src/convolution/esp_nn_conv_s16_mult8_esp32s3.S"
|
"src/convolution/esp_nn_conv_s16_mult8_esp32s3.S"
|
||||||
"src/convolution/esp_nn_conv_s16_mult8_1x1_esp32s3.S"
|
"src/convolution/esp_nn_conv_s8_mult8_1x1_esp32s3.S"
|
||||||
"src/convolution/esp_nn_conv_s16_mult4_1x1_esp32s3.S"
|
"src/convolution/esp_nn_conv_s16_mult4_1x1_esp32s3.S"
|
||||||
"src/convolution/esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3.S"
|
"src/convolution/esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3.S"
|
||||||
"src/convolution/esp_nn_depthwise_conv_s16_mult1_esp32s3.S"
|
"src/convolution/esp_nn_depthwise_conv_s16_mult1_esp32s3.S"
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ choice NN_OPTIMIZATIONS
|
|||||||
help
|
help
|
||||||
Use ANSI-C versions for verification and debug purpose.
|
Use ANSI-C versions for verification and debug purpose.
|
||||||
Optimisations are automatically picked up for a chipset.
|
Optimisations are automatically picked up for a chipset.
|
||||||
For ESP32-S3, assembly Optimisations are selected.
|
For ESP32-S3, assembly optimisations are selected.
|
||||||
For ESP32, just the ANSI C versions are selected for now.
|
For other platforms(viz., ESP32, ESP32-C3), generic optimisations are used.
|
||||||
|
|
||||||
config NN_ANSI_C
|
config NN_ANSI_C
|
||||||
bool "ANSI C"
|
bool "ANSI C"
|
||||||
@@ -17,8 +17,8 @@ config NN_OPTIMIZED
|
|||||||
bool "Optimized versions"
|
bool "Optimized versions"
|
||||||
help
|
help
|
||||||
Optimisations are automatically picked up for a chipset.
|
Optimisations are automatically picked up for a chipset.
|
||||||
For ESP32-S3, assembly Optimisations are selected.
|
For ESP32-S3, assembly optimisations are selected.
|
||||||
For ESP32, just the ANSI C versions are selected for now.
|
For other platforms(viz., ESP32, ESP32-C3), generic optimisations are used.
|
||||||
endchoice
|
endchoice
|
||||||
|
|
||||||
config NN_OPTIMIZATIONS
|
config NN_OPTIMIZATIONS
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ The library contains optimised NN (Neural Network) functions for various Espress
|
|||||||
|
|
||||||
* Supported ESP chipsets include:
|
* Supported ESP chipsets include:
|
||||||
* ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
|
* ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
|
||||||
* ESP32 (ANSI C versions)
|
* ESP32 (Generic optimisations)
|
||||||
|
* ESP32-C3 (Generic optimisations)
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
@@ -39,8 +40,8 @@ The library contains optimised NN (Neural Network) functions for various Espress
|
|||||||
* Optimized versions
|
* Optimized versions
|
||||||
* ANSI C
|
* ANSI C
|
||||||
|
|
||||||
* Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for ESP32, ANSI-C versions are selected by default.
|
* Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for other chipsets (viz., ESP32, ESP32-C3), generic optimisations are selected.
|
||||||
* For debugging purposes, you may want to select `ANSI C`
|
* For debugging purposes, you may want to select `ANSI C` reference versions.
|
||||||
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#if defined(CONFIG_NN_OPTIMIZED)
|
#if defined(CONFIG_NN_OPTIMIZED)
|
||||||
|
// select apt optimisations
|
||||||
#ifdef CONFIG_IDF_TARGET_ESP32S3
|
#ifdef CONFIG_IDF_TARGET_ESP32S3
|
||||||
#define ARCH_ESP32_S3 1
|
#define ARCH_ESP32_S3 1
|
||||||
#endif
|
#endif
|
||||||
@@ -31,12 +32,11 @@ extern "C" {
|
|||||||
#include "esp_nn_ansi_headers.h"
|
#include "esp_nn_ansi_headers.h"
|
||||||
|
|
||||||
#if defined(CONFIG_NN_OPTIMIZED)
|
#if defined(CONFIG_NN_OPTIMIZED)
|
||||||
#ifdef ARCH_ESP32_S3
|
#if defined(ARCH_ESP32_S3)
|
||||||
#include "esp_nn_esp32s3.h"
|
#include "esp_nn_esp32s3.h"
|
||||||
#endif
|
#else // for other platforms use generic optimisations
|
||||||
#ifdef ARCH_ESP32
|
#include "esp_nn_generic_opt.h"
|
||||||
#include "esp_nn_esp32.h"
|
#endif // #if defined(ARCH_ESP32_S3)
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
#include "esp_nn_ansi_c.h"
|
#include "esp_nn_ansi_c.h"
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "esp_nn_defs.h"
|
||||||
#include "esp_nn_ansi_headers.h"
|
#include "esp_nn_ansi_headers.h"
|
||||||
|
|
||||||
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
|
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
|
||||||
|
|||||||
@@ -18,8 +18,7 @@
|
|||||||
* @file Header definitions to include for esp_nn reference functions
|
* @file Header definitions to include for esp_nn reference functions
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <stdint.h>
|
#include "esp_nn_defs.h"
|
||||||
|
|
||||||
/************************** Basic math functions ****************************/
|
/************************** Basic math functions ****************************/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -81,28 +80,15 @@ void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
|
|||||||
* optimization notes: Though input_offset is int32 type,
|
* optimization notes: Though input_offset is int32 type,
|
||||||
* offset values are contained in 8 bits [-128, 127]
|
* offset values are contained in 8 bits [-128, 127]
|
||||||
*/
|
*/
|
||||||
void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
|
void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const uint16_t ch_mult,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const dw_conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data);
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 2d-convolution channelwise
|
* @brief 2d-convolution channelwise
|
||||||
@@ -112,43 +98,26 @@ void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
|
|||||||
* inputs type: int8_t, output: int8_t
|
* inputs type: int8_t, output: int8_t
|
||||||
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
|
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
|
||||||
*/
|
*/
|
||||||
void esp_nn_conv_s8_ansi(const int8_t *input_data,
|
void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data);
|
||||||
const uint16_t out_channels,
|
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max);
|
|
||||||
|
|
||||||
int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
|
int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_ch,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t out_ch,
|
const conv_params_t *conv_params);
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht);
|
|
||||||
void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
|
void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
|
||||||
|
|
||||||
int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
|
int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t ch_mult,
|
const dw_conv_params_t *conv_params);
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht);
|
|
||||||
void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
|
void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
|
||||||
|
|
||||||
/************************** Activation functions *****************************/
|
/************************** Activation functions *****************************/
|
||||||
@@ -252,9 +221,6 @@ int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t h
|
|||||||
*/
|
*/
|
||||||
void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
|
void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
|
||||||
|
|
||||||
/* ANSI C function to be hooked up when optimised version needed */
|
|
||||||
void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief reference softmax function
|
* @brief reference softmax function
|
||||||
*
|
*
|
||||||
@@ -268,6 +234,66 @@ void esp_nn_softmax_s8_ansi(const int8_t *input_data,
|
|||||||
const int32_t diff_min,
|
const int32_t diff_min,
|
||||||
int8_t *output_data);
|
int8_t *output_data);
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////// Generic optimisations /////////////////////////////
|
||||||
|
|
||||||
|
/************************** Convolution functions *****************************/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 2d-convolution channelwise optimized version
|
||||||
|
*
|
||||||
|
* @note operation: result += (input + offset) * filter
|
||||||
|
*
|
||||||
|
* inputs type: int8_t, output: int8_t
|
||||||
|
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
|
||||||
|
*/
|
||||||
|
void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
|
||||||
|
const int8_t *input_data,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const int8_t *filter_data,
|
||||||
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
int8_t *out_data,
|
||||||
|
const conv_params_t *conv_params,
|
||||||
|
const quant_data_t *quant_data);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief depthwise convolution per channel optimized version
|
||||||
|
*
|
||||||
|
* @note inputs type: int8_t, output: int8_t
|
||||||
|
* Version used in tflite is per channel.
|
||||||
|
* This version follows the same footsprints.
|
||||||
|
* Meaning, it has per out_channel shift and multiplier for
|
||||||
|
* requantization
|
||||||
|
*
|
||||||
|
* optimization notes: Though input_offset is int32 type,
|
||||||
|
* offset values are contained in 8 bits [-128, 127]
|
||||||
|
*/
|
||||||
|
void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
|
||||||
|
const int8_t *input_data,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const int8_t *filter_data,
|
||||||
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
int8_t *out_data,
|
||||||
|
const dw_conv_params_t *conv_params,
|
||||||
|
const quant_data_t *quant_data);
|
||||||
|
|
||||||
|
int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
const conv_params_t *conv_params);
|
||||||
|
void esp_nn_set_conv_scratch_buf_opt(const void *buf);
|
||||||
|
|
||||||
|
int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
const dw_conv_params_t *conv_params);
|
||||||
|
void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf);
|
||||||
|
|
||||||
|
/* ANSI C function to be hooked up when optimised version needed */
|
||||||
|
void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief optimised version of softmax function
|
* @brief optimised version of softmax function
|
||||||
*
|
*
|
||||||
|
|||||||
83
code/components/esp-nn/include/esp_nn_defs.h
Normal file
83
code/components/esp-nn/include/esp_nn_defs.h
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief structure to club data dims
|
||||||
|
* this structure can be used for input, output and filter
|
||||||
|
*/
|
||||||
|
typedef struct data_dims {
|
||||||
|
int32_t width;
|
||||||
|
int32_t height;
|
||||||
|
int32_t channels;
|
||||||
|
|
||||||
|
int32_t extra; // can be used as batch or any other param
|
||||||
|
} data_dims_t;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 2d data structure (width, height)
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
typedef struct data_2d {
|
||||||
|
int32_t width;
|
||||||
|
int32_t height;
|
||||||
|
} data_2d_t;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief min/max activation
|
||||||
|
*/
|
||||||
|
typedef struct act_params {
|
||||||
|
int32_t min;
|
||||||
|
int32_t max;
|
||||||
|
} act_params_t;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief per channel quant data
|
||||||
|
*
|
||||||
|
* @note number of shift and mult elements are equal to output channels
|
||||||
|
*/
|
||||||
|
typedef struct quant_data {
|
||||||
|
int32_t *shift;
|
||||||
|
int32_t *mult;
|
||||||
|
} quant_data_t;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief params specific to convolution 2d
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
typedef struct conv_params {
|
||||||
|
int32_t in_offset;
|
||||||
|
int32_t out_offset;
|
||||||
|
data_2d_t stride;
|
||||||
|
data_2d_t padding;
|
||||||
|
data_2d_t dilation;
|
||||||
|
act_params_t activation;
|
||||||
|
} conv_params_t;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief params specific to depthwise convolution 2d
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
typedef struct dw_conv_params {
|
||||||
|
int32_t in_offset;
|
||||||
|
int32_t out_offset;
|
||||||
|
int32_t ch_mult; // channel multiplier. (in_ch * ch_mult = out_ch)
|
||||||
|
data_2d_t stride;
|
||||||
|
data_2d_t padding;
|
||||||
|
data_2d_t dilation;
|
||||||
|
act_params_t activation;
|
||||||
|
} dw_conv_params_t;
|
||||||
@@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <stdint.h>
|
#include "esp_nn_defs.h"
|
||||||
#include "esp_nn_ansi_headers.h"
|
#include "esp_nn_ansi_headers.h"
|
||||||
|
|
||||||
/************************** Basic math functions *****************************/
|
/************************** Basic math functions *****************************/
|
||||||
@@ -85,28 +85,15 @@ void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
|
|||||||
* optimization notes: Though input_offset is int32 type,
|
* optimization notes: Though input_offset is int32 type,
|
||||||
* offset values are contained in 8 bits [-128, 127]
|
* offset values are contained in 8 bits [-128, 127]
|
||||||
*/
|
*/
|
||||||
void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
|
void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const uint16_t ch_mult,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
int8_t *out_data,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t out_wd,
|
int8_t *output_data,
|
||||||
const uint16_t out_ht,
|
const dw_conv_params_t *conv_params,
|
||||||
const int32_t out_offset,
|
const quant_data_t *quant_data);
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 2d - convolution channelwise
|
* @brief 2d - convolution channelwise
|
||||||
@@ -116,43 +103,26 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
|
|||||||
* inputs type: int8_t, output: int8_t
|
* inputs type: int8_t, output: int8_t
|
||||||
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
|
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
|
||||||
*/
|
*/
|
||||||
void esp_nn_conv_s8_esp32s3(const int8_t *input_data,
|
void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
int8_t *out_data,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t out_wd,
|
int8_t *output_data,
|
||||||
const uint16_t out_ht,
|
const conv_params_t *conv_params,
|
||||||
const uint16_t out_channels,
|
const quant_data_t *quant_data);
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max);
|
|
||||||
|
|
||||||
int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
|
int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_ch,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t out_ch,
|
const conv_params_t *conv_params);
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht);
|
|
||||||
void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
|
void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
|
||||||
|
|
||||||
int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
|
int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t ch_mult,
|
const dw_conv_params_t *conv_params);
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht);
|
|
||||||
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
|
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
|
||||||
|
|
||||||
/************************** Pooling functions *****************************/
|
/************************** Pooling functions *****************************/
|
||||||
|
|||||||
@@ -13,28 +13,27 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file Header definitions to include for esp_nn optimized functions for
|
* @file Header definitions to include for esp_nn generic optimisations
|
||||||
* the ESP32 platform.
|
* For functions which not having optimisations, _ansi versions are picked.
|
||||||
* We are hooking up just the C versions for now.
|
|
||||||
* The file hence is exactly same as `esp_nn_ansi_c.h`
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "esp_nn_defs.h"
|
||||||
#include "esp_nn_ansi_headers.h"
|
#include "esp_nn_ansi_headers.h"
|
||||||
|
|
||||||
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
|
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
|
||||||
#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
|
#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
|
||||||
|
|
||||||
#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
|
#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_opt
|
||||||
|
|
||||||
#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
|
#define esp_nn_conv_s8 esp_nn_conv_s8_opt
|
||||||
|
|
||||||
#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
|
#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_opt
|
||||||
#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
|
#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_opt
|
||||||
|
|
||||||
#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
|
#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_opt
|
||||||
#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
|
#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_opt
|
||||||
|
|
||||||
#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
|
#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
|
||||||
|
|
||||||
@@ -41,15 +41,39 @@
|
|||||||
|
|
||||||
__NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
|
__NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
|
||||||
{
|
{
|
||||||
|
#if CONFIG_IDF_TARGET_ARCH_XTENSA
|
||||||
__asm__ volatile("nsau %0, %0" : "+r" (in));
|
__asm__ volatile("nsau %0, %0" : "+r" (in));
|
||||||
return in;
|
return in;
|
||||||
}
|
#elif defined(__GNUC__)
|
||||||
|
return __builtin_clz(in);
|
||||||
__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
|
#else
|
||||||
{
|
int32_t count = 32;
|
||||||
int32_t sign = (int32_t) (val64 >> 63);
|
uint32_t x = in, y = in >> 16;
|
||||||
int32_t to_add = sign & ((1ul << 31) - 1);
|
if (y != 0) {
|
||||||
return (int32_t) ((int64_t) (val64 + to_add) >> 31);
|
count -= 16;
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
y = x >> 8;
|
||||||
|
if (y != 0) {
|
||||||
|
count -= 8;
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
y = x >> 4;
|
||||||
|
if (y != 0) {
|
||||||
|
count -= 4;
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
y = x >> 2;
|
||||||
|
if (y != 0) {
|
||||||
|
count -= 2;
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
y = x >> 1;
|
||||||
|
if (y != 0) {
|
||||||
|
return count - 2;
|
||||||
|
}
|
||||||
|
return count - x;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -57,8 +81,19 @@ __NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
|
|||||||
*/
|
*/
|
||||||
__NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
|
__NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
|
||||||
{
|
{
|
||||||
|
#if CONFIG_IDF_TARGET_ARCH_XTENSA
|
||||||
__asm__ volatile("clamps %0, %0, 7" : "+a"(in));
|
__asm__ volatile("clamps %0, %0, 7" : "+a"(in));
|
||||||
return in;
|
return in;
|
||||||
|
#else
|
||||||
|
return max(INT8_MIN, min(in, INT8_MAX));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
|
||||||
|
{
|
||||||
|
int32_t sign = (int32_t) (val64 >> 63);
|
||||||
|
int32_t to_add = sign & ((1ul << 31) - 1);
|
||||||
|
return (int32_t) ((int64_t) (val64 + to_add) >> 31);
|
||||||
}
|
}
|
||||||
|
|
||||||
__NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
|
__NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
|
||||||
@@ -144,7 +179,7 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
|
|||||||
const uint16_t pad_ht)
|
const uint16_t pad_ht)
|
||||||
{
|
{
|
||||||
/* memset with pad_val */
|
/* memset with pad_val */
|
||||||
memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels * 2);
|
memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels);
|
||||||
dst += (pad_wd + input_wd + pad_wd) * channels;
|
dst += (pad_wd + input_wd + pad_wd) * channels;
|
||||||
|
|
||||||
for (int i = 0; i < input_ht; i++) {
|
for (int i = 0; i < input_ht; i++) {
|
||||||
@@ -156,7 +191,6 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
|
||||||
static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
|
static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
|
||||||
const uint16_t input_wd,
|
const uint16_t input_wd,
|
||||||
const uint16_t input_ht,
|
const uint16_t input_ht,
|
||||||
@@ -169,13 +203,16 @@ static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
|
|||||||
for (int j = 0; j < input_wd * channels; j++) {
|
for (int j = 0; j < input_wd * channels; j++) {
|
||||||
*dst++ = *src++;
|
*dst++ = *src++;
|
||||||
}
|
}
|
||||||
memset(dst, pad_val, pad_wd * channels);
|
if (pad_wd) {
|
||||||
dst += pad_wd * channels;
|
memset(dst, pad_val, pad_wd * channels);
|
||||||
|
dst += pad_wd * channels;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/* pad end `pad_ht` lines at end */
|
/* pad end `pad_ht` lines at end */
|
||||||
memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
|
if (pad_ht) {
|
||||||
|
memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief convert 8 bit input data to 16 bit
|
* @brief convert 8 bit input data to 16 bit
|
||||||
|
|||||||
@@ -12,16 +12,14 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <esp_nn_defs.h>
|
||||||
|
|
||||||
#include <common_functions.h>
|
#include <common_functions.h>
|
||||||
|
|
||||||
int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
|
int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_ch,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t out_ch,
|
const conv_params_t *conv_params)
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht)
|
|
||||||
{
|
{
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@@ -108,29 +106,35 @@ void esp_nn_conv_u8_ansi(const uint8_t *input_data,
|
|||||||
* Assumption 2: Pointers are valid
|
* Assumption 2: Pointers are valid
|
||||||
* Assumption 3: dialation width = 1
|
* Assumption 3: dialation width = 1
|
||||||
*/
|
*/
|
||||||
void esp_nn_conv_s8_ansi(const int8_t *input_data,
|
void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data)
|
||||||
const uint16_t out_channels,
|
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t in_channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const uint16_t out_channels = output_dims->channels;
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
|
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
|
||||||
|
|
||||||
for (out_y = 0; out_y < out_ht; out_y++) {
|
for (out_y = 0; out_y < out_ht; out_y++) {
|
||||||
|
|||||||
@@ -12,30 +12,30 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <esp_nn_defs.h>
|
||||||
|
|
||||||
#include <common_functions.h>
|
#include <common_functions.h>
|
||||||
|
|
||||||
static int16_t *scratch_buffer = NULL;
|
static int16_t *scratch_buffer = NULL;
|
||||||
|
|
||||||
extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data,
|
extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data,
|
||||||
const uint16_t input_wd,
|
const uint16_t input_wd,
|
||||||
const uint16_t input_ht,
|
const uint16_t input_ht,
|
||||||
const uint16_t in_channels,
|
const uint16_t in_channels,
|
||||||
const int32_t input_offset,
|
const int32_t input_offset,
|
||||||
const int16_t *filter_data,
|
const int8_t *filter_aligned,
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const uint16_t out_wd,
|
||||||
const uint16_t out_ht,
|
const uint16_t out_ht,
|
||||||
const uint16_t out_channels,
|
const uint16_t out_channels,
|
||||||
const int32_t out_offset,
|
const int32_t out_offset,
|
||||||
const int32_t *out_shift,
|
const int32_t *out_shift,
|
||||||
const int32_t *out_mult,
|
const int32_t *out_mult,
|
||||||
const int32_t activation_min,
|
const int32_t activation_min,
|
||||||
const int32_t activation_max,
|
const int32_t activation_max,
|
||||||
void *buffer /* scratch buffer */);
|
void *buffer /* scratch buffer */);
|
||||||
|
|
||||||
extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
|
extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
|
||||||
const uint16_t input_wd,
|
const uint16_t input_wd,
|
||||||
@@ -81,34 +81,40 @@ extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int1
|
|||||||
|
|
||||||
extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
|
extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
|
||||||
|
|
||||||
static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
|
static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data)
|
||||||
const uint16_t out_channels,
|
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t in_ch = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const uint16_t out_ch = output_dims->channels;
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
|
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
|
||||||
|
|
||||||
for (out_y = 0; out_y < out_ht; out_y++) {
|
for (out_y = 0; out_y < out_ht; out_y++) {
|
||||||
for (out_x = 0; out_x < out_wd; out_x++) {
|
for (out_x = 0; out_x < out_wd; out_x++) {
|
||||||
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
|
for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) {
|
||||||
int32_t conv_out = 0;
|
int32_t conv_out = 0;
|
||||||
|
|
||||||
const int32_t base_y = stride_ht * out_y - pad_ht;
|
const int32_t base_y = stride_ht * out_y - pad_ht;
|
||||||
@@ -124,10 +130,10 @@ static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
|
|||||||
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
||||||
const int32_t in_row = base_y + filter_y_idx;
|
const int32_t in_row = base_y + filter_y_idx;
|
||||||
const int32_t in_col = base_x + filter_x_idx;
|
const int32_t in_col = base_x + filter_x_idx;
|
||||||
int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
|
int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch;
|
||||||
int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
|
int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd +
|
||||||
(filter_y_idx * filter_wd + filter_x_idx) * in_channels;
|
(filter_y_idx * filter_wd + filter_x_idx) * in_ch;
|
||||||
for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
|
for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) {
|
||||||
conv_out +=
|
conv_out +=
|
||||||
(input_data[input_base_offset + in_ch_idx] + input_offset) *
|
(input_data[input_base_offset + in_ch_idx] + input_offset) *
|
||||||
filter_data[filter_base_offset + in_ch_idx];
|
filter_data[filter_base_offset + in_ch_idx];
|
||||||
@@ -332,18 +338,35 @@ static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
|
int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t in_ch,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t out_ch,
|
const conv_params_t *conv_params)
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t in_ch = input_dims->channels;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_ch = output_dims->channels;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
|
||||||
int filter_size = filter_wd * filter_ht * in_ch * out_ch;
|
int filter_size = filter_wd * filter_ht * in_ch * out_ch;
|
||||||
int input_size = input_wd * input_ht * in_ch;
|
int input_size = input_wd * input_ht * in_ch;
|
||||||
int transpose_buf_size = 8 * in_ch; /* to store intermediate data */
|
|
||||||
|
int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */
|
||||||
|
if (input_wd * input_ht < 8) {
|
||||||
|
transpose_buf_size = 0; // not using this for leftover
|
||||||
|
}
|
||||||
int align_buf_size = 32; /* extra buffer for alignment */
|
int align_buf_size = 32; /* extra buffer for alignment */
|
||||||
return 2 * (filter_size + input_size + transpose_buf_size) + align_buf_size;
|
if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
|
||||||
|
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
|
||||||
|
return filter_size + transpose_buf_size + align_buf_size;
|
||||||
|
}
|
||||||
|
return 2 * (filter_size + input_size) + transpose_buf_size + align_buf_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
|
void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
|
||||||
@@ -351,29 +374,35 @@ void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
|
|||||||
scratch_buffer = (int16_t *) buf;
|
scratch_buffer = (int16_t *) buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
void esp_nn_conv_s8_esp32s3(const int8_t *input,
|
void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data)
|
||||||
const uint16_t out_channels,
|
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const uint16_t out_channels = output_dims->channels;
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
int filter_size = filter_wd * filter_ht * channels * out_channels;
|
int filter_size = filter_wd * filter_ht * channels * out_channels;
|
||||||
int input_size = input_wd * input_ht * channels;
|
int input_size = input_wd * input_ht * channels;
|
||||||
int align_len = 16 - (filter_size & 15);
|
int align_len = 16 - (filter_size & 15);
|
||||||
@@ -387,15 +416,16 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
|
|||||||
|
|
||||||
if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
|
if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
|
||||||
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
|
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
|
||||||
int scratch_offset = (int) (filter_data16 + filter_size);
|
int8_t *filter_aligned = (int8_t *) scratch_buffer;
|
||||||
|
int scratch_offset = (int) (filter_aligned + filter_size);
|
||||||
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
|
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
|
||||||
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
|
memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address
|
||||||
esp_nn_conv_s16_mult8_1x1_esp32s3(
|
esp_nn_conv_s8_mult8_1x1_esp32s3(
|
||||||
input, input_wd, input_ht, channels, input_offset, filter_data16,
|
input, input_wd, input_ht, channels, input_offset, filter_aligned,
|
||||||
bias, out_data, out_wd, out_ht, out_channels, out_offset,
|
bias, out_data, out_wd, out_ht, out_channels, out_offset,
|
||||||
out_shift, out_mult, activation_min, activation_max, scratch_buf);
|
out_shift, out_mult, activation_min, activation_max, scratch_buf);
|
||||||
} else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
|
} else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
|
||||||
(input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */
|
(input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */
|
||||||
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
|
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
|
||||||
int scratch_offset = (int) (input_data16 + input_size);
|
int scratch_offset = (int) (input_data16 + input_size);
|
||||||
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
|
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
|
||||||
@@ -427,10 +457,7 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
/* Basic unrolled version */
|
/* Basic unrolled version */
|
||||||
esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset,
|
esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data,
|
||||||
pad_wd, pad_ht, stride_wd, stride_ht,
|
bias, output_dims, out_data, conv_params, quant_data);
|
||||||
filter_data, filter_wd, filter_ht, bias,
|
|
||||||
out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
|
|
||||||
out_mult, activation_min, activation_max);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
179
code/components/esp-nn/src/convolution/esp_nn_conv_opt.c
Normal file
179
code/components/esp-nn/src/convolution/esp_nn_conv_opt.c
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <esp_nn_defs.h>
|
||||||
|
|
||||||
|
#include <common_functions.h>
|
||||||
|
|
||||||
|
int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
const conv_params_t *conv_params)
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void esp_nn_set_conv_scratch_buf_opt(const void *buf)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__attribute__ ((noinline))
|
||||||
|
static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims,
|
||||||
|
const int8_t *input_data,
|
||||||
|
const int8_t *filter_data,
|
||||||
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
int8_t *out_data,
|
||||||
|
const conv_params_t *conv_params,
|
||||||
|
const quant_data_t *quant_data)
|
||||||
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t in_channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const uint16_t out_channels = output_dims->channels;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
|
for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) {
|
||||||
|
for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) {
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int8_t *filter_ptr = filter_data;
|
||||||
|
const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels;
|
||||||
|
int32_t out_ch_idx = 0;
|
||||||
|
for (; out_ch_idx < out_channels; out_ch_idx++) {
|
||||||
|
int32_t conv_out = 0;
|
||||||
|
|
||||||
|
const int8_t *input_ptr = input_base_ptr;
|
||||||
|
|
||||||
|
int32_t in_ch_idx = 0;
|
||||||
|
for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
}
|
||||||
|
for (; in_ch_idx < in_channels; in_ch_idx ++) {
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
}
|
||||||
|
if (bias) {
|
||||||
|
conv_out += bias[out_ch_idx];
|
||||||
|
}
|
||||||
|
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
|
||||||
|
conv_out += out_offset;
|
||||||
|
conv_out = max(conv_out, activation_min);
|
||||||
|
conv_out = min(conv_out, activation_max);
|
||||||
|
*out_data++ = (int8_t) conv_out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Assumption 1: i/p channels == o/p channels
|
||||||
|
* Assumption 2: Pointers are valid
|
||||||
|
* Assumption 3: dialation width = 1
|
||||||
|
*/
|
||||||
|
void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
|
||||||
|
const int8_t *input_data,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const int8_t *filter_data,
|
||||||
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
int8_t *out_data,
|
||||||
|
const conv_params_t *conv_params,
|
||||||
|
const quant_data_t *quant_data)
|
||||||
|
{
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
|
||||||
|
if (filter_wd == 1 && filter_ht == 1) {
|
||||||
|
esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias,
|
||||||
|
output_dims, out_data, conv_params, quant_data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t in_channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const uint16_t out_channels = output_dims->channels;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
|
int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx;
|
||||||
|
|
||||||
|
for (out_y = 0; out_y < out_ht; out_y++) {
|
||||||
|
for (out_x = 0; out_x < out_wd; out_x++) {
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
|
||||||
|
int32_t conv_out = 0;
|
||||||
|
|
||||||
|
const int32_t base_y = stride_ht * out_y - pad_ht;
|
||||||
|
const int32_t base_x = stride_wd * out_x - pad_wd;
|
||||||
|
|
||||||
|
const int32_t filter_y_start = max(0, -base_y);
|
||||||
|
const int32_t filter_x_start = max(0, -base_x);
|
||||||
|
|
||||||
|
const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
|
||||||
|
const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
|
||||||
|
|
||||||
|
for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
|
||||||
|
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
||||||
|
const int32_t in_row = base_y + filter_y_idx;
|
||||||
|
const int32_t in_col = base_x + filter_x_idx;
|
||||||
|
|
||||||
|
const int8_t *input_ptr = input_data +
|
||||||
|
(in_row * input_wd + in_col) * in_channels;
|
||||||
|
const int8_t *filter_ptr = filter_data +
|
||||||
|
out_ch_idx * in_channels * filter_ht * filter_wd +
|
||||||
|
(filter_y_idx * filter_wd + filter_x_idx) * in_channels;
|
||||||
|
int32_t in_ch_idx = 0;
|
||||||
|
for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
}
|
||||||
|
for (; in_ch_idx < in_channels; in_ch_idx ++) {
|
||||||
|
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bias) {
|
||||||
|
conv_out += bias[out_ch_idx];
|
||||||
|
}
|
||||||
|
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
|
||||||
|
conv_out += out_offset;
|
||||||
|
conv_out = max(conv_out, activation_min);
|
||||||
|
conv_out = min(conv_out, activation_max);
|
||||||
|
*out_data++ = (int8_t) conv_out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,16 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <esp_nn_defs.h>
|
||||||
|
|
||||||
#include <common_functions.h>
|
#include <common_functions.h>
|
||||||
|
|
||||||
int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
|
int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t ch_mult,
|
const dw_conv_params_t *conv_params)
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht)
|
|
||||||
{
|
{
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@@ -31,29 +28,35 @@ void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf)
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
|
void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
|
||||||
const uint16_t input_wd,
|
const int8_t *input_data,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
|
||||||
const int32_t input_offset,
|
|
||||||
const uint16_t pad_wd,
|
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const uint16_t ch_mult,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const dw_conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data)
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
const uint16_t ch_mult = conv_params->ch_mult;
|
||||||
|
|
||||||
int out_idx = 0;
|
int out_idx = 0;
|
||||||
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
|
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
|
||||||
const int16_t base_y = (out_y * stride_ht) - pad_ht;
|
const int16_t base_y = (out_y * stride_ht) - pad_ht;
|
||||||
|
|||||||
@@ -0,0 +1,291 @@
|
|||||||
|
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <esp_nn_defs.h>
|
||||||
|
#include <common_functions.h>
|
||||||
|
|
||||||
|
int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
const dw_conv_params_t *conv_params)
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common channel multiplier == 1 case */
|
||||||
|
__attribute__ ((noinline))
|
||||||
|
static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims,
|
||||||
|
const int8_t *input_data,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const int8_t *filter_data,
|
||||||
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
int8_t *out_data,
|
||||||
|
const dw_conv_params_t *conv_params,
|
||||||
|
const quant_data_t *quant_data)
|
||||||
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
|
int out_idx = 0;
|
||||||
|
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
|
||||||
|
const int16_t base_y = (out_y * stride_ht) - pad_ht;
|
||||||
|
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
|
||||||
|
const int16_t base_x = (out_x * stride_wd) - pad_wd;
|
||||||
|
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
|
||||||
|
/* Select filter so as the point doesn't lie outside block */
|
||||||
|
int filter_y_start = max(0, -base_y);
|
||||||
|
int filter_x_start = max(0, -base_x);
|
||||||
|
int filter_y_end = min(filter_ht, input_ht - base_y);
|
||||||
|
int filter_x_end = min(filter_wd, input_wd - base_x);
|
||||||
|
|
||||||
|
int ch_idx = 0;
|
||||||
|
for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop
|
||||||
|
int32_t result0 = 0;
|
||||||
|
int32_t result1 = 0;
|
||||||
|
int32_t result2 = 0;
|
||||||
|
int32_t result3 = 0;
|
||||||
|
|
||||||
|
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
|
||||||
|
const int32_t idx_y = base_y + filter_y_idx;
|
||||||
|
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
||||||
|
const int32_t idx_x = base_x + filter_x_idx;
|
||||||
|
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
|
||||||
|
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
|
||||||
|
int32_t input_val0 = input_data[input_index + 0] + input_offset;
|
||||||
|
int32_t input_val1 = input_data[input_index + 1] + input_offset;
|
||||||
|
int32_t input_val2 = input_data[input_index + 2] + input_offset;
|
||||||
|
int32_t input_val3 = input_data[input_index + 3] + input_offset;
|
||||||
|
int32_t filter_val0 = filter_data[filter_index + 0];
|
||||||
|
int32_t filter_val1 = filter_data[filter_index + 1];
|
||||||
|
int32_t filter_val2 = filter_data[filter_index + 2];
|
||||||
|
int32_t filter_val3 = filter_data[filter_index + 3];
|
||||||
|
result0 += input_val0 * filter_val0;
|
||||||
|
result1 += input_val1 * filter_val1;
|
||||||
|
result2 += input_val2 * filter_val2;
|
||||||
|
result3 += input_val3 * filter_val3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bias) {
|
||||||
|
result0 += bias[ch_idx + 0];
|
||||||
|
result1 += bias[ch_idx + 1];
|
||||||
|
result2 += bias[ch_idx + 2];
|
||||||
|
result3 += bias[ch_idx + 3];
|
||||||
|
}
|
||||||
|
result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
|
||||||
|
result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
|
||||||
|
result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
|
||||||
|
result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
|
||||||
|
|
||||||
|
result0 += out_offset;
|
||||||
|
result1 += out_offset;
|
||||||
|
result2 += out_offset;
|
||||||
|
result3 += out_offset;
|
||||||
|
|
||||||
|
result0 = max(result0, activation_min);
|
||||||
|
result1 = max(result1, activation_min);
|
||||||
|
result2 = max(result2, activation_min);
|
||||||
|
result3 = max(result3, activation_min);
|
||||||
|
|
||||||
|
result0 = min(result0, activation_max);
|
||||||
|
result1 = min(result1, activation_max);
|
||||||
|
result2 = min(result2, activation_max);
|
||||||
|
result3 = min(result3, activation_max);
|
||||||
|
|
||||||
|
out_data[out_idx++] = result0;
|
||||||
|
out_data[out_idx++] = result1;
|
||||||
|
out_data[out_idx++] = result2;
|
||||||
|
out_data[out_idx++] = result3;
|
||||||
|
}
|
||||||
|
for (; ch_idx < channels; ch_idx++) {//channel_loop
|
||||||
|
int32_t result = 0;
|
||||||
|
|
||||||
|
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
|
||||||
|
const int32_t idx_y = base_y + filter_y_idx;
|
||||||
|
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
||||||
|
const int32_t idx_x = base_x + filter_x_idx;
|
||||||
|
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
|
||||||
|
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
|
||||||
|
int32_t input_val = input_data[input_index] + input_offset;
|
||||||
|
int32_t filter_val = filter_data[filter_index];
|
||||||
|
result += input_val * filter_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bias) {
|
||||||
|
result += bias[ch_idx];
|
||||||
|
}
|
||||||
|
result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
|
||||||
|
result += out_offset;
|
||||||
|
result = max(result, activation_min);
|
||||||
|
result = min(result, activation_max);
|
||||||
|
|
||||||
|
out_data[out_idx++] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
|
||||||
|
const int8_t *input_data,
|
||||||
|
const data_dims_t *filter_dims,
|
||||||
|
const int8_t *filter_data,
|
||||||
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
|
int8_t *out_data,
|
||||||
|
const dw_conv_params_t *conv_params,
|
||||||
|
const quant_data_t *quant_data)
|
||||||
|
{
|
||||||
|
const uint16_t ch_mult = conv_params->ch_mult;
|
||||||
|
if (ch_mult == 1) {
|
||||||
|
esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data,
|
||||||
|
bias, output_dims, out_data, conv_params, quant_data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
|
||||||
|
int out_idx = 0;
|
||||||
|
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
|
||||||
|
const int16_t base_y = (out_y * stride_ht) - pad_ht;
|
||||||
|
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
|
||||||
|
const int16_t base_x = (out_x * stride_wd) - pad_wd;
|
||||||
|
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
|
||||||
|
/* Select filter so as the point doesn't lie outside block */
|
||||||
|
int filter_y_start = max(0, -base_y);
|
||||||
|
int filter_x_start = max(0, -base_x);
|
||||||
|
int filter_y_end = min(filter_ht, input_ht - base_y);
|
||||||
|
int filter_x_end = min(filter_wd, input_wd - base_x);
|
||||||
|
|
||||||
|
for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
|
||||||
|
int ch_mult_idx = 0;
|
||||||
|
for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
|
||||||
|
int32_t result0 = 0;
|
||||||
|
int32_t result1 = 0;
|
||||||
|
int32_t result2 = 0;
|
||||||
|
int32_t result3 = 0;
|
||||||
|
const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
|
||||||
|
|
||||||
|
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
|
||||||
|
const int32_t idx_y = base_y + filter_y_idx;
|
||||||
|
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
||||||
|
const int32_t idx_x = base_x + filter_x_idx;
|
||||||
|
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
|
||||||
|
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
|
||||||
|
int32_t input_val = input_data[input_index] + input_offset;
|
||||||
|
int32_t filter_val0 = filter_data[filter_index + 0];
|
||||||
|
int32_t filter_val1 = filter_data[filter_index + 1];
|
||||||
|
int32_t filter_val2 = filter_data[filter_index + 2];
|
||||||
|
int32_t filter_val3 = filter_data[filter_index + 3];
|
||||||
|
result0 += input_val * filter_val0;
|
||||||
|
result1 += input_val * filter_val1;
|
||||||
|
result2 += input_val * filter_val2;
|
||||||
|
result3 += input_val * filter_val3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bias) {
|
||||||
|
result0 += bias[out_ch_idx + 0];
|
||||||
|
result1 += bias[out_ch_idx + 1];
|
||||||
|
result2 += bias[out_ch_idx + 2];
|
||||||
|
result3 += bias[out_ch_idx + 3];
|
||||||
|
}
|
||||||
|
result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
|
||||||
|
result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
|
||||||
|
result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
|
||||||
|
result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
|
||||||
|
|
||||||
|
result0 += out_offset;
|
||||||
|
result1 += out_offset;
|
||||||
|
result2 += out_offset;
|
||||||
|
result3 += out_offset;
|
||||||
|
|
||||||
|
result0 = max(result0, activation_min);
|
||||||
|
result1 = max(result1, activation_min);
|
||||||
|
result2 = max(result2, activation_min);
|
||||||
|
result3 = max(result3, activation_min);
|
||||||
|
result0 = min(result0, activation_max);
|
||||||
|
result1 = min(result1, activation_max);
|
||||||
|
result2 = min(result2, activation_max);
|
||||||
|
result3 = min(result3, activation_max);
|
||||||
|
|
||||||
|
out_data[out_idx++] = result0;
|
||||||
|
out_data[out_idx++] = result1;
|
||||||
|
out_data[out_idx++] = result2;
|
||||||
|
out_data[out_idx++] = result3;
|
||||||
|
}
|
||||||
|
for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
|
||||||
|
int32_t result = 0;
|
||||||
|
const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
|
||||||
|
|
||||||
|
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
|
||||||
|
const int32_t idx_y = base_y + filter_y_idx;
|
||||||
|
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
|
||||||
|
const int32_t idx_x = base_x + filter_x_idx;
|
||||||
|
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
|
||||||
|
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
|
||||||
|
int32_t input_val = input_data[input_index] + input_offset;
|
||||||
|
int32_t filter_val = filter_data[filter_index];
|
||||||
|
result += input_val * filter_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bias) {
|
||||||
|
result += bias[out_ch_idx];
|
||||||
|
}
|
||||||
|
result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
|
||||||
|
result += out_offset;
|
||||||
|
result = max(result, activation_min);
|
||||||
|
result = min(result, activation_max);
|
||||||
|
|
||||||
|
out_data[out_idx++] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,8 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <esp_nn_defs.h>
|
||||||
|
|
||||||
#include <common_functions.h>
|
#include <common_functions.h>
|
||||||
|
|
||||||
@@ -353,17 +353,59 @@ void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
|
int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
|
||||||
const uint16_t input_ht,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t channels,
|
const data_dims_t *output_dims,
|
||||||
const uint16_t ch_mult,
|
const dw_conv_params_t *conv_params)
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t channels = input_dims->channels;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t ch_mult = conv_params->ch_mult;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
|
||||||
int filter_size = filter_wd * filter_ht * channels * ch_mult;
|
int filter_size = filter_wd * filter_ht * channels * ch_mult;
|
||||||
int padding_used = ((filter_wd == 3) && (filter_ht == 3)) * 2;
|
int pad_width = 0, pad_height = 0;
|
||||||
int input_size = (input_wd + padding_used) * (input_ht + padding_used) * channels;
|
|
||||||
return 2 * (filter_size + input_size) + 16; //16 for alignment
|
if ((ch_mult == 1) && (channels % 8 == 0) && (filter_wd == 3) && (filter_ht == 3)) {
|
||||||
|
if (channels % 16 == 0) {
|
||||||
|
if (pad_wd || pad_ht) {
|
||||||
|
pad_width = pad_wd * 2;
|
||||||
|
pad_height = pad_ht * 2;
|
||||||
|
} else {
|
||||||
|
// check if we need to pad additionally
|
||||||
|
pad_width = (out_wd * stride_wd + filter_wd - 1) - input_wd;
|
||||||
|
pad_height = (out_ht * stride_ht + filter_ht - 1) - input_ht;
|
||||||
|
// printf("in(%d %d %d), out(%d %d), filter (%d %d) stride (%d %d), pad (%d %d)",
|
||||||
|
// input_wd, input_ht, channels, out_wd, out_ht, filter_wd, filter_ht,
|
||||||
|
// stride_wd, stride_ht, pad_wd, pad_ht);
|
||||||
|
}
|
||||||
|
if (pad_width || pad_height) {
|
||||||
|
int input_size = (input_wd + pad_width) * (input_ht + pad_height) * channels;
|
||||||
|
// printf("ask1 %d\n", filter_size + input_size + 16);
|
||||||
|
return filter_size + input_size + 16; // 16 for alignment
|
||||||
|
} else {
|
||||||
|
// printf("ask2 %d\n", filter_size + 16);
|
||||||
|
return filter_size + 16; // 16 for alignment
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int input_size = input_wd * input_ht * channels;
|
||||||
|
// printf("ask3 %d\n", 2 * (filter_size + input_size) + 16);
|
||||||
|
return 2 * (filter_size + input_size) + 16; // 16 for alignment
|
||||||
|
}
|
||||||
|
} else if (ch_mult % 4 == 0) {
|
||||||
|
int input_size = input_wd * input_ht * channels;
|
||||||
|
// printf("ask4 %d\n", 2 * (filter_size + input_size) + 16);
|
||||||
|
return 2 * (filter_size + input_size) + 16; // 16 for alignment
|
||||||
|
}
|
||||||
|
return 32; // just few bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
|
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
|
||||||
@@ -376,29 +418,38 @@ void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
|
|||||||
* Assumption 2: Pointers are valid
|
* Assumption 2: Pointers are valid
|
||||||
* Assumption 3: dialation width = 1
|
* Assumption 3: dialation width = 1
|
||||||
*/
|
*/
|
||||||
void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
|
|
||||||
const uint16_t input_wd,
|
|
||||||
const uint16_t input_ht,
|
|
||||||
const uint16_t channels,
|
void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
|
||||||
const int32_t input_offset,
|
const int8_t *input_data,
|
||||||
const uint16_t pad_wd,
|
const data_dims_t *filter_dims,
|
||||||
const uint16_t pad_ht,
|
|
||||||
const uint16_t stride_wd,
|
|
||||||
const uint16_t stride_ht,
|
|
||||||
const uint16_t ch_mult,
|
|
||||||
const int8_t *filter_data,
|
const int8_t *filter_data,
|
||||||
const uint16_t filter_wd,
|
|
||||||
const uint16_t filter_ht,
|
|
||||||
const int32_t *bias,
|
const int32_t *bias,
|
||||||
|
const data_dims_t *output_dims,
|
||||||
int8_t *out_data,
|
int8_t *out_data,
|
||||||
const uint16_t out_wd,
|
const dw_conv_params_t *conv_params,
|
||||||
const uint16_t out_ht,
|
const quant_data_t *quant_data)
|
||||||
const int32_t out_offset,
|
|
||||||
const int32_t *out_shift,
|
|
||||||
const int32_t *out_mult,
|
|
||||||
const int32_t activation_min,
|
|
||||||
const int32_t activation_max)
|
|
||||||
{
|
{
|
||||||
|
const uint16_t input_wd = input_dims->width;
|
||||||
|
const uint16_t input_ht = input_dims->height;
|
||||||
|
const uint16_t channels = input_dims->channels;
|
||||||
|
const int32_t input_offset = conv_params->in_offset;
|
||||||
|
const int32_t out_offset = conv_params->out_offset;
|
||||||
|
const uint16_t pad_wd = conv_params->padding.width;
|
||||||
|
const uint16_t pad_ht = conv_params->padding.height;
|
||||||
|
const uint16_t stride_wd = conv_params->stride.width;
|
||||||
|
const uint16_t stride_ht = conv_params->stride.height;
|
||||||
|
const uint16_t filter_wd = filter_dims->width;
|
||||||
|
const uint16_t filter_ht = filter_dims->height;
|
||||||
|
const uint16_t out_wd = output_dims->width;
|
||||||
|
const uint16_t out_ht = output_dims->height;
|
||||||
|
const int32_t *out_shift = quant_data->shift;
|
||||||
|
const int32_t *out_mult = quant_data->mult;
|
||||||
|
const int32_t activation_min = conv_params->activation.min;
|
||||||
|
const int32_t activation_max = conv_params->activation.max;
|
||||||
|
const uint16_t ch_mult = conv_params->ch_mult;
|
||||||
|
|
||||||
int filter_size = filter_wd * filter_ht * channels * ch_mult;
|
int filter_size = filter_wd * filter_ht * channels * ch_mult;
|
||||||
int align_len = 16 - (filter_size & 15);
|
int align_len = 16 - (filter_size & 15);
|
||||||
int input_size = input_wd * input_ht * channels;
|
int input_size = input_wd * input_ht * channels;
|
||||||
@@ -423,18 +474,27 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
|
|||||||
stride_wd, stride_ht, filter_aligned, bias,
|
stride_wd, stride_ht, filter_aligned, bias,
|
||||||
out_data, out_wd, out_ht, out_offset, out_shift,
|
out_data, out_wd, out_ht, out_offset, out_shift,
|
||||||
out_mult, activation_min, activation_max);
|
out_mult, activation_min, activation_max);
|
||||||
} else if ((pad_wd == 0) && (pad_ht == 0) &&
|
} else if ((channels % 16 == 0) && (pad_wd == 0) && (pad_ht == 0)) {
|
||||||
// because this does not handle padding offset cases yet, run just for stride (1, 1).
|
|
||||||
// end padding of input with `-input_offset` should solve this
|
|
||||||
(stride_wd == 1) && (stride_ht == 1)) {
|
|
||||||
/* process in 8 bits */
|
/* process in 8 bits */
|
||||||
int8_t *filter_aligned = (int8_t *) scratch_buffer;
|
int8_t *filter_aligned = (int8_t *) scratch_buffer;
|
||||||
|
int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
|
||||||
|
|
||||||
|
// check if we need to pad additionally
|
||||||
|
int pad_right = (out_wd * stride_wd + filter_wd - 1) - input_wd;
|
||||||
|
int pad_bottom = (out_ht * stride_ht + filter_ht - 1) - input_ht;
|
||||||
|
if (pad_right || pad_bottom) { // pad right and bottom
|
||||||
|
esp_nn_aligned_s8_pad_end_with_value(input_data, input_padded, input_wd, input_ht,
|
||||||
|
channels, -input_offset, pad_right, pad_bottom);
|
||||||
|
} else {
|
||||||
|
input_padded = (int8_t *) input_data;
|
||||||
|
}
|
||||||
memcpy(filter_aligned, filter_data, filter_size);
|
memcpy(filter_aligned, filter_data, filter_size);
|
||||||
esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_data, input_wd, input_ht, channels, input_offset,
|
esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + pad_right,
|
||||||
stride_wd, stride_ht, filter_aligned,
|
input_ht + pad_bottom, channels, input_offset,
|
||||||
bias, out_data, out_wd, out_ht, out_offset, out_shift,
|
stride_wd, stride_ht, filter_aligned, bias,
|
||||||
|
out_data, out_wd, out_ht, out_offset, out_shift,
|
||||||
out_mult, activation_min, activation_max);
|
out_mult, activation_min, activation_max);
|
||||||
} else { /* (channels % 8) == 0 && pad_wd == 1 && pad_ht == 1 */
|
} else { /* (channels % 8) == 0 */
|
||||||
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
|
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
|
||||||
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
|
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
|
||||||
esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
|
esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# Default configurations for ESP32-S3
|
||||||
|
|
||||||
|
CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y
|
||||||
|
CONFIG_ESP32S3_SPIRAM_SUPPORT=y
|
||||||
|
|
||||||
|
CONFIG_ESP32S3_DATA_CACHE_64KB=y
|
||||||
|
CONFIG_ESP32S3_DATA_CACHE_8WAYS=y
|
||||||
|
CONFIG_ESP32S3_DATA_CACHE_LINE_64B=y
|
||||||
@@ -23,7 +23,9 @@
|
|||||||
#include "test_utils.h"
|
#include "test_utils.h"
|
||||||
|
|
||||||
#if CONFIG_IDF_CMAKE
|
#if CONFIG_IDF_CMAKE
|
||||||
|
#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
|
||||||
#define IDF_HEAP_CAPS 1
|
#define IDF_HEAP_CAPS 1
|
||||||
|
#endif
|
||||||
|
|
||||||
#if IDF_HEAP_CAPS
|
#if IDF_HEAP_CAPS
|
||||||
#include "esp_heap_caps.h"
|
#include "esp_heap_caps.h"
|
||||||
@@ -138,6 +140,11 @@ void esp_nn_add_elementwise_s8_test()
|
|||||||
out_c_orig = out_data_c;
|
out_c_orig = out_data_c;
|
||||||
out_opt_orig = out_data_opt;
|
out_opt_orig = out_data_opt;
|
||||||
#endif
|
#endif
|
||||||
|
if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
|
||||||
|
out_opt_orig == NULL) {
|
||||||
|
printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
|
||||||
|
goto elementwise_add_test_cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
input1[i] = rand() % 256 - 128;
|
input1[i] = rand() % 256 - 128;
|
||||||
@@ -194,10 +201,10 @@ elementwise_add_test_cleanup:
|
|||||||
if (input2_orig) {
|
if (input2_orig) {
|
||||||
free(input2_orig);
|
free(input2_orig);
|
||||||
}
|
}
|
||||||
if (out_data_c) {
|
if (out_c_orig) {
|
||||||
free(out_c_orig);
|
free(out_c_orig);
|
||||||
}
|
}
|
||||||
if (out_data_opt) {
|
if (out_opt_orig) {
|
||||||
free(out_opt_orig);
|
free(out_opt_orig);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -282,6 +289,11 @@ void esp_nn_mul_elementwise_s8_test()
|
|||||||
out_c_orig = out_data_c;
|
out_c_orig = out_data_c;
|
||||||
out_opt_orig = out_data_opt;
|
out_opt_orig = out_data_opt;
|
||||||
#endif
|
#endif
|
||||||
|
if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
|
||||||
|
out_opt_orig == NULL) {
|
||||||
|
printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
|
||||||
|
goto elementwise_mult_test_cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
input1[i] = rand() % 256 - 128;
|
input1[i] = rand() % 256 - 128;
|
||||||
@@ -333,10 +345,10 @@ elementwise_mult_test_cleanup:
|
|||||||
if (input2_orig) {
|
if (input2_orig) {
|
||||||
free(input2_orig);
|
free(input2_orig);
|
||||||
}
|
}
|
||||||
if (out_data_c) {
|
if (out_c_orig) {
|
||||||
free(out_c_orig);
|
free(out_c_orig);
|
||||||
}
|
}
|
||||||
if (out_data_opt) {
|
if (out_opt_orig) {
|
||||||
free(out_opt_orig);
|
free(out_opt_orig);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,8 +22,9 @@
|
|||||||
#include "test_utils.h"
|
#include "test_utils.h"
|
||||||
|
|
||||||
#if CONFIG_IDF_CMAKE
|
#if CONFIG_IDF_CMAKE
|
||||||
|
#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
|
||||||
#define IDF_HEAP_CAPS 1
|
#define IDF_HEAP_CAPS 1
|
||||||
|
#endif
|
||||||
#if IDF_HEAP_CAPS
|
#if IDF_HEAP_CAPS
|
||||||
#include "esp_heap_caps.h"
|
#include "esp_heap_caps.h"
|
||||||
#endif
|
#endif
|
||||||
@@ -44,8 +45,8 @@ void esp_nn_depthwise_conv_s8_test()
|
|||||||
uint16_t filter_ht, filter_wd, ch_mult;
|
uint16_t filter_ht, filter_wd, ch_mult;
|
||||||
uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
|
uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
|
||||||
|
|
||||||
// run for 10 iterations
|
// run for 15 iterations
|
||||||
for (int itr = 0; itr < 10; itr++) {
|
for (int itr = 0; itr < 15; itr++) {
|
||||||
/* prepare data */
|
/* prepare data */
|
||||||
switch (itr) {
|
switch (itr) {
|
||||||
case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
|
case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
|
||||||
@@ -144,22 +145,52 @@ void esp_nn_depthwise_conv_s8_test()
|
|||||||
stride_wd = 2;
|
stride_wd = 2;
|
||||||
stride_ht = 2;
|
stride_ht = 2;
|
||||||
break;
|
break;
|
||||||
default:
|
case 8: // same as case 7, with large parameters
|
||||||
input_wd = 4;
|
input_wd = 58;
|
||||||
input_ht = 4;
|
input_ht = 58;
|
||||||
filter_ht = 3;
|
filter_ht = 3;
|
||||||
filter_wd = 3;
|
filter_wd = 3;
|
||||||
ch_mult = 4;
|
ch_mult = 1;
|
||||||
channels = 4;
|
channels = 128;
|
||||||
pad_wd = 1;
|
pad_wd = 0;
|
||||||
pad_ht = 1;
|
pad_ht = 0;
|
||||||
stride_wd = 1;
|
stride_wd = 2;
|
||||||
stride_ht = 1;
|
stride_ht = 2;
|
||||||
|
break;
|
||||||
|
case 9: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0) stride (2,2)
|
||||||
|
input_wd = 6;
|
||||||
|
input_ht = 6;
|
||||||
|
filter_ht = 3;
|
||||||
|
filter_wd = 3;
|
||||||
|
ch_mult = 1;
|
||||||
|
channels = 16;
|
||||||
|
pad_wd = 0;
|
||||||
|
pad_ht = 0;
|
||||||
|
stride_wd = 2;
|
||||||
|
stride_ht = 2;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
input_wd = 6;
|
||||||
|
input_ht = 6;
|
||||||
|
filter_ht = 3;
|
||||||
|
filter_wd = 3;
|
||||||
|
ch_mult = 1;
|
||||||
|
channels = 16;
|
||||||
|
stride_wd = rand() % 2 + 1;
|
||||||
|
stride_ht = stride_wd;
|
||||||
|
pad_wd = stride_wd == 1 ? 0 : rand() % 2;
|
||||||
|
pad_ht = pad_wd;
|
||||||
|
printf("stride(%d), pad (%d)\t", stride_wd, pad_wd);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
|
uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
|
||||||
uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
|
uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
|
||||||
|
if (itr == 9) {
|
||||||
|
// expect the function to handle this gracefully
|
||||||
|
out_wd += 1;
|
||||||
|
out_ht += 1;
|
||||||
|
}
|
||||||
int in_size = input_wd * input_ht * channels;
|
int in_size = input_wd * input_ht * channels;
|
||||||
int out_size = out_wd * out_ht * channels * ch_mult;
|
int out_size = out_wd * out_ht * channels * ch_mult;
|
||||||
int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
|
int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
|
||||||
@@ -210,9 +241,16 @@ void esp_nn_depthwise_conv_s8_test()
|
|||||||
out_mult[i] = 0x7eb0e200 + rand() % 50;
|
out_mult[i] = 0x7eb0e200 + rand() % 50;
|
||||||
}
|
}
|
||||||
|
|
||||||
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(input_wd, input_ht,
|
data_dims_t input_dims = {.width = input_wd, .height = input_ht, .channels = channels, 1};
|
||||||
channels, ch_mult,
|
data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = channels * ch_mult, 1};
|
||||||
filter_wd, filter_ht);
|
data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
|
||||||
|
dw_conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset, .ch_mult = ch_mult,
|
||||||
|
.stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
|
||||||
|
.dilation = {0, 0}, .activation = {activation_min, activation_max}};
|
||||||
|
quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
|
||||||
|
|
||||||
|
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(&input_dims, &filter_dims,
|
||||||
|
&output_dims, &conv_params);
|
||||||
if (scratch_buf_size > 0) {
|
if (scratch_buf_size > 0) {
|
||||||
#if IDF_HEAP_CAPS
|
#if IDF_HEAP_CAPS
|
||||||
scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
|
scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
|
||||||
@@ -234,11 +272,8 @@ void esp_nn_depthwise_conv_s8_test()
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* C function */
|
/* C function */
|
||||||
esp_nn_depthwise_conv_s8_ansi(input, input_wd, input_ht, channels, input_offset,
|
esp_nn_depthwise_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 4,
|
||||||
pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
|
bias + 1, &output_dims, out_data_c, &conv_params, &quant_data);
|
||||||
filter_data + 4, filter_wd, filter_ht,
|
|
||||||
bias + 1, out_data_c, out_wd, out_ht, out_offset, out_shift,
|
|
||||||
out_mult, activation_min, activation_max);
|
|
||||||
|
|
||||||
if (itr == 0) {
|
if (itr == 0) {
|
||||||
profile_c_end();
|
profile_c_end();
|
||||||
@@ -246,11 +281,8 @@ void esp_nn_depthwise_conv_s8_test()
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Optimized function */
|
/* Optimized function */
|
||||||
esp_nn_depthwise_conv_s8(input, input_wd, input_ht, channels, input_offset,
|
esp_nn_depthwise_conv_s8(&input_dims, input, &filter_dims, filter_data + 4,
|
||||||
pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
|
bias + 1, &output_dims, out_data_opt, &conv_params, &quant_data);
|
||||||
filter_data + 4, filter_wd, filter_ht,
|
|
||||||
bias + 1, out_data_opt, out_wd, out_ht, out_offset, out_shift,
|
|
||||||
out_mult, activation_min, activation_max);
|
|
||||||
|
|
||||||
if (itr == 0) {
|
if (itr == 0) {
|
||||||
/* disable profiler */
|
/* disable profiler */
|
||||||
@@ -479,8 +511,16 @@ void esp_nn_conv_s8_test()
|
|||||||
out_mult[i] = 0x7f67f4f8 + rand() % 50;
|
out_mult[i] = 0x7f67f4f8 + rand() % 50;
|
||||||
}
|
}
|
||||||
|
|
||||||
int scratch_buf_size = esp_nn_get_conv_scratch_size(in_wd, in_ht, in_channels,
|
data_dims_t input_dims = {.width = in_wd, .height = in_ht, .channels = in_channels, 1};
|
||||||
out_channels, filter_wd, filter_ht);
|
data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = out_channels, 1};
|
||||||
|
data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
|
||||||
|
conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset,
|
||||||
|
.stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
|
||||||
|
.dilation = {0, 0}, .activation = {activation_min, activation_max}};
|
||||||
|
quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
|
||||||
|
|
||||||
|
int scratch_buf_size = esp_nn_get_conv_scratch_size(&input_dims, &filter_dims,
|
||||||
|
&output_dims, &conv_params);
|
||||||
if (scratch_buf_size > 0) {
|
if (scratch_buf_size > 0) {
|
||||||
#if IDF_HEAP_CAPS
|
#if IDF_HEAP_CAPS
|
||||||
void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
|
void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
|
||||||
@@ -502,11 +542,8 @@ void esp_nn_conv_s8_test()
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* C function */
|
/* C function */
|
||||||
esp_nn_conv_s8_ansi(input, in_wd, in_ht, in_channels, input_offset,
|
esp_nn_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 2,
|
||||||
pad_wd, pad_ht, stride_wd, stride_ht,
|
bias, &output_dims, out_data_c, &conv_params, &quant_data);
|
||||||
filter_data + 2, filter_wd, filter_ht, bias,
|
|
||||||
out_data_c, out_wd, out_ht, out_channels, out_offset, out_shift,
|
|
||||||
out_mult, activation_min, activation_max);
|
|
||||||
|
|
||||||
if (itr == 0) {
|
if (itr == 0) {
|
||||||
profile_c_end();
|
profile_c_end();
|
||||||
@@ -514,11 +551,8 @@ void esp_nn_conv_s8_test()
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Optimized function */
|
/* Optimized function */
|
||||||
esp_nn_conv_s8(input, in_wd, in_ht, in_channels, input_offset,
|
esp_nn_conv_s8(&input_dims, input, &filter_dims, filter_data + 2,
|
||||||
pad_wd, pad_ht, stride_wd, stride_ht,
|
bias, &output_dims, out_data_opt, &conv_params, &quant_data);
|
||||||
filter_data + 2, filter_wd, filter_ht, bias,
|
|
||||||
out_data_opt, out_wd, out_ht, out_channels, out_offset, out_shift,
|
|
||||||
out_mult, activation_min, activation_max);
|
|
||||||
|
|
||||||
if (itr == 0) {
|
if (itr == 0) {
|
||||||
/* disable profiler */
|
/* disable profiler */
|
||||||
|
|||||||
BIN
code/components/esp-nn_20220716.zip
Normal file
BIN
code/components/esp-nn_20220716.zip
Normal file
Binary file not shown.
@@ -756,7 +756,7 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
|
|||||||
_fit = _val + _valminus;
|
_fit = _val + _valminus;
|
||||||
|
|
||||||
}
|
}
|
||||||
if (result > 10)
|
if (result >= 10)
|
||||||
result = result - 10;
|
result = result - 10;
|
||||||
if (result < 0)
|
if (result < 0)
|
||||||
result = result + 10;
|
result = result + 10;
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ list(REMOVE_ITEM srcs_kernels
|
|||||||
"${tfmicro_kernels_dir}/depthwise_conv.cc"
|
"${tfmicro_kernels_dir}/depthwise_conv.cc"
|
||||||
"${tfmicro_kernels_dir}/fully_connected.cc"
|
"${tfmicro_kernels_dir}/fully_connected.cc"
|
||||||
"${tfmicro_kernels_dir}/mul.cc"
|
"${tfmicro_kernels_dir}/mul.cc"
|
||||||
"${tfmicro_kernels_dir}/pooling.cc")
|
"${tfmicro_kernels_dir}/pooling.cc"
|
||||||
|
"${tfmicro_kernels_dir}/softmax.cc")
|
||||||
|
|
||||||
FILE(GLOB esp_nn_kernels
|
FILE(GLOB esp_nn_kernels
|
||||||
"${tfmicro_kernels_dir}/esp_nn/*.cc")
|
"${tfmicro_kernels_dir}/esp_nn/*.cc")
|
||||||
@@ -38,6 +39,8 @@ set(lib_srcs
|
|||||||
"${tflite_dir}/kernels/kernel_util.cc"
|
"${tflite_dir}/kernels/kernel_util.cc"
|
||||||
"${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
|
"${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
|
||||||
"${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
|
"${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
|
||||||
|
"${tflite_dir}/micro/arena_allocator/recording_simple_memory_allocator.cc"
|
||||||
|
"${tflite_dir}/micro/arena_allocator/simple_memory_allocator.cc"
|
||||||
"${tflite_dir}/c/common.cc"
|
"${tflite_dir}/c/common.cc"
|
||||||
"${tflite_dir}/core/api/error_reporter.cc"
|
"${tflite_dir}/core/api/error_reporter.cc"
|
||||||
"${tflite_dir}/core/api/flatbuffer_conversions.cc"
|
"${tflite_dir}/core/api/flatbuffer_conversions.cc"
|
||||||
|
|||||||
@@ -179,6 +179,8 @@ typedef enum {
|
|||||||
kTfLiteBuiltinMultinomial = 149,
|
kTfLiteBuiltinMultinomial = 149,
|
||||||
kTfLiteBuiltinGelu = 150,
|
kTfLiteBuiltinGelu = 150,
|
||||||
kTfLiteBuiltinDynamicUpdateSlice = 151,
|
kTfLiteBuiltinDynamicUpdateSlice = 151,
|
||||||
|
kTfLiteBuiltinRelu0To1 = 152,
|
||||||
|
kTfLiteBuiltinUnsortedSegmentProd = 153,
|
||||||
} TfLiteBuiltinOperator;
|
} TfLiteBuiltinOperator;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|||||||
@@ -518,6 +518,9 @@ typedef struct {
|
|||||||
bool approximate;
|
bool approximate;
|
||||||
} TfLiteGeluParams;
|
} TfLiteGeluParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int num_segments;
|
||||||
|
} TfLiteUnsortedSegmentProdParams;
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
|||||||
@@ -113,7 +113,13 @@ typedef struct TfLiteQuantizationParams {
|
|||||||
} TfLiteQuantizationParams;
|
} TfLiteQuantizationParams;
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Opaque types used by c_api_opaque.h.
|
// Opaque types used by c_api.h, c_api_opaque.h and common.h.
|
||||||
|
|
||||||
|
// TfLiteOpaqueContext is an opaque version of TfLiteContext;
|
||||||
|
typedef struct TfLiteOpaqueContext TfLiteOpaqueContext;
|
||||||
|
|
||||||
|
// TfLiteOpaqueNode is an opaque version of TfLiteNode;
|
||||||
|
typedef struct TfLiteOpaqueNode TfLiteOpaqueNode;
|
||||||
|
|
||||||
// TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
|
// TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
|
||||||
typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;
|
typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;
|
||||||
|
|||||||
@@ -14,13 +14,33 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/c/c_api_types.h"
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
|
#ifdef TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/core/macros.h"
|
||||||
|
#include "tensorflow/lite/tensorflow_profiler_logger.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef TF_LITE_STATIC_MEMORY
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#endif // TF_LITE_STATIC_MEMORY
|
#endif // TF_LITE_STATIC_MEMORY
|
||||||
|
|
||||||
|
#ifdef TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
namespace tflite {
|
||||||
|
// Use weak symbols here (even though they are guarded by macros) to avoid
|
||||||
|
// build breakage when building a benchmark requires TFLite runs. The main
|
||||||
|
// benchmark library should have tensor_profiler_logger dependency.
|
||||||
|
TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor,
|
||||||
|
size_t num_bytes);
|
||||||
|
|
||||||
|
TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor);
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
size_t TfLiteIntArrayGetSizeInBytes(int size) {
|
size_t TfLiteIntArrayGetSizeInBytes(int size) {
|
||||||
@@ -99,7 +119,12 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
|
|||||||
void TfLiteTensorDataFree(TfLiteTensor* t) {
|
void TfLiteTensorDataFree(TfLiteTensor* t) {
|
||||||
if (t->allocation_type == kTfLiteDynamic ||
|
if (t->allocation_type == kTfLiteDynamic ||
|
||||||
t->allocation_type == kTfLitePersistentRo) {
|
t->allocation_type == kTfLitePersistentRo) {
|
||||||
free(t->data.raw);
|
if (t->data.raw) {
|
||||||
|
#ifdef TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
tflite::OnTfLiteTensorDealloc(t);
|
||||||
|
#endif
|
||||||
|
free(t->data.raw);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
t->data.raw = nullptr;
|
t->data.raw = nullptr;
|
||||||
}
|
}
|
||||||
@@ -161,7 +186,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
|
|||||||
t->dims = nullptr;
|
t->dims = nullptr;
|
||||||
|
|
||||||
if (t->dims_signature) {
|
if (t->dims_signature) {
|
||||||
TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
|
TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature);
|
||||||
}
|
}
|
||||||
t->dims_signature = nullptr;
|
t->dims_signature = nullptr;
|
||||||
|
|
||||||
@@ -191,16 +216,12 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
|
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
|
||||||
if (!src || !dst)
|
if (!src || !dst) return kTfLiteOk;
|
||||||
return kTfLiteOk;
|
if (src->bytes != dst->bytes) return kTfLiteError;
|
||||||
if (src->bytes != dst->bytes)
|
if (src == dst) return kTfLiteOk;
|
||||||
return kTfLiteError;
|
|
||||||
if (src == dst)
|
|
||||||
return kTfLiteOk;
|
|
||||||
|
|
||||||
dst->type = src->type;
|
dst->type = src->type;
|
||||||
if (dst->dims)
|
if (dst->dims) TfLiteIntArrayFree(dst->dims);
|
||||||
TfLiteIntArrayFree(dst->dims);
|
|
||||||
dst->dims = TfLiteIntArrayCopy(src->dims);
|
dst->dims = TfLiteIntArrayCopy(src->dims);
|
||||||
memcpy(dst->data.raw, src->data.raw, src->bytes);
|
memcpy(dst->data.raw, src->data.raw, src->bytes);
|
||||||
dst->buffer_handle = src->buffer_handle;
|
dst->buffer_handle = src->buffer_handle;
|
||||||
@@ -218,8 +239,17 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
|
|||||||
// TODO(b/145340303): Tensor data should be aligned.
|
// TODO(b/145340303): Tensor data should be aligned.
|
||||||
if (!tensor->data.raw) {
|
if (!tensor->data.raw) {
|
||||||
tensor->data.raw = (char*)malloc(num_bytes);
|
tensor->data.raw = (char*)malloc(num_bytes);
|
||||||
|
#ifdef TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
|
||||||
|
#endif
|
||||||
} else if (num_bytes > tensor->bytes) {
|
} else if (num_bytes > tensor->bytes) {
|
||||||
|
#ifdef TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
tflite::OnTfLiteTensorDealloc(tensor);
|
||||||
|
#endif
|
||||||
tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
|
tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
|
||||||
|
#ifdef TF_LITE_TENSORFLOW_PROFILER
|
||||||
|
tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
tensor->bytes = num_bytes;
|
tensor->bytes = num_bytes;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,9 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
|
|||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
#else // TF_LITE_STRIP_ERROR_STRINGS
|
#else // TF_LITE_STRIP_ERROR_STRINGS
|
||||||
#define UNUSED(...) (void)sizeof(#__VA_ARGS__)
|
#define ARGS_UNUSED(...) (void)sizeof(#__VA_ARGS__)
|
||||||
#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
|
#define TF_LITE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
|
||||||
#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
|
#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
|
||||||
#endif // TF_LITE_STRIP_ERROR_STRINGS
|
#endif // TF_LITE_STRIP_ERROR_STRINGS
|
||||||
|
|
||||||
// Check whether value is true, and if not return kTfLiteError from
|
// Check whether value is true, and if not return kTfLiteError from
|
||||||
@@ -842,6 +842,32 @@ typedef struct TfLiteContext {
|
|||||||
size_t* bytes);
|
size_t* bytes);
|
||||||
} TfLiteContext;
|
} TfLiteContext;
|
||||||
|
|
||||||
|
// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
|
||||||
|
// for C API which doesn't use internal types (such as `TfLiteContext`) but only
|
||||||
|
// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each
|
||||||
|
// field is the exactly the same as with `TfLiteRegistration`.
|
||||||
|
typedef struct TfLiteRegistrationExternal {
|
||||||
|
// Custom op name.
|
||||||
|
const char* custom_name;
|
||||||
|
|
||||||
|
// The version of the op. The verion should be higher than 0.
|
||||||
|
const int version;
|
||||||
|
|
||||||
|
// Initializes the op from serialized data.
|
||||||
|
void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
|
||||||
|
size_t length);
|
||||||
|
|
||||||
|
// The pointer `buffer` is the data previously returned by an init invocation.
|
||||||
|
void (*free)(TfLiteOpaqueContext* context, void* buffer);
|
||||||
|
|
||||||
|
// Called when the inputs that this node depends on have been resized.
|
||||||
|
TfLiteStatus (*prepare)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
|
||||||
|
|
||||||
|
// Called when the node is executed. (should read node->inputs and output to
|
||||||
|
// node->outputs).
|
||||||
|
TfLiteStatus (*invoke)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
|
||||||
|
} TfLiteRegistrationExternal;
|
||||||
|
|
||||||
typedef struct TfLiteRegistration {
|
typedef struct TfLiteRegistration {
|
||||||
// Initializes the op from serialized data.
|
// Initializes the op from serialized data.
|
||||||
// Called only *once* for the lifetime of the op, so any one-time allocations
|
// Called only *once* for the lifetime of the op, so any one-time allocations
|
||||||
@@ -903,8 +929,31 @@ typedef struct TfLiteRegistration {
|
|||||||
// Note: It is the responsibility of the registration binder to set this
|
// Note: It is the responsibility of the registration binder to set this
|
||||||
// properly.
|
// properly.
|
||||||
int version;
|
int version;
|
||||||
|
|
||||||
|
// The external version of `TfLiteRegistration`. Since we can't use internal
|
||||||
|
// types (such as `TfLiteContext`) for C API to maintain ABI stability.
|
||||||
|
// C API user will provide `TfLiteRegistrationExternal` to implement custom
|
||||||
|
// ops. We keep it inside of `TfLiteRegistration` and use it to route
|
||||||
|
// callbacks properly.
|
||||||
|
TfLiteRegistrationExternal* registration_external;
|
||||||
} TfLiteRegistration;
|
} TfLiteRegistration;
|
||||||
|
|
||||||
|
// Old version of `TfLiteRegistration` to maintain binary backward
|
||||||
|
// compatibility.
|
||||||
|
// WARNING: This structure is deprecated / not an official part of the API.
|
||||||
|
// It should be only used for binary backward compatibility.
|
||||||
|
typedef struct TfLiteRegistration_V1 {
|
||||||
|
void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
|
||||||
|
void (*free)(TfLiteContext* context, void* buffer);
|
||||||
|
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
|
||||||
|
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
|
||||||
|
const char* (*profiling_string)(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* node);
|
||||||
|
int32_t builtin_code;
|
||||||
|
const char* custom_name;
|
||||||
|
int version;
|
||||||
|
} TfLiteRegistration_V1;
|
||||||
|
|
||||||
// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
|
// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
|
||||||
// values should be 1, 2, 4, 8, ...etc.
|
// values should be 1, 2, 4, 8, ...etc.
|
||||||
typedef enum TfLiteDelegateFlags {
|
typedef enum TfLiteDelegateFlags {
|
||||||
|
|||||||
@@ -836,6 +836,16 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
case BuiltinOperator_UNSORTED_SEGMENT_PROD: {
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteUnsortedSegmentProdParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* unsorted_segment_prod_params =
|
||||||
|
op->builtin_options_as_UnsortedSegmentProdOptions()) {
|
||||||
|
params->num_segments = unsorted_segment_prod_params->num_segments();
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
// Below are the ops with no builtin_data structure.
|
// Below are the ops with no builtin_data structure.
|
||||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||||
// ok for now, since there is no call implementation either.
|
// ok for now, since there is no call implementation either.
|
||||||
@@ -848,6 +858,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
case BuiltinOperator_MATRIX_DIAG:
|
case BuiltinOperator_MATRIX_DIAG:
|
||||||
case BuiltinOperator_MATRIX_SET_DIAG:
|
case BuiltinOperator_MATRIX_SET_DIAG:
|
||||||
case BuiltinOperator_RELU_N1_TO_1:
|
case BuiltinOperator_RELU_N1_TO_1:
|
||||||
|
case BuiltinOperator_RELU_0_TO_1:
|
||||||
case BuiltinOperator_SELECT:
|
case BuiltinOperator_SELECT:
|
||||||
case BuiltinOperator_SELECT_V2:
|
case BuiltinOperator_SELECT_V2:
|
||||||
case BuiltinOperator_SLICE:
|
case BuiltinOperator_SLICE:
|
||||||
|
|||||||
@@ -23,6 +23,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate.
|
||||||
|
// This is used for cases (e.g. when using "TF Lite with Google Play Services")
|
||||||
|
// where the TF Lite runtime might be built using a newer (or older)
|
||||||
|
// version of the TF Lite sources than the app, and hence might have a
|
||||||
|
// different definition of the TfLiteDelegate type. TF Lite APIs use
|
||||||
|
// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to
|
||||||
|
// refer to a delegate defined with that potentially different version
|
||||||
|
// of the TfLiteDelegate type.
|
||||||
|
struct TfLiteOpaqueDelegateStruct;
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
/// Abstract interface that returns TfLiteRegistrations given op codes or custom
|
/// Abstract interface that returns TfLiteRegistrations given op codes or custom
|
||||||
@@ -37,8 +47,10 @@ class OpResolver {
|
|||||||
virtual const TfLiteRegistration* FindOp(const char* op,
|
virtual const TfLiteRegistration* FindOp(const char* op,
|
||||||
int version) const = 0;
|
int version) const = 0;
|
||||||
|
|
||||||
|
// Represents a sequence of delegates.
|
||||||
using TfLiteDelegatePtrVector =
|
using TfLiteDelegatePtrVector =
|
||||||
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
|
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
|
||||||
|
|
||||||
// Returns optional delegates for resolving and handling ops in the flatbuffer
|
// Returns optional delegates for resolving and handling ops in the flatbuffer
|
||||||
// model. This may be used in addition to the standard TfLiteRegistration
|
// model. This may be used in addition to the standard TfLiteRegistration
|
||||||
// lookup for graph resolution.
|
// lookup for graph resolution.
|
||||||
@@ -47,16 +59,55 @@ class OpResolver {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Represent a function that creates a TfLite delegate instance.
|
// Represents a function that creates a TfLite delegate instance.
|
||||||
using TfLiteDelegateCreator =
|
using TfLiteDelegateCreator =
|
||||||
std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
|
std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
|
||||||
int /*num_threads*/)>;
|
int /*num_threads*/)>;
|
||||||
|
|
||||||
|
// Represents a sequence of delegate creator functions.
|
||||||
using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
|
using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
|
||||||
|
|
||||||
// Returns a vector of delegate creators to create optional delegates for
|
// Returns a vector of delegate creators to create optional delegates for
|
||||||
// resolving and handling ops in the flatbuffer model. This may be used in
|
// resolving and handling ops in the flatbuffer model. This may be used in
|
||||||
// addition to the standard TfLiteRegistration lookup for graph resolution.
|
// addition to the standard TfLiteRegistration lookup for graph resolution.
|
||||||
|
//
|
||||||
|
// Note that this method is not used (will not be called) if you are using
|
||||||
|
// TF Lite in Google Play Services; the GetOpaqueDelegateCreators method
|
||||||
|
// (see below) is used for that case.
|
||||||
virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
|
virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
|
||||||
|
|
||||||
|
// TODO(b/202712825): it would be nice if we could avoid the need for separate
|
||||||
|
// "opaque" types & methods for use only with TF Lite in Google Play Services.
|
||||||
|
|
||||||
|
// Represents an opaque delegate instance.
|
||||||
|
// WARNING: Experimental interface, subject to change.
|
||||||
|
using TfLiteOpaqueDelegatePtr =
|
||||||
|
std::unique_ptr<TfLiteOpaqueDelegateStruct,
|
||||||
|
void (*)(TfLiteOpaqueDelegateStruct*)>;
|
||||||
|
|
||||||
|
// Represents a function that creates an opaque delegate instance.
|
||||||
|
// WARNING: Experimental interface, subject to change.
|
||||||
|
using TfLiteOpaqueDelegateCreator =
|
||||||
|
std::function<TfLiteOpaqueDelegatePtr(int /*num_threads*/)>;
|
||||||
|
|
||||||
|
// Represents a sequence of opaque delegate creator functions.
|
||||||
|
// WARNING: Experimental interface, subject to change.
|
||||||
|
using TfLiteOpaqueDelegateCreators = std::vector<TfLiteOpaqueDelegateCreator>;
|
||||||
|
|
||||||
|
// Returns a vector of opaque delegate creators to create optional opaque
|
||||||
|
// delegates for resolving and handling ops in the flatbuffer model. This may
|
||||||
|
// be used in addition to the standard TfLiteRegistration lookup for graph
|
||||||
|
// resolution.
|
||||||
|
//
|
||||||
|
// Note that this method will be called only if you are using TF Lite in
|
||||||
|
// Google Play Services; if you are using regular TF Lite, GetDelegateCreators
|
||||||
|
// (see above) is used instead.
|
||||||
|
//
|
||||||
|
// WARNING: Experimental interface, subject to change.
|
||||||
|
virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
virtual ~OpResolver() {}
|
virtual ~OpResolver() {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ namespace tflite {
|
|||||||
namespace reference_ops {
|
namespace reference_ops {
|
||||||
|
|
||||||
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
|
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
|
||||||
int32_t result = static_cast<int32_t>(value) * (1 << amount);
|
int64_t result = static_cast<int64_t>(value) * (1 << amount);
|
||||||
result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
|
result = std::min<int64_t>(result, std::numeric_limits<int16_t>::max());
|
||||||
result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
|
result = std::max<int64_t>(result, std::numeric_limits<int16_t>::min());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,11 @@ class RuntimeShape {
|
|||||||
public:
|
public:
|
||||||
RuntimeShape& operator=(RuntimeShape const&) = delete;
|
RuntimeShape& operator=(RuntimeShape const&) = delete;
|
||||||
|
|
||||||
|
// RuntimeShape in TFLM supports up to 5 dimensions.
|
||||||
|
// The name kMaxSmallSize comes from the same file of the upstream
|
||||||
|
// tensorflow lite repo and need to be kept the same for max reuse.
|
||||||
|
static constexpr int kMaxSmallSize = 5;
|
||||||
|
|
||||||
RuntimeShape() : size_(0) {}
|
RuntimeShape() : size_(0) {}
|
||||||
|
|
||||||
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
|
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
|
||||||
@@ -104,11 +109,9 @@ class RuntimeShape {
|
|||||||
sizeof(int32_t) * shape.DimensionsCount());
|
sizeof(int32_t) * shape.DimensionsCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
// A maximum of 4 dimensions are supported on TFLM.
|
|
||||||
static constexpr int kMaxSize = 5;
|
|
||||||
int32_t size_;
|
int32_t size_;
|
||||||
union {
|
union {
|
||||||
int32_t dims_[kMaxSize];
|
int32_t dims_[kMaxSmallSize];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -974,11 +974,11 @@ struct StridedSliceParams {
|
|||||||
int8_t strides_count;
|
int8_t strides_count;
|
||||||
int32_t strides[5];
|
int32_t strides[5];
|
||||||
|
|
||||||
int16_t begin_mask;
|
uint16_t begin_mask;
|
||||||
int16_t ellipsis_mask;
|
uint16_t ellipsis_mask;
|
||||||
int16_t end_mask;
|
uint16_t end_mask;
|
||||||
int16_t new_axis_mask;
|
uint16_t new_axis_mask;
|
||||||
int16_t shrink_axis_mask;
|
uint16_t shrink_axis_mask;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TanhParams {
|
struct TanhParams {
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
|||||||
const TfLiteTensor* input3,
|
const TfLiteTensor* input3,
|
||||||
TfLiteIntArray** output_shape);
|
TfLiteIntArray** output_shape);
|
||||||
|
|
||||||
// Return the size of given type in bytes. Return 0 in in case of string.
|
// Return the size of given type in bytes. Return 0 in case of string.
|
||||||
int TfLiteTypeGetSize(TfLiteType type);
|
int TfLiteTypeGetSize(TfLiteType type);
|
||||||
|
|
||||||
// Whether the current platform is mobile (Android or iOS).
|
// Whether the current platform is mobile (Android or iOS).
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
|
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
|
||||||
#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
|
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@@ -97,4 +97,4 @@ class INonPersistentBufferAllocator {
|
|||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
|
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
NonPersistentArenaBufferAllocator::NonPersistentArenaBufferAllocator(
|
||||||
|
uint8_t* buffer, size_t buffer_size)
|
||||||
|
: buffer_head_(buffer),
|
||||||
|
buffer_tail_(buffer + buffer_size),
|
||||||
|
head_temp_(buffer),
|
||||||
|
next_temp_(buffer) {}
|
||||||
|
|
||||||
|
NonPersistentArenaBufferAllocator::~NonPersistentArenaBufferAllocator() {}
|
||||||
|
|
||||||
|
// Allocates a temporary buffer. This buffer is not resizable.
|
||||||
|
uint8_t* NonPersistentArenaBufferAllocator::AllocateTemp(size_t size,
|
||||||
|
size_t alignment) {
|
||||||
|
uint8_t* const aligned_result = AlignPointerUp(next_temp_, alignment);
|
||||||
|
const size_t available_memory = buffer_tail_ - aligned_result;
|
||||||
|
if (available_memory < size) {
|
||||||
|
MicroPrintf(
|
||||||
|
"Failed to allocate temp memory. Requested: %u, "
|
||||||
|
"available %u, missing: %u",
|
||||||
|
size, available_memory, size - available_memory);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
next_temp_ = aligned_result + size;
|
||||||
|
temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(aligned_result);
|
||||||
|
temp_buffer_count_++;
|
||||||
|
return aligned_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signals that a temporary buffer is no longer needed.
|
||||||
|
void NonPersistentArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
|
||||||
|
temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(temp_buf);
|
||||||
|
temp_buffer_count_--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if all temporary buffers are already deallocated.
|
||||||
|
bool NonPersistentArenaBufferAllocator::IsAllTempDeallocated() {
|
||||||
|
if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
|
||||||
|
MicroPrintf(
|
||||||
|
"Number of allocated temp buffers: %d. Checksum passing status: %d",
|
||||||
|
temp_buffer_count_, !temp_buffer_ptr_check_sum_);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signals that all temporary allocations can be reclaimed. TFLM calls this
|
||||||
|
// API when it knows that all temporary buffers that it requested has been
|
||||||
|
// deallocated. The goal of API is to facilitate implementations of
|
||||||
|
// INonPersistentBufferAllocator can reuse buffer with some reasonable
|
||||||
|
// complexity.
|
||||||
|
TfLiteStatus NonPersistentArenaBufferAllocator::ResetTempAllocations() {
|
||||||
|
if (!IsAllTempDeallocated()) {
|
||||||
|
MicroPrintf(
|
||||||
|
"All temp buffers must be freed before calling ResetTempAllocations()");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
next_temp_ = head_temp_;
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a buffer that is resizable viable ResizeBuffer().
|
||||||
|
uint8_t* NonPersistentArenaBufferAllocator::AllocateResizableBuffer(
|
||||||
|
size_t size, size_t alignment) {
|
||||||
|
// Only supports one resizable buffer, which starts at the buffer head.
|
||||||
|
uint8_t* expected_resizable_buf = AlignPointerUp(buffer_head_, alignment);
|
||||||
|
|
||||||
|
if (head_temp_ != expected_resizable_buf) {
|
||||||
|
MicroPrintf(
|
||||||
|
"Cannot allocate a new resizable buffer when one is already allocated");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ResizeBuffer(expected_resizable_buf, size, alignment) == kTfLiteOk) {
|
||||||
|
return expected_resizable_buf;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resizes a buffer that is previously returned by the AllocateResizableBuffer.
|
||||||
|
// Note that ResizeBuffer(old_resizable_buf, 0, 1) effectively deallocates
|
||||||
|
// a previous allocated resizable buffer.
|
||||||
|
TfLiteStatus NonPersistentArenaBufferAllocator::ResizeBuffer(
|
||||||
|
uint8_t* resizable_buf, size_t size, size_t alignment) {
|
||||||
|
// Only supports one resizable buffer, which starts at the buffer head.
|
||||||
|
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
|
||||||
|
if (resizable_buf != expect_resizable_buf) {
|
||||||
|
MicroPrintf("Internal error: buffer is not resizable");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (head_temp_ != next_temp_) {
|
||||||
|
MicroPrintf("ResetTempAllocations() is not called before ResizeBuffer().");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t available_memory = buffer_tail_ - expect_resizable_buf;
|
||||||
|
if (available_memory < size) {
|
||||||
|
MicroPrintf(
|
||||||
|
"Failed to resize buffer. Requested: %u, available %u, missing: %u",
|
||||||
|
size, available_memory, size - available_memory);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
head_temp_ = expect_resizable_buf + size;
|
||||||
|
next_temp_ = head_temp_;
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frees up the memory occupied by the resizable buffer.
|
||||||
|
TfLiteStatus NonPersistentArenaBufferAllocator::DeallocateResizableBuffer(
|
||||||
|
uint8_t* resizable_buf) {
|
||||||
|
return ResizeBuffer(resizable_buf, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a pointer pointing to the start of the overlay memory, which is
|
||||||
|
// used for activation tensors and scratch buffers by kernels at Invoke stage.
|
||||||
|
uint8_t* NonPersistentArenaBufferAllocator::GetOverlayMemoryAddress() const {
|
||||||
|
return buffer_head_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserves the size of the overlay memory. This overlay is reserved for the
|
||||||
|
// kernels at Invoke stage. This is referred to as the overlay because before
|
||||||
|
// Invoket state, the same memory can be used for temp buffers. The layout of
|
||||||
|
// the memory is planned by the memory planner separately at Invoke stage.
|
||||||
|
TfLiteStatus
|
||||||
|
NonPersistentArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
|
||||||
|
size_t size, size_t alignment) {
|
||||||
|
uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
|
||||||
|
return ResizeBuffer(expect_resizable_buf, size, alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the size of non-persistent buffer in use.
|
||||||
|
size_t NonPersistentArenaBufferAllocator::GetNonPersistentUsedBytes() const {
|
||||||
|
return (next_temp_ - buffer_head_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of bytes available with a given alignment. This number
|
||||||
|
// takes in account any temporary allocations.
|
||||||
|
size_t NonPersistentArenaBufferAllocator::GetAvailableMemory(
|
||||||
|
size_t alignment) const {
|
||||||
|
uint8_t* const aligned_temp = AlignPointerUp(next_temp_, alignment);
|
||||||
|
uint8_t* const aligned_tail = AlignPointerDown(buffer_tail_, alignment);
|
||||||
|
return aligned_tail - aligned_temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
|
||||||
|
#include "tensorflow/lite/micro/compatibility.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
// Implement INonPersistentBufferAllocator on an arena that is dedicated for
|
||||||
|
// non-persistent buffers.
|
||||||
|
class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator {
|
||||||
|
public:
|
||||||
|
NonPersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
|
||||||
|
virtual ~NonPersistentArenaBufferAllocator();
|
||||||
|
|
||||||
|
// Allocates a temporary buffer. This buffer is not resizable.
|
||||||
|
uint8_t* AllocateTemp(size_t size, size_t alignment) override;
|
||||||
|
|
||||||
|
// Signals that a temporary buffer is no longer needed.
|
||||||
|
void DeallocateTemp(uint8_t* buf) override;
|
||||||
|
|
||||||
|
// Returns true if all temporary buffers are already deallocated.
|
||||||
|
bool IsAllTempDeallocated() override;
|
||||||
|
|
||||||
|
// Signals that all temporary allocations can be reclaimed. TFLM calls this
|
||||||
|
// API when it knows that all temporary buffers that it requested has been
|
||||||
|
// deallocated.
|
||||||
|
TfLiteStatus ResetTempAllocations() override;
|
||||||
|
|
||||||
|
// Returns a buffer that is resizable viable ResizeBuffer().
|
||||||
|
uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) override;
|
||||||
|
|
||||||
|
// Resizes a buffer that is previously returned by the
|
||||||
|
// AllocateResizableBuffer.
|
||||||
|
TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size,
|
||||||
|
size_t alignment) override;
|
||||||
|
|
||||||
|
// Frees up the memory occupied by the resizable buffer.
|
||||||
|
TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) override;
|
||||||
|
|
||||||
|
// Returns a pointer pointing to the start of the overlay memory, which is
|
||||||
|
// used for activation tensors and scratch buffers by kernels at Invoke stage.
|
||||||
|
uint8_t* GetOverlayMemoryAddress() const override;
|
||||||
|
|
||||||
|
// Reserves the size of the overlay memory. This overlay is reserved for the
|
||||||
|
// kernels at Invoke stage. This is referred to as the overlay because before
|
||||||
|
// Invoket state, the same memory can be used for temp buffers. The layout of
|
||||||
|
// the memory is planned by the memory planner separately at Invoke stage.
|
||||||
|
TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size,
|
||||||
|
size_t alignment) override;
|
||||||
|
|
||||||
|
// Returns the size of non-persistent buffer in use.
|
||||||
|
size_t GetNonPersistentUsedBytes() const override;
|
||||||
|
|
||||||
|
// Returns the number of bytes available with a given alignment. This number
|
||||||
|
// takes in account any temporary allocations.
|
||||||
|
size_t GetAvailableMemory(size_t alignment) const override;
|
||||||
|
|
||||||
|
TF_LITE_REMOVE_VIRTUAL_DELETE
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The memory arena that this allocator manages.
|
||||||
|
uint8_t* const buffer_head_;
|
||||||
|
uint8_t* const buffer_tail_;
|
||||||
|
|
||||||
|
// The whole region is split into two parts:
|
||||||
|
// buffer_head_ to head_temp_ - 1 belongs to the only resizable buffer.
|
||||||
|
// head_temp_ to buffer_tail_ can be used for (non-resizable) temp buffers.
|
||||||
|
uint8_t* head_temp_;
|
||||||
|
|
||||||
|
// next_temp_ points to the next available temp buffer allocation address and
|
||||||
|
// its range is between head_temp_ and buffer_tail_
|
||||||
|
uint8_t* next_temp_;
|
||||||
|
|
||||||
|
// XOR Check sum for outstanding temp buffers.
|
||||||
|
// If all temp buffers are deallocated OR no temp buffers are allocated,
|
||||||
|
// temp_buffer_ptr_check_sum_ == nullptr.
|
||||||
|
intptr_t temp_buffer_ptr_check_sum_ = 0;
|
||||||
|
// Count of outstanding temp buffers.
|
||||||
|
int temp_buffer_count_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
PersistentArenaBufferAllocator::PersistentArenaBufferAllocator(
|
||||||
|
uint8_t* buffer, size_t buffer_size)
|
||||||
|
: buffer_head_(buffer),
|
||||||
|
buffer_tail_(buffer + buffer_size),
|
||||||
|
tail_temp_(buffer_tail_) {}
|
||||||
|
|
||||||
|
PersistentArenaBufferAllocator::~PersistentArenaBufferAllocator() {}
|
||||||
|
|
||||||
|
uint8_t* PersistentArenaBufferAllocator::AllocatePersistentBuffer(
|
||||||
|
size_t size, size_t alignment) {
|
||||||
|
uint8_t* const aligned_result =
|
||||||
|
AlignPointerDown(tail_temp_ - size, alignment);
|
||||||
|
if (aligned_result < buffer_head_) {
|
||||||
|
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
||||||
|
const size_t missing_memory = buffer_head_ - aligned_result;
|
||||||
|
MicroPrintf(
|
||||||
|
"Failed to allocate tail memory. Requested: %u, "
|
||||||
|
"available %u, missing: %u",
|
||||||
|
size, size - missing_memory, missing_memory);
|
||||||
|
#endif
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tail_temp_ = aligned_result;
|
||||||
|
return aligned_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t PersistentArenaBufferAllocator::GetPersistentUsedBytes() const {
|
||||||
|
return buffer_tail_ - tail_temp_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
|
||||||
|
#include "tensorflow/lite/micro/compatibility.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
// PersistentArenaBufferAllocator is an implementatation of
|
||||||
|
// IPersistentBufferAllocator interface on an arena that is dedicated for
|
||||||
|
// persistent buffers.
|
||||||
|
class PersistentArenaBufferAllocator : public IPersistentBufferAllocator {
|
||||||
|
public:
|
||||||
|
PersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
|
||||||
|
virtual ~PersistentArenaBufferAllocator();
|
||||||
|
|
||||||
|
// Allocates persistent memory. The persistent buffer is never freed.
|
||||||
|
// Returns nullptr if errors occured.
|
||||||
|
uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) override;
|
||||||
|
|
||||||
|
// Returns the size of all persistent allocations in bytes.
|
||||||
|
size_t GetPersistentUsedBytes() const override;
|
||||||
|
|
||||||
|
TF_LITE_REMOVE_VIRTUAL_DELETE
|
||||||
|
private:
|
||||||
|
// The memory arena that this allocator manages.
|
||||||
|
uint8_t* const buffer_head_;
|
||||||
|
uint8_t* const buffer_tail_;
|
||||||
|
|
||||||
|
// The whole region is split into two parts:
|
||||||
|
// tail_temp_ to buffer_tail_ contains allocated buffers;
|
||||||
|
// buffer_head_ to tail_temp_ - 1 belongs to still available spaces.
|
||||||
|
// So in essence, the allocated region grows from the bottom and emulates
|
||||||
|
// SimpleMemoryAllocator's persistent part.
|
||||||
|
uint8_t* tail_temp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
|
||||||
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/micro/recording_simple_memory_allocator.h"
|
#include "tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h"
|
||||||
|
|
||||||
#include <new>
|
#include <new>
|
||||||
|
|
||||||
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
|
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
|
||||||
#define TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
|
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||||
#include "tensorflow/lite/micro/compatibility.h"
|
#include "tensorflow/lite/micro/compatibility.h"
|
||||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@@ -62,4 +62,4 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
|
|||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
|
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
|
||||||
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
|
#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
|
||||||
#define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
|
#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
|
||||||
#include "tensorflow/lite/micro/compatibility.h"
|
#include "tensorflow/lite/micro/compatibility.h"
|
||||||
#include "tensorflow/lite/micro/ibuffer_allocator.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@@ -147,4 +147,4 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
|
|||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
|
#endif // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
|
||||||
@@ -16,10 +16,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/micro/fake_micro_context.h"
|
#include "tensorflow/lite/micro/fake_micro_context.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||||
#include "tensorflow/lite/micro/micro_allocator.h"
|
#include "tensorflow/lite/micro/micro_allocator.h"
|
||||||
#include "tensorflow/lite/micro/micro_arena_constants.h"
|
#include "tensorflow/lite/micro/micro_arena_constants.h"
|
||||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
#include "tensorflow/lite/micro/micro_utils.h"
|
#include "tensorflow/lite/micro/micro_utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
|
MicroPrintf("Only float32 is supported currently, got %s",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
|
MicroPrintf("Only float32 is supported currently, got %s",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_RELU() {
|
TfLiteRegistration Register_RELU() {
|
||||||
return {/*init=*/ReluInit,
|
return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/ReluPrepare,
|
|
||||||
/*invoke=*/ReluEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_RELU6() {
|
TfLiteRegistration Register_RELU6() {
|
||||||
return {/*init=*/Relu6Init,
|
return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Relu6Prepare,
|
|
||||||
/*invoke=*/Relu6Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_ADD() {
|
TfLiteRegistration Register_ADD() {
|
||||||
return {/*init=*/AddInit,
|
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/AddPrepare,
|
|
||||||
/*invoke=*/AddEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_ADD_N() {
|
TfLiteRegistration Register_ADD_N() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace arg_min_max
|
} // namespace arg_min_max
|
||||||
|
|
||||||
TfLiteRegistration Register_ARG_MAX() {
|
TfLiteRegistration Register_ARG_MAX() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/nullptr,
|
|
||||||
/*invoke=*/arg_min_max::ArgMaxEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_ARG_MIN() {
|
TfLiteRegistration Register_ARG_MIN() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/nullptr,
|
|
||||||
/*invoke=*/arg_min_max::ArgMinEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
|
|||||||
@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace.
|
} // namespace.
|
||||||
|
|
||||||
TfLiteRegistration Register_ASSIGN_VARIABLE() {
|
TfLiteRegistration Register_ASSIGN_VARIABLE() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace.
|
} // namespace.
|
||||||
|
|
||||||
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
|
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_BROADCAST_ARGS() {
|
TfLiteRegistration Register_BROADCAST_ARGS() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
|
||||||
/*free=*/nullptr,
|
BroadcastArgsEval);
|
||||||
/*prepare=*/BroadcastArgsPrepare,
|
|
||||||
/*invoke=*/BroadcastArgsEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_BROADCAST_TO() {
|
TfLiteRegistration Register_BROADCAST_TO() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
|
||||||
/*free=*/nullptr,
|
BroadcastToEval);
|
||||||
/*prepare=*/BroadcastToPrepare,
|
|
||||||
/*invoke=*/BroadcastToEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace.
|
} // namespace.
|
||||||
|
|
||||||
TfLiteRegistration Register_CALL_ONCE() {
|
TfLiteRegistration Register_CALL_ONCE() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_CAST() {
|
TfLiteRegistration Register_CAST() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace ceil
|
} // namespace ceil
|
||||||
|
|
||||||
TfLiteRegistration Register_CEIL() {
|
TfLiteRegistration Register_CEIL() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/ceil::Prepare,
|
|
||||||
/*invoke=*/ceil::Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
|
|||||||
@@ -108,14 +108,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
||||||
static TfLiteRegistration r = {/*init=*/CircularBufferInit,
|
static TfLiteRegistration r = tflite::micro::RegisterOp(CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/CircularBufferPrepare,
|
|
||||||
/*invoke=*/CircularBufferEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace comparisons
|
} // namespace comparisons
|
||||||
|
|
||||||
TfLiteRegistration Register_EQUAL() {
|
TfLiteRegistration Register_EQUAL() {
|
||||||
return {/*init=*/comparisons::Init,
|
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||||
/*free=*/nullptr,
|
comparisons::EqualEval);
|
||||||
/*prepare=*/comparisons::Prepare,
|
|
||||||
/*invoke=*/comparisons::EqualEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_NOT_EQUAL() {
|
TfLiteRegistration Register_NOT_EQUAL() {
|
||||||
return {/*init=*/comparisons::Init,
|
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||||
/*free=*/nullptr,
|
comparisons::NotEqualEval);
|
||||||
/*prepare=*/comparisons::Prepare,
|
|
||||||
/*invoke=*/comparisons::NotEqualEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_GREATER() {
|
TfLiteRegistration Register_GREATER() {
|
||||||
return {/*init=*/comparisons::Init,
|
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||||
/*free=*/nullptr,
|
comparisons::GreaterEval);
|
||||||
/*prepare=*/comparisons::Prepare,
|
|
||||||
/*invoke=*/comparisons::GreaterEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_GREATER_EQUAL() {
|
TfLiteRegistration Register_GREATER_EQUAL() {
|
||||||
return {/*init=*/comparisons::Init,
|
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||||
/*free=*/nullptr,
|
comparisons::GreaterEqualEval);
|
||||||
/*prepare=*/comparisons::Prepare,
|
|
||||||
/*invoke=*/comparisons::GreaterEqualEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_LESS() {
|
TfLiteRegistration Register_LESS() {
|
||||||
return {/*init=*/comparisons::Init,
|
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||||
/*free=*/nullptr,
|
comparisons::LessEval);
|
||||||
/*prepare=*/comparisons::Prepare,
|
|
||||||
/*invoke=*/comparisons::LessEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_LESS_EQUAL() {
|
TfLiteRegistration Register_LESS_EQUAL() {
|
||||||
return {/*init=*/comparisons::Init,
|
return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
|
||||||
/*free=*/nullptr,
|
comparisons::LessEqualEval);
|
||||||
/*prepare=*/comparisons::Prepare,
|
|
||||||
/*invoke=*/comparisons::LessEqualEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
|
|||||||
@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, input != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
int num_dimensions = NumDimensions(input);
|
int num_dimensions = NumDimensions(input);
|
||||||
|
|
||||||
if (num_dimensions > 4) {
|
if (num_dimensions > RuntimeShape::kMaxSmallSize) {
|
||||||
TF_LITE_KERNEL_LOG(
|
TF_LITE_KERNEL_LOG(
|
||||||
context,
|
context,
|
||||||
"Op Concatenation does not currently support num dimensions >4 "
|
"Op Concatenation does not currently support num dimensions > %d "
|
||||||
"Tensor has %d dimensions.",
|
"Tensor has %d dimensions.",
|
||||||
num_dimensions);
|
RuntimeShape::kMaxSmallSize, num_dimensions);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
micro_context->DeallocateTempTfLiteTensor(input);
|
micro_context->DeallocateTempTfLiteTensor(input);
|
||||||
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace concatenation
|
} // namespace concatenation
|
||||||
|
|
||||||
TfLiteRegistration Register_CONCATENATION() {
|
TfLiteRegistration Register_CONCATENATION() {
|
||||||
return {/*init=*/concatenation::Init,
|
return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
|
||||||
/*free=*/nullptr,
|
concatenation::Eval);
|
||||||
/*prepare=*/concatenation::Prepare,
|
|
||||||
/*invoke=*/concatenation::Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/padding.h"
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
tflite::micro::GetTensorShape(filter),
|
tflite::micro::GetTensorShape(filter),
|
||||||
tflite::micro::GetTensorData<float>(filter),
|
tflite::micro::GetTensorData<float>(filter),
|
||||||
tflite::micro::GetTensorShape(bias),
|
tflite::micro::GetTensorShape(bias),
|
||||||
tflite::micro::GetTensorData<float>(bias),
|
tflite::micro::GetOptionalTensorData<float>(bias),
|
||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<float>(output),
|
tflite::micro::GetTensorData<float>(output),
|
||||||
tflite::micro::GetTensorShape(nullptr), nullptr);
|
tflite::micro::GetTensorShape(nullptr), nullptr);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kTfLiteInt16: {
|
case kTfLiteInt16: {
|
||||||
reference_integer_ops::ConvPerChannel(
|
switch (bias->type) {
|
||||||
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
|
case kTfLiteInt32: {
|
||||||
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
reference_integer_ops::ConvPerChannel(
|
||||||
tflite::micro::GetTensorData<int16_t>(input),
|
ConvParamsQuantized(params, data),
|
||||||
tflite::micro::GetTensorShape(filter),
|
data.per_channel_output_multiplier, data.per_channel_output_shift,
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
tflite::micro::GetTensorShape(input),
|
||||||
tflite::micro::GetTensorShape(bias),
|
tflite::micro::GetTensorData<int16_t>(input),
|
||||||
tflite::micro::GetTensorData<std::int64_t>(bias),
|
tflite::micro::GetTensorShape(filter),
|
||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
tflite::micro::GetTensorData<int16_t>(output));
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kTfLiteInt64: {
|
||||||
|
reference_integer_ops::ConvPerChannel(
|
||||||
|
ConvParamsQuantized(params, data),
|
||||||
|
data.per_channel_output_multiplier, data.per_channel_output_shift,
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
MicroPrintf("Bias type %s (%d) not supported.",
|
||||||
|
TfLiteTypeGetName(bias->type), bias->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
tflite::micro::GetTensorShape(filter),
|
tflite::micro::GetTensorShape(filter),
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
tflite::micro::GetTensorShape(bias),
|
tflite::micro::GetTensorShape(bias),
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
tflite::micro::GetOptionalTensorData<int32_t>(bias),
|
||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
|
||||||
TfLiteTypeGetName(input->type), input->type);
|
input->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_CONV_2D() {
|
TfLiteRegistration Register_CONV_2D() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/ConvPrepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
|
|||||||
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
|
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
|
||||||
TfLiteRegistration registration, int16_t* output_data);
|
TfLiteRegistration registration, int16_t* output_data);
|
||||||
|
|
||||||
|
TfLiteStatus TestConvQuantizedPerChannel(
|
||||||
|
int* input_dims_data, const float* input_data, int16_t* input_quantized,
|
||||||
|
float input_scale, int input_zero_point, int* filter_dims_data,
|
||||||
|
const float* filter_data, int8_t* filter_data_quantized,
|
||||||
|
int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
|
||||||
|
float* bias_scales, int* bias_zero_points, int* output_dims_data,
|
||||||
|
const float* expected_output_data, int16_t* expected_output_data_quantized,
|
||||||
|
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration, int16_t* output_data);
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|||||||
@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_CUMSUM() {
|
TfLiteRegistration Register_CUMSUM() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_DEPTH_TO_SPACE() {
|
TfLiteRegistration Register_DEPTH_TO_SPACE() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
tflite::micro::GetTensorShape(filter),
|
tflite::micro::GetTensorShape(filter),
|
||||||
tflite::micro::GetTensorData<float>(filter),
|
tflite::micro::GetTensorData<float>(filter),
|
||||||
tflite::micro::GetTensorShape(bias),
|
tflite::micro::GetTensorShape(bias),
|
||||||
tflite::micro::GetTensorData<float>(bias),
|
tflite::micro::GetOptionalTensorData<float>(bias),
|
||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<float>(output));
|
tflite::micro::GetTensorData<float>(output));
|
||||||
break;
|
break;
|
||||||
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
tflite::micro::GetTensorShape(filter),
|
tflite::micro::GetTensorShape(filter),
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
tflite::micro::GetTensorShape(bias),
|
tflite::micro::GetTensorShape(bias),
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
tflite::micro::GetOptionalTensorData<int32_t>(bias),
|
||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
break;
|
break;
|
||||||
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/DepthwiseConvPrepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
|
|||||||
|
|
||||||
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
|
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||||
|
|
||||||
|
// This is the most generic TfLiteRegistration. The actual supported types may
|
||||||
|
// still be target dependent. The only requirement is that every implementation
|
||||||
|
// (reference or optimized) must define this function.
|
||||||
|
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
|
||||||
|
|
||||||
|
#if defined(CMSIS_NN)
|
||||||
|
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||||
|
// int8 activations and int8 weights and uses the latency optimized
|
||||||
|
// implementations.
|
||||||
|
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
|
||||||
|
|
||||||
|
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||||
|
// int16 activations and int8 weights and uses the latency optimized
|
||||||
|
// implementations.
|
||||||
|
TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
|
||||||
|
|
||||||
|
#else
|
||||||
|
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
|
||||||
|
return Register_DEPTHWISE_CONV_2D();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
|
||||||
|
return Register_DEPTHWISE_CONV_2D();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
|
||||||
|
|||||||
@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<float>(output));
|
tflite::micro::GetTensorData<float>(output));
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteUInt8:
|
||||||
|
reference_ops::Dequantize(data->quantization_params,
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
MicroPrintf("Input %s, output %s not supported.",
|
MicroPrintf("Input %s, output %s not supported.",
|
||||||
TfLiteTypeGetName(input->type),
|
TfLiteTypeGetName(input->type),
|
||||||
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_DEQUANTIZE() {
|
TfLiteRegistration Register_DEQUANTIZE() {
|
||||||
return {/*init=*/DequantizeInit,
|
return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
|
||||||
/*free=*/nullptr,
|
DequantizeEval);
|
||||||
/*prepare=*/DequantizePrepare,
|
|
||||||
/*invoke=*/DequantizeEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
TF_LITE_ENSURE(context,
|
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
input->type == kTfLiteInt16 ||
|
||||||
|
input->type == kTfLiteUInt8);
|
||||||
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
|
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
|
||||||
|
|
||||||
if (output->type == kTfLiteInt32) {
|
if (output->type == kTfLiteInt32) {
|
||||||
|
|||||||
@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {}
|
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
|
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
|
||||||
static TfLiteRegistration r = {/*init=*/Init,
|
static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/Free,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
@@ -27,6 +29,22 @@ namespace micro {
|
|||||||
namespace elementwise {
|
namespace elementwise {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kAbsNameId = 0;
|
||||||
|
constexpr int kRsrqtNameId = 1;
|
||||||
|
|
||||||
|
const int kElementwiseInputTensor = 0;
|
||||||
|
const int kElementwiseOutputTensor = 0;
|
||||||
|
|
||||||
|
struct OpDataAbsRsqrt {
|
||||||
|
int32_t multiplier;
|
||||||
|
int shift;
|
||||||
|
int input_offset;
|
||||||
|
int output_offset;
|
||||||
|
bool needs_rescale;
|
||||||
|
TfLiteQuantizationType input_quantization_type;
|
||||||
|
TfLiteType input_type;
|
||||||
|
};
|
||||||
|
|
||||||
bool IsNumericSupportedType(const TfLiteType type) {
|
bool IsNumericSupportedType(const TfLiteType type) {
|
||||||
return type == kTfLiteFloat32;
|
return type == kTfLiteFloat32;
|
||||||
}
|
}
|
||||||
@@ -35,16 +53,40 @@ bool IsLogicalSupportedType(const TfLiteType type) {
|
|||||||
return type == kTfLiteBool;
|
return type == kTfLiteBool;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsAbsSupportedType(const TfLiteType type) {
|
||||||
|
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsRsqrtSupportedType(const TfLiteType type) {
|
||||||
|
return type == kTfLiteFloat32 || type == kTfLiteInt8;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetAbsOutputMultiplier(const float input_scale,
|
||||||
|
const float output_scale,
|
||||||
|
int32_t* multiplier, int* shift) {
|
||||||
|
QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
|
||||||
|
multiplier, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetRsqrtOutputMultiplier(const float input_scale,
|
||||||
|
const float output_scale,
|
||||||
|
int32_t* multiplier, int* shift) {
|
||||||
|
const double scale =
|
||||||
|
1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
|
||||||
|
QuantizeMultiplier(scale, multiplier, shift);
|
||||||
|
}
|
||||||
|
|
||||||
typedef bool (*IsSupportedType)(TfLiteType);
|
typedef bool (*IsSupportedType)(TfLiteType);
|
||||||
template <IsSupportedType>
|
template <IsSupportedType>
|
||||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
MicroContext* micro_context = GetMicroContext(context);
|
MicroContext* micro_context = GetMicroContext(context);
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
TfLiteTensor* input =
|
||||||
|
micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
TfLiteTensor* output =
|
||||||
|
micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
if (!IsSupportedType(input->type)) {
|
if (!IsSupportedType(input->type)) {
|
||||||
@@ -58,9 +100,79 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef bool (*IsSupportedType)(TfLiteType);
|
||||||
|
template <IsSupportedType, const int op_nameid>
|
||||||
|
TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
MicroContext* micro_context = GetMicroContext(context);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
if (!IsSupportedType(input->type)) {
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||||
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
|
||||||
|
op_data->input_type = input->type;
|
||||||
|
|
||||||
|
// For int16 type input, we support both quantized and non-quantized
|
||||||
|
// evaluation.
|
||||||
|
if (op_nameid == kAbsNameId) {
|
||||||
|
op_data->input_quantization_type = input->quantization.type;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input->type == kTfLiteInt8 ||
|
||||||
|
(input->type == kTfLiteInt16 &&
|
||||||
|
input->quantization.type != kTfLiteNoQuantization)) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->quantization.type,
|
||||||
|
kTfLiteAffineQuantization);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->quantization.type,
|
||||||
|
kTfLiteAffineQuantization);
|
||||||
|
const auto* input_params =
|
||||||
|
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
|
||||||
|
const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||||
|
output->quantization.params);
|
||||||
|
TF_LITE_ENSURE(context, input_params != nullptr);
|
||||||
|
TF_LITE_ENSURE(context, input_params->scale != nullptr);
|
||||||
|
TF_LITE_ENSURE(context, input_params->scale->size > 0);
|
||||||
|
TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
|
||||||
|
TF_LITE_ENSURE(context, output_params != nullptr);
|
||||||
|
TF_LITE_ENSURE(context, output_params->scale != nullptr);
|
||||||
|
TF_LITE_ENSURE(context, output_params->scale->size > 0);
|
||||||
|
TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
|
||||||
|
op_data->input_offset = input_params->zero_point->data[0];
|
||||||
|
op_data->output_offset = output_params->zero_point->data[0];
|
||||||
|
if (input->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
|
||||||
|
TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
|
||||||
|
}
|
||||||
|
const float input_scale = input_params->scale->data[0];
|
||||||
|
const float output_scale = output_params->scale->data[0];
|
||||||
|
op_data->needs_rescale = input_scale != output_scale;
|
||||||
|
if (op_nameid == kAbsNameId && op_data->needs_rescale) {
|
||||||
|
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
|
||||||
|
&op_data->shift);
|
||||||
|
} else if (op_nameid == kRsrqtNameId) {
|
||||||
|
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
|
||||||
|
&op_data->shift);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
micro_context->DeallocateTempTfLiteTensor(input);
|
||||||
|
micro_context->DeallocateTempTfLiteTensor(output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
inline TfLiteStatus EvalImplQuantized(
|
||||||
T func(T), TfLiteType expected_type) {
|
TfLiteContext* context, TfLiteNode* node,
|
||||||
|
T func(TfLiteContext*, TfLiteNode*, T),
|
||||||
|
TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
|
||||||
|
TfLiteType expected_type) {
|
||||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
||||||
@@ -68,6 +180,34 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
|||||||
const T* in_data = tflite::micro::GetTensorData<T>(input);
|
const T* in_data = tflite::micro::GetTensorData<T>(input);
|
||||||
T* out_data = tflite::micro::GetTensorData<T>(output);
|
T* out_data = tflite::micro::GetTensorData<T>(output);
|
||||||
for (size_t i = 0; i < num_elements; ++i) {
|
for (size_t i = 0; i < num_elements; ++i) {
|
||||||
|
if (validate_input_func) {
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
validate_input_func(context, node, in_data[i]));
|
||||||
|
}
|
||||||
|
out_data[i] = func(context, node, in_data[i]);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T AbsHelper(T i) {
|
||||||
|
return std::abs(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
T func(T), TfLiteStatus validate_input_func(T),
|
||||||
|
TfLiteType expected_type) {
|
||||||
|
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||||
|
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
||||||
|
const size_t num_elements = ElementCount(*input->dims);
|
||||||
|
const T* in_data = tflite::micro::GetTensorData<T>(input);
|
||||||
|
T* out_data = tflite::micro::GetTensorData<T>(output);
|
||||||
|
for (size_t i = 0; i < num_elements; ++i) {
|
||||||
|
if (validate_input_func) {
|
||||||
|
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
|
||||||
|
}
|
||||||
out_data[i] = func(in_data[i]);
|
out_data[i] = func(in_data[i]);
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
|||||||
|
|
||||||
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
|
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
|
||||||
float float_func(float)) {
|
float float_func(float)) {
|
||||||
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
|
return EvalImpl<float>(context, node, float_func,
|
||||||
|
/*validate_input_func=*/nullptr, kTfLiteFloat32);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
|
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
|
||||||
bool bool_func(bool)) {
|
bool bool_func(bool)) {
|
||||||
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
|
return EvalImpl<bool>(context, node, bool_func,
|
||||||
|
/*validate_input_func=*/nullptr, kTfLiteBool);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
|
||||||
|
size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
|
||||||
|
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||||
|
const int kMin = std::numeric_limits<T>::min();
|
||||||
|
const int kMax = std::numeric_limits<T>::max();
|
||||||
|
|
||||||
|
const int32_t value = std::abs(i - op_data->input_offset);
|
||||||
|
if (!op_data->needs_rescale) {
|
||||||
|
return static_cast<T>(
|
||||||
|
std::min(std::max(static_cast<long int>(value + op_data->output_offset),
|
||||||
|
static_cast<long int>(kMin)),
|
||||||
|
static_cast<long int>(kMax)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t output = tflite::MultiplyByQuantizedMultiplier(
|
||||||
|
value, op_data->multiplier, op_data->shift) +
|
||||||
|
op_data->output_offset;
|
||||||
|
return static_cast<T>(std::min(
|
||||||
|
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
|
||||||
|
static_cast<long int>(kMax)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
|
||||||
|
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||||
|
const int kMin = std::numeric_limits<T>::min();
|
||||||
|
const int kMax = std::numeric_limits<T>::max();
|
||||||
|
|
||||||
|
const int32_t value = (i - op_data->input_offset);
|
||||||
|
const int32_t kShift = 20; // Shift to keep value integer.
|
||||||
|
if (value == 0) {
|
||||||
|
// Assume that any value close to 0 represents the max output value.
|
||||||
|
return static_cast<T>(kMax);
|
||||||
|
}
|
||||||
|
int32_t inv_sqrt_multiplier;
|
||||||
|
int inv_sqrt_shift;
|
||||||
|
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
|
||||||
|
&inv_sqrt_shift);
|
||||||
|
const int32_t data = tflite::MultiplyByQuantizedMultiplier(
|
||||||
|
static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
|
||||||
|
const int32_t output =
|
||||||
|
tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
|
||||||
|
op_data->shift - kShift) +
|
||||||
|
op_data->output_offset;
|
||||||
|
return static_cast<T>(std::min(
|
||||||
|
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
|
||||||
|
static_cast<long int>(kMax)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
T i) {
|
||||||
|
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
|
||||||
|
"Rsqrt is only defined for positive values");
|
||||||
|
return static_cast<TfLiteStatus>(kTfLiteOk);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
return EvalNumeric(context, node, std::abs);
|
OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
|
||||||
|
TfLiteType type = op_data->input_type;
|
||||||
|
TfLiteQuantizationType input_quantization_type =
|
||||||
|
op_data->input_quantization_type;
|
||||||
|
TfLiteStatus eval_result;
|
||||||
|
|
||||||
|
switch (type) {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
eval_result = EvalNumeric(context, node, std::abs);
|
||||||
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
eval_result =
|
||||||
|
EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
|
||||||
|
/*validate_input_func=*/nullptr, type);
|
||||||
|
break;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
eval_result =
|
||||||
|
input_quantization_type == kTfLiteNoQuantization
|
||||||
|
? EvalImpl<int16_t>(context, node, AbsHelper,
|
||||||
|
/*validate_input_func=*/nullptr, type)
|
||||||
|
: EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
|
||||||
|
/*validate_input_func=*/nullptr,
|
||||||
|
type);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||||
|
TfLiteTypeGetName(type));
|
||||||
|
return kTfLiteError;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return eval_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
|
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
|
||||||
|
TfLiteType type = op_data->input_type;
|
||||||
|
switch (type) {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
return EvalImpl<float>(
|
||||||
|
context, node, [](float f) { return 1.f / std::sqrt(f); },
|
||||||
|
/*validate_input_func=*/nullptr, type);
|
||||||
|
case kTfLiteInt8:
|
||||||
|
return EvalImplQuantized<int8_t>(context, node,
|
||||||
|
elementwise::RsqrtEvalQuantized,
|
||||||
|
elementwise::validate_input_func, type);
|
||||||
|
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
|
||||||
|
TfLiteTypeGetName(type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@@ -119,99 +373,55 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace elementwise
|
} // namespace elementwise
|
||||||
|
|
||||||
TfLiteRegistration Register_ABS() {
|
TfLiteRegistration Register_ABS() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
elementwise::ElementWiseAbsRsqrtInit,
|
||||||
/*prepare=*/
|
elementwise::PrepareAbsRsqrt<elementwise::IsAbsSupportedType,
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
elementwise::kAbsNameId>,
|
||||||
/*invoke=*/elementwise::AbsEval,
|
elementwise::AbsEval);
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_SIN() {
|
TfLiteRegistration Register_SIN() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
/*prepare=*/
|
elementwise::SinEval);
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
/*invoke=*/elementwise::SinEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_COS() {
|
TfLiteRegistration Register_COS() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
/*prepare=*/
|
elementwise::CosEval);
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
/*invoke=*/elementwise::CosEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_LOG() {
|
TfLiteRegistration Register_LOG() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
/*prepare=*/
|
elementwise::LogEval);
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
/*invoke=*/elementwise::LogEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_SQRT() {
|
TfLiteRegistration Register_SQRT() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
/*prepare=*/
|
elementwise::SqrtEval);
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
/*invoke=*/elementwise::SqrtEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_RSQRT() {
|
TfLiteRegistration Register_RSQRT() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
elementwise::ElementWiseAbsRsqrtInit,
|
||||||
/*prepare=*/
|
elementwise::PrepareAbsRsqrt<elementwise::IsRsqrtSupportedType,
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
elementwise::kRsrqtNameId>,
|
||||||
/*invoke=*/elementwise::RsqrtEval,
|
elementwise::RsqrtEval);
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_SQUARE() {
|
TfLiteRegistration Register_SQUARE() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
/*prepare=*/
|
elementwise::SquareEval);
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
/*invoke=*/elementwise::SquareEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_LOGICAL_NOT() {
|
TfLiteRegistration Register_LOGICAL_NOT() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(
|
||||||
/*free=*/nullptr,
|
nullptr, elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
|
||||||
/*prepare=*/
|
elementwise::LogicalNotEval);
|
||||||
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
|
|
||||||
/*invoke=*/elementwise::LogicalNotEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
|
|||||||
@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_ELU() {
|
TfLiteRegistration Register_ELU() {
|
||||||
return {/*init=*/EluInit,
|
return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/EluPrepare,
|
|
||||||
/*invoke=*/EluEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_ADD() {
|
TfLiteRegistration Register_ADD() {
|
||||||
return {/*init=*/AddInit,
|
return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/AddPrepare,
|
|
||||||
/*invoke=*/AddEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
#if ESP_NN
|
#if ESP_NN
|
||||||
if (input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteInt8) {
|
||||||
|
data_dims_t input_dims = {
|
||||||
|
.width = input_width, .height = input_height,
|
||||||
|
.channels = input->dims->data[3], 1
|
||||||
|
};
|
||||||
|
data_dims_t output_dims = {
|
||||||
|
.width = output_width, .height = output_height,
|
||||||
|
.channels = output->dims->data[3], 1
|
||||||
|
};
|
||||||
|
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||||
|
conv_params_t conv_params = {
|
||||||
|
.in_offset = 0, .out_offset = 0,
|
||||||
|
.stride = {params.stride_width, params.stride_height},
|
||||||
|
.padding = {data->op_data.padding.width, data->op_data.padding.height},
|
||||||
|
.dilation = {0, 0}, .activation = {-128, 127}
|
||||||
|
};
|
||||||
|
|
||||||
int scratch_buf_size = esp_nn_get_conv_scratch_size(
|
int scratch_buf_size = esp_nn_get_conv_scratch_size(
|
||||||
input_width, input_height, input->dims->data[3],
|
&input_dims, &filter_dims, &output_dims, &conv_params);
|
||||||
output->dims->data[3], filter_width, filter_height);
|
|
||||||
if (scratch_buf_size > 0) {
|
if (scratch_buf_size > 0) {
|
||||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||||
context, scratch_buf_size, &data->buffer_idx));
|
context, scratch_buf_size, &data->buffer_idx));
|
||||||
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
|
|||||||
const int input_size = input_width * input_height * input_depth;
|
const int input_size = input_width * input_height * input_depth;
|
||||||
const int output_size = output_width * output_height * output_depth;
|
const int output_size = output_width * output_height * output_depth;
|
||||||
|
|
||||||
|
data_dims_t input_dims = {
|
||||||
|
.width = input_width, .height = input_height,
|
||||||
|
.channels = input_depth, 1
|
||||||
|
};
|
||||||
|
data_dims_t output_dims = {
|
||||||
|
.width = output_width, .height = output_height,
|
||||||
|
.channels = output_depth, 1
|
||||||
|
};
|
||||||
|
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||||
|
conv_params_t conv_params = {
|
||||||
|
.in_offset = input_offset, .out_offset = output_offset,
|
||||||
|
.stride = {stride_width, stride_height},
|
||||||
|
.padding = {pad_width, pad_height},
|
||||||
|
.dilation = {0, 0},
|
||||||
|
.activation = {activation_min, activation_max}
|
||||||
|
};
|
||||||
|
quant_data_t quant_data = {
|
||||||
|
.shift = data.op_data.per_channel_output_shift,
|
||||||
|
.mult = data.op_data.per_channel_output_multiplier
|
||||||
|
};
|
||||||
|
|
||||||
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
||||||
esp_nn_conv_s8(input_data + i_batch * input_size,
|
esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
|
||||||
input_width, input_height, input_depth, input_offset,
|
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
pad_width, pad_height, stride_width, stride_height,
|
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
|
||||||
filter_width, filter_height,
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
tflite::micro::GetTensorData<int32_t>(bias),
|
||||||
output_data + i_batch * output_size,
|
&output_dims, output_data + i_batch * output_size,
|
||||||
output_width, output_height, output_depth, output_offset,
|
&conv_params, &quant_data);
|
||||||
data.op_data.per_channel_output_shift,
|
|
||||||
data.op_data.per_channel_output_multiplier,
|
|
||||||
activation_min, activation_max);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reference_integer_ops::ConvPerChannel(
|
reference_integer_ops::ConvPerChannel(
|
||||||
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTypeGetName(input->type), input->type);
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
conv_total_time += esp_timer_get_time() - start_time;
|
long long time_this_instance = esp_timer_get_time() - start_time;
|
||||||
|
conv_total_time += time_this_instance;
|
||||||
|
//printf("time this instance: %llu\n", time_this_instance / 1000);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_CONV_2D() {
|
TfLiteRegistration Register_CONV_2D() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|||||||
if (data.buffer_idx > -1) {
|
if (data.buffer_idx > -1) {
|
||||||
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
|
scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
|
esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
|
||||||
|
|
||||||
|
data_dims_t input_dims = {
|
||||||
|
.width = input_width, .height = input_height,
|
||||||
|
.channels = input_depth, 1
|
||||||
|
};
|
||||||
|
data_dims_t output_dims = {
|
||||||
|
.width = output_width, .height = output_height,
|
||||||
|
.channels = output_depth, 1
|
||||||
|
};
|
||||||
|
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||||
|
dw_conv_params_t conv_params = {
|
||||||
|
.in_offset = input_offset, .out_offset = output_offset,
|
||||||
|
.ch_mult = depth_multiplier,
|
||||||
|
.stride = {stride_width, stride_height},
|
||||||
|
.padding = {pad_width, pad_height}, .dilation = {0, 0},
|
||||||
|
.activation = {activation_min, activation_max}
|
||||||
|
};
|
||||||
|
quant_data_t quant_data = {
|
||||||
|
.shift = data.op_data.per_channel_output_shift,
|
||||||
|
.mult = data.op_data.per_channel_output_multiplier
|
||||||
|
};
|
||||||
|
|
||||||
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
for (int i_batch = 0; i_batch < batch_size; i_batch++) {
|
||||||
esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
|
esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
|
||||||
input_height, input_depth, input_offset,
|
&filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
pad_width, pad_height,
|
|
||||||
stride_width, stride_height, depth_multiplier,
|
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
|
||||||
filter_width, filter_height,
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
tflite::micro::GetTensorData<int32_t>(bias),
|
||||||
output_data + i_batch * output_size,
|
&output_dims, output_data + i_batch * output_size,
|
||||||
output_width, output_height, output_offset,
|
&conv_params, &quant_data);
|
||||||
data.op_data.per_channel_output_shift,
|
|
||||||
data.op_data.per_channel_output_multiplier,
|
|
||||||
activation_min, activation_max);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reference_integer_ops::DepthwiseConvPerChannel(
|
reference_integer_ops::DepthwiseConvPerChannel(
|
||||||
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
#if ESP_NN
|
#if ESP_NN
|
||||||
if (input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteInt8) {
|
||||||
|
data_dims_t input_dims = {
|
||||||
|
.width = input_width, .height = input_height,
|
||||||
|
.channels = input->dims->data[3], 1
|
||||||
|
};
|
||||||
|
data_dims_t output_dims = {
|
||||||
|
.width = output_width, .height = output_height,
|
||||||
|
.channels = output->dims->data[3], 1
|
||||||
|
};
|
||||||
|
data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
|
||||||
|
dw_conv_params_t conv_params = {
|
||||||
|
.in_offset = 0, .out_offset = 0,
|
||||||
|
.ch_mult = params.depth_multiplier,
|
||||||
|
.stride = {params.stride_width, params.stride_height},
|
||||||
|
.padding = {data->op_data.padding.width, data->op_data.padding.height},
|
||||||
|
.dilation = {0, 0}, .activation = {-128, 127}
|
||||||
|
};
|
||||||
|
|
||||||
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
|
int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
|
||||||
input_width, input_height, input->dims->data[3],
|
&input_dims, &filter_dims, &output_dims, &conv_params);
|
||||||
params.depth_multiplier, filter_width, filter_height);
|
|
||||||
if (scratch_buf_size > 0) {
|
if (scratch_buf_size > 0) {
|
||||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||||
context, scratch_buf_size, &data->buffer_idx));
|
context, scratch_buf_size, &data->buffer_idx));
|
||||||
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTypeGetName(input->type), input->type);
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
dc_total_time += esp_timer_get_time() - start_time;
|
long long time_this_instance = esp_timer_get_time() - start_time;
|
||||||
|
dc_total_time += time_this_instance;
|
||||||
|
// printf("time this instance: %llu\n", time_this_instance / 1000);
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_MUL() {
|
TfLiteRegistration Register_MUL() {
|
||||||
return {/*init=*/MulInit,
|
return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/MulPrepare,
|
|
||||||
/*invoke=*/MulEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_AVERAGE_POOL_2D() {
|
TfLiteRegistration Register_AVERAGE_POOL_2D() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/PoolingPrepare,
|
|
||||||
/*invoke=*/AverageEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_MAX_POOL_2D() {
|
TfLiteRegistration Register_MAX_POOL_2D() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/PoolingPrepare,
|
|
||||||
/*invoke=*/MaxEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -0,0 +1,208 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
#include "freertos/FreeRTOS.h"
|
||||||
|
#include <esp_timer.h>
|
||||||
|
|
||||||
|
#if ESP_NN
|
||||||
|
#include <esp_nn.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
long long softmax_total_time = 0;
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
// Softmax parameter data that persists in user_data
|
||||||
|
const int kInt16LUTArraySize = 513;
|
||||||
|
|
||||||
|
struct NodeData {
|
||||||
|
SoftmaxParams op_data;
|
||||||
|
#if ESP_NN
|
||||||
|
int buffer_idx;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(NodeData));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
|
||||||
|
TfLiteEvalTensor* output, const NodeData* data) {
|
||||||
|
if (input->type == kTfLiteInt8) {
|
||||||
|
if (output->type == kTfLiteInt16) {
|
||||||
|
tflite::reference_ops::Softmax(
|
||||||
|
data->op_data, tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(output));
|
||||||
|
} else {
|
||||||
|
#if ESP_NN
|
||||||
|
const int32_t input_beta_multiplier = data->op_data.input_multiplier;
|
||||||
|
const int32_t input_beta_left_shift = data->op_data.input_left_shift;
|
||||||
|
const int diff_min = data->op_data.diff_min;
|
||||||
|
const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
|
||||||
|
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
|
||||||
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
|
const int outer_size =
|
||||||
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
|
const int depth =
|
||||||
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
const int8_t *in_ptr = tflite::micro::GetTensorData<int8_t>(input);
|
||||||
|
int8_t *out_ptr = tflite::micro::GetTensorData<int8_t>(output);
|
||||||
|
void *scratch_buf = NULL;
|
||||||
|
if (data->buffer_idx > -1) {
|
||||||
|
scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
|
||||||
|
}
|
||||||
|
esp_nn_set_softmax_scratch_buf(scratch_buf);
|
||||||
|
esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
|
||||||
|
input_beta_left_shift, diff_min, out_ptr);
|
||||||
|
#else
|
||||||
|
tflite::reference_ops::Softmax(
|
||||||
|
data->op_data, tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tflite::reference_ops::SoftmaxInt16(
|
||||||
|
data->op_data, tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(output));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||||
|
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||||
|
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
NodeData data = *static_cast<NodeData*>(node->user_data);
|
||||||
|
|
||||||
|
long long start_time = esp_timer_get_time();
|
||||||
|
switch (input->type) {
|
||||||
|
case kTfLiteFloat32: {
|
||||||
|
tflite::reference_ops::Softmax(
|
||||||
|
data.op_data, tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<float>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
case kTfLiteInt16: {
|
||||||
|
SoftmaxQuantized(context, input, output, &data);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
softmax_total_time += esp_timer_get_time() - start_time;
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
MicroContext* micro_context = GetMicroContext(context);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||||
|
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||||
|
NodeData* data = static_cast<NodeData*>(node->user_data);
|
||||||
|
// Only allocate LUTs for KTfLiteInt16 data type
|
||||||
|
if (input->type == kTfLiteInt16) {
|
||||||
|
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||||
|
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||||
|
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||||
|
data->op_data.exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||||
|
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||||
|
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||||
|
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||||
|
data->op_data.one_over_one_plus_x_lut =
|
||||||
|
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE(context,
|
||||||
|
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate LUT if required
|
||||||
|
if (input->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
|
// exp LUT only used on negative values
|
||||||
|
// we consider exp(-10.0) is insignificant to accumulation
|
||||||
|
gen_lut<float, int16_t, int16_t>(
|
||||||
|
[](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
|
||||||
|
data->op_data.exp_lut);
|
||||||
|
gen_lut<float, int16_t, int16_t>(
|
||||||
|
[](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
|
||||||
|
1.0f, data->op_data.one_over_one_plus_x_lut);
|
||||||
|
data->op_data.zero_point = output->params.zero_point;
|
||||||
|
data->op_data.scale = output->params.scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
|
auto ret_val =
|
||||||
|
CalculateSoftmaxParams(context, input, output, params, &data->op_data);
|
||||||
|
|
||||||
|
#if ESP_NN
|
||||||
|
if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
|
||||||
|
const int32_t input_width = input->dims->data[1];
|
||||||
|
const int32_t input_height = input->dims->data[2];
|
||||||
|
int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
|
||||||
|
input_height);
|
||||||
|
if (scratch_buf_size > 0) {
|
||||||
|
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||||
|
context, scratch_buf_size, &data->buffer_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
micro_context->DeallocateTempTfLiteTensor(input);
|
||||||
|
micro_context->DeallocateTempTfLiteTensor(output);
|
||||||
|
return ret_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_SOFTMAX() {
|
||||||
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_EXP() {
|
TfLiteRegistration Register_EXP() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_EXPAND_DIMS() {
|
TfLiteRegistration Register_EXPAND_DIMS() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_FILL() {
|
TfLiteRegistration Register_FILL() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace floor
|
} // namespace floor
|
||||||
|
|
||||||
TfLiteRegistration Register_FLOOR() {
|
TfLiteRegistration Register_FLOOR() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/nullptr,
|
|
||||||
/*invoke=*/floor::Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
|
|||||||
@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_FLOOR_DIV() {
|
TfLiteRegistration Register_FLOOR_DIV() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_FLOOR_MOD() {
|
TfLiteRegistration Register_FLOOR_MOD() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
|
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
|
||||||
node, kFullyConnectedOutputTensor);
|
node, kFullyConnectedOutputTensor);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
|
||||||
"Hybrid models are not supported on TFLite Micro.");
|
|
||||||
|
|
||||||
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
|
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
|
||||||
context, params->activation, input->type,
|
context, params->activation, input->type,
|
||||||
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case kTfLiteInt16: {
|
||||||
|
const int64_t* bias_data =
|
||||||
|
nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
tflite::reference_integer_ops::FullyConnected(
|
||||||
|
FullyConnectedParamsQuantized(data),
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias), bias_data,
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default: {
|
default: {
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
TfLiteTypeGetName(input->type), input->type);
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
@@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(CMSIS_NN)
|
||||||
|
// Returns a TfLiteRegistration struct for kernel variant that only supports
|
||||||
|
// int16.
|
||||||
|
TfLiteRegistration Register_FULLY_CONNECTED_INT16();
|
||||||
|
|
||||||
|
#else
|
||||||
|
// Note that while this block gets used for both reference and optimized kernels
|
||||||
|
// that do not have any specialized implementations, the only goal here is to
|
||||||
|
// define fallback implementation that allow reference kernels to still be used
|
||||||
|
// from applications that call a more specific kernel variant.
|
||||||
|
|
||||||
|
inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
|
||||||
|
return Register_FULLY_CONNECTED();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||||
|
|||||||
@@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_GATHER() {
|
TfLiteRegistration Register_GATHER() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -195,14 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_GATHER_ND() {
|
TfLiteRegistration Register_GATHER_ND() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_HARD_SWISH() {
|
TfLiteRegistration Register_HARD_SWISH() {
|
||||||
return {/*init=*/HardSwishInit,
|
return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare,
|
||||||
/*free=*/nullptr,
|
HardSwishEval);
|
||||||
/*prepare=*/tflite::HardSwishPrepare,
|
|
||||||
/*invoke=*/HardSwishEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace.
|
} // namespace.
|
||||||
|
|
||||||
TfLiteRegistration Register_IF() {
|
TfLiteRegistration Register_IF() {
|
||||||
return {/*init=*/Init,
|
return tflite::micro::RegisterOp(Init, Prepare, Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/Prepare,
|
|
||||||
/*invoke=*/Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||||
#include "tensorflow/lite/micro/micro_arena_constants.h"
|
#include "tensorflow/lite/micro/micro_arena_constants.h"
|
||||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
|
||||||
#include "tensorflow/lite/micro/test_helpers.h"
|
#include "tensorflow/lite/micro/test_helpers.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@@ -30,7 +30,7 @@ uint8_t KernelRunner::kKernelRunnerBuffer_[];
|
|||||||
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||||
TfLiteTensor* tensors, int tensors_size,
|
TfLiteTensor* tensors, int tensors_size,
|
||||||
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
||||||
void* builtin_data)
|
void* builtin_data, TfLiteIntArray* intermediates)
|
||||||
: registration_(registration),
|
: registration_(registration),
|
||||||
allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
|
allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
|
||||||
kKernelRunnerBuffer_,
|
kKernelRunnerBuffer_,
|
||||||
@@ -54,6 +54,7 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
|||||||
node_.inputs = inputs;
|
node_.inputs = inputs;
|
||||||
node_.outputs = outputs;
|
node_.outputs = outputs;
|
||||||
node_.builtin_data = builtin_data;
|
node_.builtin_data = builtin_data;
|
||||||
|
node_.intermediates = intermediates;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool KernelRunner::ValidateTempBufferDeallocated() {
|
bool KernelRunner::ValidateTempBufferDeallocated() {
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
|
||||||
#include "tensorflow/lite/micro/fake_micro_context.h"
|
#include "tensorflow/lite/micro/fake_micro_context.h"
|
||||||
#include "tensorflow/lite/micro/mock_micro_graph.h"
|
#include "tensorflow/lite/micro/mock_micro_graph.h"
|
||||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace micro {
|
namespace micro {
|
||||||
@@ -35,7 +35,8 @@ class KernelRunner {
|
|||||||
public:
|
public:
|
||||||
KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
|
KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
|
||||||
int tensors_size, TfLiteIntArray* inputs,
|
int tensors_size, TfLiteIntArray* inputs,
|
||||||
TfLiteIntArray* outputs, void* builtin_data);
|
TfLiteIntArray* outputs, void* builtin_data,
|
||||||
|
TfLiteIntArray* intermediates = nullptr);
|
||||||
|
|
||||||
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
|
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
|
||||||
// exceptions will be DebugLog'd and returned as a status code.
|
// exceptions will be DebugLog'd and returned as a status code.
|
||||||
|
|||||||
@@ -36,6 +36,21 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration RegisterOp(
|
||||||
|
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||||
|
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||||
|
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
|
||||||
|
return {/*init=*/init,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/prepare,
|
||||||
|
/*invoke=*/invoke,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0,
|
||||||
|
/*registration_external=*/nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a mutable tensor for a given input index. is_variable must be checked
|
// Returns a mutable tensor for a given input index. is_variable must be checked
|
||||||
// during prepare when the full TfLiteTensor is available.
|
// during prepare when the full TfLiteTensor is available.
|
||||||
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
|
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
|
||||||
|
|||||||
@@ -27,6 +27,11 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace micro {
|
namespace micro {
|
||||||
|
|
||||||
|
TfLiteRegistration RegisterOp(
|
||||||
|
void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
|
||||||
|
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
|
||||||
|
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
|
||||||
|
|
||||||
// Returns a mutable tensor for a given input index. is_variable must be checked
|
// Returns a mutable tensor for a given input index. is_variable must be checked
|
||||||
// during prepare when the full TfLiteTensor is available.
|
// during prepare when the full TfLiteTensor is available.
|
||||||
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
|
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
|
||||||
@@ -40,19 +45,33 @@ const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context,
|
|||||||
TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
|
TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
|
||||||
const TfLiteNode* node, int index);
|
const TfLiteNode* node, int index);
|
||||||
|
|
||||||
// Returns data for a TfLiteEvalTensor struct.
|
// Returns data for a TfLiteEvalTensor struct that are expected to exist.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* GetTensorData(TfLiteEvalTensor* tensor) {
|
T* GetTensorData(TfLiteEvalTensor* tensor) {
|
||||||
return tensor != nullptr ? reinterpret_cast<T*>(tensor->data.raw) : nullptr;
|
TFLITE_DCHECK(tensor != nullptr);
|
||||||
|
return reinterpret_cast<T*>(tensor->data.raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns const data for a TfLiteEvalTensor struct.
|
// Returns const data for a TfLiteEvalTensor struct that are expected to exist.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T* GetTensorData(const TfLiteEvalTensor* tensor) {
|
const T* GetTensorData(const TfLiteEvalTensor* tensor) {
|
||||||
TFLITE_DCHECK(tensor != nullptr);
|
TFLITE_DCHECK(tensor != nullptr);
|
||||||
return reinterpret_cast<const T*>(tensor->data.raw);
|
return reinterpret_cast<const T*>(tensor->data.raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns data for a TfLiteEvalTensor struct that could be null.
|
||||||
|
template <typename T>
|
||||||
|
T* GetOptionalTensorData(TfLiteEvalTensor* tensor) {
|
||||||
|
return tensor == nullptr ? nullptr : reinterpret_cast<T*>(tensor->data.raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns const data for a TfLiteEvalTensor struct that could be null.
|
||||||
|
template <typename T>
|
||||||
|
const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
|
||||||
|
return tensor == nullptr ? nullptr
|
||||||
|
: reinterpret_cast<const T*>(tensor->data.raw);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the shape of a TfLiteEvalTensor struct.
|
// Returns the shape of a TfLiteEvalTensor struct.
|
||||||
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
|
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
|
||||||
|
|
||||||
|
|||||||
@@ -136,14 +136,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_L2_POOL_2D() {
|
TfLiteRegistration Register_L2_POOL_2D() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, L2Prepare, L2Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/L2Prepare,
|
|
||||||
/*invoke=*/L2Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -137,14 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace l2norm
|
} // namespace l2norm
|
||||||
|
|
||||||
TfLiteRegistration Register_L2NORM_REF() {
|
TfLiteRegistration Register_L2NORM_REF() {
|
||||||
return {/*init=*/l2norm::Init,
|
return tflite::micro::RegisterOp(l2norm::Init, l2norm::Prepare, l2norm::Eval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/l2norm::Prepare,
|
|
||||||
/*invoke=*/l2norm::Eval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }
|
TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }
|
||||||
|
|||||||
@@ -88,14 +88,8 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_LEAKY_RELU() {
|
TfLiteRegistration Register_LEAKY_RELU() {
|
||||||
return {/*init=*/LeakyReluInit,
|
return tflite::micro::RegisterOp(LeakyReluInit, LeakyReluPrepare,
|
||||||
/*free=*/nullptr,
|
LeakyReluEval);
|
||||||
/*prepare=*/LeakyReluPrepare,
|
|
||||||
/*invoke=*/LeakyReluEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -142,14 +142,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_LOG_SOFTMAX() {
|
TfLiteRegistration Register_LOG_SOFTMAX() {
|
||||||
return {/*init=*/nullptr,
|
return tflite::micro::RegisterOp(nullptr, LogSoftmaxPrepare, LogSoftmaxEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/LogSoftmaxPrepare,
|
|
||||||
/*invoke=*/LogSoftmaxEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -34,29 +34,11 @@ TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_LOGICAL_OR() {
|
TfLiteRegistration Register_LOGICAL_OR() {
|
||||||
// Init, Free, Prepare, Eval are satisfying the Interface required by
|
return tflite::micro::RegisterOp(nullptr, nullptr, LogicalOrEval);
|
||||||
// TfLiteRegistration.
|
|
||||||
return {/*init=*/nullptr,
|
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/nullptr,
|
|
||||||
/*invoke=*/LogicalOrEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration Register_LOGICAL_AND() {
|
TfLiteRegistration Register_LOGICAL_AND() {
|
||||||
// Init, Free, Prepare, Eval are satisfying the Interface required by
|
return tflite::micro::RegisterOp(nullptr, nullptr, LogicalAndEval);
|
||||||
// TfLiteRegistration.
|
|
||||||
return {/*init=*/nullptr,
|
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/nullptr,
|
|
||||||
/*invoke=*/LogicalAndEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -106,13 +106,6 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_LOGISTIC() {
|
TfLiteRegistration Register_LOGISTIC() {
|
||||||
return {/*init=*/LogisticInit,
|
return tflite::micro::RegisterOp(LogisticInit, LogisticPrepare, LogisticEval);
|
||||||
/*free=*/nullptr,
|
|
||||||
/*prepare=*/LogisticPrepare,
|
|
||||||
/*invoke=*/LogisticEval,
|
|
||||||
/*profiling_string=*/nullptr,
|
|
||||||
/*builtin_code=*/0,
|
|
||||||
/*custom_name=*/nullptr,
|
|
||||||
/*version=*/0};
|
|
||||||
}
|
}
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,250 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
// Pamameters for integer LSTM.
|
||||||
|
// Consider split this into two Integer Parameters if more fields are added.
|
||||||
|
struct IntegerLstmParameter {
|
||||||
|
int32_t effective_input_to_input_scale_a;
|
||||||
|
int32_t effective_input_to_input_scale_b;
|
||||||
|
int32_t effective_recurrent_to_input_scale_a;
|
||||||
|
int32_t effective_recurrent_to_input_scale_b;
|
||||||
|
int32_t effective_cell_to_input_scale_a;
|
||||||
|
int32_t effective_cell_to_input_scale_b;
|
||||||
|
int32_t effective_input_to_forget_scale_a;
|
||||||
|
int32_t effective_input_to_forget_scale_b;
|
||||||
|
int32_t effective_recurrent_to_forget_scale_a;
|
||||||
|
int32_t effective_recurrent_to_forget_scale_b;
|
||||||
|
int32_t effective_cell_to_forget_scale_a;
|
||||||
|
int32_t effective_cell_to_forget_scale_b;
|
||||||
|
int32_t effective_input_to_cell_scale_a;
|
||||||
|
int32_t effective_input_to_cell_scale_b;
|
||||||
|
int32_t effective_recurrent_to_cell_scale_a;
|
||||||
|
int32_t effective_recurrent_to_cell_scale_b;
|
||||||
|
int32_t effective_input_to_output_scale_a;
|
||||||
|
int32_t effective_input_to_output_scale_b;
|
||||||
|
int32_t effective_recurrent_to_output_scale_a;
|
||||||
|
int32_t effective_recurrent_to_output_scale_b;
|
||||||
|
int32_t effective_cell_to_output_scale_a;
|
||||||
|
int32_t effective_cell_to_output_scale_b;
|
||||||
|
int32_t effective_proj_scale_a;
|
||||||
|
int32_t effective_proj_scale_b;
|
||||||
|
int32_t effective_hidden_scale_a;
|
||||||
|
int32_t effective_hidden_scale_b;
|
||||||
|
int32_t layer_norm_input_scale_a;
|
||||||
|
int32_t layer_norm_input_scale_b;
|
||||||
|
int32_t layer_norm_forget_scale_a;
|
||||||
|
int32_t layer_norm_forget_scale_b;
|
||||||
|
int32_t layer_norm_cell_scale_a;
|
||||||
|
int32_t layer_norm_cell_scale_b;
|
||||||
|
int32_t layer_norm_output_scale_a;
|
||||||
|
int32_t layer_norm_output_scale_b;
|
||||||
|
// Quantized clip value for cell and projection. Zero value means no clipping.
|
||||||
|
int16_t quantized_cell_clip;
|
||||||
|
int8_t quantized_proj_clip;
|
||||||
|
int32_t hidden_zp;
|
||||||
|
int32_t cell_scale;
|
||||||
|
|
||||||
|
int32_t input_variance_guard;
|
||||||
|
int32_t forget_variance_guard;
|
||||||
|
int32_t cell_variance_guard;
|
||||||
|
int32_t output_variance_guard;
|
||||||
|
|
||||||
|
// Pre-calculate bias + zero_point * weight.
|
||||||
|
int32_t* input_to_forget_effective_bias;
|
||||||
|
int32_t* recurrent_to_forget_effective_bias;
|
||||||
|
int32_t* input_to_cell_effective_bias;
|
||||||
|
int32_t* recurrent_to_cell_effective_bias;
|
||||||
|
int32_t* input_to_output_effective_bias;
|
||||||
|
int32_t* recurrent_to_output_effective_bias;
|
||||||
|
int32_t* input_to_input_effective_bias;
|
||||||
|
int32_t* recurrent_to_input_effective_bias;
|
||||||
|
int32_t* projection_effective_bias;
|
||||||
|
|
||||||
|
// Scale and zero point for intermediate tensors.
|
||||||
|
// Used only in the 8x8_8 case.
|
||||||
|
int32_t intermediate_scale_a[8];
|
||||||
|
int32_t intermediate_scale_b[8];
|
||||||
|
int32_t intermediate_zp[12];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Scales for hybrid op with integer inputs and float weights
|
||||||
|
struct HybridLstmScales {
|
||||||
|
float input_to_input_weights_scale;
|
||||||
|
float input_to_forget_weights_scale;
|
||||||
|
float input_to_cell_weights_scale;
|
||||||
|
float input_to_output_weights_scale;
|
||||||
|
float aux_input_to_input_weights_scale;
|
||||||
|
float aux_input_to_forget_weights_scale;
|
||||||
|
float aux_input_to_cell_weights_scale;
|
||||||
|
float aux_input_to_output_weights_scale;
|
||||||
|
float recurrent_to_input_weights_scale;
|
||||||
|
float recurrent_to_forget_weights_scale;
|
||||||
|
float recurrent_to_cell_weights_scale;
|
||||||
|
float recurrent_to_output_weights_scale;
|
||||||
|
float cell_to_input_weights_scale;
|
||||||
|
float cell_to_forget_weights_scale;
|
||||||
|
float cell_to_output_weights_scale;
|
||||||
|
float projection_weights_scale;
|
||||||
|
};
|
||||||
|
|
||||||
|
TfLiteStatus EvalFloatLstm(
|
||||||
|
const TfLiteEvalTensor* input,
|
||||||
|
const TfLiteEvalTensor* input_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* aux_input,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_gate_bias,
|
||||||
|
const TfLiteEvalTensor* forget_gate_bias,
|
||||||
|
const TfLiteEvalTensor* cell_gate_bias,
|
||||||
|
const TfLiteEvalTensor* output_gate_bias,
|
||||||
|
const TfLiteEvalTensor* projection_weights,
|
||||||
|
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||||
|
bool forward_sequence, bool time_major, int output_offset,
|
||||||
|
float* scratch_buffer, TfLiteEvalTensor* output_state,
|
||||||
|
TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
|
||||||
|
|
||||||
|
TfLiteStatus EvalHybridLstm(
|
||||||
|
const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
|
||||||
|
const TfLiteEvalTensor* input_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_input_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* input_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_forget_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* input_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_cell_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* input_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_output_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* cell_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* aux_input,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* aux_input_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_gate_bias,
|
||||||
|
const TfLiteEvalTensor* forget_gate_bias,
|
||||||
|
const TfLiteEvalTensor* cell_gate_bias,
|
||||||
|
const TfLiteEvalTensor* output_gate_bias,
|
||||||
|
const TfLiteEvalTensor* projection_weights,
|
||||||
|
const TfLiteEvalTensor* projection_weights_ledger,
|
||||||
|
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||||
|
bool forward_sequence, bool time_major, int output_offset,
|
||||||
|
float* scratch_buffer, float* input_sf, float* aux_input_sf,
|
||||||
|
float* output_state_sf, float* prod_scaling_factors,
|
||||||
|
float* recovered_cell_weights, int8_t* input_quantized,
|
||||||
|
int8_t* aux_input_quantized, int8_t* output_state_quantized,
|
||||||
|
int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
|
||||||
|
TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
|
||||||
|
TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
|
||||||
|
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
|
||||||
|
bool* compute_row_sums);
|
||||||
|
|
||||||
|
TfLiteStatus EvalInteger8x8_16Lstm(
|
||||||
|
const TfLiteEvalTensor* input,
|
||||||
|
const TfLiteEvalTensor* input_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* input_gate_bias,
|
||||||
|
const TfLiteEvalTensor* forget_gate_bias,
|
||||||
|
const TfLiteEvalTensor* cell_gate_bias,
|
||||||
|
const TfLiteEvalTensor* output_gate_bias,
|
||||||
|
const TfLiteEvalTensor* projection_weights,
|
||||||
|
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||||
|
bool forward_sequence, bool time_major,
|
||||||
|
const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
|
||||||
|
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
|
||||||
|
TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
|
||||||
|
int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5);
|
||||||
|
|
||||||
|
TfLiteStatus EvalInteger8x8_8Lstm(
|
||||||
|
const TfLiteEvalTensor* input,
|
||||||
|
const TfLiteEvalTensor* input_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* input_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_cell_weights,
|
||||||
|
const TfLiteEvalTensor* recurrent_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_input_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_forget_weights,
|
||||||
|
const TfLiteEvalTensor* cell_to_output_weights,
|
||||||
|
const TfLiteEvalTensor* input_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* forget_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* cell_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* output_layer_norm_coefficients,
|
||||||
|
const TfLiteEvalTensor* input_gate_bias,
|
||||||
|
const TfLiteEvalTensor* forget_gate_bias,
|
||||||
|
const TfLiteEvalTensor* cell_gate_bias,
|
||||||
|
const TfLiteEvalTensor* output_gate_bias,
|
||||||
|
const TfLiteEvalTensor* projection_weights,
|
||||||
|
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||||
|
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
|
||||||
|
TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
|
||||||
|
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||||
|
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7);
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user