mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-07 20:16:55 +03:00
Compare commits
62 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87a6445ff6 | ||
|
|
b7e6d33d48 | ||
|
|
52e9cd20ee | ||
|
|
b34bd5d988 | ||
|
|
58d5e7bc58 | ||
|
|
1e09bfbb80 | ||
|
|
91fa1c066c | ||
|
|
10d49b55d1 | ||
|
|
67d0bf6a27 | ||
|
|
d36cbde7aa | ||
|
|
016f4088d4 | ||
|
|
c2d1bbb4be | ||
|
|
bc6a01444a | ||
|
|
24f0902194 | ||
|
|
19fd6a10dd | ||
|
|
a45a5296e4 | ||
|
|
1459bb15c1 | ||
|
|
ce5f3c463b | ||
|
|
04ebbf35e7 | ||
|
|
ba1d6e30e2 | ||
|
|
e9ac8933f9 | ||
|
|
ec96b7f878 | ||
|
|
ba7d429178 | ||
|
|
79be2089be | ||
|
|
ea2305de47 | ||
|
|
635b2c35a8 | ||
|
|
afdc4bb3f1 | ||
|
|
3d49ec72ba | ||
|
|
520f818adc | ||
|
|
20b054472e | ||
|
|
21a70c5655 | ||
|
|
08270f5d6d | ||
|
|
9923be2f1d | ||
|
|
5df57c95d4 | ||
|
|
d8c91466d0 | ||
|
|
37b2e370fe | ||
|
|
98dfba0640 | ||
|
|
574c9084c2 | ||
|
|
9862ae8e7a | ||
|
|
7bc4e63209 | ||
|
|
ad40150cfa | ||
|
|
970530d99f | ||
|
|
c6ae989b82 | ||
|
|
a0ebf354b1 | ||
|
|
97ecbc792e | ||
|
|
5934a59489 | ||
|
|
ee18046581 | ||
|
|
1e4e38c02f | ||
|
|
7a3038eceb | ||
|
|
7d2f86b72e | ||
|
|
3aaa319505 | ||
|
|
f4075f0a51 | ||
|
|
59643a8d52 | ||
|
|
baf2a880e4 | ||
|
|
d71e8320c7 | ||
|
|
3b3d924f40 | ||
|
|
60701bc007 | ||
|
|
5ca3e184e0 | ||
|
|
2903d1a0a6 | ||
|
|
5f0f1802a4 | ||
|
|
5be56d9b00 | ||
|
|
d3fd1b5045 |
206
Changelog.md
206
Changelog.md
@@ -1,6 +1,208 @@
|
|||||||
# Versions
|
# Versions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 6.7.2 Image Processing in Memory - (2021-05-01)
|
||||||
|
|
||||||
|
* NEW 6.7.2: Updated html for setup modus - remove reboot on edit configuration)
|
||||||
|
|
||||||
|
* NEW 6.7.1: Improved stability of camera (back to v6.6.1) - remove black strips and areas
|
||||||
|
|
||||||
|
* Upgrade digital CNN to v8.3.0 (added new type of digits)
|
||||||
|
|
||||||
|
* Internal update: TFlite (v2.5), esp32cam, startup sequence
|
||||||
|
|
||||||
|
* Rollback to espressif v2.1.0, as v3.2.0 shows unstable reboot
|
||||||
|
|
||||||
|
* Bugfix: WLan-passwords, reset of hostname
|
||||||
|
|
||||||
|
|
||||||
|
##### 6.6.1 Image Processing in Memory - (2021-04-05)
|
||||||
|
|
||||||
|
* NEW 6.6.1: failed SD card initialization indicated by fast blinking LED at startup
|
||||||
|
* Improved SD-card handling (increase compatibility with more type of cards)
|
||||||
|
|
||||||
|
##### 6.5.0 Image Processing in Memory - (2021-03-25)
|
||||||
|
|
||||||
|
* Upgrade digital CNN to v8.2.0 (added new type of digits)
|
||||||
|
* Supporting alignment structures in ROI definition
|
||||||
|
* Bug fixing: definition of hostname in `config.ini`
|
||||||
|
|
||||||
|
##### 6.4.0 Image Processing in Memory - (2021-03-20)
|
||||||
|
|
||||||
|
* Additional alignment marks for settings the ROIs (analog and digit)
|
||||||
|
* Upgrade analog CNN to v7.0.0 (added new type of pointer)
|
||||||
|
|
||||||
|
##### 6.3.1 Image Processing in Memory - (2021-03-16)
|
||||||
|
|
||||||
|
* NEW: 6.3.1: bug fixing in initial edit reference image and `config.ini` (Spelling error in `InitialRotate`)
|
||||||
|
* Initial setup mode: bug fixing, error correction
|
||||||
|
* Bug-fixing
|
||||||
|
|
||||||
|
##### 6.2.2 Image Processing in Memory - (2021-03-10)
|
||||||
|
|
||||||
|
* NEW 6.2.2: bug fixing
|
||||||
|
* NEW 6.2.1: Changed brightness and contrast to default if not enabled (resolves to bright images)
|
||||||
|
* Determination of fixed illumination settings during startup - speed up of 5s in each run
|
||||||
|
* Update digital CNN to v8.1.1 (additional digital images trained)
|
||||||
|
* Extended error message in MQTT error message
|
||||||
|
|
||||||
|
|
||||||
|
* Image brightness is now adjustable
|
||||||
|
|
||||||
|
|
||||||
|
* Bug fixing: minor topics
|
||||||
|
|
||||||
|
|
||||||
|
##### 6.1.0 Image Processing in Memory - (2021-01-20)
|
||||||
|
|
||||||
|
* Disabling of analog / digital counters in configuration
|
||||||
|
* Improved Alignment Algorithm (`AlignmentAlgo` = `Default`, `Accurate` , `Fast`)
|
||||||
|
* Analog counters: `ExtendedResolution` (last digit is extended by sub comma value of CNN)
|
||||||
|
* `config.ini`: additional parameter `hostname` (additional to wlan.ini)
|
||||||
|
* Switching of GPIO12/13 via http-interface: `/GPIO?GPIO=12&Status=high/low`
|
||||||
|
* Bug fixing: html configuration page, wlan password ("=" now possible)
|
||||||
|
|
||||||
|
##### 6.0.0 Image Processing in Memory - (2021-01-02)
|
||||||
|
|
||||||
|
* **Major change**: image processing fully in memory - no need of SD card buffer anymore
|
||||||
|
|
||||||
|
* Need to limit camera resolution to VGA (due to memory limits)
|
||||||
|
* MQTT: Last Will Testament (LWT) implemented: "connection lost" in case of connection lost to `TopicError`
|
||||||
|
* Disabled `CheckDigitIncreaseConsistency` in default configuration - must now be explicit enabled if needed
|
||||||
|
* Update digital CNN to v7.2.1 (additional digital images trained)
|
||||||
|
* Setting of arbitrary time server in `config.ini`
|
||||||
|
* Option for fixed IP-, DNS-Settings in `wlan.ini`
|
||||||
|
* Increased stability (internal image and camera handling)
|
||||||
|
* Bug fixing: edit digits, handling PreValue, html-bugs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 5.0.0 Setup Modus - (2020-12-06)
|
||||||
|
|
||||||
|
* Implementation of initial setup modus for fresh installation
|
||||||
|
|
||||||
|
* Code restructuring (full compatibility between pure ESP-IDF and Platformio w/ espressif)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 4.1.1 Configuration editor - (2020-12-02)
|
||||||
|
|
||||||
|
* Bug fixing: internal improvement of file handling (reduce not responding)
|
||||||
|
|
||||||
|
|
||||||
|
##### 4.1.0 Configuration editor - (2020-11-30)
|
||||||
|
|
||||||
|
* Implementation of configuration editor (including basic and expert mode)
|
||||||
|
|
||||||
|
* Adjustable time zone to adjust to local time setting (incl. daylight saving time)
|
||||||
|
|
||||||
|
* MQTT: additional topic for error reporting
|
||||||
|
|
||||||
|
* standardized access to current logfile via `http://IP-ADRESS/logfileact`
|
||||||
|
|
||||||
|
* Update digital CNN to v7.2.0, analog CNN to 6.3.0
|
||||||
|
|
||||||
|
* Bug fixing: truncation error, CheckDigitConsistency & PreValue implementation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 4.0.0 Tflite Core - (2020-11-15)
|
||||||
|
|
||||||
|
* Implementation of rolling log-files
|
||||||
|
|
||||||
|
* Update Tflite-Core to master@20201108 (v2.4)
|
||||||
|
|
||||||
|
* Bug-fixing for reducing reboots
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 3.1.0 MQTT-Client - (2020-10-26)
|
||||||
|
|
||||||
|
* Update digital CNN to v6.5.0 and HTML (Info to hostname, IP, ssid)
|
||||||
|
|
||||||
|
* New implementation of "checkDigitConsistency" also for digits
|
||||||
|
* MQTT-Adapter: user and password for sign in MQTT-Broker
|
||||||
|
|
||||||
|
##### 3.0.0 MQTT-Client (2020-10-14)
|
||||||
|
|
||||||
|
* Implementation of MQTT Client
|
||||||
|
* Improved Version Control
|
||||||
|
* bug-fixing
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 2.2.1 Version Control (2020-09-27)
|
||||||
|
|
||||||
|
* Bug-Fixing (hostname in wlan.ini and error handling inside flow)
|
||||||
|
|
||||||
|
|
||||||
|
##### 2.2.0 Version Control (2020-09-27)
|
||||||
|
|
||||||
|
* Integrated automated versioning system (menu: SYSTEM --> INFO)
|
||||||
|
* Update Build-System to PlatformIO - Espressif 32 v2.0.0 (ESP-IDF 4.1)
|
||||||
|
|
||||||
|
|
||||||
|
##### 2.1.0 Decimal Shift, Chrome & Edge (2020-09-25)
|
||||||
|
|
||||||
|
* Implementation of Decimal Shift
|
||||||
|
|
||||||
|
* Update default CNN for digits to v6.4.0
|
||||||
|
|
||||||
|
* Improvement HTML
|
||||||
|
|
||||||
|
* Support for Chrome and Edge
|
||||||
|
|
||||||
|
* Reduce logging to minimum - extended logging on demand
|
||||||
|
|
||||||
|
* Implementation of hostname in wlan.ini (`hostname = "HOSTNAME")`
|
||||||
|
|
||||||
|
* Bug fixing, code corrections
|
||||||
|
|
||||||
|
|
||||||
|
##### 2.0.0 Layout update (2020-09-12)
|
||||||
|
|
||||||
|
* Update to **new and modern layout**
|
||||||
|
* Support for Chrome improved
|
||||||
|
* Improved robustness: improved error handling in auto flow reduces spontaneous reboots
|
||||||
|
* File server: Option for "DELETE ALL"
|
||||||
|
* WLan: support of spaces in SSID and password
|
||||||
|
* Reference Image: Option for mirror image, option for image update on the fly
|
||||||
|
* additional parameter in `wasserzaehler.html?noerror=true` to suppress an potential error message
|
||||||
|
* bug fixing
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 1.1.3 (2020-09-09)
|
||||||
|
|
||||||
|
* **Bug in configuration of analog ROIs corrected** - correction in v.1.0.2 did not work properly
|
||||||
|
* Improved update page for the web server (`/html` can be updated via a zip-file, which is provided in `/firmware/html.zip`)
|
||||||
|
* Improved Chrome support
|
||||||
|
|
||||||
|
##### 1.1.0 (2020-09-06)
|
||||||
|
|
||||||
|
* Implementation of "delete complete directory"
|
||||||
|
**Attention: beside the `firmware.bin`, also the content of `/html` needs to be updated!**
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 1.0.2 (2020-09-06)
|
||||||
|
|
||||||
|
* Bug in configuration of analog ROIs corrected
|
||||||
|
* minor bug correction
|
||||||
|
|
||||||
|
##### 1.0.1 (2020-09-05)
|
||||||
|
|
||||||
|
* preValue.ini Bug corrected
|
||||||
|
* minor bug correction
|
||||||
|
|
||||||
|
##### 1.0.0 (2020-09-04)
|
||||||
|
|
||||||
|
* **First usable version** - compatible to previous project (https://github.com/jomjol/water-meter-system-complete)
|
||||||
|
* NEW:
|
||||||
|
* no docker container for CNN calculation necessary
|
||||||
|
* web based configuration editor on board
|
||||||
|
|
||||||
##### 0.1.0 (2020-08-07)
|
##### 0.1.0 (2020-08-07)
|
||||||
|
|
||||||
* Initial Version
|
* Initial Version
|
||||||
|
|
||||||
94
FeatureRequest.md
Normal file
94
FeatureRequest.md
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
## Feature Requests
|
||||||
|
|
||||||
|
**There are a lot of ideas for further improvements, but only limited capacity on side of the developer.** Therefore I have created this page as a collection of ideas.
|
||||||
|
|
||||||
|
1. Who ever has a new idea can put it here, so it that it is not forgotten.
|
||||||
|
|
||||||
|
2. Who ever has time, capacity and passion to support, can take any of the ideas and implement them.
|
||||||
|
I will support and help where ever I can!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
____
|
||||||
|
|
||||||
|
#### #6 Check for double ROI names
|
||||||
|
|
||||||
|
Check during configuration, that ROI names are unique.
|
||||||
|
|
||||||
|
To do:
|
||||||
|
|
||||||
|
* Implementation of ROI name checking in html code before saving analog or digital ROIs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### #5 Configurable decimal separator (point or comma)
|
||||||
|
|
||||||
|
Decimal separator configurable for different systems
|
||||||
|
|
||||||
|
To do:
|
||||||
|
|
||||||
|
* Implementation of decimal point into postprocessing module
|
||||||
|
* Extension of configuration
|
||||||
|
* Adaption of the html configuration to implement shifting
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### #4 Initial Shifting and Rotation
|
||||||
|
|
||||||
|
* https://github.com/jomjol/AI-on-the-edge-device/issues/123
|
||||||
|
|
||||||
|
Implementation of a shifting additional to the initial rotation of the raw camera input
|
||||||
|
|
||||||
|
To do:
|
||||||
|
|
||||||
|
* Implementation of shifting
|
||||||
|
* Extension of configuration
|
||||||
|
* Adaption of the html configuration to implement shifting
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### #3 Allow grouping of digits to multiple reading values
|
||||||
|
|
||||||
|
* https://github.com/jomjol/AI-on-the-edge-device/issues/123
|
||||||
|
|
||||||
|
Implementation of two different independent readouts in one setup
|
||||||
|
|
||||||
|
To do:
|
||||||
|
|
||||||
|
* Extend the configuration, setting and processing flow for two independend readouts
|
||||||
|
|
||||||
|
https://github.com/jomjol/AI-on-the-edge-device/issues/123
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
____
|
||||||
|
|
||||||
|
#### #2 MQTT-controll with callback
|
||||||
|
* https://github.com/jomjol/AI-on-the-edge-device/issues/105
|
||||||
|
|
||||||
|
Extend the MQTT client to also enable callbacks for configuration setting
|
||||||
|
|
||||||
|
To do:
|
||||||
|
|
||||||
|
* implement callback for receiving information and override `config.ini` settings
|
||||||
|
|
||||||
|
* change configuration management to handle online updates (currently changes need a restart)
|
||||||
|
|
||||||
|
* think about the startup, as there the default config is loaded
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
____
|
||||||
|
|
||||||
|
#### #1 Optional GPIO for external flash/lighting
|
||||||
|
|
||||||
|
* https://github.com/jomjol/AI-on-the-edge-device/issues/133
|
||||||
|
|
||||||
|
Implementation of an an extrnal flash / lightning through GPIOs.
|
||||||
|
* available GPIOs: 12 & 13 (currently in use for html switching)
|
||||||
|
|
||||||
|
To do:
|
||||||
|
|
||||||
|
* Implementation of a software module for external light source (e.g. WS8132 LED controller, ...)
|
||||||
|
* Update of the camera module to use the external light instead of the internal flash light
|
||||||
|
* Adopt the configuration algorithm with a configurable light source
|
||||||
185
README.md
185
README.md
@@ -4,7 +4,9 @@ This is an example of Artificial Intelligence (AI) calculations on a very cheap
|
|||||||
|
|
||||||
### Details on **function**, **installation** and **configuration** can be found on the **[Wiki Page](https://github.com/jomjol/AI-on-the-edge-device/wiki)**
|
### Details on **function**, **installation** and **configuration** can be found on the **[Wiki Page](https://github.com/jomjol/AI-on-the-edge-device/wiki)**
|
||||||
|
|
||||||
A 3d-printable housing can be found here: https://www.thingiverse.com/thing:4571627
|
A 3d-printable housing can be found here: https://www.thingiverse.com/thing:4573481
|
||||||
|
|
||||||
|
respectively ESP32-Cam housing only: https://www.thingiverse.com/thing:4571627
|
||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/watermeter_all.jpg" width="200"><img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/main.jpg" width="200"><img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/size.png" width="200">
|
<img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/watermeter_all.jpg" width="200"><img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/main.jpg" width="200"><img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/size.png" width="200">
|
||||||
|
|
||||||
@@ -24,7 +26,9 @@ If you would like to support the developer with a cup of coffee you can do that
|
|||||||
<input type="image" src="https://www.paypalobjects.com/en_US/DK/i/btn/btn_donateCC_LG.gif" border="0" name="submit" title="PayPal - The safer, easier way to pay online!" alt="Donate with PayPal button" />
|
<input type="image" src="https://www.paypalobjects.com/en_US/DK/i/btn/btn_donateCC_LG.gif" border="0" name="submit" title="PayPal - The safer, easier way to pay online!" alt="Donate with PayPal button" />
|
||||||
<img alt="" border="0" src="https://www.paypal.com/en_DE/i/scr/pixel.gif" width="1" height="1" />
|
<img alt="" border="0" src="https://www.paypal.com/en_DE/i/scr/pixel.gif" width="1" height="1" />
|
||||||
</form>
|
</form>
|
||||||
|
If you have any technical topics, you can file a issue in this repository.
|
||||||
|
|
||||||
|
In other cases you can contact the developer via email: <img src="https://raw.githubusercontent.com/jomjol/AI-on-the-edge-device/master/images/mail.jpg" height="25">
|
||||||
|
|
||||||
## Change log
|
## Change log
|
||||||
|
|
||||||
@@ -41,180 +45,49 @@ If you would like to support the developer with a cup of coffee you can do that
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 6.2.2 Image Processing in Memory - (2021-03-10)
|
##### 7.0.1 MQTT-Update - (2021-05-13)
|
||||||
|
|
||||||
* NEW 6.2.2: bug fixing
|
* NEW: 7.0.1: bug fix wlan password with "="
|
||||||
* NEW 6.2.1: Changed brightness and contrast to default if not enabled (resolves to bright images)
|
|
||||||
* Determination of fixed illumination settings during startup - speed up of 5s in each run
|
|
||||||
* Update digital CNN to v8.1.1 (additional digital images trained)
|
|
||||||
* Extended error message in MQTT error message
|
|
||||||
|
|
||||||
|
* Upgrade digital CNN to v8.5.0 (added new images)
|
||||||
|
|
||||||
* Image brightness is now adjustable
|
* New MQTT topics: flow rate (units/minute), time stamp (last correct read readout)
|
||||||
|
|
||||||
|
* Update MQTT/Error topic to " " in case no error (instead of empty string)
|
||||||
|
|
||||||
* Bug fixing: minor topics
|
* Portrait or landscape image orientation in rotated image (avoid cropping)
|
||||||
|
|
||||||
|
|
||||||
##### 6.1.0 Image Processing in Memory - (2021-01-20)
|
|
||||||
|
|
||||||
* Disabling of analog / digital counters in configuration
|
|
||||||
* Improved Alignment Algorithm (`AlignmentAlgo` = `Default`, `Accurate` , `Fast`)
|
|
||||||
* Analog counters: `ExtendedResolution` (last digit is extended by sub comma value of CNN)
|
|
||||||
* `config.ini`: additional parameter `hostname` (additional to wlan.ini)
|
|
||||||
* Switching of GPIO12/13 via http-interface: `/GPIO?GPIO=12&Status=high/low`
|
|
||||||
* Bug fixing: html configuration page, wlan password ("=" now possible)
|
|
||||||
|
|
||||||
##### 6.0.0 Image Processing in Memory - (2021-01-02)
|
|
||||||
|
|
||||||
* **Major change**: image processing fully in memory - no need of SD card buffer anymore
|
|
||||||
|
|
||||||
* Need to limit camera resolution to VGA (due to memory limits)
|
|
||||||
|
|
||||||
* MQTT: Last Will Testament (LWT) implemented: "connection lost" in case of connection lost to `TopicError`
|
|
||||||
|
|
||||||
* Disabled `CheckDigitIncreaseConsistency` in default configuration - must now be explicit enabled if needed
|
|
||||||
|
|
||||||
* Update digital CNN to v7.2.1 (additional digital images trained)
|
|
||||||
|
|
||||||
* Setting of arbitrary time server in `config.ini`
|
|
||||||
|
|
||||||
* Option for fixed IP-, DNS-Settings in `wlan.ini`
|
|
||||||
|
|
||||||
* Increased stability (internal image and camera handling)
|
|
||||||
|
|
||||||
* Bug fixing: edit digits, handling PreValue, html-bugs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Additional ideas
|
||||||
|
|
||||||
|
There are some ideas and feature request, which are not followed currently - mainly due to capacity reasons on side of the developer. They are collected here: [FeatureRequest.md](FeatureRequest.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
------
|
||||||
|
|
||||||
|
## History
|
||||||
|
|
||||||
|
##### 6.7.2 Image Processing in Memory - (2021-05-01)
|
||||||
|
|
||||||
##### 5.0.0 Setup Modus - (2020-12-06)
|
##### 5.0.0 Setup Modus - (2020-12-06)
|
||||||
|
|
||||||
* Implementation of initial setup modus for fresh installation
|
|
||||||
|
|
||||||
* Code restructuring (full compatibility between pure ESP-IDF and Platformio w/ espressif)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 4.1.1 Configuration editor - (2020-12-02)
|
##### 4.1.1 Configuration editor - (2020-12-02)
|
||||||
|
|
||||||
* Bug fixing: internal improvement of file handling (reduce not responding)
|
|
||||||
|
|
||||||
|
|
||||||
##### 4.1.0 Configuration editor - (2020-11-30)
|
|
||||||
|
|
||||||
* Implementation of configuration editor (including basic and expert mode)
|
|
||||||
|
|
||||||
* Adjustable time zone to adjust to local time setting (incl. daylight saving time)
|
|
||||||
|
|
||||||
* MQTT: additional topic for error reporting
|
|
||||||
|
|
||||||
* standardized access to current logfile via `http://IP-ADRESS/logfileact`
|
|
||||||
|
|
||||||
* Update digital CNN to v7.2.0, analog CNN to 6.3.0
|
|
||||||
|
|
||||||
* Bug fixing: truncation error, CheckDigitConsistency & PreValue implementation
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 4.0.0 Tflite Core - (2020-11-15)
|
##### 4.0.0 Tflite Core - (2020-11-15)
|
||||||
* Implementation of rolling log-files
|
|
||||||
|
|
||||||
* Update Tflite-Core to master@20201108 (v2.4)
|
|
||||||
|
|
||||||
* Bug-fixing for reducing reboots
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 3.1.0 MQTT-Client - (2020-10-26)
|
##### 3.1.0 MQTT-Client - (2020-10-26)
|
||||||
|
|
||||||
* Update digital CNN to v6.5.0 and HTML (Info to hostname, IP, ssid)
|
##### 2.2.1 Version Control - (2020-09-27)
|
||||||
|
|
||||||
* New implementation of "checkDigitConsistency" also for digits
|
|
||||||
* MQTT-Adapter: user and password for sign in MQTT-Broker
|
|
||||||
|
|
||||||
##### 3.0.0 MQTT-Client (2020-10-14)
|
|
||||||
|
|
||||||
* Implementation of MQTT Client
|
|
||||||
* Improved Version Control
|
|
||||||
* bug-fixing
|
|
||||||
|
|
||||||
|
|
||||||
|
##### 2.1.0 Decimal Shift, Chrome & Edge - (2020-09-25)
|
||||||
##### 2.2.1 Version Control (2020-09-27)
|
|
||||||
|
|
||||||
* Bug-Fixing (hostname in wlan.ini and error handling inside flow)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 2.2.0 Version Control (2020-09-27)
|
##### 2.0.0 Layout update - (2020-09-12)
|
||||||
|
|
||||||
* Integrated automated versioning system (menu: SYSTEM --> INFO)
|
##### 1.1.3 Initial Version - (2020-09-09)
|
||||||
* Update Build-System to PlatformIO - Espressif 32 v2.0.0 (ESP-IDF 4.1)
|
|
||||||
|
|
||||||
|
|
||||||
##### 2.1.0 Decimal Shift, Chrome & Edge (2020-09-25)
|
|
||||||
|
|
||||||
* Implementation of Decimal Shift
|
|
||||||
|
|
||||||
* Update default CNN for digits to v6.4.0
|
|
||||||
|
|
||||||
* Improvement HTML
|
|
||||||
|
|
||||||
* Support for Chrome and Edge
|
|
||||||
|
|
||||||
* Reduce logging to minimum - extended logging on demand
|
|
||||||
|
|
||||||
* Implementation of hostname in wlan.ini (`hostname = "HOSTNAME")`
|
|
||||||
|
|
||||||
* Bug fixing, code corrections
|
|
||||||
|
|
||||||
|
|
||||||
##### 2.0.0 Layout update (2020-09-12)
|
|
||||||
|
|
||||||
* Update to **new and modern layout**
|
|
||||||
* Support for Chrome improved
|
|
||||||
* Improved robustness: improved error handling in auto flow reduces spontaneous reboots
|
|
||||||
* File server: Option for "DELETE ALL"
|
|
||||||
* WLan: support of spaces in SSID and password
|
|
||||||
* Reference Image: Option for mirror image, option for image update on the fly
|
|
||||||
* additional parameter in `wasserzaehler.html?noerror=true` to suppress an potential error message
|
|
||||||
* bug fixing
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 1.1.3 (2020-09-09)
|
|
||||||
|
|
||||||
* **Bug in configuration of analog ROIs corrected** - correction in v.1.0.2 did not work properly
|
|
||||||
* Improved update page for the web server (`/html` can be updated via a zip-file, which is provided in `/firmware/html.zip`)
|
|
||||||
* Improved Chrome support
|
|
||||||
|
|
||||||
##### 1.1.0 (2020-09-06)
|
|
||||||
|
|
||||||
* Implementation of "delete complete directory"
|
|
||||||
**Attention: beside the `firmware.bin`, also the content of `/html` needs to be updated!**
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### 1.0.2 (2020-09-06)
|
|
||||||
|
|
||||||
* Bug in configuration of analog ROIs corrected
|
|
||||||
* minor bug correction
|
|
||||||
|
|
||||||
##### 1.0.1 (2020-09-05)
|
|
||||||
|
|
||||||
* preValue.ini Bug corrected
|
|
||||||
* minor bug correction
|
|
||||||
|
|
||||||
##### 1.0.0 (2020-09-04)
|
|
||||||
|
|
||||||
* **First usable version** - compatible to previous project (https://github.com/jomjol/water-meter-system-complete)
|
|
||||||
* NEW:
|
|
||||||
* no docker container for CNN calculation necessary
|
|
||||||
* web based configuration editor on board
|
|
||||||
|
|
||||||
##### 0.1.0 (2020-08-07)
|
|
||||||
|
|
||||||
* Initial Version
|
|
||||||
|
|
||||||
|
|
||||||
#### [Full Changelog](Changelog.md)
|
#### [Full Changelog](Changelog.md)
|
||||||
@@ -223,4 +96,4 @@ If you would like to support the developer with a cup of coffee you can do that
|
|||||||
|
|
||||||
## Solved topics
|
## Solved topics
|
||||||
|
|
||||||
* n.a.
|
* n.a.
|
||||||
@@ -1,539 +0,0 @@
|
|||||||
#include "connect_wlan.h"
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
#include "freertos/FreeRTOS.h"
|
|
||||||
#include "freertos/task.h"
|
|
||||||
#include "freertos/event_groups.h"
|
|
||||||
#include "esp_wifi.h"
|
|
||||||
#include "esp_log.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <vector>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "Helper.h"
|
|
||||||
|
|
||||||
static const char *TAG = "connect_wlan";
|
|
||||||
|
|
||||||
std::string ssid = "";
|
|
||||||
std::string passphrase = "";
|
|
||||||
std::string hostname = "";
|
|
||||||
std::string ipaddress = "";
|
|
||||||
std::string gw = "";
|
|
||||||
std::string netmask = "";
|
|
||||||
std::string dns = "";
|
|
||||||
|
|
||||||
std::string std_hostname = "watermeter";
|
|
||||||
|
|
||||||
#define BLINK_GPIO GPIO_NUM_33
|
|
||||||
|
|
||||||
static EventGroupHandle_t s_wifi_event_group;
|
|
||||||
|
|
||||||
#define WIFI_CONNECTED_BIT BIT0
|
|
||||||
#define WIFI_FAIL_BIT BIT1
|
|
||||||
static int s_retry_num = 0;
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<string> ZerlegeZeile(std::string input, std::string _delimiter = "")
|
|
||||||
{
|
|
||||||
std::vector<string> Output;
|
|
||||||
std::string delimiter = " =,";
|
|
||||||
if (_delimiter.length() > 0){
|
|
||||||
delimiter = _delimiter;
|
|
||||||
}
|
|
||||||
|
|
||||||
input = trim(input, delimiter);
|
|
||||||
size_t pos = findDelimiterPos(input, delimiter);
|
|
||||||
std::string token;
|
|
||||||
while (pos != std::string::npos) {
|
|
||||||
token = input.substr(0, pos);
|
|
||||||
token = trim(token, delimiter);
|
|
||||||
Output.push_back(token);
|
|
||||||
input.erase(0, pos + 1);
|
|
||||||
input = trim(input, delimiter);
|
|
||||||
pos = findDelimiterPos(input, delimiter);
|
|
||||||
}
|
|
||||||
Output.push_back(input);
|
|
||||||
|
|
||||||
return Output;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void blinkstatus(int dauer, int _anzahl)
|
|
||||||
{
|
|
||||||
gpio_reset_pin(BLINK_GPIO);
|
|
||||||
gpio_set_direction(BLINK_GPIO, GPIO_MODE_OUTPUT);
|
|
||||||
for (int i = 0; i < _anzahl; ++i)
|
|
||||||
{
|
|
||||||
gpio_set_level(BLINK_GPIO, 0);
|
|
||||||
vTaskDelay(dauer / portTICK_PERIOD_MS);
|
|
||||||
gpio_set_level(BLINK_GPIO, 1);
|
|
||||||
vTaskDelay(dauer / portTICK_PERIOD_MS);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void strinttoip4(std::string ip, int &a, int &b, int &c, int &d) {
|
|
||||||
std::stringstream s(ip);
|
|
||||||
char ch; //to temporarily store the '.'
|
|
||||||
s >> a >> ch >> b >> ch >> c >> ch >> d;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static void event_handler_neu(void* arg, esp_event_base_t event_base,
|
|
||||||
int32_t event_id, void* event_data)
|
|
||||||
{
|
|
||||||
if (event_base == WIFI_EVENT && event_id == WIFI_EVENT_STA_START) {
|
|
||||||
blinkstatus(200, 1);
|
|
||||||
esp_wifi_connect();
|
|
||||||
} else if (event_base == WIFI_EVENT && event_id == WIFI_EVENT_STA_DISCONNECTED) {
|
|
||||||
blinkstatus(200, 5);
|
|
||||||
esp_wifi_connect();
|
|
||||||
s_retry_num++;
|
|
||||||
ESP_LOGI(TAG, "retry to connect to the AP");
|
|
||||||
} else if (event_base == IP_EVENT && event_id == IP_EVENT_STA_GOT_IP) {
|
|
||||||
blinkstatus(1000, 3);
|
|
||||||
ip_event_got_ip_t* event = (ip_event_got_ip_t*) event_data;
|
|
||||||
ESP_LOGI(TAG, "got ip:" IPSTR, IP2STR(&event->ip_info.ip));
|
|
||||||
s_retry_num = 0;
|
|
||||||
xEventGroupSetBits(s_wifi_event_group, WIFI_CONNECTED_BIT);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void initialise_wifi()
|
|
||||||
{
|
|
||||||
s_wifi_event_group = xEventGroupCreate();
|
|
||||||
ESP_ERROR_CHECK(esp_netif_init());
|
|
||||||
ESP_ERROR_CHECK(esp_event_loop_create_default());
|
|
||||||
esp_netif_create_default_wifi_sta();
|
|
||||||
|
|
||||||
wifi_init_config_t cfg = WIFI_INIT_CONFIG_DEFAULT();
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_init(&cfg));
|
|
||||||
|
|
||||||
esp_event_handler_instance_t instance_any_id;
|
|
||||||
esp_event_handler_instance_t instance_got_ip;
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_register(WIFI_EVENT,
|
|
||||||
ESP_EVENT_ANY_ID,
|
|
||||||
&event_handler_neu,
|
|
||||||
NULL,
|
|
||||||
&instance_any_id));
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_register(IP_EVENT,
|
|
||||||
IP_EVENT_STA_GOT_IP,
|
|
||||||
&event_handler_neu,
|
|
||||||
NULL,
|
|
||||||
&instance_got_ip));
|
|
||||||
|
|
||||||
|
|
||||||
wifi_config_t wifi_config = { };
|
|
||||||
strcpy((char*)wifi_config.sta.ssid, (const char*)ssid.c_str());
|
|
||||||
strcpy((char*)wifi_config.sta.password, (const char*)passphrase.c_str());
|
|
||||||
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_set_mode(WIFI_MODE_STA) );
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_set_config(ESP_IF_WIFI_STA, &wifi_config) );
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_start() );
|
|
||||||
|
|
||||||
ESP_LOGI(TAG, "wifi_init_sta finished.");
|
|
||||||
|
|
||||||
// Waiting until either the connection is established (WIFI_CONNECTED_BIT) or connection failed for the maximum
|
|
||||||
// number of re-tries (WIFI_FAIL_BIT). The bits are set by event_handler() (see above)
|
|
||||||
EventBits_t bits = xEventGroupWaitBits(s_wifi_event_group,
|
|
||||||
WIFI_CONNECTED_BIT | WIFI_FAIL_BIT,
|
|
||||||
pdFALSE,
|
|
||||||
pdFALSE,
|
|
||||||
portMAX_DELAY);
|
|
||||||
|
|
||||||
// xEventGroupWaitBits() returns the bits before the call returned, hence we can test which event actually
|
|
||||||
// happened.
|
|
||||||
if (bits & WIFI_CONNECTED_BIT) {
|
|
||||||
ESP_LOGI(TAG, "connected to ap SSID:%s password:%s",
|
|
||||||
ssid.c_str(), passphrase.c_str());
|
|
||||||
} else if (bits & WIFI_FAIL_BIT) {
|
|
||||||
ESP_LOGI(TAG, "Failed to connect to SSID:%s, password:%s",
|
|
||||||
ssid.c_str(), passphrase.c_str());
|
|
||||||
} else {
|
|
||||||
ESP_LOGE(TAG, "UNEXPECTED EVENT");
|
|
||||||
}
|
|
||||||
|
|
||||||
// The event will not be processed after unregister
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_unregister(IP_EVENT, IP_EVENT_STA_GOT_IP, instance_got_ip));
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_unregister(WIFI_EVENT, ESP_EVENT_ANY_ID, instance_any_id));
|
|
||||||
vEventGroupDelete(s_wifi_event_group);
|
|
||||||
|
|
||||||
tcpip_adapter_ip_info_t ip_info;
|
|
||||||
ESP_ERROR_CHECK(tcpip_adapter_get_ip_info(TCPIP_ADAPTER_IF_STA, &ip_info));
|
|
||||||
ipaddress = std::string(ip4addr_ntoa(&ip_info.ip));
|
|
||||||
netmask = std::string(ip4addr_ntoa(&ip_info.netmask));
|
|
||||||
gw = std::string(ip4addr_ntoa(&ip_info.gw));
|
|
||||||
printf("IPv4 : %s\n", ip4addr_ntoa(&ip_info.ip));
|
|
||||||
printf("HostName : %s\n", hostname.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void initialise_wifi_fixed_ip2()
|
|
||||||
{
|
|
||||||
s_wifi_event_group = xEventGroupCreate();
|
|
||||||
ESP_ERROR_CHECK(esp_netif_init());
|
|
||||||
ESP_ERROR_CHECK(esp_event_loop_create_default());
|
|
||||||
esp_netif_t *my_sta = esp_netif_create_default_wifi_sta();
|
|
||||||
|
|
||||||
esp_netif_dhcpc_stop(my_sta);
|
|
||||||
|
|
||||||
esp_netif_ip_info_t ip_info;
|
|
||||||
|
|
||||||
int a, b, c, d;
|
|
||||||
|
|
||||||
strinttoip4(ipaddress, a, b, c, d);
|
|
||||||
IP4_ADDR(&ip_info.ip, a, b, c, d);
|
|
||||||
|
|
||||||
strinttoip4(gw, a, b, c, d);
|
|
||||||
IP4_ADDR(&ip_info.gw, a, b, c, d);
|
|
||||||
|
|
||||||
strinttoip4(netmask, a, b, c, d);
|
|
||||||
IP4_ADDR(&ip_info.netmask, a, b, c, d);
|
|
||||||
|
|
||||||
esp_netif_set_ip_info(my_sta, &ip_info);
|
|
||||||
|
|
||||||
|
|
||||||
wifi_init_config_t cfg = WIFI_INIT_CONFIG_DEFAULT();
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_init(&cfg));
|
|
||||||
|
|
||||||
if (dns.length() > 0) {
|
|
||||||
esp_netif_dns_info_t dns_info;
|
|
||||||
ip4_addr_t ip;
|
|
||||||
ip.addr = esp_ip4addr_aton(dns.c_str());
|
|
||||||
ip_addr_set_ip4_u32(&dns_info.ip, ip.addr);
|
|
||||||
ESP_ERROR_CHECK(esp_netif_set_dns_info(my_sta, ESP_NETIF_DNS_MAIN, &dns_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
esp_event_handler_instance_t instance_any_id;
|
|
||||||
esp_event_handler_instance_t instance_got_ip;
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_register(WIFI_EVENT,
|
|
||||||
ESP_EVENT_ANY_ID,
|
|
||||||
&event_handler_neu,
|
|
||||||
NULL,
|
|
||||||
&instance_any_id));
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_register(IP_EVENT,
|
|
||||||
IP_EVENT_STA_GOT_IP,
|
|
||||||
&event_handler_neu,
|
|
||||||
NULL,
|
|
||||||
&instance_got_ip));
|
|
||||||
|
|
||||||
|
|
||||||
wifi_config_t wifi_config = { };
|
|
||||||
strcpy((char*)wifi_config.sta.ssid, (const char*)ssid.c_str());
|
|
||||||
strcpy((char*)wifi_config.sta.password, (const char*)passphrase.c_str());
|
|
||||||
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_set_mode(WIFI_MODE_STA) );
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_set_config(ESP_IF_WIFI_STA, &wifi_config) );
|
|
||||||
ESP_ERROR_CHECK(esp_wifi_start() );
|
|
||||||
|
|
||||||
ESP_LOGI(TAG, "wifi_init_sta finished.");
|
|
||||||
|
|
||||||
// Waiting until either the connection is established (WIFI_CONNECTED_BIT) or connection failed for the maximum
|
|
||||||
// number of re-tries (WIFI_FAIL_BIT). The bits are set by event_handler() (see above)
|
|
||||||
EventBits_t bits = xEventGroupWaitBits(s_wifi_event_group,
|
|
||||||
WIFI_CONNECTED_BIT | WIFI_FAIL_BIT,
|
|
||||||
pdFALSE,
|
|
||||||
pdFALSE,
|
|
||||||
portMAX_DELAY);
|
|
||||||
|
|
||||||
// xEventGroupWaitBits() returns the bits before the call returned, hence we can test which event actually
|
|
||||||
// happened.
|
|
||||||
if (bits & WIFI_CONNECTED_BIT) {
|
|
||||||
ESP_LOGI(TAG, "connected to ap SSID:%s password:%s",
|
|
||||||
ssid.c_str(), passphrase.c_str());
|
|
||||||
} else if (bits & WIFI_FAIL_BIT) {
|
|
||||||
ESP_LOGI(TAG, "Failed to connect to SSID:%s, password:%s",
|
|
||||||
ssid.c_str(), passphrase.c_str());
|
|
||||||
} else {
|
|
||||||
ESP_LOGE(TAG, "UNEXPECTED EVENT");
|
|
||||||
}
|
|
||||||
|
|
||||||
// The event will not be processed after unregister
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_unregister(IP_EVENT, IP_EVENT_STA_GOT_IP, instance_got_ip));
|
|
||||||
ESP_ERROR_CHECK(esp_event_handler_instance_unregister(WIFI_EVENT, ESP_EVENT_ANY_ID, instance_any_id));
|
|
||||||
vEventGroupDelete(s_wifi_event_group);
|
|
||||||
|
|
||||||
tcpip_adapter_ip_info_t ip_info2;
|
|
||||||
ESP_ERROR_CHECK(tcpip_adapter_get_ip_info(TCPIP_ADAPTER_IF_STA, &ip_info2));
|
|
||||||
ipaddress = std::string(ip4addr_ntoa(&ip_info2.ip));
|
|
||||||
netmask = std::string(ip4addr_ntoa(&ip_info2.netmask));
|
|
||||||
gw = std::string(ip4addr_ntoa(&ip_info2.gw));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ConnectToWLAN()
|
|
||||||
{
|
|
||||||
if (ipaddress.length() == 0 || gw.length() == 0 || netmask.length() == 0)
|
|
||||||
{
|
|
||||||
printf("Connect to WLAN with dyn. IP\n");
|
|
||||||
initialise_wifi();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
printf("Connect to WLAN with fixed IP\n");
|
|
||||||
initialise_wifi_fixed_ip2();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
bool ChangeHostName(std::string fn, std::string _newhostname)
|
|
||||||
{
|
|
||||||
if (_newhostname == hostname)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
string line = "";
|
|
||||||
std::vector<string> zerlegt;
|
|
||||||
|
|
||||||
bool found = false;
|
|
||||||
|
|
||||||
std::vector<string> neuesfile;
|
|
||||||
|
|
||||||
FILE* pFile;
|
|
||||||
fn = FormatFileName(fn);
|
|
||||||
pFile = OpenFileAndWait(fn.c_str(), "r");
|
|
||||||
|
|
||||||
printf("file loaded\n");
|
|
||||||
|
|
||||||
if (pFile == NULL)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
char zw[1024];
|
|
||||||
fgets(zw, 1024, pFile);
|
|
||||||
line = std::string(zw);
|
|
||||||
|
|
||||||
while ((line.size() > 0) || !(feof(pFile)))
|
|
||||||
{
|
|
||||||
printf("%s", line.c_str());
|
|
||||||
zerlegt = ZerlegeZeile(line, "=");
|
|
||||||
zerlegt[0] = trim(zerlegt[0], " ");
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "HOSTNAME")){
|
|
||||||
line = "hostname = \"" + _newhostname + "\"\n";
|
|
||||||
found = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
neuesfile.push_back(line);
|
|
||||||
|
|
||||||
if (fgets(zw, 1024, pFile) == NULL)
|
|
||||||
{
|
|
||||||
line = "";
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
line = std::string(zw);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!found)
|
|
||||||
{
|
|
||||||
line = "hostname = \"" + _newhostname + "\"\n";
|
|
||||||
neuesfile.push_back(line);
|
|
||||||
}
|
|
||||||
|
|
||||||
fclose(pFile);
|
|
||||||
|
|
||||||
pFile = OpenFileAndWait(fn.c_str(), "w+");
|
|
||||||
|
|
||||||
for (int i = 0; i < neuesfile.size(); ++i)
|
|
||||||
{
|
|
||||||
fputs(neuesfile[i].c_str(), pFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
fclose(pFile);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void LoadWlanFromFile(std::string fn)
|
|
||||||
{
|
|
||||||
string line = "";
|
|
||||||
std::vector<string> zerlegt;
|
|
||||||
hostname = std_hostname;
|
|
||||||
|
|
||||||
FILE* pFile;
|
|
||||||
fn = FormatFileName(fn);
|
|
||||||
|
|
||||||
pFile = OpenFileAndWait(fn.c_str(), "r");
|
|
||||||
printf("file loaded\n");
|
|
||||||
|
|
||||||
if (pFile == NULL)
|
|
||||||
return;
|
|
||||||
|
|
||||||
char zw[1024];
|
|
||||||
fgets(zw, 1024, pFile);
|
|
||||||
line = std::string(zw);
|
|
||||||
|
|
||||||
while ((line.size() > 0) || !(feof(pFile)))
|
|
||||||
{
|
|
||||||
printf("%s", line.c_str());
|
|
||||||
zerlegt = ZerlegeZeile(line, "=");
|
|
||||||
zerlegt[0] = trim(zerlegt[0], " ");
|
|
||||||
for (int i = 2; i < zerlegt.size(); ++i)
|
|
||||||
zerlegt[i] = zerlegt[i-1] + zerlegt[i];
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "HOSTNAME")){
|
|
||||||
hostname = trim(zerlegt[1]);
|
|
||||||
if ((hostname[0] == '"') && (hostname[hostname.length()-1] == '"')){
|
|
||||||
hostname = hostname.substr(1, hostname.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "SSID")){
|
|
||||||
ssid = trim(zerlegt[1]);
|
|
||||||
if ((ssid[0] == '"') && (ssid[ssid.length()-1] == '"')){
|
|
||||||
ssid = ssid.substr(1, ssid.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "PASSWORD")){
|
|
||||||
passphrase = zerlegt[1];
|
|
||||||
if ((passphrase[0] == '"') && (passphrase[passphrase.length()-1] == '"')){
|
|
||||||
passphrase = passphrase.substr(1, passphrase.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "IP")){
|
|
||||||
ipaddress = zerlegt[1];
|
|
||||||
if ((ipaddress[0] == '"') && (ipaddress[ipaddress.length()-1] == '"')){
|
|
||||||
ipaddress = ipaddress.substr(1, ipaddress.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "GATEWAY")){
|
|
||||||
gw = zerlegt[1];
|
|
||||||
if ((gw[0] == '"') && (gw[gw.length()-1] == '"')){
|
|
||||||
gw = gw.substr(1, gw.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "NETMASK")){
|
|
||||||
netmask = zerlegt[1];
|
|
||||||
if ((netmask[0] == '"') && (netmask[netmask.length()-1] == '"')){
|
|
||||||
netmask = netmask.substr(1, netmask.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "DNS")){
|
|
||||||
dns = zerlegt[1];
|
|
||||||
if ((dns[0] == '"') && (dns[dns.length()-1] == '"')){
|
|
||||||
dns = dns.substr(1, dns.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (fgets(zw, 1024, pFile) == NULL)
|
|
||||||
{
|
|
||||||
line = "";
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
line = std::string(zw);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fclose(pFile);
|
|
||||||
|
|
||||||
// Check if Hostname was empty in .ini if yes set to std_hostname
|
|
||||||
if(hostname.length() <= 0){
|
|
||||||
hostname = std_hostname;
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("\nWLan: %s, %s\n", ssid.c_str(), passphrase.c_str());
|
|
||||||
printf("Hostename: %s\n", hostname.c_str());
|
|
||||||
printf("Fixed IP: %s, Gateway %s, Netmask %s, DNS %s\n", ipaddress.c_str(), gw.c_str(), netmask.c_str(), dns.c_str());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void LoadNetConfigFromFile(std::string _fn, std::string &_ip, std::string &_gw, std::string &_netmask, std::string &_dns)
|
|
||||||
{
|
|
||||||
string line = "";
|
|
||||||
std::vector<string> zerlegt;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FILE* pFile;
|
|
||||||
_fn = FormatFileName(_fn);
|
|
||||||
pFile = OpenFileAndWait(_fn.c_str(), "r");
|
|
||||||
|
|
||||||
if (pFile == NULL)
|
|
||||||
return;
|
|
||||||
|
|
||||||
char zw[1024];
|
|
||||||
fgets(zw, 1024, pFile);
|
|
||||||
line = std::string(zw);
|
|
||||||
|
|
||||||
while ((line.size() > 0) || !(feof(pFile)))
|
|
||||||
{
|
|
||||||
printf("%s", line.c_str());
|
|
||||||
zerlegt = ZerlegeZeile(line, "=");
|
|
||||||
zerlegt[0] = trim(zerlegt[0], " ");
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "IP")){
|
|
||||||
_ip = zerlegt[1];
|
|
||||||
if ((_ip[0] == '"') && (_ip[_ip.length()-1] == '"')){
|
|
||||||
_ip = _ip.substr(1, _ip.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "GATEWAY")){
|
|
||||||
_gw = zerlegt[1];
|
|
||||||
if ((_gw[0] == '"') && (_gw[_gw.length()-1] == '"')){
|
|
||||||
_gw = _gw.substr(1, _gw.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "NETMASK")){
|
|
||||||
_netmask = zerlegt[1];
|
|
||||||
if ((_netmask[0] == '"') && (_netmask[_netmask.length()-1] == '"')){
|
|
||||||
_netmask = _netmask.substr(1, _netmask.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "DNS")){
|
|
||||||
_dns = zerlegt[1];
|
|
||||||
if ((_dns[0] == '"') && (_dns[_dns.length()-1] == '"')){
|
|
||||||
_dns = _dns.substr(1, _dns.length()-2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (fgets(zw, 1024, pFile) == NULL)
|
|
||||||
{
|
|
||||||
line = "";
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
line = std::string(zw);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fclose(pFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::string getHostname(){
|
|
||||||
return hostname;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getIPAddress(){
|
|
||||||
return ipaddress;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getSSID(){
|
|
||||||
return ssid;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getNetMask(){
|
|
||||||
return netmask;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getGW(){
|
|
||||||
return gw;
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
//#ifndef CONNECT_WLAN_H
|
|
||||||
//#define CONNECT_WLAN_H
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include "driver/gpio.h"
|
|
||||||
|
|
||||||
const int CONNECTED_BIT = BIT0;
|
|
||||||
|
|
||||||
void ConnectToWLAN();
|
|
||||||
|
|
||||||
void LoadWlanFromFile(std::string fn);
|
|
||||||
|
|
||||||
bool ChangeHostName(std::string fn, std::string _newhostname);
|
|
||||||
|
|
||||||
std::string getHostname();
|
|
||||||
std::string getIPAddress();
|
|
||||||
std::string getSSID();
|
|
||||||
std::string getNetMask();
|
|
||||||
std::string getGW();
|
|
||||||
|
|
||||||
//#endif
|
|
||||||
@@ -282,7 +282,7 @@ bool ChangeHostName(std::string fn, std::string _newhostname)
|
|||||||
|
|
||||||
if (!found)
|
if (!found)
|
||||||
{
|
{
|
||||||
line = "hostname = \"" + _newhostname + "\"\n";
|
line = "\nhostname = \"" + _newhostname + "\"\n";
|
||||||
neuesfile.push_back(line);
|
neuesfile.push_back(line);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,11 +292,14 @@ bool ChangeHostName(std::string fn, std::string _newhostname)
|
|||||||
|
|
||||||
for (int i = 0; i < neuesfile.size(); ++i)
|
for (int i = 0; i < neuesfile.size(); ++i)
|
||||||
{
|
{
|
||||||
|
printf(neuesfile[i].c_str());
|
||||||
fputs(neuesfile[i].c_str(), pFile);
|
fputs(neuesfile[i].c_str(), pFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
fclose(pFile);
|
fclose(pFile);
|
||||||
|
|
||||||
|
printf("*** Update hostname done ***\n");
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -326,7 +329,7 @@ void LoadWlanFromFile(std::string fn)
|
|||||||
zerlegt = ZerlegeZeile(line, "=");
|
zerlegt = ZerlegeZeile(line, "=");
|
||||||
zerlegt[0] = trim(zerlegt[0], " ");
|
zerlegt[0] = trim(zerlegt[0], " ");
|
||||||
for (int i = 2; i < zerlegt.size(); ++i)
|
for (int i = 2; i < zerlegt.size(); ++i)
|
||||||
zerlegt[i] = zerlegt[i-1] + zerlegt[i];
|
zerlegt[1] = zerlegt[1] + "=" + zerlegt[i];
|
||||||
|
|
||||||
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "HOSTNAME")){
|
if ((zerlegt.size() > 1) && (toUpper(zerlegt[0]) == "HOSTNAME")){
|
||||||
hostname = trim(zerlegt[1]);
|
hostname = trim(zerlegt[1]);
|
||||||
|
|||||||
@@ -62,7 +62,8 @@ bool frame2jpg_cb(camera_fb_t * fb, uint8_t quality, jpg_out_cb cb, void * arg);
|
|||||||
* @param height Height in pixels of the source image
|
* @param height Height in pixels of the source image
|
||||||
* @param format Format of the source image
|
* @param format Format of the source image
|
||||||
* @param quality JPEG quality of the resulting image
|
* @param quality JPEG quality of the resulting image
|
||||||
* @param out Pointer to be populated with the address of the resulting buffer
|
* @param out Pointer to be populated with the address of the resulting buffer.
|
||||||
|
* You MUST free the pointer once you are done with it.
|
||||||
* @param out_len Pointer to be populated with the length of the output buffer
|
* @param out_len Pointer to be populated with the length of the output buffer
|
||||||
*
|
*
|
||||||
* @return true on success
|
* @return true on success
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ bool fmt2bmp(uint8_t *src, size_t src_len, uint16_t width, uint16_t height, pixf
|
|||||||
}
|
}
|
||||||
*out = out_buf;
|
*out = out_buf;
|
||||||
*out_len = out_size;
|
*out_len = out_size;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool frame2bmp(camera_fb_t * fb, uint8_t ** out, size_t * out_len)
|
bool frame2bmp(camera_fb_t * fb, uint8_t ** out, size_t * out_len)
|
||||||
|
|||||||
@@ -1321,7 +1321,7 @@ esp_err_t camera_init(const camera_config_t* config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
vsync_intr_disable();
|
vsync_intr_disable();
|
||||||
err = gpio_install_isr_service(ESP_INTR_FLAG_LEVEL1 | ESP_INTR_FLAG_IRAM);
|
err = gpio_install_isr_service(ESP_INTR_FLAG_LOWMED | ESP_INTR_FLAG_IRAM);
|
||||||
if (err != ESP_OK) {
|
if (err != ESP_OK) {
|
||||||
if (err != ESP_ERR_INVALID_STATE) {
|
if (err != ESP_ERR_INVALID_STATE) {
|
||||||
ESP_LOGE(TAG, "gpio_install_isr_service failed (%x)", err);
|
ESP_LOGE(TAG, "gpio_install_isr_service failed (%x)", err);
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
name: "esp32-camera"
|
|
||||||
|
|
||||||
version: "1.0.0"
|
version: "1.0.0"
|
||||||
|
|
||||||
description: This package hosts ESP32 compatible driver for OV2640 image sensors. Additionally it provides a few tools, which allow converting the captured frame data to the more common BMP and JPEG formats.
|
description: This package hosts ESP32 compatible driver for OV2640 image sensors. Additionally it provides a few tools, which allow converting the captured frame data to the more common BMP and JPEG formats.
|
||||||
|
url: https://github.com/espressif/esp32-camera
|
||||||
|
|||||||
@@ -9,4 +9,5 @@ static const char *TAGPARTOTA = "server_ota";
|
|||||||
void register_server_ota_sdcard_uri(httpd_handle_t server);
|
void register_server_ota_sdcard_uri(httpd_handle_t server);
|
||||||
void CheckOTAUpdate();
|
void CheckOTAUpdate();
|
||||||
void doReboot();
|
void doReboot();
|
||||||
|
void hard_restart();
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ void ClassFlowAlignment::SetInitialParameter(void)
|
|||||||
initalrotate = 0;
|
initalrotate = 0;
|
||||||
anz_ref = 0;
|
anz_ref = 0;
|
||||||
initialmirror = false;
|
initialmirror = false;
|
||||||
|
initialflip = false;
|
||||||
SaveAllFiles = false;
|
SaveAllFiles = false;
|
||||||
namerawimage = "/sdcard/img_tmp/raw.jpg";
|
namerawimage = "/sdcard/img_tmp/raw.jpg";
|
||||||
FileStoreRefAlignment = "/sdcard/config/align.txt";
|
FileStoreRefAlignment = "/sdcard/config/align.txt";
|
||||||
@@ -72,6 +73,11 @@ bool ClassFlowAlignment::ReadParameter(FILE* pfile, string& aktparamgraph)
|
|||||||
while (this->getNextLine(pfile, &aktparamgraph) && !this->isNewParagraph(aktparamgraph))
|
while (this->getNextLine(pfile, &aktparamgraph) && !this->isNewParagraph(aktparamgraph))
|
||||||
{
|
{
|
||||||
zerlegt = ZerlegeZeile(aktparamgraph);
|
zerlegt = ZerlegeZeile(aktparamgraph);
|
||||||
|
if ((toUpper(zerlegt[0]) == "FLIPIMAGESIZE") && (zerlegt.size() > 1))
|
||||||
|
{
|
||||||
|
if (toUpper(zerlegt[1]) == "TRUE")
|
||||||
|
initialflip = true;
|
||||||
|
}
|
||||||
if ((toUpper(zerlegt[0]) == "INITIALMIRROR") && (zerlegt.size() > 1))
|
if ((toUpper(zerlegt[0]) == "INITIALMIRROR") && (zerlegt.size() > 1))
|
||||||
{
|
{
|
||||||
if (toUpper(zerlegt[1]) == "TRUE")
|
if (toUpper(zerlegt[1]) == "TRUE")
|
||||||
@@ -153,7 +159,13 @@ bool ClassFlowAlignment::doFlow(string time)
|
|||||||
delete AlignAndCutImage;
|
delete AlignAndCutImage;
|
||||||
AlignAndCutImage = new CAlignAndCutImage(ImageBasis, ImageTMP);
|
AlignAndCutImage = new CAlignAndCutImage(ImageBasis, ImageTMP);
|
||||||
|
|
||||||
CRotateImage rt(AlignAndCutImage, ImageTMP);
|
CRotateImage rt(AlignAndCutImage, ImageTMP, initialflip);
|
||||||
|
if (initialflip)
|
||||||
|
{
|
||||||
|
int _zw = ImageBasis->height;
|
||||||
|
ImageBasis->height = ImageBasis->width;
|
||||||
|
ImageBasis->width = _zw;
|
||||||
|
}
|
||||||
|
|
||||||
if (initialmirror){
|
if (initialmirror){
|
||||||
printf("do mirror\n");
|
printf("do mirror\n");
|
||||||
@@ -161,7 +173,7 @@ bool ClassFlowAlignment::doFlow(string time)
|
|||||||
if (SaveAllFiles) AlignAndCutImage->SaveToFile(FormatFileName("/sdcard/img_tmp/mirror.jpg"));
|
if (SaveAllFiles) AlignAndCutImage->SaveToFile(FormatFileName("/sdcard/img_tmp/mirror.jpg"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (initalrotate != 0)
|
if ((initalrotate != 0) || initialflip)
|
||||||
{
|
{
|
||||||
rt.Rotate(initalrotate);
|
rt.Rotate(initalrotate);
|
||||||
if (SaveAllFiles) AlignAndCutImage->SaveToFile(FormatFileName("/sdcard/img_tmp/rot.jpg"));
|
if (SaveAllFiles) AlignAndCutImage->SaveToFile(FormatFileName("/sdcard/img_tmp/rot.jpg"));
|
||||||
@@ -176,6 +188,12 @@ bool ClassFlowAlignment::doFlow(string time)
|
|||||||
|
|
||||||
if (SaveAllFiles)
|
if (SaveAllFiles)
|
||||||
{
|
{
|
||||||
|
if (initialflip)
|
||||||
|
{
|
||||||
|
int _zw = ImageTMP->width;
|
||||||
|
ImageTMP->width = ImageTMP->height;
|
||||||
|
ImageTMP->height = _zw;
|
||||||
|
}
|
||||||
DrawRef(ImageTMP);
|
DrawRef(ImageTMP);
|
||||||
ImageTMP->SaveToFile(FormatFileName("/sdcard/img_tmp/alg_roi.jpg"));
|
ImageTMP->SaveToFile(FormatFileName("/sdcard/img_tmp/alg_roi.jpg"));
|
||||||
}
|
}
|
||||||
@@ -209,7 +227,7 @@ void ClassFlowAlignment::SaveReferenceAlignmentValues()
|
|||||||
time(&rawtime);
|
time(&rawtime);
|
||||||
timeinfo = localtime(&rawtime);
|
timeinfo = localtime(&rawtime);
|
||||||
|
|
||||||
strftime(buffer, 80, "%Y-%m-%d_%H-%M-%S", timeinfo);
|
strftime(buffer, 80, "%Y-%m-%dT%H:%M:%S", timeinfo);
|
||||||
zwtime = std::string(buffer);
|
zwtime = std::string(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class ClassFlowAlignment :
|
|||||||
protected:
|
protected:
|
||||||
float initalrotate;
|
float initalrotate;
|
||||||
bool initialmirror;
|
bool initialmirror;
|
||||||
|
bool initialflip;
|
||||||
RefInfo References[2];
|
RefInfo References[2];
|
||||||
int anz_ref;
|
int anz_ref;
|
||||||
string namerawimage;
|
string namerawimage;
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
#include <sstream> // std::stringstream
|
||||||
|
|
||||||
|
|
||||||
// #define OHNETFLITE
|
// #define OHNETFLITE
|
||||||
|
|
||||||
|
|||||||
@@ -382,6 +382,9 @@ bool ClassFlowControll::ReadParameter(FILE* pfile, string& aktparamgraph)
|
|||||||
{
|
{
|
||||||
// reboot notwendig damit die neue wlan.ini auch benutzt wird !!!
|
// reboot notwendig damit die neue wlan.ini auch benutzt wird !!!
|
||||||
fclose(pfile);
|
fclose(pfile);
|
||||||
|
printf("do reboot\n");
|
||||||
|
esp_restart();
|
||||||
|
hard_restart();
|
||||||
doReboot();
|
doReboot();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ void ClassFlowMQTT::SetInitialParameter(void)
|
|||||||
uri = "";
|
uri = "";
|
||||||
topic = "";
|
topic = "";
|
||||||
topicError = "";
|
topicError = "";
|
||||||
|
topicRate = "";
|
||||||
|
topicTimeStamp = "";
|
||||||
clientname = "watermeter";
|
clientname = "watermeter";
|
||||||
OldValue = "";
|
OldValue = "";
|
||||||
flowpostprocessing = NULL;
|
flowpostprocessing = NULL;
|
||||||
@@ -94,6 +96,15 @@ bool ClassFlowMQTT::ReadParameter(FILE* pfile, string& aktparamgraph)
|
|||||||
{
|
{
|
||||||
this->topicError = zerlegt[1];
|
this->topicError = zerlegt[1];
|
||||||
}
|
}
|
||||||
|
if ((toUpper(zerlegt[0]) == "TOPICRATE") && (zerlegt.size() > 1))
|
||||||
|
{
|
||||||
|
this->topicRate = zerlegt[1];
|
||||||
|
}
|
||||||
|
if ((toUpper(zerlegt[0]) == "TOPICTIMESTAMP") && (zerlegt.size() > 1))
|
||||||
|
{
|
||||||
|
this->topicTimeStamp = zerlegt[1];
|
||||||
|
}
|
||||||
|
|
||||||
if ((toUpper(zerlegt[0]) == "CLIENTID") && (zerlegt.size() > 1))
|
if ((toUpper(zerlegt[0]) == "CLIENTID") && (zerlegt.size() > 1))
|
||||||
{
|
{
|
||||||
this->clientname = zerlegt[1];
|
this->clientname = zerlegt[1];
|
||||||
@@ -114,12 +125,16 @@ bool ClassFlowMQTT::doFlow(string zwtime)
|
|||||||
{
|
{
|
||||||
std::string result;
|
std::string result;
|
||||||
std::string resulterror = "";
|
std::string resulterror = "";
|
||||||
|
std::string resultrate = "";
|
||||||
|
std::string resulttimestamp = "";
|
||||||
string zw = "";
|
string zw = "";
|
||||||
|
|
||||||
if (flowpostprocessing)
|
if (flowpostprocessing)
|
||||||
{
|
{
|
||||||
result = flowpostprocessing->getReadoutParam(false, true);
|
result = flowpostprocessing->getReadoutParam(false, true);
|
||||||
resulterror = flowpostprocessing->getReadoutError();
|
resulterror = flowpostprocessing->getReadoutError();
|
||||||
|
resultrate = flowpostprocessing->getReadoutRate();
|
||||||
|
resulttimestamp = flowpostprocessing->getReadoutTimeStamp();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -139,9 +154,21 @@ bool ClassFlowMQTT::doFlow(string zwtime)
|
|||||||
MQTTPublish(topic, result);
|
MQTTPublish(topic, result);
|
||||||
|
|
||||||
if (topicError.length() > 0) {
|
if (topicError.length() > 0) {
|
||||||
|
if (resulterror.length() == 0)
|
||||||
|
{
|
||||||
|
resulterror = " ";
|
||||||
|
}
|
||||||
MQTTPublish(topicError, resulterror);
|
MQTTPublish(topicError, resulterror);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (topicRate.length() > 0) {
|
||||||
|
MQTTPublish(topicRate, resultrate);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (topicRate.length() > 0) {
|
||||||
|
MQTTPublish(topicTimeStamp, resulttimestamp);
|
||||||
|
}
|
||||||
|
|
||||||
OldValue = result;
|
OldValue = result;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class ClassFlowMQTT :
|
|||||||
public ClassFlow
|
public ClassFlow
|
||||||
{
|
{
|
||||||
protected:
|
protected:
|
||||||
std::string uri, topic, topicError, clientname;
|
std::string uri, topic, topicError, clientname, topicRate, topicTimeStamp;
|
||||||
std::string OldValue;
|
std::string OldValue;
|
||||||
ClassFlowPostProcessing* flowpostprocessing;
|
ClassFlowPostProcessing* flowpostprocessing;
|
||||||
std::string user, password;
|
std::string user, password;
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ esp_err_t ClassFlowMakeImage::camera_capture(){
|
|||||||
|
|
||||||
void ClassFlowMakeImage::takePictureWithFlash(int flashdauer)
|
void ClassFlowMakeImage::takePictureWithFlash(int flashdauer)
|
||||||
{
|
{
|
||||||
|
// für den Fall, dass das Bild geflippt wird, muss es hier zurück gesetzt werden ////
|
||||||
|
rawImage->width = image_width;
|
||||||
|
rawImage->height = image_height;
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////
|
||||||
Camera.CaptureToBasisImage(rawImage, flashdauer);
|
Camera.CaptureToBasisImage(rawImage, flashdauer);
|
||||||
if (SaveAllFiles) rawImage->SaveToFile(namerawimage);
|
if (SaveAllFiles) rawImage->SaveToFile(namerawimage);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ bool ClassFlowPostProcessing::LoadPreValue(void)
|
|||||||
int yy, month, dd, hh, mm, ss;
|
int yy, month, dd, hh, mm, ss;
|
||||||
struct tm whenStart;
|
struct tm whenStart;
|
||||||
|
|
||||||
sscanf(zwtime.c_str(), "%d-%d-%d_%d-%d-%d", &yy, &month, &dd, &hh, &mm, &ss);
|
sscanf(zwtime.c_str(), "%d-%d-%dT%d:%d:%d", &yy, &month, &dd, &hh, &mm, &ss);
|
||||||
whenStart.tm_year = yy - 1900;
|
whenStart.tm_year = yy - 1900;
|
||||||
whenStart.tm_mon = month - 1;
|
whenStart.tm_mon = month - 1;
|
||||||
whenStart.tm_mday = dd;
|
whenStart.tm_mday = dd;
|
||||||
@@ -74,10 +74,9 @@ bool ClassFlowPostProcessing::LoadPreValue(void)
|
|||||||
|
|
||||||
tStart = mktime(&whenStart);
|
tStart = mktime(&whenStart);
|
||||||
|
|
||||||
time_t now;
|
time(&lastvalue);
|
||||||
time(&now);
|
localtime(&lastvalue);
|
||||||
localtime(&now);
|
double difference = difftime(lastvalue, tStart);
|
||||||
double difference = difftime(now, tStart);
|
|
||||||
difference /= 60;
|
difference /= 60;
|
||||||
if (difference > PreValueAgeStartup)
|
if (difference > PreValueAgeStartup)
|
||||||
return false;
|
return false;
|
||||||
@@ -122,13 +121,17 @@ void ClassFlowPostProcessing::SavePreValue(float value, string zwtime)
|
|||||||
time(&rawtime);
|
time(&rawtime);
|
||||||
timeinfo = localtime(&rawtime);
|
timeinfo = localtime(&rawtime);
|
||||||
|
|
||||||
strftime(buffer, 80, "%Y-%m-%d_%H-%M-%S", timeinfo);
|
strftime(buffer, 80, "%Y-%m-%dT%H:%M:%S", timeinfo);
|
||||||
zwtime = std::string(buffer);
|
timeStamp = std::string(buffer);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
timeStamp = zwtime;
|
||||||
}
|
}
|
||||||
|
|
||||||
PreValue = value;
|
PreValue = value;
|
||||||
|
|
||||||
fputs(zwtime.c_str(), pFile);
|
fputs(timeStamp.c_str(), pFile);
|
||||||
fputs("\n", pFile);
|
fputs("\n", pFile);
|
||||||
fputs(to_string(value).c_str(), pFile);
|
fputs(to_string(value).c_str(), pFile);
|
||||||
fputs("\n", pFile);
|
fputs("\n", pFile);
|
||||||
@@ -139,6 +142,7 @@ void ClassFlowPostProcessing::SavePreValue(float value, string zwtime)
|
|||||||
|
|
||||||
ClassFlowPostProcessing::ClassFlowPostProcessing(std::vector<ClassFlow*>* lfc)
|
ClassFlowPostProcessing::ClassFlowPostProcessing(std::vector<ClassFlow*>* lfc)
|
||||||
{
|
{
|
||||||
|
FlowRateAct = 0;
|
||||||
PreValueUse = false;
|
PreValueUse = false;
|
||||||
PreValueAgeStartup = 30;
|
PreValueAgeStartup = 30;
|
||||||
AllowNegativeRates = false;
|
AllowNegativeRates = false;
|
||||||
@@ -150,6 +154,7 @@ ClassFlowPostProcessing::ClassFlowPostProcessing(std::vector<ClassFlow*>* lfc)
|
|||||||
checkDigitIncreaseConsistency = false;
|
checkDigitIncreaseConsistency = false;
|
||||||
DecimalShift = 0;
|
DecimalShift = 0;
|
||||||
ErrorMessageText = "";
|
ErrorMessageText = "";
|
||||||
|
timeStamp = "";
|
||||||
FilePreValue = FormatFileName("/sdcard/config/prevalue.ini");
|
FilePreValue = FormatFileName("/sdcard/config/prevalue.ini");
|
||||||
ListFlowControll = lfc;
|
ListFlowControll = lfc;
|
||||||
}
|
}
|
||||||
@@ -300,7 +305,7 @@ bool ClassFlowPostProcessing::doFlow(string zwtime)
|
|||||||
timeinfo = localtime(&imagetime);
|
timeinfo = localtime(&imagetime);
|
||||||
|
|
||||||
char strftime_buf[64];
|
char strftime_buf[64];
|
||||||
strftime(strftime_buf, sizeof(strftime_buf), "%Y-%m-%d_%H-%M-%S", timeinfo);
|
strftime(strftime_buf, sizeof(strftime_buf), "%Y-%m-%dT%H:%M:%S", timeinfo);
|
||||||
zwtime = std::string(strftime_buf);
|
zwtime = std::string(strftime_buf);
|
||||||
|
|
||||||
|
|
||||||
@@ -343,12 +348,15 @@ bool ClassFlowPostProcessing::doFlow(string zwtime)
|
|||||||
|
|
||||||
PreValueOkay = true;
|
PreValueOkay = true;
|
||||||
PreValue = Value;
|
PreValue = Value;
|
||||||
|
time(&lastvalue);
|
||||||
|
localtime(&lastvalue);
|
||||||
|
|
||||||
SavePreValue(Value, zwtime);
|
SavePreValue(Value, zwtime);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
zw = ErsetzteN(ReturnRawValue);
|
zw = ErsetzteN(ReturnRawValue);
|
||||||
|
|
||||||
Value = std::stof(zw);
|
Value = std::stof(zw);
|
||||||
@@ -373,6 +381,7 @@ bool ClassFlowPostProcessing::doFlow(string zwtime)
|
|||||||
zwvalue = RundeOutput(Value, AnzahlAnalog - DecimalShift);
|
zwvalue = RundeOutput(Value, AnzahlAnalog - DecimalShift);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ReturnValueNoError = zwvalue;
|
ReturnValueNoError = zwvalue;
|
||||||
ReturnValue = zwvalue;
|
ReturnValue = zwvalue;
|
||||||
if (ErrorMessage && (ErrorMessageText.length() > 0))
|
if (ErrorMessage && (ErrorMessageText.length() > 0))
|
||||||
@@ -380,10 +389,15 @@ bool ClassFlowPostProcessing::doFlow(string zwtime)
|
|||||||
|
|
||||||
if (ErrorMessageText.length() == 0)
|
if (ErrorMessageText.length() == 0)
|
||||||
{
|
{
|
||||||
|
time_t currenttime;
|
||||||
|
time(¤ttime);
|
||||||
|
localtime(¤ttime);
|
||||||
|
double difference = difftime(currenttime, lastvalue); // in Sekunden
|
||||||
|
difference /= 60; // in Minuten
|
||||||
|
FlowRateAct = (Value - PreValue) / difference;
|
||||||
|
|
||||||
PreValue = Value;
|
PreValue = Value;
|
||||||
|
|
||||||
SavePreValue(Value, zwtime);
|
SavePreValue(Value, zwtime);
|
||||||
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -506,6 +520,16 @@ float ClassFlowPostProcessing::checkDigitConsistency(float input, int _decilamsh
|
|||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string ClassFlowPostProcessing::getReadoutRate()
|
||||||
|
{
|
||||||
|
return std::to_string(FlowRateAct);
|
||||||
|
}
|
||||||
|
|
||||||
|
string ClassFlowPostProcessing::getReadoutTimeStamp()
|
||||||
|
{
|
||||||
|
return timeStamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
string ClassFlowPostProcessing::getReadoutError()
|
string ClassFlowPostProcessing::getReadoutError()
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ protected:
|
|||||||
bool PreValueOkay;
|
bool PreValueOkay;
|
||||||
bool checkDigitIncreaseConsistency;
|
bool checkDigitIncreaseConsistency;
|
||||||
int DecimalShift;
|
int DecimalShift;
|
||||||
|
time_t lastvalue;
|
||||||
|
float FlowRateAct; // m3 / min
|
||||||
|
|
||||||
|
|
||||||
string FilePreValue;
|
string FilePreValue;
|
||||||
float PreValue; // letzter Wert, der gut ausgelesen wurde
|
float PreValue; // letzter Wert, der gut ausgelesen wurde
|
||||||
@@ -25,6 +28,7 @@ protected:
|
|||||||
string ReturnValue; // korrigierter Rückgabewert, ggf. mit Fehlermeldung
|
string ReturnValue; // korrigierter Rückgabewert, ggf. mit Fehlermeldung
|
||||||
string ReturnValueNoError; // korrigierter Rückgabewert ohne Fehlermeldung
|
string ReturnValueNoError; // korrigierter Rückgabewert ohne Fehlermeldung
|
||||||
string ErrorMessageText; // Fehlermeldung bei Consistency Check
|
string ErrorMessageText; // Fehlermeldung bei Consistency Check
|
||||||
|
string timeStamp;
|
||||||
|
|
||||||
bool LoadPreValue(void);
|
bool LoadPreValue(void);
|
||||||
string ShiftDecimal(string in, int _decShift);
|
string ShiftDecimal(string in, int _decShift);
|
||||||
@@ -40,6 +44,8 @@ public:
|
|||||||
string getReadout();
|
string getReadout();
|
||||||
string getReadoutParam(bool _rawValue, bool _noerror);
|
string getReadoutParam(bool _rawValue, bool _noerror);
|
||||||
string getReadoutError();
|
string getReadoutError();
|
||||||
|
string getReadoutRate();
|
||||||
|
string getReadoutTimeStamp();
|
||||||
void SavePreValue(float value, string time = "");
|
void SavePreValue(float value, string time = "");
|
||||||
string GetPreValue();
|
string GetPreValue();
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "CRotateImage.h"
|
#include "CRotateImage.h"
|
||||||
|
|
||||||
|
|
||||||
CRotateImage::CRotateImage(CImageBasis *_org, CImageBasis *_temp)
|
CRotateImage::CRotateImage(CImageBasis *_org, CImageBasis *_temp, bool _flip)
|
||||||
{
|
{
|
||||||
rgb_image = _org->rgb_image;
|
rgb_image = _org->rgb_image;
|
||||||
channels = _org->channels;
|
channels = _org->channels;
|
||||||
@@ -9,8 +9,10 @@ CRotateImage::CRotateImage(CImageBasis *_org, CImageBasis *_temp)
|
|||||||
height = _org->height;
|
height = _org->height;
|
||||||
bpp = _org->bpp;
|
bpp = _org->bpp;
|
||||||
externalImage = true;
|
externalImage = true;
|
||||||
ImageTMP = _temp;
|
ImageTMP = _temp;
|
||||||
|
ImageOrg = _org;
|
||||||
islocked = false;
|
islocked = false;
|
||||||
|
doflip = _flip;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CRotateImage::Mirror(){
|
void CRotateImage::Mirror(){
|
||||||
@@ -58,12 +60,33 @@ void CRotateImage::Mirror(){
|
|||||||
|
|
||||||
void CRotateImage::Rotate(float _angle, int _centerx, int _centery)
|
void CRotateImage::Rotate(float _angle, int _centerx, int _centery)
|
||||||
{
|
{
|
||||||
|
int org_width, org_height;
|
||||||
float m[2][3];
|
float m[2][3];
|
||||||
|
|
||||||
float x_center = _centerx;
|
float x_center = _centerx;
|
||||||
float y_center = _centery;
|
float y_center = _centery;
|
||||||
_angle = _angle / 180 * M_PI;
|
_angle = _angle / 180 * M_PI;
|
||||||
|
|
||||||
|
if (doflip)
|
||||||
|
{
|
||||||
|
org_width = width;
|
||||||
|
org_height = height;
|
||||||
|
height = org_width;
|
||||||
|
width = org_height;
|
||||||
|
x_center = x_center - (org_width/2) + (org_height/2);
|
||||||
|
y_center = y_center + (org_width/2) - (org_height/2);
|
||||||
|
if (ImageOrg)
|
||||||
|
{
|
||||||
|
ImageOrg->height = height;
|
||||||
|
ImageOrg->width = width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
org_width = width;
|
||||||
|
org_height = height;
|
||||||
|
}
|
||||||
|
|
||||||
m[0][0] = cos(_angle);
|
m[0][0] = cos(_angle);
|
||||||
m[0][1] = sin(_angle);
|
m[0][1] = sin(_angle);
|
||||||
m[0][2] = (1 - m[0][0]) * x_center - m[0][1] * y_center;
|
m[0][2] = (1 - m[0][0]) * x_center - m[0][1] * y_center;
|
||||||
@@ -72,6 +95,12 @@ void CRotateImage::Rotate(float _angle, int _centerx, int _centery)
|
|||||||
m[1][1] = m[0][0];
|
m[1][1] = m[0][0];
|
||||||
m[1][2] = m[0][1] * x_center + (1 - m[0][0]) * y_center;
|
m[1][2] = m[0][1] * x_center + (1 - m[0][0]) * y_center;
|
||||||
|
|
||||||
|
if (doflip)
|
||||||
|
{
|
||||||
|
m[0][2] = m[0][2] + (org_width/2) - (org_height/2);
|
||||||
|
m[1][2] = m[1][2] - (org_width/2) + (org_height/2);
|
||||||
|
}
|
||||||
|
|
||||||
int memsize = width * height * channels;
|
int memsize = width * height * channels;
|
||||||
uint8_t* odata;
|
uint8_t* odata;
|
||||||
if (ImageTMP)
|
if (ImageTMP)
|
||||||
@@ -101,9 +130,9 @@ void CRotateImage::Rotate(float _angle, int _centerx, int _centery)
|
|||||||
x_source += int(m[0][2]);
|
x_source += int(m[0][2]);
|
||||||
y_source += int(m[1][2]);
|
y_source += int(m[1][2]);
|
||||||
|
|
||||||
if ((x_source >= 0) && (x_source < width) && (y_source >= 0) && (y_source < height))
|
if ((x_source >= 0) && (x_source < org_width) && (y_source >= 0) && (y_source < org_height))
|
||||||
{
|
{
|
||||||
p_source = rgb_image + (channels * (y_source * width + x_source));
|
p_source = rgb_image + (channels * (y_source * org_width + x_source));
|
||||||
for (int _channels = 0; _channels < channels; ++_channels)
|
for (int _channels = 0; _channels < channels; ++_channels)
|
||||||
p_target[_channels] = p_source[_channels];
|
p_target[_channels] = p_source[_channels];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,10 +4,11 @@
|
|||||||
class CRotateImage: public CImageBasis
|
class CRotateImage: public CImageBasis
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CImageBasis *ImageTMP;
|
CImageBasis *ImageTMP, *ImageOrg;
|
||||||
CRotateImage(std::string _image) : CImageBasis(_image) {ImageTMP = NULL;};
|
bool doflip;
|
||||||
CRotateImage(uint8_t* _rgb_image, int _channels, int _width, int _height, int _bpp) : CImageBasis(_rgb_image, _channels, _width, _height, _bpp) {ImageTMP = NULL;};
|
CRotateImage(std::string _image, bool _flip = false) : CImageBasis(_image) {ImageTMP = NULL; ImageOrg = NULL; doflip = _flip;};
|
||||||
CRotateImage(CImageBasis *_org, CImageBasis *_temp);
|
CRotateImage(uint8_t* _rgb_image, int _channels, int _width, int _height, int _bpp, bool _flip = false) : CImageBasis(_rgb_image, _channels, _width, _height, _bpp) {ImageTMP = NULL; ImageOrg = NULL; doflip = _flip;};
|
||||||
|
CRotateImage(CImageBasis *_org, CImageBasis *_temp, bool _flip = false);
|
||||||
|
|
||||||
void Rotate(float _angle);
|
void Rotate(float _angle);
|
||||||
void Rotate(float _angle, int _centerx, int _centery);
|
void Rotate(float _angle, int _centerx, int _centery);
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ void ClassLogFile::WriteToDedicatedFile(std::string _fn, std::string info, bool
|
|||||||
time(&rawtime);
|
time(&rawtime);
|
||||||
timeinfo = localtime(&rawtime);
|
timeinfo = localtime(&rawtime);
|
||||||
|
|
||||||
strftime(buffer, 80, "%Y-%m-%d_%H-%M-%S", timeinfo);
|
strftime(buffer, 80, "%Y-%m-%dT%H:%M:%S", timeinfo);
|
||||||
|
|
||||||
zwtime = std::string(buffer);
|
zwtime = std::string(buffer);
|
||||||
info = zwtime + ": " + info;
|
info = zwtime + ": " + info;
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
#include "tensorflow/lite/micro/micro_interpreter.h"
|
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/version.h"
|
//#include "tensorflow/lite/version.h"
|
||||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||||
#include "esp_err.h"
|
#include "esp_err.h"
|
||||||
#include "esp_log.h"
|
#include "esp_log.h"
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ if(NOT DEFINED ENV{IDF_PATH})
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
idf_component_register(
|
idf_component_register(
|
||||||
SRCS tensorflow/lite/micro/micro_error_reporter.cc tensorflow/lite/micro/simple_memory_allocator.cc tensorflow/lite/micro/memory_helpers.cc tensorflow/lite/micro/test_helpers.cc tensorflow/lite/micro/recording_micro_allocator.cc tensorflow/lite/micro/micro_time.cc tensorflow/lite/micro/recording_simple_memory_allocator.cc tensorflow/lite/micro/micro_string.cc tensorflow/lite/micro/micro_profiler.cc tensorflow/lite/micro/debug_log.cc tensorflow/lite/micro/all_ops_resolver.cc tensorflow/lite/micro/micro_utils.cc tensorflow/lite/micro/micro_interpreter.cc tensorflow/lite/micro/micro_allocator.cc tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc tensorflow/lite/micro/memory_planner/linear_memory_planner.cc tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc tensorflow/lite/micro/testing/test_conv_model.cc tensorflow/lite/c/common.c tensorflow/lite/core/api/error_reporter.cc tensorflow/lite/core/api/flatbuffer_conversions.cc tensorflow/lite/core/api/op_resolver.cc tensorflow/lite/core/api/tensor_utils.cc tensorflow/lite/kernels/internal/quantization_util.cc tensorflow/lite/kernels/kernel_util.cc tensorflow/lite/schema/schema_utils.cc tensorflow/lite/micro/kernels/prelu.cc tensorflow/lite/micro/kernels/dequantize.cc tensorflow/lite/micro/kernels/pad.cc tensorflow/lite/micro/kernels/shape.cc tensorflow/lite/micro/kernels/l2norm.cc tensorflow/lite/micro/kernels/tanh.cc tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc tensorflow/lite/micro/kernels/logical.cc tensorflow/lite/micro/kernels/kernel_util.cc tensorflow/lite/micro/kernels/ceil.cc tensorflow/lite/micro/kernels/arg_min_max.cc tensorflow/lite/micro/kernels/softmax.cc tensorflow/lite/micro/kernels/sub.cc tensorflow/lite/micro/kernels/add.cc tensorflow/lite/micro/kernels/floor.cc tensorflow/lite/micro/kernels/kernel_runner.cc tensorflow/lite/micro/kernels/split_v.cc tensorflow/lite/micro/kernels/hard_swish.cc tensorflow/lite/micro/kernels/pooling.cc tensorflow/lite/micro/kernels/concatenation.cc tensorflow/lite/micro/kernels/mul.cc tensorflow/lite/micro/kernels/unpack.cc tensorflow/lite/micro/kernels/round.cc tensorflow/lite/micro/kernels/quantize.cc tensorflow/lite/micro/kernels/ethosu.cc tensorflow/lite/micro/kernels/svdf.cc tensorflow/lite/micro/kernels/maximum_minimum.cc tensorflow/lite/micro/kernels/reshape.cc tensorflow/lite/micro/kernels/reduce.cc tensorflow/lite/micro/kernels/strided_slice.cc tensorflow/lite/micro/kernels/neg.cc tensorflow/lite/micro/kernels/pack.cc tensorflow/lite/micro/kernels/elementwise.cc tensorflow/lite/micro/kernels/comparisons.cc tensorflow/lite/micro/kernels/fully_connected.cc tensorflow/lite/micro/kernels/depthwise_conv.cc tensorflow/lite/micro/kernels/split.cc tensorflow/lite/micro/kernels/logistic.cc tensorflow/lite/micro/kernels/circular_buffer.cc tensorflow/lite/micro/kernels/conv.cc tensorflow/lite/micro/kernels/activations.cc
|
SRCS tensorflow/lite/micro/simple_memory_allocator.cc tensorflow/lite/micro/micro_error_reporter.cc tensorflow/lite/micro/memory_helpers.cc tensorflow/lite/micro/test_helpers.cc tensorflow/lite/micro/recording_micro_allocator.cc tensorflow/lite/micro/micro_time.cc tensorflow/lite/micro/recording_simple_memory_allocator.cc tensorflow/lite/micro/micro_string.cc tensorflow/lite/micro/micro_profiler.cc tensorflow/lite/micro/debug_log.cc tensorflow/lite/micro/all_ops_resolver.cc tensorflow/lite/micro/micro_utils.cc tensorflow/lite/micro/micro_interpreter.cc tensorflow/lite/micro/micro_allocator.cc tensorflow/lite/micro/system_setup.cc tensorflow/lite/micro/memory_planner/linear_memory_planner.cc tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc tensorflow/lite/c/common.c tensorflow/lite/core/api/error_reporter.cc tensorflow/lite/core/api/flatbuffer_conversions.cc tensorflow/lite/core/api/op_resolver.cc tensorflow/lite/core/api/tensor_utils.cc tensorflow/lite/kernels/internal/quantization_util.cc tensorflow/lite/kernels/kernel_util.cc tensorflow/lite/schema/schema_utils.cc tensorflow/lite/micro/kernels/activations.cc tensorflow/lite/micro/kernels/add.cc tensorflow/lite/micro/kernels/add_n.cc tensorflow/lite/micro/kernels/arg_min_max.cc tensorflow/lite/micro/kernels/batch_to_space_nd.cc tensorflow/lite/micro/kernels/cast.cc tensorflow/lite/micro/kernels/ceil.cc tensorflow/lite/micro/kernels/circular_buffer.cc tensorflow/lite/micro/kernels/comparisons.cc tensorflow/lite/micro/kernels/concatenation.cc tensorflow/lite/micro/kernels/conv.cc tensorflow/lite/micro/kernels/conv_common.cc tensorflow/lite/micro/kernels/depthwise_conv.cc tensorflow/lite/micro/kernels/depthwise_conv_common.cc tensorflow/lite/micro/kernels/dequantize.cc tensorflow/lite/micro/kernels/detection_postprocess.cc tensorflow/lite/micro/kernels/div.cc tensorflow/lite/micro/kernels/elementwise.cc tensorflow/lite/micro/kernels/elu.cc tensorflow/lite/micro/kernels/ethosu.cc tensorflow/lite/micro/kernels/exp.cc tensorflow/lite/micro/kernels/expand_dims.cc tensorflow/lite/micro/kernels/fill.cc tensorflow/lite/micro/kernels/floor.cc tensorflow/lite/micro/kernels/fully_connected.cc tensorflow/lite/micro/kernels/fully_connected_common.cc tensorflow/lite/micro/kernels/hard_swish.cc tensorflow/lite/micro/kernels/kernel_runner.cc tensorflow/lite/micro/kernels/kernel_util.cc tensorflow/lite/micro/kernels/l2norm.cc tensorflow/lite/micro/kernels/l2_pool_2d.cc tensorflow/lite/micro/kernels/leaky_relu.cc tensorflow/lite/micro/kernels/logical.cc tensorflow/lite/micro/kernels/logistic.cc tensorflow/lite/micro/kernels/maximum_minimum.cc tensorflow/lite/micro/kernels/mul.cc tensorflow/lite/micro/kernels/neg.cc tensorflow/lite/micro/kernels/pack.cc tensorflow/lite/micro/kernels/pad.cc tensorflow/lite/micro/kernels/pooling.cc tensorflow/lite/micro/kernels/prelu.cc tensorflow/lite/micro/kernels/quantize.cc tensorflow/lite/micro/kernels/quantize_common.cc tensorflow/lite/micro/kernels/reduce.cc tensorflow/lite/micro/kernels/reshape.cc tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc tensorflow/lite/micro/kernels/round.cc tensorflow/lite/micro/kernels/shape.cc tensorflow/lite/micro/kernels/softmax.cc tensorflow/lite/micro/kernels/softmax_common.cc tensorflow/lite/micro/kernels/space_to_batch_nd.cc tensorflow/lite/micro/kernels/split.cc tensorflow/lite/micro/kernels/split_v.cc tensorflow/lite/micro/kernels/squeeze.cc tensorflow/lite/micro/kernels/strided_slice.cc tensorflow/lite/micro/kernels/sub.cc tensorflow/lite/micro/kernels/svdf.cc tensorflow/lite/micro/kernels/svdf_common.cc tensorflow/lite/micro/kernels/tanh.cc tensorflow/lite/micro/kernels/transpose_conv.cc tensorflow/lite/micro/kernels/unpack.cc tensorflow/lite/micro/kernels/zeros_like.cc
|
||||||
INCLUDE_DIRS . third_party/gemmlowp third_party/flatbuffers/include third_party/ruy)
|
INCLUDE_DIRS . third_party/gemmlowp third_party/flatbuffers/include third_party/ruy)
|
||||||
|
|
||||||
# Reduce the level of paranoia to be able to compile TF sources
|
# Reduce the level of paranoia to be able to compile TF sources
|
||||||
@@ -32,7 +32,7 @@ target_compile_options(${COMPONENT_LIB} PRIVATE
|
|||||||
-Wno-missing-field-initializers
|
-Wno-missing-field-initializers
|
||||||
-Wno-type-limits)
|
-Wno-type-limits)
|
||||||
|
|
||||||
target_compile_options(${COMPONENT_LIB} PRIVATE -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter)
|
target_compile_options(${COMPONENT_LIB} PRIVATE -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DESP)
|
||||||
target_compile_options(${COMPONENT_LIB} PRIVATE $<$<COMPILE_LANGUAGE:CXX>: -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter >)
|
target_compile_options(${COMPONENT_LIB} PRIVATE $<$<COMPILE_LANGUAGE:CXX>: -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DESP >)
|
||||||
target_compile_options(${COMPONENT_LIB} INTERFACE $<$<IN_LIST:-DTF_LITE_STATIC_MEMORY,$<TARGET_PROPERTY:${COMPONENT_LIB},COMPILE_OPTIONS>>:-DTF_LITE_STATIC_MEMORY>)
|
target_compile_options(${COMPONENT_LIB} INTERFACE $<$<IN_LIST:-DTF_LITE_STATIC_MEMORY,$<TARGET_PROPERTY:${COMPONENT_LIB},COMPILE_OPTIONS>>:-DTF_LITE_STATIC_MEMORY>)
|
||||||
target_link_libraries(${COMPONENT_LIB} PRIVATE -lm)
|
target_link_libraries(${COMPONENT_LIB} PRIVATE -lm)
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_PUBLIC_VERSION_H_
|
|
||||||
#define TENSORFLOW_CORE_PUBLIC_VERSION_H_
|
|
||||||
|
|
||||||
// TensorFlow uses semantic versioning, see http://semver.org/.
|
|
||||||
|
|
||||||
// Also update tensorflow/tensorflow.bzl and
|
|
||||||
// tensorflow/tools/pip_package/setup.py
|
|
||||||
#define TF_MAJOR_VERSION 2
|
|
||||||
#define TF_MINOR_VERSION 5
|
|
||||||
#define TF_PATCH_VERSION 0
|
|
||||||
|
|
||||||
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
|
||||||
// "-beta", "-rc", "-rc.1")
|
|
||||||
#define TF_VERSION_SUFFIX ""
|
|
||||||
|
|
||||||
#define TF_STR_HELPER(x) #x
|
|
||||||
#define TF_STR(x) TF_STR_HELPER(x)
|
|
||||||
|
|
||||||
// e.g. "0.5.0" or "0.6.0-alpha".
|
|
||||||
#define TF_VERSION_STRING \
|
|
||||||
(TF_STR(TF_MAJOR_VERSION) "." TF_STR(TF_MINOR_VERSION) "." TF_STR( \
|
|
||||||
TF_PATCH_VERSION) TF_VERSION_SUFFIX)
|
|
||||||
|
|
||||||
// GraphDef compatibility versions (the versions field in graph.proto).
|
|
||||||
//
|
|
||||||
// Each graph has producer and min_consumer versions, and each
|
|
||||||
// consumer has its own version and a min_producer. In addition, graphs can
|
|
||||||
// mark specific consumer versions as bad (to prevent bugs from executing).
|
|
||||||
// A consumer will execute a graph if the consumer's version is at least the
|
|
||||||
// graph's min_consumer, the graph's producer version is at least the consumer's
|
|
||||||
// min_producer, and the consumer version isn't specifically disallowed by the
|
|
||||||
// graph.
|
|
||||||
//
|
|
||||||
// By default, newly created graphs have producer version TF_GRAPH_DEF_VERSION
|
|
||||||
// min_consumer TF_GRAPH_DEF_MIN_CONSUMER, and no other bad consumer versions.
|
|
||||||
//
|
|
||||||
// Version history:
|
|
||||||
//
|
|
||||||
// 0. Graphs created before GraphDef versioning
|
|
||||||
// 1. First real version (2dec2015)
|
|
||||||
// 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015)
|
|
||||||
// 3. Remove TileGrad, since it was equivalent to reduce_sum (30dec2015)
|
|
||||||
// 4. When support for this version is removed, we can safely make AttrValue
|
|
||||||
// parsing more strict with respect to empty list values (see
|
|
||||||
// 111635679, 7jan2016).
|
|
||||||
// 5. Graphs are wholly-validated during Session::Create() (7jan2016).
|
|
||||||
// 6. TensorFlow is scalar strict within Google (27jan2016).
|
|
||||||
// 7. Remove TopK in favor of TopKV2 (5feb2016).
|
|
||||||
// 8. Replace RandomCrop from C++ with pure Python (5feb2016).
|
|
||||||
// 9. Deprecate batch_norm_with_global_normalization (16feb2016).
|
|
||||||
// 10. Deprecate conv3d_backprop_{filter,input} (10jun2016).
|
|
||||||
// 11. Deprecate {batch}_self_adjoint_eig (3aug2016).
|
|
||||||
// 12. Graph consumers understand the node_def field of FunctionDef (22aug2016).
|
|
||||||
// 13. Deprecate multiple batch linear algebra ops (9sep2016).
|
|
||||||
// 14. Deprecate batch_matrix_* ops. (10sep2016).
|
|
||||||
// 15. Deprecate batch_fft_* ops. (14sep2016).
|
|
||||||
// 16. Deprecate tensor_array (v1) ops in favor of v2 (10nov2016).
|
|
||||||
// 17. Deprecate inv (11nov2016).
|
|
||||||
// 17. Expose reverse_v2 (10nov2016)
|
|
||||||
// 18. Add VariableV2 (30nov2016)
|
|
||||||
// 19. Deprecated ops created by models moved out of core SkipGram, NegTrain.
|
|
||||||
// (08dec2016)
|
|
||||||
// 20. Catch all version 1.0 changes to Python API generation. SplitV is now
|
|
||||||
// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
|
|
||||||
// now used by tf.concat. Graphs use flooring
|
|
||||||
// division and mod semantics. TensorArrayV3. (12dec2016)
|
|
||||||
// Also considered the version for when it is required for reduction
|
|
||||||
// ops' indices to be scalar or vector, and not higher rank.
|
|
||||||
// Some earlier graph def versions allowed this.
|
|
||||||
// 21. Dropped FunctionDef.Node support, switched to node_def introduced
|
|
||||||
// in version 12. (11jan2017)
|
|
||||||
// 22. Placeholder now can specify and enforce scalar and partial
|
|
||||||
// shapes, particularly when restoring a graph from GraphDef
|
|
||||||
// produced at version 22 or later. (04/10/2016)
|
|
||||||
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
|
|
||||||
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
|
|
||||||
// 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
|
|
||||||
// 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25).
|
|
||||||
// 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating
|
|
||||||
// whether default-valued attrs have been stripped from the nodes in the
|
|
||||||
// GraphDef. (7dec2017)
|
|
||||||
// 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops
|
|
||||||
// deprecated in favor of V2 ops. (2018/01/23)
|
|
||||||
// 28. Deprecate MatrixExponential op in favor of Python implementation.
|
|
||||||
// (2018/08/21).
|
|
||||||
// (2019/02/15). Added `control_ret` field to FunctionDef proto, and
|
|
||||||
// `control_output` field to OpDef proto.
|
|
||||||
// 29. Deprecate StatefulStandardNormal op in favor of StatefulStandardNormalV2.
|
|
||||||
// (2019/03/25).
|
|
||||||
// (2019/04/17). Added `arg_attr` field to FunctionDefProto.
|
|
||||||
// 30. (2019/05/09) First date based GraphDef version. GraphDef
|
|
||||||
// versions advance by 1 each day after this point.
|
|
||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
|
||||||
#define TF_GRAPH_DEF_VERSION 578 // Updated: 2020/11/7
|
|
||||||
|
|
||||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
|
||||||
//
|
|
||||||
// The checkpoint versions have the same semantics as GraphDef versions, but the
|
|
||||||
// numbering scheme is separate. We have no plans to ever deprecate checkpoint
|
|
||||||
// versions, but it's good to have this in place in case we ever need to.
|
|
||||||
//
|
|
||||||
// Version history:
|
|
||||||
//
|
|
||||||
// 0. Checkpoints saved before checkpoint versioning.
|
|
||||||
// 1. First real version (10feb2015).
|
|
||||||
#define TF_CHECKPOINT_VERSION_MIN_PRODUCER 0
|
|
||||||
#define TF_CHECKPOINT_VERSION_MIN_CONSUMER 0
|
|
||||||
#define TF_CHECKPOINT_VERSION 1
|
|
||||||
|
|
||||||
/// Version query functions (defined in generated version_info.cc)
|
|
||||||
|
|
||||||
// Host compiler version (declared elsewhere to be __VERSION__)
|
|
||||||
extern const char* tf_compiler_version();
|
|
||||||
// The git commit designator when tensorflow was built
|
|
||||||
// If no git repository, this will be "internal".
|
|
||||||
extern const char* tf_git_version();
|
|
||||||
// Value of the _GLIBCXX_USE_CXX11_ABI flag, or 0 if it's not set.
|
|
||||||
extern int tf_cxx11_abi_flag();
|
|
||||||
// Returns 1 if build is monolithic, or 0 otherwise.
|
|
||||||
extern int tf_monolithic_build();
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_
|
|
||||||
@@ -67,9 +67,8 @@ typedef struct {
|
|||||||
typedef enum {
|
typedef enum {
|
||||||
kTfLiteActNone = 0,
|
kTfLiteActNone = 0,
|
||||||
kTfLiteActRelu,
|
kTfLiteActRelu,
|
||||||
kTfLiteActReluN1To1, // min(max(-1, x), 1)
|
kTfLiteActReluN1To1, // min(max(-1, x), 1)
|
||||||
kTfLiteActRelu1 = kTfLiteActReluN1To1, // kTfLiteActRelu1 will be deprecated.
|
kTfLiteActRelu6, // min(max(0, x), 6)
|
||||||
kTfLiteActRelu6, // min(max(0, x), 6)
|
|
||||||
kTfLiteActTanh,
|
kTfLiteActTanh,
|
||||||
kTfLiteActSignBit,
|
kTfLiteActSignBit,
|
||||||
kTfLiteActSigmoid,
|
kTfLiteActSigmoid,
|
||||||
@@ -88,6 +87,17 @@ typedef struct {
|
|||||||
int dilation_height_factor;
|
int dilation_height_factor;
|
||||||
} TfLiteConvParams;
|
} TfLiteConvParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
TfLitePadding padding;
|
||||||
|
int stride_width;
|
||||||
|
int stride_height;
|
||||||
|
int stride_depth;
|
||||||
|
int dilation_width_factor;
|
||||||
|
int dilation_height_factor;
|
||||||
|
int dilation_depth_factor;
|
||||||
|
TfLiteFusedActivation activation;
|
||||||
|
} TfLiteConv3DParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
TfLitePadding padding;
|
TfLitePadding padding;
|
||||||
int stride_width;
|
int stride_width;
|
||||||
@@ -214,6 +224,10 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
bool adj_x;
|
bool adj_x;
|
||||||
bool adj_y;
|
bool adj_y;
|
||||||
|
// Parameters for BatchMatMul version 4 or above.
|
||||||
|
// If set to true and the weights are quantized, then non constant inputs
|
||||||
|
// are quantized at evaluation time with asymmetric quantization.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteBatchMatMulParams;
|
} TfLiteBatchMatMulParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@@ -351,6 +365,7 @@ typedef struct {
|
|||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int axis;
|
int axis;
|
||||||
|
int batch_dims;
|
||||||
} TfLiteGatherParams;
|
} TfLiteGatherParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@@ -474,6 +489,12 @@ typedef struct {
|
|||||||
int init_subgraph_index;
|
int init_subgraph_index;
|
||||||
} TfLiteCallOnceParams;
|
} TfLiteCallOnceParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int table_id;
|
||||||
|
TfLiteType key_dtype;
|
||||||
|
TfLiteType value_dtype;
|
||||||
|
} TfLiteHashtableParams;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
|||||||
95
code/components/tfmicro/tensorflow/lite/c/c_api_types.h
Normal file
95
code/components/tfmicro/tensorflow/lite/c/c_api_types.h
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file declares types used by the pure C inference API defined in c_api.h,
|
||||||
|
// some of which are also used in the C++ and C kernel and interpreter APIs.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_C_C_API_TYPES_H_
|
||||||
|
#define TENSORFLOW_LITE_C_C_API_TYPES_H_
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
|
||||||
|
// library.
|
||||||
|
#ifdef SWIG
|
||||||
|
#define TFL_CAPI_EXPORT
|
||||||
|
#else
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#ifdef TFL_COMPILE_LIBRARY
|
||||||
|
#define TFL_CAPI_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define TFL_CAPI_EXPORT __declspec(dllimport)
|
||||||
|
#endif // TFL_COMPILE_LIBRARY
|
||||||
|
#else
|
||||||
|
#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif // _WIN32
|
||||||
|
#endif // SWIG
|
||||||
|
|
||||||
|
typedef enum TfLiteStatus {
|
||||||
|
kTfLiteOk = 0,
|
||||||
|
|
||||||
|
// Generally referring to an error in the runtime (i.e. interpreter)
|
||||||
|
kTfLiteError = 1,
|
||||||
|
|
||||||
|
// Generally referring to an error from a TfLiteDelegate itself.
|
||||||
|
kTfLiteDelegateError = 2,
|
||||||
|
|
||||||
|
// Generally referring to an error in applying a delegate due to
|
||||||
|
// incompatibility between runtime and delegate, e.g., this error is returned
|
||||||
|
// when trying to apply a TfLite delegate onto a model graph that's already
|
||||||
|
// immutable.
|
||||||
|
kTfLiteApplicationError = 3
|
||||||
|
} TfLiteStatus;
|
||||||
|
|
||||||
|
// Types supported by tensor
|
||||||
|
typedef enum {
|
||||||
|
kTfLiteNoType = 0,
|
||||||
|
kTfLiteFloat32 = 1,
|
||||||
|
kTfLiteInt32 = 2,
|
||||||
|
kTfLiteUInt8 = 3,
|
||||||
|
kTfLiteInt64 = 4,
|
||||||
|
kTfLiteString = 5,
|
||||||
|
kTfLiteBool = 6,
|
||||||
|
kTfLiteInt16 = 7,
|
||||||
|
kTfLiteComplex64 = 8,
|
||||||
|
kTfLiteInt8 = 9,
|
||||||
|
kTfLiteFloat16 = 10,
|
||||||
|
kTfLiteFloat64 = 11,
|
||||||
|
kTfLiteComplex128 = 12,
|
||||||
|
kTfLiteUInt64 = 13,
|
||||||
|
kTfLiteResource = 14,
|
||||||
|
kTfLiteVariant = 15,
|
||||||
|
kTfLiteUInt32 = 16,
|
||||||
|
} TfLiteType;
|
||||||
|
|
||||||
|
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
|
||||||
|
// If per-layer quantization is specified this field will still be populated in
|
||||||
|
// addition to TfLiteAffineQuantization.
|
||||||
|
// Parameters for asymmetric quantization. Quantized values can be converted
|
||||||
|
// back to float using:
|
||||||
|
// real_value = scale * (quantized_value - zero_point)
|
||||||
|
typedef struct TfLiteQuantizationParams {
|
||||||
|
float scale;
|
||||||
|
int32_t zero_point;
|
||||||
|
} TfLiteQuantizationParams;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern C
|
||||||
|
#endif
|
||||||
|
#endif // TENSORFLOW_LITE_C_C_API_TYPES_H_
|
||||||
@@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
|
|
||||||
#ifndef TF_LITE_STATIC_MEMORY
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
@@ -197,12 +199,16 @@ const char* TfLiteTypeGetName(TfLiteType type) {
|
|||||||
return "INT16";
|
return "INT16";
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
return "INT32";
|
return "INT32";
|
||||||
|
case kTfLiteUInt32:
|
||||||
|
return "UINT32";
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
return "UINT8";
|
return "UINT8";
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
return "INT8";
|
return "INT8";
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
return "INT64";
|
return "INT64";
|
||||||
|
case kTfLiteUInt64:
|
||||||
|
return "UINT64";
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
return "BOOL";
|
return "BOOL";
|
||||||
case kTfLiteComplex64:
|
case kTfLiteComplex64:
|
||||||
@@ -215,6 +221,10 @@ const char* TfLiteTypeGetName(TfLiteType type) {
|
|||||||
return "FLOAT16";
|
return "FLOAT16";
|
||||||
case kTfLiteFloat64:
|
case kTfLiteFloat64:
|
||||||
return "FLOAT64";
|
return "FLOAT64";
|
||||||
|
case kTfLiteResource:
|
||||||
|
return "RESOURCE";
|
||||||
|
case kTfLiteVariant:
|
||||||
|
return "VARIANT";
|
||||||
}
|
}
|
||||||
return "Unknown type";
|
return "Unknown type";
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,26 +40,12 @@ limitations under the License.
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/c_api_types.h" // IWYU pragma: export
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
|
||||||
typedef enum TfLiteStatus {
|
|
||||||
kTfLiteOk = 0,
|
|
||||||
|
|
||||||
// Generally referring to an error in the runtime (i.e. interpreter)
|
|
||||||
kTfLiteError = 1,
|
|
||||||
|
|
||||||
// Generally referring to an error from a TfLiteDelegate itself.
|
|
||||||
kTfLiteDelegateError = 2,
|
|
||||||
|
|
||||||
// Generally referring to an error in applying a delegate due to
|
|
||||||
// incompatibility between runtime and delegate, e.g., this error is returned
|
|
||||||
// when trying to apply a TfLite delegate onto a model graph that's already
|
|
||||||
// immutable.
|
|
||||||
kTfLiteApplicationError = 3
|
|
||||||
} TfLiteStatus;
|
|
||||||
|
|
||||||
// The list of external context types known to TF Lite. This list exists solely
|
// The list of external context types known to TF Lite. This list exists solely
|
||||||
// to avoid conflicts and to ensure ops can share the external contexts they
|
// to avoid conflicts and to ensure ops can share the external contexts they
|
||||||
// need. Access to the external contexts is controlled by one of the
|
// need. Access to the external contexts is controlled by one of the
|
||||||
@@ -80,7 +66,7 @@ struct TfLiteRegistration;
|
|||||||
|
|
||||||
// An external context is a collection of information unrelated to the TF Lite
|
// An external context is a collection of information unrelated to the TF Lite
|
||||||
// framework, but useful to a subset of the ops. TF Lite knows very little
|
// framework, but useful to a subset of the ops. TF Lite knows very little
|
||||||
// about about the actual contexts, but it keeps a list of them, and is able to
|
// about the actual contexts, but it keeps a list of them, and is able to
|
||||||
// refresh them if configurations like the number of recommended threads
|
// refresh them if configurations like the number of recommended threads
|
||||||
// change.
|
// change.
|
||||||
typedef struct TfLiteExternalContext {
|
typedef struct TfLiteExternalContext {
|
||||||
@@ -98,7 +84,8 @@ typedef struct TfLiteIntArray {
|
|||||||
// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
|
// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
|
||||||
#if (!defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
|
#if (!defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
|
||||||
__GNUC_MINOR__ >= 1) || \
|
__GNUC_MINOR__ >= 1) || \
|
||||||
defined(HEXAGON) || (__clang_major__ == 7 && __clang_minor__ == 1)
|
defined(HEXAGON) || \
|
||||||
|
(defined(__clang__) && __clang_major__ == 7 && __clang_minor__ == 1)
|
||||||
int data[0];
|
int data[0];
|
||||||
#else
|
#else
|
||||||
int data[];
|
int data[];
|
||||||
@@ -254,22 +241,6 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
|
|
||||||
// library.
|
|
||||||
#ifdef SWIG
|
|
||||||
#define TFL_CAPI_EXPORT
|
|
||||||
#else
|
|
||||||
#if defined(_WIN32)
|
|
||||||
#ifdef TFL_COMPILE_LIBRARY
|
|
||||||
#define TFL_CAPI_EXPORT __declspec(dllexport)
|
|
||||||
#else
|
|
||||||
#define TFL_CAPI_EXPORT __declspec(dllimport)
|
|
||||||
#endif // TFL_COMPILE_LIBRARY
|
|
||||||
#else
|
|
||||||
#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
|
|
||||||
#endif // _WIN32
|
|
||||||
#endif // SWIG
|
|
||||||
|
|
||||||
// Single-precision complex data type compatible with the C99 definition.
|
// Single-precision complex data type compatible with the C99 definition.
|
||||||
typedef struct TfLiteComplex64 {
|
typedef struct TfLiteComplex64 {
|
||||||
float re, im; // real and imaginary parts, respectively.
|
float re, im; // real and imaginary parts, respectively.
|
||||||
@@ -285,23 +256,6 @@ typedef struct TfLiteFloat16 {
|
|||||||
uint16_t data;
|
uint16_t data;
|
||||||
} TfLiteFloat16;
|
} TfLiteFloat16;
|
||||||
|
|
||||||
// Types supported by tensor
|
|
||||||
typedef enum {
|
|
||||||
kTfLiteNoType = 0,
|
|
||||||
kTfLiteFloat32 = 1,
|
|
||||||
kTfLiteInt32 = 2,
|
|
||||||
kTfLiteUInt8 = 3,
|
|
||||||
kTfLiteInt64 = 4,
|
|
||||||
kTfLiteString = 5,
|
|
||||||
kTfLiteBool = 6,
|
|
||||||
kTfLiteInt16 = 7,
|
|
||||||
kTfLiteComplex64 = 8,
|
|
||||||
kTfLiteInt8 = 9,
|
|
||||||
kTfLiteFloat16 = 10,
|
|
||||||
kTfLiteFloat64 = 11,
|
|
||||||
kTfLiteComplex128 = 12,
|
|
||||||
} TfLiteType;
|
|
||||||
|
|
||||||
// Return the name of a given type, for error reporting purposes.
|
// Return the name of a given type, for error reporting purposes.
|
||||||
const char* TfLiteTypeGetName(TfLiteType type);
|
const char* TfLiteTypeGetName(TfLiteType type);
|
||||||
|
|
||||||
@@ -318,22 +272,12 @@ typedef enum TfLiteQuantizationType {
|
|||||||
typedef struct TfLiteQuantization {
|
typedef struct TfLiteQuantization {
|
||||||
// The type of quantization held by params.
|
// The type of quantization held by params.
|
||||||
TfLiteQuantizationType type;
|
TfLiteQuantizationType type;
|
||||||
// Holds a reference to one of the quantization param structures specified
|
// Holds an optional reference to a quantization param structure. The actual
|
||||||
// below.
|
// type depends on the value of the `type` field (see the comment there for
|
||||||
|
// the values and corresponding types).
|
||||||
void* params;
|
void* params;
|
||||||
} TfLiteQuantization;
|
} TfLiteQuantization;
|
||||||
|
|
||||||
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
|
|
||||||
// If per-layer quantization is specified this field will still be populated in
|
|
||||||
// addition to TfLiteAffineQuantization.
|
|
||||||
// Parameters for asymmetric quantization. Quantized values can be converted
|
|
||||||
// back to float using:
|
|
||||||
// real_value = scale * (quantized_value - zero_point)
|
|
||||||
typedef struct TfLiteQuantizationParams {
|
|
||||||
float scale;
|
|
||||||
int32_t zero_point;
|
|
||||||
} TfLiteQuantizationParams;
|
|
||||||
|
|
||||||
// Parameters for asymmetric quantization across a dimension (i.e per output
|
// Parameters for asymmetric quantization across a dimension (i.e per output
|
||||||
// channel quantization).
|
// channel quantization).
|
||||||
// quantized_dimension specifies which dimension the scales and zero_points
|
// quantized_dimension specifies which dimension the scales and zero_points
|
||||||
@@ -353,7 +297,9 @@ typedef union TfLitePtrUnion {
|
|||||||
* GetTensorData<TYPE>(tensor) instead, otherwise only access .data, as other
|
* GetTensorData<TYPE>(tensor) instead, otherwise only access .data, as other
|
||||||
* members are deprecated. */
|
* members are deprecated. */
|
||||||
int32_t* i32;
|
int32_t* i32;
|
||||||
|
uint32_t* u32;
|
||||||
int64_t* i64;
|
int64_t* i64;
|
||||||
|
uint64_t* u64;
|
||||||
float* f;
|
float* f;
|
||||||
TfLiteFloat16* f16;
|
TfLiteFloat16* f16;
|
||||||
double* f64;
|
double* f64;
|
||||||
@@ -430,6 +376,17 @@ typedef struct TfLiteCustomAllocation {
|
|||||||
size_t bytes;
|
size_t bytes;
|
||||||
} TfLiteCustomAllocation;
|
} TfLiteCustomAllocation;
|
||||||
|
|
||||||
|
// The flags used in `Interpreter::SetCustomAllocationForTensor`.
|
||||||
|
// Note that this is a bitmask, so the values should be 1, 2, 4, 8, ...etc.
|
||||||
|
typedef enum TfLiteCustomAllocationFlags {
|
||||||
|
kTfLiteCustomAllocationFlagsNone = 0,
|
||||||
|
// Skips checking whether allocation.data points to an aligned buffer as
|
||||||
|
// expected by the TFLite runtime.
|
||||||
|
// NOTE: Setting this flag can cause crashes when calling Invoke().
|
||||||
|
// Use with caution.
|
||||||
|
kTfLiteCustomAllocationFlagsSkipAlignCheck = 1,
|
||||||
|
} TfLiteCustomAllocationFlags;
|
||||||
|
|
||||||
// A tensor in the interpreter system which is a wrapper around a buffer of
|
// A tensor in the interpreter system which is a wrapper around a buffer of
|
||||||
// data including a dimensionality (or NULL if not currently defined).
|
// data including a dimensionality (or NULL if not currently defined).
|
||||||
#ifndef TF_LITE_STATIC_MEMORY
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
@@ -534,7 +491,7 @@ typedef struct TfLiteNode {
|
|||||||
// WARNING: This is an experimental interface that is subject to change.
|
// WARNING: This is an experimental interface that is subject to change.
|
||||||
struct TfLiteDelegate* delegate;
|
struct TfLiteDelegate* delegate;
|
||||||
} TfLiteNode;
|
} TfLiteNode;
|
||||||
#else // defined(TF_LITE_STATIC_MEMORY)?
|
#else // defined(TF_LITE_STATIC_MEMORY)?
|
||||||
// NOTE: This flag is opt-in only at compile time.
|
// NOTE: This flag is opt-in only at compile time.
|
||||||
//
|
//
|
||||||
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
||||||
|
|||||||
@@ -169,6 +169,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseAdd(op, error_reporter, allocator, builtin_data);
|
return ParseAdd(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_ADD_N: {
|
||||||
|
return ParseAddN(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_ARG_MAX: {
|
case BuiltinOperator_ARG_MAX: {
|
||||||
return ParseArgMax(op, error_reporter, allocator, builtin_data);
|
return ParseArgMax(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -181,6 +185,14 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParsePool(op, error_reporter, allocator, builtin_data);
|
return ParsePool(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_BATCH_MATMUL: {
|
||||||
|
return ParseBatchMatMul(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_BATCH_TO_SPACE_ND: {
|
||||||
|
return ParseBatchToSpaceNd(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_CEIL: {
|
case BuiltinOperator_CEIL: {
|
||||||
return ParseCeil(op, error_reporter, allocator, builtin_data);
|
return ParseCeil(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -193,6 +205,14 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseConv2D(op, error_reporter, allocator, builtin_data);
|
return ParseConv2D(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_CUMSUM: {
|
||||||
|
return ParseCumsum(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_DEPTH_TO_SPACE: {
|
||||||
|
return ParseDepthToSpace(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
||||||
return ParseDepthwiseConv2D(op, error_reporter, allocator, builtin_data);
|
return ParseDepthwiseConv2D(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -201,14 +221,46 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseDequantize(op, error_reporter, allocator, builtin_data);
|
return ParseDequantize(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_DIV: {
|
||||||
|
return ParseDiv(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_ELU: {
|
||||||
|
return ParseElu(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_EXP: {
|
||||||
|
return ParseExp(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_EXPAND_DIMS: {
|
||||||
|
return ParseExpandDims(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_FILL: {
|
||||||
|
return ParseFill(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_FLOOR: {
|
case BuiltinOperator_FLOOR: {
|
||||||
return ParseFloor(op, error_reporter, allocator, builtin_data);
|
return ParseFloor(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_FLOOR_DIV: {
|
||||||
|
return ParseFloorDiv(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_FLOOR_MOD: {
|
||||||
|
return ParseFloorMod(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_FULLY_CONNECTED: {
|
case BuiltinOperator_FULLY_CONNECTED: {
|
||||||
return ParseFullyConnected(op, error_reporter, allocator, builtin_data);
|
return ParseFullyConnected(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_GATHER_ND: {
|
||||||
|
return ParseGatherNd(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_GREATER: {
|
case BuiltinOperator_GREATER: {
|
||||||
return ParseGreater(op, error_reporter, allocator, builtin_data);
|
return ParseGreater(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -229,6 +281,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParsePool(op, error_reporter, allocator, builtin_data);
|
return ParsePool(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_LEAKY_RELU: {
|
||||||
|
return ParseLeakyRelu(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_LESS: {
|
case BuiltinOperator_LESS: {
|
||||||
return ParseLess(op, error_reporter, allocator, builtin_data);
|
return ParseLess(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -257,6 +313,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseLogistic(op, error_reporter, allocator, builtin_data);
|
return ParseLogistic(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_LOG_SOFTMAX: {
|
||||||
|
return ParseLogSoftmax(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_MAXIMUM: {
|
case BuiltinOperator_MAXIMUM: {
|
||||||
return ParseMaximum(op, error_reporter, allocator, builtin_data);
|
return ParseMaximum(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -297,6 +357,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParsePadV2(op, error_reporter, allocator, builtin_data);
|
return ParsePadV2(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_POW: {
|
||||||
|
return ParsePow(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_PRELU: {
|
case BuiltinOperator_PRELU: {
|
||||||
return ParsePrelu(op, error_reporter, allocator, builtin_data);
|
return ParsePrelu(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -362,6 +426,14 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseSoftmax(op, error_reporter, allocator, builtin_data);
|
return ParseSoftmax(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_SPACE_TO_BATCH_ND: {
|
||||||
|
return ParseSpaceToBatchNd(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_SPACE_TO_DEPTH: {
|
||||||
|
return ParseSpaceToDepth(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_SPLIT: {
|
case BuiltinOperator_SPLIT: {
|
||||||
return ParseSplit(op, error_reporter, allocator, builtin_data);
|
return ParseSplit(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -378,6 +450,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseSquare(op, error_reporter, allocator, builtin_data);
|
return ParseSquare(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_SQUEEZE: {
|
||||||
|
return ParseSqueeze(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_STRIDED_SLICE: {
|
case BuiltinOperator_STRIDED_SLICE: {
|
||||||
return ParseStridedSlice(op, error_reporter, allocator, builtin_data);
|
return ParseStridedSlice(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@@ -398,23 +474,20 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseTanh(op, error_reporter, allocator, builtin_data);
|
return ParseTanh(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_TRANSPOSE_CONV: {
|
||||||
|
return ParseTransposeConv(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_UNPACK: {
|
case BuiltinOperator_UNPACK: {
|
||||||
return ParseUnpack(op, error_reporter, allocator, builtin_data);
|
return ParseUnpack(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_ZEROS_LIKE: {
|
||||||
|
return ParseZerosLike(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_CAST: {
|
case BuiltinOperator_CAST: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteCastParams>();
|
return ParseCast(op, error_reporter, allocator, builtin_data);
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* schema_params = op->builtin_options_as_CastOptions()) {
|
|
||||||
TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->in_data_type(),
|
|
||||||
¶ms->in_data_type,
|
|
||||||
error_reporter));
|
|
||||||
TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->out_data_type(),
|
|
||||||
¶ms->out_data_type,
|
|
||||||
error_reporter));
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
}
|
||||||
case BuiltinOperator_LSH_PROJECTION: {
|
case BuiltinOperator_LSH_PROJECTION: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteLSHProjectionParams>();
|
auto params = safe_allocator.Allocate<TfLiteLSHProjectionParams>();
|
||||||
@@ -483,16 +556,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
case BuiltinOperator_HASHTABLE_LOOKUP:
|
case BuiltinOperator_HASHTABLE_LOOKUP:
|
||||||
// no-op.
|
// no-op.
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
case BuiltinOperator_DIV: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteDivParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* schema_params = op->builtin_options_as_DivOptions()) {
|
|
||||||
params->activation =
|
|
||||||
ConvertActivation(schema_params->fused_activation_function());
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
|
case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteLocalResponseNormParams>();
|
auto params = safe_allocator.Allocate<TfLiteLocalResponseNormParams>();
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
@@ -584,66 +648,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_SPACE_TO_DEPTH: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* schema_params =
|
|
||||||
op->builtin_options_as_SpaceToDepthOptions()) {
|
|
||||||
params->block_size = schema_params->block_size();
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_DEPTH_TO_SPACE: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteDepthToSpaceParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* schema_params =
|
|
||||||
op->builtin_options_as_DepthToSpaceOptions()) {
|
|
||||||
params->block_size = schema_params->block_size();
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_GATHER: {
|
case BuiltinOperator_GATHER: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteGatherParams>();
|
return ParseGather(op, error_reporter, allocator, builtin_data);
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
params->axis = 0;
|
|
||||||
if (const auto* gather_params = op->builtin_options_as_GatherOptions()) {
|
|
||||||
params->axis = gather_params->axis();
|
|
||||||
}
|
|
||||||
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
case BuiltinOperator_SQUEEZE: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteSqueezeParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
|
|
||||||
const auto* squeeze_dims = schema_params->squeeze_dims();
|
|
||||||
if (squeeze_dims != nullptr) {
|
|
||||||
TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray(
|
|
||||||
sizeof(params->squeeze_dims), squeeze_dims, params->squeeze_dims,
|
|
||||||
error_reporter, "squeeze"));
|
|
||||||
params->num_squeeze_dims = squeeze_dims->size();
|
|
||||||
} else {
|
|
||||||
params->num_squeeze_dims = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_TRANSPOSE_CONV: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteTransposeConvParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* transpose_conv_params =
|
|
||||||
op->builtin_options_as_TransposeConvOptions()) {
|
|
||||||
params->padding = ConvertPadding(transpose_conv_params->padding());
|
|
||||||
params->stride_width = transpose_conv_params->stride_w();
|
|
||||||
params->stride_height = transpose_conv_params->stride_h();
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
}
|
||||||
case BuiltinOperator_SPARSE_TO_DENSE: {
|
case BuiltinOperator_SPARSE_TO_DENSE: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteSparseToDenseParams>();
|
auto params = safe_allocator.Allocate<TfLiteSparseToDenseParams>();
|
||||||
@@ -683,16 +690,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_LEAKY_RELU: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteLeakyReluParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* leaky_relu_params =
|
|
||||||
op->builtin_options_as_LeakyReluOptions()) {
|
|
||||||
params->alpha = leaky_relu_params->alpha();
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_MIRROR_PAD: {
|
case BuiltinOperator_MIRROR_PAD: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteMirrorPaddingParams>();
|
auto params = safe_allocator.Allocate<TfLiteMirrorPaddingParams>();
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
@@ -750,17 +747,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_BATCH_MATMUL: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteBatchMatMulParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* bmm_params =
|
|
||||||
op->builtin_options_as_BatchMatMulOptions()) {
|
|
||||||
params->adj_x = bmm_params->adj_x();
|
|
||||||
params->adj_y = bmm_params->adj_y();
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_CALL_ONCE: {
|
case BuiltinOperator_CALL_ONCE: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteCallOnceParams>();
|
auto params = safe_allocator.Allocate<TfLiteCallOnceParams>();
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
@@ -771,50 +757,59 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_CUMSUM: {
|
case BuiltinOperator_CONV_3D: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteCumsumParams>();
|
auto params = safe_allocator.Allocate<TfLiteConv3DParams>();
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
if (const auto* cumsum_params = op->builtin_options_as_CumsumOptions()) {
|
if (const auto* conv3d_params = op->builtin_options_as_Conv3DOptions()) {
|
||||||
params->exclusive = cumsum_params->exclusive();
|
params->padding = ConvertPadding(conv3d_params->padding());
|
||||||
params->reverse = cumsum_params->reverse();
|
params->activation =
|
||||||
|
ConvertActivation(conv3d_params->fused_activation_function());
|
||||||
|
params->stride_depth = conv3d_params->stride_d();
|
||||||
|
params->stride_height = conv3d_params->stride_h();
|
||||||
|
params->stride_width = conv3d_params->stride_w();
|
||||||
|
params->dilation_depth_factor = conv3d_params->dilation_d_factor();
|
||||||
|
params->dilation_height_factor = conv3d_params->dilation_h_factor();
|
||||||
|
params->dilation_width_factor = conv3d_params->dilation_w_factor();
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
case BuiltinOperator_HASHTABLE: {
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteHashtableParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* hashtable_params =
|
||||||
|
op->builtin_options_as_HashtableOptions()) {
|
||||||
|
params->table_id = hashtable_params->table_id();
|
||||||
|
TF_LITE_ENSURE_STATUS(ConvertTensorType(
|
||||||
|
hashtable_params->key_dtype(), ¶ms->key_dtype, error_reporter));
|
||||||
|
TF_LITE_ENSURE_STATUS(ConvertTensorType(hashtable_params->value_dtype(),
|
||||||
|
¶ms->value_dtype,
|
||||||
|
error_reporter));
|
||||||
}
|
}
|
||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
// Below are the ops with no builtin_data structure.
|
// Below are the ops with no builtin_data structure.
|
||||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
|
||||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||||
// ok for now, since there is no call implementation either.
|
// ok for now, since there is no call implementation either.
|
||||||
case BuiltinOperator_CALL:
|
case BuiltinOperator_CALL:
|
||||||
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
||||||
case BuiltinOperator_COS:
|
case BuiltinOperator_COS:
|
||||||
case BuiltinOperator_CUSTOM:
|
case BuiltinOperator_CUSTOM:
|
||||||
case BuiltinOperator_ELU:
|
|
||||||
case BuiltinOperator_EMBEDDING_LOOKUP:
|
case BuiltinOperator_EMBEDDING_LOOKUP:
|
||||||
case BuiltinOperator_EQUAL:
|
case BuiltinOperator_EQUAL:
|
||||||
case BuiltinOperator_EXP:
|
|
||||||
case BuiltinOperator_EXPAND_DIMS:
|
|
||||||
case BuiltinOperator_LOG_SOFTMAX:
|
|
||||||
case BuiltinOperator_MATRIX_DIAG:
|
case BuiltinOperator_MATRIX_DIAG:
|
||||||
case BuiltinOperator_MATRIX_SET_DIAG:
|
case BuiltinOperator_MATRIX_SET_DIAG:
|
||||||
case BuiltinOperator_RELU_N1_TO_1:
|
case BuiltinOperator_RELU_N1_TO_1:
|
||||||
case BuiltinOperator_SELECT:
|
case BuiltinOperator_SELECT:
|
||||||
case BuiltinOperator_SELECT_V2:
|
case BuiltinOperator_SELECT_V2:
|
||||||
case BuiltinOperator_SLICE:
|
case BuiltinOperator_SLICE:
|
||||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
|
||||||
case BuiltinOperator_TILE:
|
case BuiltinOperator_TILE:
|
||||||
case BuiltinOperator_TOPK_V2:
|
case BuiltinOperator_TOPK_V2:
|
||||||
case BuiltinOperator_TRANSPOSE:
|
case BuiltinOperator_TRANSPOSE:
|
||||||
case BuiltinOperator_POW:
|
|
||||||
case BuiltinOperator_FLOOR_DIV:
|
|
||||||
case BuiltinOperator_ZEROS_LIKE:
|
|
||||||
case BuiltinOperator_FILL:
|
|
||||||
case BuiltinOperator_FLOOR_MOD:
|
|
||||||
case BuiltinOperator_RANGE:
|
case BuiltinOperator_RANGE:
|
||||||
case BuiltinOperator_SQUARED_DIFFERENCE:
|
case BuiltinOperator_SQUARED_DIFFERENCE:
|
||||||
case BuiltinOperator_REVERSE_V2:
|
case BuiltinOperator_REVERSE_V2:
|
||||||
case BuiltinOperator_ADD_N:
|
|
||||||
case BuiltinOperator_GATHER_ND:
|
|
||||||
case BuiltinOperator_WHERE:
|
case BuiltinOperator_WHERE:
|
||||||
case BuiltinOperator_RANK:
|
case BuiltinOperator_RANK:
|
||||||
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
|
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
|
||||||
@@ -823,6 +818,13 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
case BuiltinOperator_DENSIFY:
|
case BuiltinOperator_DENSIFY:
|
||||||
case BuiltinOperator_SEGMENT_SUM:
|
case BuiltinOperator_SEGMENT_SUM:
|
||||||
case BuiltinOperator_BROADCAST_TO:
|
case BuiltinOperator_BROADCAST_TO:
|
||||||
|
case BuiltinOperator_RFFT2D:
|
||||||
|
case BuiltinOperator_IMAG:
|
||||||
|
case BuiltinOperator_REAL:
|
||||||
|
case BuiltinOperator_COMPLEX_ABS:
|
||||||
|
case BuiltinOperator_HASHTABLE_FIND:
|
||||||
|
case BuiltinOperator_HASHTABLE_IMPORT:
|
||||||
|
case BuiltinOperator_HASHTABLE_SIZE:
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
|
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@@ -850,6 +852,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
|||||||
case TensorType_INT32:
|
case TensorType_INT32:
|
||||||
*type = kTfLiteInt32;
|
*type = kTfLiteInt32;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
case TensorType_UINT32:
|
||||||
|
*type = kTfLiteUInt32;
|
||||||
|
return kTfLiteOk;
|
||||||
case TensorType_UINT8:
|
case TensorType_UINT8:
|
||||||
*type = kTfLiteUInt8;
|
*type = kTfLiteUInt8;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@@ -859,6 +864,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
|||||||
case TensorType_INT64:
|
case TensorType_INT64:
|
||||||
*type = kTfLiteInt64;
|
*type = kTfLiteInt64;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
case TensorType_UINT64:
|
||||||
|
*type = kTfLiteUInt64;
|
||||||
|
return kTfLiteOk;
|
||||||
case TensorType_STRING:
|
case TensorType_STRING:
|
||||||
*type = kTfLiteString;
|
*type = kTfLiteString;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@@ -871,6 +879,12 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
|||||||
case TensorType_COMPLEX128:
|
case TensorType_COMPLEX128:
|
||||||
*type = kTfLiteComplex128;
|
*type = kTfLiteComplex128;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
case TensorType_RESOURCE:
|
||||||
|
*type = kTfLiteResource;
|
||||||
|
return kTfLiteOk;
|
||||||
|
case TensorType_VARIANT:
|
||||||
|
*type = kTfLiteVariant;
|
||||||
|
return kTfLiteOk;
|
||||||
default:
|
default:
|
||||||
*type = kTfLiteNoType;
|
*type = kTfLiteNoType;
|
||||||
TF_LITE_REPORT_ERROR(error_reporter,
|
TF_LITE_REPORT_ERROR(error_reporter,
|
||||||
@@ -912,6 +926,11 @@ TfLiteStatus ParseAdd(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseAddN(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseArgMax(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseArgMax(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
@@ -962,6 +981,56 @@ TfLiteStatus ParseArgMin(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseBatchMatMul(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteBatchMatMulParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* bmm_params = op->builtin_options_as_BatchMatMulOptions()) {
|
||||||
|
params->adj_x = bmm_params->adj_x();
|
||||||
|
params->adj_y = bmm_params->adj_y();
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
bmm_params->asymmetric_quantize_inputs();
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseBatchToSpaceNd(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseCast(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteCastParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* schema_params = op->builtin_options_as_CastOptions()) {
|
||||||
|
TF_LITE_ENSURE_STATUS(ConvertTensorType(
|
||||||
|
schema_params->in_data_type(), ¶ms->in_data_type, error_reporter));
|
||||||
|
TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->out_data_type(),
|
||||||
|
¶ms->out_data_type,
|
||||||
|
error_reporter));
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1030,6 +1099,24 @@ TfLiteStatus ParseConv2D(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseCumsum(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteCumsumParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* cumsum_params = op->builtin_options_as_CumsumOptions()) {
|
||||||
|
params->exclusive = cumsum_params->exclusive();
|
||||||
|
params->reverse = cumsum_params->reverse();
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1038,6 +1125,31 @@ TfLiteStatus ParseCos(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseDepthToSpace(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
std::unique_ptr<TfLiteDepthToSpaceParams,
|
||||||
|
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||||
|
params = safe_allocator.Allocate<TfLiteDepthToSpaceParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
|
||||||
|
const auto* schema_params = op->builtin_options_as_DepthToSpaceOptions();
|
||||||
|
if (schema_params != nullptr) {
|
||||||
|
params->block_size = schema_params->block_size();
|
||||||
|
} else {
|
||||||
|
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||||
|
// reasonable defaults in the params struct. We are not doing so until we
|
||||||
|
// better undertand the ramifications of changing the legacy behavior.
|
||||||
|
}
|
||||||
|
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseDepthwiseConv2D(const Operator* op,
|
TfLiteStatus ParseDepthwiseConv2D(const Operator* op,
|
||||||
ErrorReporter* error_reporter,
|
ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
@@ -1082,6 +1194,29 @@ TfLiteStatus ParseDequantize(const Operator*, ErrorReporter*,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseDiv(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteDivParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* schema_params = op->builtin_options_as_DivOptions()) {
|
||||||
|
params->activation =
|
||||||
|
ConvertActivation(schema_params->fused_activation_function());
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseElu(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
||||||
|
void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1090,6 +1225,30 @@ TfLiteStatus ParseEqual(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseExp(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
||||||
|
void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseExpandDims(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseFill(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
||||||
|
void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1098,6 +1257,22 @@ TfLiteStatus ParseFloor(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseFloorDiv(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseFloorMod(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseFullyConnected(const Operator* op,
|
TfLiteStatus ParseFullyConnected(const Operator* op,
|
||||||
ErrorReporter* error_reporter,
|
ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
@@ -1144,6 +1319,35 @@ TfLiteStatus ParseFullyConnected(const Operator* op,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseGather(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteGatherParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
params->axis = 0;
|
||||||
|
params->batch_dims = 0;
|
||||||
|
if (const auto* gather_params = op->builtin_options_as_GatherOptions()) {
|
||||||
|
params->axis = gather_params->axis();
|
||||||
|
params->batch_dims = gather_params->batch_dims();
|
||||||
|
}
|
||||||
|
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseGatherNd(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1195,6 +1399,22 @@ TfLiteStatus ParseL2Normalization(const Operator* op,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
auto params = safe_allocator.Allocate<TfLiteLeakyReluParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
if (const auto* leaky_relu_params =
|
||||||
|
op->builtin_options_as_LeakyReluOptions()) {
|
||||||
|
params->alpha = leaky_relu_params->alpha();
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1251,6 +1471,14 @@ TfLiteStatus ParseLogistic(const Operator*, ErrorReporter*,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseLogSoftmax(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1378,6 +1606,14 @@ TfLiteStatus ParsePool(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParsePow(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
||||||
|
void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1599,6 +1835,39 @@ TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseSpaceToBatchNd(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseSpaceToDepth(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
std::unique_ptr<TfLiteSpaceToDepthParams,
|
||||||
|
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||||
|
params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
|
||||||
|
const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions();
|
||||||
|
if (schema_params != nullptr) {
|
||||||
|
params->block_size = schema_params->block_size();
|
||||||
|
} else {
|
||||||
|
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||||
|
// reasonable defaults in the params struct. We are not doing so until we
|
||||||
|
// better undertand the ramifications of changing the legacy behavior.
|
||||||
|
}
|
||||||
|
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
@@ -1647,6 +1916,39 @@ TfLiteStatus ParseSplitV(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseSqueeze(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
|
||||||
|
std::unique_ptr<TfLiteSqueezeParams,
|
||||||
|
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||||
|
params = safe_allocator.Allocate<TfLiteSqueezeParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
|
||||||
|
const SqueezeOptions* schema_params = op->builtin_options_as_SqueezeOptions();
|
||||||
|
|
||||||
|
if (schema_params != nullptr) {
|
||||||
|
const auto* squeeze_dims = schema_params->squeeze_dims();
|
||||||
|
if (squeeze_dims != nullptr) {
|
||||||
|
TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray(
|
||||||
|
sizeof(params->squeeze_dims), squeeze_dims, params->squeeze_dims,
|
||||||
|
error_reporter, "squeeze"));
|
||||||
|
params->num_squeeze_dims = squeeze_dims->size();
|
||||||
|
} else {
|
||||||
|
params->num_squeeze_dims = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||||
|
// reasonable defaults in the params struct. We are not doing so until we
|
||||||
|
// better undertand the ramifications of changing the legacy behavior.
|
||||||
|
}
|
||||||
|
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
// switch-case in ParseOpData because this function is used as part of the
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
// selective registration for the OpResolver implementation in micro.
|
// selective registration for the OpResolver implementation in micro.
|
||||||
@@ -1753,6 +2055,40 @@ TfLiteStatus ParseTanh(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
|||||||
void**) {
|
void**) {
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
//
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseTranspose(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseTransposeConv(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
std::unique_ptr<TfLiteTransposeConvParams,
|
||||||
|
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||||
|
params = safe_allocator.Allocate<TfLiteTransposeConvParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
const TransposeConvOptions* transpose_conv_params =
|
||||||
|
op->builtin_options_as_TransposeConvOptions();
|
||||||
|
if (transpose_conv_params != nullptr) {
|
||||||
|
params->padding = ConvertPadding(transpose_conv_params->padding());
|
||||||
|
params->stride_width = transpose_conv_params->stride_w();
|
||||||
|
params->stride_height = transpose_conv_params->stride_h();
|
||||||
|
} else {
|
||||||
|
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||||
|
// reasonable defaults in the params struct. We are not doing so until we
|
||||||
|
// better undertand the ramifications of changing the legacy behavior.
|
||||||
|
}
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
@@ -1779,6 +2115,14 @@ TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||||
|
// switch-case in ParseOpData because this function is used as part of the
|
||||||
|
// selective registration for the OpResolver implementation in micro.
|
||||||
|
TfLiteStatus ParseZerosLike(const Operator*, ErrorReporter*,
|
||||||
|
BuiltinDataAllocator*, void**) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||||
ErrorReporter* error_reporter,
|
ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
|
|||||||
@@ -75,15 +75,30 @@ TfLiteStatus ParseAbs(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
TfLiteStatus ParseAdd(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseAdd(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseAddN(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseArgMax(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseArgMax(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseArgMin(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseArgMin(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseBatchMatMul(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseBatchToSpaceNd(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseCeil(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseCeil(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseCast(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseConcatenation(const Operator* op,
|
TfLiteStatus ParseConcatenation(const Operator* op,
|
||||||
ErrorReporter* error_reporter,
|
ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
@@ -95,6 +110,14 @@ TfLiteStatus ParseConv2D(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
TfLiteStatus ParseCos(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseCos(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseCumsum(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseDepthToSpace(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseDepthwiseConv2D(const Operator* op,
|
TfLiteStatus ParseDepthwiseConv2D(const Operator* op,
|
||||||
ErrorReporter* error_reporter,
|
ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
@@ -104,17 +127,48 @@ TfLiteStatus ParseDequantize(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
void** builtin_data);
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseDiv(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseElu(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseEqual(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseEqual(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseExp(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseExpandDims(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseFill(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseFloor(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseFloor(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseFloorDiv(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseFloorMod(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseFullyConnected(const Operator* op,
|
TfLiteStatus ParseFullyConnected(const Operator* op,
|
||||||
ErrorReporter* error_reporter,
|
ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
void** builtin_data);
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseGather(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseGatherNd(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseGreater(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseGreater(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
@@ -132,6 +186,10 @@ TfLiteStatus ParseL2Normalization(const Operator* op,
|
|||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
void** builtin_data);
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseLess(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseLess(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
@@ -158,6 +216,10 @@ TfLiteStatus ParseLogistic(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
BuiltinDataAllocator* allocator,
|
BuiltinDataAllocator* allocator,
|
||||||
void** builtin_data);
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseLogSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseMaximum(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseMaximum(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
@@ -186,6 +248,9 @@ TfLiteStatus ParsePadV2(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
TfLiteStatus ParsePool(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParsePool(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParsePow(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParsePrelu(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParsePrelu(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
@@ -230,12 +295,25 @@ TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseSpaceToBatchNd(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseSpaceToDepth(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseSplitV(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSplitV(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseSqueeze(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseSqrt(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSqrt(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
@@ -256,9 +334,22 @@ TfLiteStatus ParseSvdf(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
TfLiteStatus ParseTanh(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseTanh(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseTranspose(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseTransposeConv(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseZerosLike(const Operator* op, ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
|
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ TfLiteStatus GetRegistrationFromOpCode(
|
|||||||
if (*registration == nullptr) {
|
if (*registration == nullptr) {
|
||||||
TF_LITE_REPORT_ERROR(
|
TF_LITE_REPORT_ERROR(
|
||||||
error_reporter,
|
error_reporter,
|
||||||
"Didn't find op for builtin opcode '%s' version '%d'\n",
|
"Didn't find op for builtin opcode '%s' version '%d'. "
|
||||||
|
"An older version of this builtin might be supported. "
|
||||||
|
"Are you using an old TFLite binary with a newer model?\n",
|
||||||
EnumNameBuiltinOperator(builtin_code), version);
|
EnumNameBuiltinOperator(builtin_code), version);
|
||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
||||||
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|||||||
@@ -1,194 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_LITE_CORE_API_PROFILER_H_
|
|
||||||
#define TENSORFLOW_LITE_CORE_API_PROFILER_H_
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
|
|
||||||
// A simple utility for enabling profiled event tracing in TensorFlow Lite.
|
|
||||||
class Profiler {
|
|
||||||
public:
|
|
||||||
// As certain Profiler instance might be only interested in certain event
|
|
||||||
// types, we define each event type value to allow a Profiler to use
|
|
||||||
// bitmasking bitwise operations to determine whether an event should be
|
|
||||||
// recorded or not.
|
|
||||||
enum class EventType {
|
|
||||||
// Default event type, the metadata field has no special significance.
|
|
||||||
DEFAULT = 1,
|
|
||||||
|
|
||||||
// The event is an operator invocation and the event_metadata field is the
|
|
||||||
// index of operator node.
|
|
||||||
OPERATOR_INVOKE_EVENT = 2,
|
|
||||||
|
|
||||||
// The event is an invocation for an internal operator of a TFLite delegate.
|
|
||||||
// The event_metadata field is the index of operator node that's specific to
|
|
||||||
// the delegate.
|
|
||||||
DELEGATE_OPERATOR_INVOKE_EVENT = 4,
|
|
||||||
|
|
||||||
// The event is a recording of runtime instrumentation such as the overall
|
|
||||||
// TFLite runtime status, the TFLite delegate status (if a delegate
|
|
||||||
// is applied), and the overall model inference latency etc.
|
|
||||||
// Note, the delegate status and overall status are stored as separate
|
|
||||||
// event_metadata fields. In particular, the delegate status is encoded
|
|
||||||
// as DelegateStatus::full_status().
|
|
||||||
GENERAL_RUNTIME_INSTRUMENTATION_EVENT = 8,
|
|
||||||
};
|
|
||||||
|
|
||||||
virtual ~Profiler() {}
|
|
||||||
|
|
||||||
// Signals the beginning of an event and returns a handle to the profile
|
|
||||||
// event. The `event_metadata1` and `event_metadata2` have different
|
|
||||||
// interpretations based on the actual Profiler instance and the `event_type`.
|
|
||||||
// For example, as for the 'SubgraphAwareProfiler' defined in
|
|
||||||
// lite/core/subgraph.h, when the event_type is OPERATOR_INVOKE_EVENT,
|
|
||||||
// `event_metadata1` represents the index of a TFLite node, and
|
|
||||||
// `event_metadata2` represents the index of the subgraph that this event
|
|
||||||
// comes from.
|
|
||||||
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
|
|
||||||
int64_t event_metadata1,
|
|
||||||
int64_t event_metadata2) = 0;
|
|
||||||
// Similar w/ the above, but `event_metadata2` defaults to 0.
|
|
||||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
|
||||||
int64_t event_metadata) {
|
|
||||||
return BeginEvent(tag, event_type, event_metadata, /*event_metadata2*/ 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signals an end to the specified profile event with 'event_metadata's, This
|
|
||||||
// is useful when 'event_metadata's are not available when the event begins
|
|
||||||
// or when one wants to overwrite the 'event_metadata's set at the beginning.
|
|
||||||
virtual void EndEvent(uint32_t event_handle, int64_t event_metadata1,
|
|
||||||
int64_t event_metadata2) {}
|
|
||||||
// Signals an end to the specified profile event.
|
|
||||||
virtual void EndEvent(uint32_t event_handle) = 0;
|
|
||||||
|
|
||||||
// Appends an event of type 'event_type' with 'tag' and 'event_metadata'
|
|
||||||
// which started at 'start' and ended at 'end'
|
|
||||||
// Note:
|
|
||||||
// In cases were ProfileSimmarizer and tensorflow::StatsCalculator are used
|
|
||||||
// they assume the value is in "usec", if in any case subclasses
|
|
||||||
// didn't put usec, then the values are not meaningful.
|
|
||||||
// TODO karimnosseir: Revisit and make the function more clear.
|
|
||||||
void AddEvent(const char* tag, EventType event_type, uint64_t start,
|
|
||||||
uint64_t end, int64_t event_metadata) {
|
|
||||||
AddEvent(tag, event_type, start, end, event_metadata,
|
|
||||||
/*event_metadata2*/ 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void AddEvent(const char* tag, EventType event_type, uint64_t start,
|
|
||||||
uint64_t end, int64_t event_metadata1,
|
|
||||||
int64_t event_metadata2) {}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
friend class ScopedProfile;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Adds a profile event to `profiler` that begins with the construction
|
|
||||||
// of the object and ends when the object goes out of scope.
|
|
||||||
// The lifetime of tag should be at least the lifetime of `profiler`.
|
|
||||||
// `profiler` may be null, in which case nothing is profiled.
|
|
||||||
class ScopedProfile {
|
|
||||||
public:
|
|
||||||
ScopedProfile(Profiler* profiler, const char* tag,
|
|
||||||
Profiler::EventType event_type = Profiler::EventType::DEFAULT,
|
|
||||||
int64_t event_metadata = 0)
|
|
||||||
: profiler_(profiler), event_handle_(0) {
|
|
||||||
if (profiler) {
|
|
||||||
event_handle_ = profiler_->BeginEvent(tag, event_type, event_metadata);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~ScopedProfile() {
|
|
||||||
if (profiler_) {
|
|
||||||
profiler_->EndEvent(event_handle_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Profiler* profiler_;
|
|
||||||
uint32_t event_handle_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ScopedOperatorProfile : public ScopedProfile {
|
|
||||||
public:
|
|
||||||
ScopedOperatorProfile(Profiler* profiler, const char* tag, int node_index)
|
|
||||||
: ScopedProfile(profiler, tag, Profiler::EventType::OPERATOR_INVOKE_EVENT,
|
|
||||||
static_cast<uint32_t>(node_index)) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
class ScopedDelegateOperatorProfile : public ScopedProfile {
|
|
||||||
public:
|
|
||||||
ScopedDelegateOperatorProfile(Profiler* profiler, const char* tag,
|
|
||||||
int node_index)
|
|
||||||
: ScopedProfile(profiler, tag,
|
|
||||||
Profiler::EventType::DELEGATE_OPERATOR_INVOKE_EVENT,
|
|
||||||
static_cast<uint32_t>(node_index)) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
class ScopedRuntimeInstrumentationProfile : public ScopedProfile {
|
|
||||||
public:
|
|
||||||
ScopedRuntimeInstrumentationProfile(Profiler* profiler, const char* tag)
|
|
||||||
: ScopedProfile(
|
|
||||||
profiler, tag,
|
|
||||||
Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, -1) {}
|
|
||||||
|
|
||||||
void set_runtime_status(int64_t delegate_status, int64_t interpreter_status) {
|
|
||||||
if (profiler_) {
|
|
||||||
delegate_status_ = delegate_status;
|
|
||||||
interpreter_status_ = interpreter_status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~ScopedRuntimeInstrumentationProfile() {
|
|
||||||
if (profiler_) {
|
|
||||||
profiler_->EndEvent(event_handle_, delegate_status_, interpreter_status_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int64_t delegate_status_;
|
|
||||||
int64_t interpreter_status_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#define TFLITE_VARNAME_UNIQ_IMPL(name, ctr) name##ctr
|
|
||||||
#define TFLITE_VARNAME_UNIQ(name, ctr) TFLITE_VARNAME_UNIQ_IMPL(name, ctr)
|
|
||||||
|
|
||||||
#define TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler, tag) \
|
|
||||||
tflite::ScopedProfile TFLITE_VARNAME_UNIQ(_profile_, __COUNTER__)( \
|
|
||||||
(profiler), (tag))
|
|
||||||
|
|
||||||
#define TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE(profiler, tag, node_index) \
|
|
||||||
tflite::ScopedOperatorProfile TFLITE_VARNAME_UNIQ(_profile_, __COUNTER__)( \
|
|
||||||
(profiler), (tag), (node_index))
|
|
||||||
|
|
||||||
#define TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(profiler, tag, node_index) \
|
|
||||||
tflite::ScopedDelegateOperatorProfile TFLITE_VARNAME_UNIQ( \
|
|
||||||
_profile_, __COUNTER__)((profiler), (tag), (node_index))
|
|
||||||
|
|
||||||
#define TFLITE_ADD_RUNTIME_INSTRUMENTATION_EVENT( \
|
|
||||||
profiler, tag, delegate_status, interpreter_status) \
|
|
||||||
do { \
|
|
||||||
if (!profiler) { \
|
|
||||||
const auto handle = profiler->BeginEvent( \
|
|
||||||
tag, Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, \
|
|
||||||
delegate_status, interpreter_status); \
|
|
||||||
profiler->EndEvent(handle); \
|
|
||||||
} \
|
|
||||||
} while (false);
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_CORE_API_PROFILER_H_
|
|
||||||
@@ -178,14 +178,54 @@ inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
|
|||||||
// - input x is in the range -(1<<47) <= x < (1<<47)
|
// - input x is in the range -(1<<47) <= x < (1<<47)
|
||||||
assert(quantized_multiplier >= 0);
|
assert(quantized_multiplier >= 0);
|
||||||
assert(shift >= -31 && shift < 8);
|
assert(shift >= -31 && shift < 8);
|
||||||
|
assert(x >= -(static_cast<int64_t>(1) << 47) &&
|
||||||
|
x < (static_cast<int64_t>(1) << 47));
|
||||||
|
|
||||||
int32_t reduced_multiplier = (quantized_multiplier + (1 << 15)) >> 16;
|
int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
|
||||||
|
? ((quantized_multiplier + (1 << 15)) >> 16)
|
||||||
|
: 0x7FFF;
|
||||||
int total_shift = 15 - shift;
|
int total_shift = 15 - shift;
|
||||||
x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
|
x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
|
||||||
int32_t result = x >> total_shift;
|
int32_t result = x >> total_shift;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_NEON
|
||||||
|
// Round uses ARM's rounding shift right.
|
||||||
|
inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
|
||||||
|
int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
|
||||||
|
const int left_shift = std::max(shift, 0);
|
||||||
|
const int right_shift = std::min(shift, 0);
|
||||||
|
int32x4x4_t result;
|
||||||
|
|
||||||
|
int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
|
||||||
|
int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
|
||||||
|
int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
|
||||||
|
|
||||||
|
result.val[0] =
|
||||||
|
vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
|
||||||
|
multiplier_dup),
|
||||||
|
right_shift_dup);
|
||||||
|
|
||||||
|
result.val[1] =
|
||||||
|
vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
|
||||||
|
multiplier_dup),
|
||||||
|
right_shift_dup);
|
||||||
|
|
||||||
|
result.val[2] =
|
||||||
|
vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
|
||||||
|
multiplier_dup),
|
||||||
|
right_shift_dup);
|
||||||
|
|
||||||
|
result.val[3] =
|
||||||
|
vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
|
||||||
|
multiplier_dup),
|
||||||
|
right_shift_dup);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int CountLeadingZeros(T integer_input) {
|
int CountLeadingZeros(T integer_input) {
|
||||||
static_assert(std::is_unsigned<T>::value,
|
static_assert(std::is_unsigned<T>::value,
|
||||||
@@ -261,10 +301,11 @@ inline void gen_lut(double (*func)(double), double min, double max,
|
|||||||
TfLiteRound(func(min + i * step + half_step) * 32768.0);
|
TfLiteRound(func(min + i * step + half_step) * 32768.0);
|
||||||
double midpoint_err = midpoint_interp_val - midpoint_val;
|
double midpoint_err = midpoint_interp_val - midpoint_val;
|
||||||
double bias = TfLiteRound(midpoint_err / 2.0);
|
double bias = TfLiteRound(midpoint_err / 2.0);
|
||||||
table[i] = std::min(std::max(sample_val - bias, -32768.0), 32767.0);
|
table[i] = std::min<double>(std::max<double>(sample_val - bias, -32768.0),
|
||||||
|
32767.0);
|
||||||
}
|
}
|
||||||
table[num - 1] =
|
table[num - 1] = std::min<double>(
|
||||||
std::min(std::max(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
|
std::max<double>(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
|
// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
|
||||||
@@ -289,10 +330,11 @@ inline void gen_lut(float (*func)(float), float min, float max, int16_t* table,
|
|||||||
TfLiteRound(func(min + i * step + half_step) * 32768.0f);
|
TfLiteRound(func(min + i * step + half_step) * 32768.0f);
|
||||||
float midpoint_err = midpoint_interp_val - midpoint_val;
|
float midpoint_err = midpoint_interp_val - midpoint_val;
|
||||||
float bias = TfLiteRound(midpoint_err / 2.0f);
|
float bias = TfLiteRound(midpoint_err / 2.0f);
|
||||||
table[i] = std::min(std::max(sample_val - bias, -32768.0f), 32767.0f);
|
table[i] = std::min<float>(std::max<float>(sample_val - bias, -32768.0f),
|
||||||
|
32767.0f);
|
||||||
}
|
}
|
||||||
table[num - 1] = std::min(
|
table[num - 1] = std::min<float>(
|
||||||
std::max(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
|
std::max<float>(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
|
// int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ namespace tflite {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_STD_GLOBAL_SWITCH1(TfLiteRound, round);
|
DECLARE_STD_GLOBAL_SWITCH1(TfLiteRound, round);
|
||||||
|
DECLARE_STD_GLOBAL_SWITCH1(TfLiteExpm1, expm1);
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
|
||||||
|
|
||||||
#include <complex>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|||||||
@@ -289,7 +289,7 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
|
|||||||
input_beta_real_multiplier = (1ll << 31) - 1.0;
|
input_beta_real_multiplier = (1ll << 31) - 1.0;
|
||||||
}
|
}
|
||||||
#else // TFLITE_EMULATE_FLOAT
|
#else // TFLITE_EMULATE_FLOAT
|
||||||
const double input_beta_real_multiplier = std::min(
|
const double input_beta_real_multiplier = std::min<double>(
|
||||||
beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
|
beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
|
||||||
#endif // TFLITE_EMULATE_FLOAT
|
#endif // TFLITE_EMULATE_FLOAT
|
||||||
|
|
||||||
|
|||||||
@@ -202,14 +202,6 @@ inline void Add(const ArithmeticParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
|
|
||||||
// dimensionality if the runtime code does a single loop over one dimension
|
|
||||||
// that handles broadcasting as the base case. The code generator would then
|
|
||||||
// generate max(D1, D2) nested for loops.
|
|
||||||
// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
|
|
||||||
// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
|
|
||||||
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
|
|
||||||
// reference_ops.h.
|
|
||||||
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
|
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
|
||||||
const RuntimeShape& input1_shape,
|
const RuntimeShape& input1_shape,
|
||||||
const float* input1_data,
|
const float* input1_data,
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_N_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_N_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
// T is expected to be either float or int.
|
||||||
|
template <typename T>
|
||||||
|
inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
|
||||||
|
const T* const* input_data, T* output_data) {
|
||||||
|
// All inputs and output should have the same shape, this is checked during
|
||||||
|
// Prepare stage.
|
||||||
|
const size_t size = input_shape.FlatSize();
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
T x = 0;
|
||||||
|
for (size_t j = 0; j < num_inputs; ++j) {
|
||||||
|
x += input_data[j][i];
|
||||||
|
}
|
||||||
|
output_data[i] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_N_H_
|
||||||
@@ -15,12 +15,23 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
namespace reference_ops {
|
namespace reference_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
|
||||||
|
if (is_arg_max) {
|
||||||
|
return std::greater<T>();
|
||||||
|
} else {
|
||||||
|
return std::less<T>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T1, typename T2, typename T3, typename Cmp>
|
template <typename T1, typename T2, typename T3, typename Cmp>
|
||||||
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||||
const T3* input2_data, const RuntimeShape& output_shape,
|
const T3* input2_data, const RuntimeShape& output_shape,
|
||||||
@@ -62,6 +73,15 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2, typename T3>
|
||||||
|
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||||
|
const T3* input2_data, const RuntimeShape& output_shape,
|
||||||
|
T2* output_data, const bool is_arg_max) {
|
||||||
|
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
|
||||||
|
GetComparefunction<T1>(is_arg_max));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace reference_ops
|
} // namespace reference_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_TO_SPACE_ND_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_TO_SPACE_ND_H_
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
// TODO(b/135760455): Move this method anonymous namespace in a cc file.
|
||||||
|
inline RuntimeShape ExtendShapeBatchToSpace(const RuntimeShape& shape) {
|
||||||
|
if (shape.DimensionsCount() == 4) {
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
RuntimeShape new_shape(4, 1);
|
||||||
|
new_shape.SetDim(0, shape.Dims(0));
|
||||||
|
new_shape.SetDim(1, shape.Dims(1));
|
||||||
|
new_shape.SetDim(3, shape.Dims(2));
|
||||||
|
return new_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape,
|
||||||
|
const T* input1_data,
|
||||||
|
const RuntimeShape& unextended_input2_shape,
|
||||||
|
const int32_t* block_shape_data,
|
||||||
|
const RuntimeShape& unextended_input3_shape,
|
||||||
|
const int32_t* crops_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
ruy::profiler::ScopeLabel label("BatchToSpaceND");
|
||||||
|
TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
|
||||||
|
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
|
||||||
|
unextended_output_shape.DimensionsCount());
|
||||||
|
|
||||||
|
const RuntimeShape input1_shape =
|
||||||
|
ExtendShapeBatchToSpace(unextended_input1_shape);
|
||||||
|
const RuntimeShape output_shape =
|
||||||
|
ExtendShapeBatchToSpace(unextended_output_shape);
|
||||||
|
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_batch_size = output_shape.Dims(0);
|
||||||
|
|
||||||
|
const int depth = input1_shape.Dims(3);
|
||||||
|
const int input_width = input1_shape.Dims(2);
|
||||||
|
const int input_height = input1_shape.Dims(1);
|
||||||
|
const int input_batch_size = input1_shape.Dims(0);
|
||||||
|
|
||||||
|
const int block_shape_height = block_shape_data[0];
|
||||||
|
const int block_shape_width =
|
||||||
|
unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
|
||||||
|
const int crops_top = crops_data[0];
|
||||||
|
const int crops_left =
|
||||||
|
unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
|
||||||
|
for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
|
||||||
|
const int out_batch = in_batch % output_batch_size;
|
||||||
|
const int spatial_offset = in_batch / output_batch_size;
|
||||||
|
for (int in_h = 0; in_h < input_height; ++in_h) {
|
||||||
|
const int out_h = in_h * block_shape_height +
|
||||||
|
spatial_offset / block_shape_width - crops_top;
|
||||||
|
if (out_h < 0 || out_h >= output_height) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int in_w = 0; in_w < input_width; ++in_w) {
|
||||||
|
const int out_w = in_w * block_shape_width +
|
||||||
|
spatial_offset % block_shape_width - crops_left;
|
||||||
|
|
||||||
|
if (out_w < 0 || out_w >= output_width) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
|
||||||
|
const T* in =
|
||||||
|
input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
|
||||||
|
memcpy(out, in, depth * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_TO_SPACE_ND_H_
|
||||||
@@ -23,9 +23,6 @@ namespace tflite {
|
|||||||
|
|
||||||
namespace reference_ops {
|
namespace reference_ops {
|
||||||
|
|
||||||
// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
|
|
||||||
// generalized and efficient BroadcastBinaryFunction.
|
|
||||||
//
|
|
||||||
// Also appears to duplicate MinimumMaximum.
|
// Also appears to duplicate MinimumMaximum.
|
||||||
//
|
//
|
||||||
// R: Result type. T1: Input 1 type. T2: Input 2 type.
|
// R: Result type. T1: Input 1 type. T2: Input 2 type.
|
||||||
@@ -63,7 +60,6 @@ inline void BroadcastBinaryFunction4DSlow(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// R: Result type. T1: Input 1 type. T2: Input 2 type.
|
// R: Result type. T1: Input 1 type. T2: Input 2 type.
|
||||||
// TODO(renjieliu): Refactor other binary functions to use this one.
|
|
||||||
template <typename R, typename T1, typename T2>
|
template <typename R, typename T1, typename T2>
|
||||||
inline void BinaryFunction(const RuntimeShape& input1_shape,
|
inline void BinaryFunction(const RuntimeShape& input1_shape,
|
||||||
const T1* input1_data,
|
const T1* input1_data,
|
||||||
|
|||||||
@@ -68,8 +68,7 @@ inline void Concatenation(const ConcatenationParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(prabhumk): This is the same as the optimized implementation.
|
// TODO(b/174275780): The quantized implementation of concatentation isn't fully
|
||||||
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
|
|
||||||
// quantized as it takes scale as a floating point value. This should be fixed
|
// quantized as it takes scale as a floating point value. This should be fixed
|
||||||
// when optimizng this routine further.
|
// when optimizng this routine further.
|
||||||
inline void ConcatenationWithScaling(const ConcatenationParams& params,
|
inline void ConcatenationWithScaling(const ConcatenationParams& params,
|
||||||
|
|||||||
@@ -15,16 +15,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
namespace reference_ops {
|
namespace reference_ops {
|
||||||
|
|
||||||
|
|
||||||
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
const float* input_data, const RuntimeShape& filter_shape,
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
const float* filter_data, const RuntimeShape& bias_shape,
|
const float* filter_data, const RuntimeShape& bias_shape,
|
||||||
@@ -108,8 +105,8 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
uint8_t* output_data, const RuntimeShape& im2col_shape,
|
uint8_t* output_data, const RuntimeShape& im2col_shape,
|
||||||
uint8_t* im2col_data, void* cpu_backend_context) {
|
uint8_t* im2col_data, void* cpu_backend_context) {
|
||||||
(void)cpu_backend_context; // only used in optimized code.
|
(void)cpu_backend_context; // only used in optimized code.
|
||||||
(void)im2col_data; // only used in optimized code.
|
(void)im2col_data; // only used in optimized code.
|
||||||
(void)im2col_shape; // only used in optimized code.
|
(void)im2col_shape; // only used in optimized code.
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
const int stride_height = params.stride_height;
|
const int stride_height = params.stride_height;
|
||||||
const int dilation_width_factor = params.dilation_width_factor;
|
const int dilation_width_factor = params.dilation_width_factor;
|
||||||
|
|||||||
@@ -0,0 +1,239 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void DivCheckArithmeticParams(const ArithmeticParams& params) {
|
||||||
|
TFLITE_DCHECK_LE(params.quantized_activation_min,
|
||||||
|
params.quantized_activation_max);
|
||||||
|
// Input offset is negative input zero point. Activation tensors are
|
||||||
|
// asymmetric quantized so they span the full int8 range.
|
||||||
|
constexpr int32_t max_value =
|
||||||
|
static_cast<int32_t>(std::numeric_limits<T>::max());
|
||||||
|
TFLITE_DCHECK_GE(params.input1_offset, -max_value);
|
||||||
|
TFLITE_DCHECK_LE(params.input1_offset, max_value);
|
||||||
|
TFLITE_DCHECK_GE(params.input2_offset, -max_value);
|
||||||
|
TFLITE_DCHECK_LE(params.input2_offset, max_value);
|
||||||
|
TFLITE_DCHECK_GE(params.output_offset, -max_value);
|
||||||
|
TFLITE_DCHECK_LE(params.output_offset, max_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Element-wise div that can often be used for inner loop of broadcast Div as
|
||||||
|
// well as the non-broadcast Div.
|
||||||
|
template <typename T>
|
||||||
|
inline void DivElementwise(int size, const ArithmeticParams& params,
|
||||||
|
const T* input1_data, const T* input2_data,
|
||||||
|
T* output_data) {
|
||||||
|
DivCheckArithmeticParams<T>(params);
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
const int32_t input1_val = params.input1_offset + input1_data[i];
|
||||||
|
const int32_t input2_val = params.input2_offset + input2_data[i];
|
||||||
|
TFLITE_DCHECK_NE(input2_val, 0);
|
||||||
|
int recip_shift;
|
||||||
|
const int32_t input2_inv =
|
||||||
|
(input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
|
||||||
|
: -GetReciprocal(-input2_val, 31, &recip_shift);
|
||||||
|
const int headroom = CountLeadingSignBits(input1_val);
|
||||||
|
const int32_t unscaled_quotient =
|
||||||
|
MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv,
|
||||||
|
headroom);
|
||||||
|
const int total_shift = params.output_shift - recip_shift - headroom;
|
||||||
|
const int32_t unclamped_result =
|
||||||
|
params.output_offset +
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
unscaled_quotient, params.output_multiplier, total_shift);
|
||||||
|
const int32_t clamped_output =
|
||||||
|
std::min(params.quantized_activation_max,
|
||||||
|
std::max(params.quantized_activation_min, unclamped_result));
|
||||||
|
output_data[i] = static_cast<T>(clamped_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Div(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape, const uint8_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape, const uint8_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape, uint8_t* output_data) {
|
||||||
|
TFLITE_DCHECK_LE(params.quantized_activation_min,
|
||||||
|
params.quantized_activation_max);
|
||||||
|
const int flat_size =
|
||||||
|
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||||
|
|
||||||
|
DivElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Div(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape, const int8_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape, const int8_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||||
|
TFLITE_DCHECK_LE(params.quantized_activation_min,
|
||||||
|
params.quantized_activation_max);
|
||||||
|
const int flat_size =
|
||||||
|
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||||
|
|
||||||
|
DivElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N = 5>
|
||||||
|
inline void BroadcastDivSlowQuantized(
|
||||||
|
const ArithmeticParams& params, const RuntimeShape& unextended_input1_shape,
|
||||||
|
const T* input1_data, const RuntimeShape& unextended_input2_shape,
|
||||||
|
const T* input2_data, const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
|
||||||
|
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
|
||||||
|
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
|
||||||
|
|
||||||
|
NdArrayDesc<N> desc1;
|
||||||
|
NdArrayDesc<N> desc2;
|
||||||
|
NdArrayDesc<N> output_desc;
|
||||||
|
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||||
|
unextended_input2_shape, &desc1, &desc2);
|
||||||
|
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
|
||||||
|
&output_desc);
|
||||||
|
|
||||||
|
DivCheckArithmeticParams<T>(params);
|
||||||
|
|
||||||
|
auto div_func = [&](int indexes[N]) {
|
||||||
|
const int32_t input1_val =
|
||||||
|
params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
|
||||||
|
const int32_t input2_val =
|
||||||
|
params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
|
||||||
|
TFLITE_DCHECK_NE(input2_val, 0);
|
||||||
|
int recip_shift;
|
||||||
|
const int32_t input2_inv =
|
||||||
|
(input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
|
||||||
|
: -GetReciprocal(-input2_val, 31, &recip_shift);
|
||||||
|
const int headroom = CountLeadingSignBits(input1_val);
|
||||||
|
const int32_t unscaled_quotient =
|
||||||
|
MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv,
|
||||||
|
headroom);
|
||||||
|
const int total_shift = params.output_shift - recip_shift - headroom;
|
||||||
|
const int32_t unclamped_result =
|
||||||
|
params.output_offset +
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
unscaled_quotient, params.output_multiplier, total_shift);
|
||||||
|
const int32_t clamped_output =
|
||||||
|
std::min(params.quantized_activation_max,
|
||||||
|
std::max(params.quantized_activation_min, unclamped_result));
|
||||||
|
output_data[SubscriptToIndex(output_desc, indexes)] =
|
||||||
|
static_cast<T>(clamped_output);
|
||||||
|
};
|
||||||
|
NDOpsHelper<N>(output_desc, div_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N = 5>
|
||||||
|
inline void BroadcastDivSlow(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& unextended_input1_shape,
|
||||||
|
const uint8_t* input1_data,
|
||||||
|
const RuntimeShape& unextended_input2_shape,
|
||||||
|
const uint8_t* input2_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
uint8_t* output_data) {
|
||||||
|
BroadcastDivSlowQuantized<uint8_t, N>(
|
||||||
|
params, unextended_input1_shape, input1_data, unextended_input2_shape,
|
||||||
|
input2_data, unextended_output_shape, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N = 5>
|
||||||
|
inline void BroadcastDivSlow(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& unextended_input1_shape,
|
||||||
|
const int8_t* input1_data,
|
||||||
|
const RuntimeShape& unextended_input2_shape,
|
||||||
|
const int8_t* input2_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
int8_t* output_data) {
|
||||||
|
BroadcastDivSlowQuantized<int8_t, N>(
|
||||||
|
params, unextended_input1_shape, input1_data, unextended_input2_shape,
|
||||||
|
input2_data, unextended_output_shape, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
|
||||||
|
// dimensionality if the runtime code does a single loop over one dimension
|
||||||
|
// that handles broadcasting as the base case. The code generator would then
|
||||||
|
// generate max(D1, D2) nested for loops.
|
||||||
|
template <typename T, int N = 5>
|
||||||
|
void BroadcastDivSlow(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& unextended_input1_shape,
|
||||||
|
const T* input1_data,
|
||||||
|
const RuntimeShape& unextended_input2_shape,
|
||||||
|
const T* input2_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
T output_activation_min;
|
||||||
|
T output_activation_max;
|
||||||
|
GetActivationParams(params, &output_activation_min, &output_activation_max);
|
||||||
|
|
||||||
|
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
|
||||||
|
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
|
||||||
|
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
|
||||||
|
|
||||||
|
NdArrayDesc<N> desc1;
|
||||||
|
NdArrayDesc<N> desc2;
|
||||||
|
NdArrayDesc<N> output_desc;
|
||||||
|
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||||
|
unextended_input2_shape, &desc1, &desc2);
|
||||||
|
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
|
||||||
|
&output_desc);
|
||||||
|
|
||||||
|
// In Tensorflow, the dimensions are canonically named (batch_number, row,
|
||||||
|
// col, channel), with extents (batches, height, width, depth), with the
|
||||||
|
// trailing dimension changing most rapidly (channels has the smallest
|
||||||
|
// stride, typically 1 element).
|
||||||
|
//
|
||||||
|
// In generated C code, we store arrays with the dimensions reversed. The
|
||||||
|
// first dimension has smallest stride.
|
||||||
|
|
||||||
|
auto div_func = [&](int indexes[N]) {
|
||||||
|
output_data[SubscriptToIndex(output_desc, indexes)] =
|
||||||
|
ActivationFunctionWithMinMax(
|
||||||
|
input1_data[SubscriptToIndex(desc1, indexes)] /
|
||||||
|
input2_data[SubscriptToIndex(desc2, indexes)],
|
||||||
|
output_activation_min, output_activation_max);
|
||||||
|
};
|
||||||
|
NDOpsHelper<N>(output_desc, div_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Div(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape, const T* input1_data,
|
||||||
|
const RuntimeShape& input2_shape, const T* input2_data,
|
||||||
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
|
T output_activation_min;
|
||||||
|
T output_activation_max;
|
||||||
|
GetActivationParams(params, &output_activation_min, &output_activation_max);
|
||||||
|
|
||||||
|
const int flat_size =
|
||||||
|
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||||
|
for (int i = 0; i < flat_size; ++i) {
|
||||||
|
output_data[i] = ActivationFunctionWithMinMax(
|
||||||
|
input1_data[i] / input2_data[i], output_activation_min,
|
||||||
|
output_activation_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ELU_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ELU_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
inline void Elu(const RuntimeShape& input_shape, const float* input_data,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
for (int i = 0; i < flat_size; ++i) {
|
||||||
|
const float val = input_data[i];
|
||||||
|
output_data[i] = val < 0.0f ? TfLiteExpm1(val) : val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ELU_H_
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_EXP_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_EXP_H_
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Exp(const T* input_data, const size_t num_elements,
|
||||||
|
T* output_data) {
|
||||||
|
ruy::profiler::ScopeLabel label("Exp");
|
||||||
|
for (size_t idx = 0; idx < num_elements; ++idx) {
|
||||||
|
output_data[idx] = std::exp(input_data[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_EXP_H_
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FILL_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FILL_H_
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Fill(const RuntimeShape& value_shape, const T* value_data,
|
||||||
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
|
TFLITE_DCHECK_EQ(value_shape.DimensionsCount(), 0);
|
||||||
|
const int flat_size = output_shape.FlatSize();
|
||||||
|
for (int i = 0; i < flat_size; ++i) {
|
||||||
|
output_data[i] = *value_data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FILL_H_
|
||||||
@@ -31,7 +31,7 @@ inline void FullyConnected(
|
|||||||
float* output_data) {
|
float* output_data) {
|
||||||
const float output_activation_min = params.float_activation_min;
|
const float output_activation_min = params.float_activation_min;
|
||||||
const float output_activation_max = params.float_activation_max;
|
const float output_activation_max = params.float_activation_max;
|
||||||
// TODO(benoitjacob): This really should be:
|
// TODO(b/62193649): This really should be:
|
||||||
// const int batches = ArraySize(output_dims, 1);
|
// const int batches = ArraySize(output_dims, 1);
|
||||||
// but the current --variable_batch hack consists in overwriting the 3rd
|
// but the current --variable_batch hack consists in overwriting the 3rd
|
||||||
// dimension with the runtime batch size, as we don't keep track for each
|
// dimension with the runtime batch size, as we don't keep track for each
|
||||||
@@ -76,7 +76,7 @@ inline void FullyConnected(
|
|||||||
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
|
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
|
||||||
|
|
||||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
// TODO(benoitjacob): This really should be:
|
// TODO(b/62193649): This really should be:
|
||||||
// const int batches = ArraySize(output_dims, 1);
|
// const int batches = ArraySize(output_dims, 1);
|
||||||
// but the current --variable_batch hack consists in overwriting the 3rd
|
// but the current --variable_batch hack consists in overwriting the 3rd
|
||||||
// dimension with the runtime batch size, as we don't keep track for each
|
// dimension with the runtime batch size, as we don't keep track for each
|
||||||
@@ -123,7 +123,7 @@ inline void FullyConnected(
|
|||||||
|
|
||||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
TFLITE_DCHECK_EQ(output_offset, 0);
|
TFLITE_DCHECK_EQ(output_offset, 0);
|
||||||
// TODO(benoitjacob): This really should be:
|
// TODO(b/62193649): This really should be:
|
||||||
// const int batches = ArraySize(output_dims, 1);
|
// const int batches = ArraySize(output_dims, 1);
|
||||||
// but the current --variable_batch hack consists in overwriting the 3rd
|
// but the current --variable_batch hack consists in overwriting the 3rd
|
||||||
// dimension with the runtime batch size, as we don't keep track for each
|
// dimension with the runtime batch size, as we don't keep track for each
|
||||||
@@ -176,7 +176,7 @@ inline void ShuffledFullyConnected(
|
|||||||
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
|
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
|
||||||
TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
|
TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
|
||||||
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
|
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
|
||||||
// TODO(benoitjacob): This really should be:
|
// TODO(b/62193649): This really should be:
|
||||||
// const int batches = ArraySize(output_dims, 1);
|
// const int batches = ArraySize(output_dims, 1);
|
||||||
// but the current --variable_batch hack consists in overwriting the 3rd
|
// but the current --variable_batch hack consists in overwriting the 3rd
|
||||||
// dimension with the runtime batch size, as we don't keep track for each
|
// dimension with the runtime batch size, as we don't keep track for each
|
||||||
|
|||||||
@@ -34,55 +34,24 @@ inline void CheckArithmeticParams(const ArithmeticParams& params) {
|
|||||||
TFLITE_DCHECK_LE(-params.input2_offset, std::numeric_limits<int8_t>::max());
|
TFLITE_DCHECK_LE(-params.input2_offset, std::numeric_limits<int8_t>::max());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Element-wise add that can often be used for inner loop of broadcast add as
|
inline void ElementWise(
|
||||||
// well as the non-broadcast add.
|
int size, const ArithmeticParams& params, const int8_t* input1_data,
|
||||||
inline void AddElementwise(int size, const ArithmeticParams& params,
|
const int8_t* input2_data, int8_t* output_data,
|
||||||
const int8_t* input1_data, const int8_t* input2_data,
|
void (*check_arithmetic_params)(const ArithmeticParams&),
|
||||||
int8_t* output_data) {
|
int8_t (*binary_func)(int8_t, int8_t, const ArithmeticParams&)) {
|
||||||
CheckArithmeticParams(params);
|
CheckArithmeticParams(params);
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
const int32_t input1_val = params.input1_offset + input1_data[i];
|
output_data[i] = binary_func(input1_data[i], input2_data[i], params);
|
||||||
const int32_t input2_val = params.input2_offset + input2_data[i];
|
|
||||||
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
|
|
||||||
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
|
|
||||||
const int32_t scaled_input1_val =
|
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
|
||||||
shifted_input1_val, params.input1_multiplier, params.input1_shift);
|
|
||||||
const int32_t scaled_input2_val =
|
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
|
||||||
shifted_input2_val, params.input2_multiplier, params.input2_shift);
|
|
||||||
const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
|
|
||||||
const int32_t raw_output =
|
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
|
||||||
raw_sum, params.output_multiplier, params.output_shift) +
|
|
||||||
params.output_offset;
|
|
||||||
const int32_t clamped_output =
|
|
||||||
std::min(params.quantized_activation_max,
|
|
||||||
std::max(params.quantized_activation_min, raw_output));
|
|
||||||
output_data[i] = static_cast<int8_t>(clamped_output);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Add(const ArithmeticParams& params,
|
inline void BroadcastBinaryFunction4DSlow(
|
||||||
const RuntimeShape& input1_shape, const int8_t* input1_data,
|
const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
||||||
const RuntimeShape& input2_shape, const int8_t* input2_data,
|
const int8_t* input1_data, const RuntimeShape& input2_shape,
|
||||||
const RuntimeShape& output_shape, int8_t* output_data) {
|
const int8_t* input2_data, const RuntimeShape& output_shape,
|
||||||
CheckArithmeticParams(params);
|
int8_t* output_data,
|
||||||
|
void (*check_arithmetic_params)(const ArithmeticParams&),
|
||||||
const int flat_size =
|
int8_t (*binary_func)(int8_t, int8_t, const ArithmeticParams&)) {
|
||||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
|
||||||
|
|
||||||
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
|
|
||||||
const RuntimeShape& input1_shape,
|
|
||||||
const int8_t* input1_data,
|
|
||||||
const RuntimeShape& input2_shape,
|
|
||||||
const int8_t* input2_data,
|
|
||||||
const RuntimeShape& output_shape,
|
|
||||||
int8_t* output_data) {
|
|
||||||
NdArrayDesc<4> desc1;
|
NdArrayDesc<4> desc1;
|
||||||
NdArrayDesc<4> desc2;
|
NdArrayDesc<4> desc2;
|
||||||
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
|
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
|
||||||
@@ -105,40 +74,70 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
|
|||||||
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
|
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
|
||||||
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
|
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
|
||||||
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
|
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
|
||||||
const int32_t input1_val =
|
output_data[Offset(extended_output_shape, b, y, x, c)] = binary_func(
|
||||||
params.input1_offset +
|
input1_data[SubscriptToIndex(desc1, b, y, x, c)],
|
||||||
input1_data[SubscriptToIndex(desc1, b, y, x, c)];
|
input2_data[SubscriptToIndex(desc2, b, y, x, c)], params);
|
||||||
const int32_t input2_val =
|
|
||||||
params.input2_offset +
|
|
||||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
|
||||||
const int32_t shifted_input1_val =
|
|
||||||
input1_val * (1 << params.left_shift);
|
|
||||||
const int32_t shifted_input2_val =
|
|
||||||
input2_val * (1 << params.left_shift);
|
|
||||||
const int32_t scaled_input1_val =
|
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
|
||||||
shifted_input1_val, params.input1_multiplier,
|
|
||||||
params.input1_shift);
|
|
||||||
const int32_t scaled_input2_val =
|
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
|
||||||
shifted_input2_val, params.input2_multiplier,
|
|
||||||
params.input2_shift);
|
|
||||||
const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
|
|
||||||
const int32_t raw_output =
|
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
|
||||||
raw_sum, params.output_multiplier, params.output_shift) +
|
|
||||||
params.output_offset;
|
|
||||||
const int32_t clamped_output =
|
|
||||||
std::min(params.quantized_activation_max,
|
|
||||||
std::max(params.quantized_activation_min, raw_output));
|
|
||||||
output_data[Offset(extended_output_shape, b, y, x, c)] =
|
|
||||||
static_cast<int8_t>(clamped_output);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int8_t AddFunc(int8_t x, int8_t y, const ArithmeticParams& params) {
|
||||||
|
const int32_t input1_val = params.input1_offset + x;
|
||||||
|
const int32_t input2_val = params.input2_offset + y;
|
||||||
|
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
|
||||||
|
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
|
||||||
|
const int32_t scaled_input1_val =
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
shifted_input1_val, params.input1_multiplier, params.input1_shift);
|
||||||
|
const int32_t scaled_input2_val =
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
shifted_input2_val, params.input2_multiplier, params.input2_shift);
|
||||||
|
const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
|
||||||
|
const int32_t raw_output =
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
raw_sum, params.output_multiplier, params.output_shift) +
|
||||||
|
params.output_offset;
|
||||||
|
const int32_t clamped_output =
|
||||||
|
std::min(params.quantized_activation_max,
|
||||||
|
std::max(params.quantized_activation_min, raw_output));
|
||||||
|
return static_cast<int8_t>(clamped_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Element-wise add that can often be used for inner loop of broadcast add as
|
||||||
|
// well as the non-broadcast add.
|
||||||
|
inline void AddElementwise(int size, const ArithmeticParams& params,
|
||||||
|
const int8_t* input1_data, const int8_t* input2_data,
|
||||||
|
int8_t* output_data) {
|
||||||
|
ElementWise(size, params, input1_data, input2_data, output_data,
|
||||||
|
CheckArithmeticParams, AddFunc);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Add(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape, const int8_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape, const int8_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||||
|
CheckArithmeticParams(params);
|
||||||
|
|
||||||
|
const int flat_size =
|
||||||
|
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||||
|
|
||||||
|
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape,
|
||||||
|
const int8_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape,
|
||||||
|
const int8_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
int8_t* output_data) {
|
||||||
|
BroadcastBinaryFunction4DSlow(params, input1_shape, input1_data, input2_shape,
|
||||||
|
input2_data, output_shape, output_data,
|
||||||
|
CheckArithmeticParams, AddFunc);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace reference_integer_ops
|
} // namespace reference_integer_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ inline void ConvPerChannel(
|
|||||||
// long as the filter size (filter_y * filter_x * in_channel)
|
// long as the filter size (filter_y * filter_x * in_channel)
|
||||||
// does not exceed 2^16, which is the case in all the models
|
// does not exceed 2^16, which is the case in all the models
|
||||||
// we have seen so far.
|
// we have seen so far.
|
||||||
// TODO(jianlijianli): Add a check to make sure the
|
// TODO(b/174275578): Add a check to make sure the
|
||||||
// accumulator depth is smaller than 2^16.
|
// accumulator depth is smaller than 2^16.
|
||||||
acc += filter_val * (input_val + input_offset);
|
acc += filter_val * (input_val + input_offset);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ inline void DepthwiseConvPerChannel(
|
|||||||
// long as the filter size (filter_y * filter_x * in_channel)
|
// long as the filter size (filter_y * filter_x * in_channel)
|
||||||
// does not exceed 2^16, which is the case in all the models
|
// does not exceed 2^16, which is the case in all the models
|
||||||
// we have seen so far.
|
// we have seen so far.
|
||||||
// TODO(jianlijianli): Add a check to make sure the
|
// TODO(b/174275578): Add a check to make sure the
|
||||||
// accumulator depth is smaller than 2^16.
|
// accumulator depth is smaller than 2^16.
|
||||||
acc += filter_val * (input_val + input_offset);
|
acc += filter_val * (input_val + input_offset);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,23 +58,36 @@ inline void Logistic(int32_t input_zero_point, int32_t input_range_radius,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Logistic(int32_t input_multiplier, int32_t input_size,
|
inline void Logistic(int32_t input_multiplier, int32_t input_left_shift,
|
||||||
const int16_t* ptr_input_data, int16_t* ptr_output_data) {
|
int32_t input_size, const int16_t* ptr_input_data,
|
||||||
|
int16_t* ptr_output_data) {
|
||||||
// We use the LUT for sigmoid and take into account, that
|
// We use the LUT for sigmoid and take into account, that
|
||||||
// tanh(x) = 2*sigmoid(2*x) - 1
|
// tanh(x) = 2*sigmoid(2*x) - 1
|
||||||
|
|
||||||
int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1;
|
// We scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
|
||||||
|
// In case of general parameter scale, multiplier 3 is taken into account
|
||||||
|
// in TanhPrepare function and it is included in
|
||||||
|
// input_multiplier already.
|
||||||
|
|
||||||
|
TFLITE_DCHECK_GE(input_left_shift, 0);
|
||||||
|
if (input_multiplier == 0) { // power of two case
|
||||||
|
input_multiplier = 3 << input_left_shift;
|
||||||
|
input_left_shift = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t round = (input_left_shift > 0) ? 1 << (input_left_shift - 1) : 0;
|
||||||
|
|
||||||
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
|
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||||
int32_t input_data = (*ptr_input_data) * input_data_mul;
|
int32_t input_data =
|
||||||
|
((*ptr_input_data) * input_multiplier + round) >> input_left_shift;
|
||||||
|
|
||||||
// Scale by 3/4 to expand range [-8,8]->[-10.7,10.7] and
|
// We do interpolation on unsigned values.
|
||||||
// we do interpolation on unsigned values.
|
uint32_t abs_input_data = abs(input_data);
|
||||||
uint32_t abs_input_data = 3 * abs(input_data);
|
|
||||||
|
|
||||||
// We divide by 2 power of 9, because
|
// We divide by 2 power of 9, because
|
||||||
// we need to divide by 2 in power of 7 for
|
// we need to divide by 2 in power of 7 for
|
||||||
// the input conversion + 1/4 from the scale above.
|
// the input conversion + 1/4 from the scale above.
|
||||||
|
|
||||||
// Define uh as uint32_t type not to make this function overflow.
|
// Define uh as uint32_t type not to make this function overflow.
|
||||||
uint32_t uh = abs_input_data >> 9;
|
uint32_t uh = abs_input_data >> 9;
|
||||||
uint32_t result;
|
uint32_t result;
|
||||||
|
|||||||
@@ -65,19 +65,25 @@ inline void Tanh(int32_t input_multiplier, int32_t input_left_shift,
|
|||||||
// We use the LUT for sigmoid and take into account, that
|
// We use the LUT for sigmoid and take into account, that
|
||||||
// tanh(x) = 2*sigmoid(2*x) - 1
|
// tanh(x) = 2*sigmoid(2*x) - 1
|
||||||
|
|
||||||
int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1;
|
// We scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
|
||||||
|
// In case of general parameter scale, multiplier 3 is taken into account
|
||||||
|
// in TanhPrepare function and it is included in
|
||||||
|
// input_multiplier already.
|
||||||
|
|
||||||
|
if (input_multiplier == 0) { // power of two case
|
||||||
|
input_multiplier = 3 << input_left_shift;
|
||||||
|
input_left_shift = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t round = (input_left_shift > 0) ? 1 << (input_left_shift - 1) : 0;
|
||||||
|
|
||||||
int flat_size = MatchingFlatSize(input_shape, output_shape);
|
int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
|
||||||
for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++) {
|
for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||||
int32_t input_data = (*ptr_input_data) * input_data_mul;
|
int32_t input_data =
|
||||||
|
((*ptr_input_data) * input_multiplier + round) >> input_left_shift;
|
||||||
|
|
||||||
if (input_left_shift == 1) {
|
uint32_t abs_input_data = abs(input_data);
|
||||||
input_data <<= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
|
|
||||||
uint32_t abs_input_data = 3 * abs(input_data);
|
|
||||||
uint32_t uh = abs_input_data >> 8;
|
uint32_t uh = abs_input_data >> 8;
|
||||||
int32_t result;
|
int32_t result;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,221 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
// Fixed-point per-channel-quantization transpose convolution reference kernel.
|
||||||
|
inline void TransposeConv(
|
||||||
|
const ConvParams& params, const int32_t* output_multiplier,
|
||||||
|
const int32_t* output_shift, const RuntimeShape& input_shape,
|
||||||
|
const int8_t* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||||
|
const int32_t* bias_data, const RuntimeShape& output_shape,
|
||||||
|
int8_t* output_data, const RuntimeShape& im2col_shape, int8_t* im2col_data,
|
||||||
|
int32_t* scratch_buffer) {
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
(void)im2col_data; // only used in optimized code.
|
||||||
|
(void)im2col_shape; // only used in optimized code.
|
||||||
|
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||||
|
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||||
|
if (bias_data) {
|
||||||
|
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||||
|
}
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int32_t input_offset = params.input_offset;
|
||||||
|
const int32_t output_offset = params.output_offset;
|
||||||
|
const int32_t output_activation_min = std::numeric_limits<int8_t>::min();
|
||||||
|
const int32_t output_activation_max = std::numeric_limits<int8_t>::max();
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
const int num_elements = output_shape.FlatSize();
|
||||||
|
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||||
|
// 'scatter' based trick as in float version.
|
||||||
|
memset(scratch_buffer, 0, num_elements * sizeof(int32_t));
|
||||||
|
|
||||||
|
// Loop through input elements one at a time.
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||||
|
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||||
|
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||||
|
// Loop through the output elements it will influence.
|
||||||
|
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||||
|
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth;
|
||||||
|
++out_channel) {
|
||||||
|
// Compute output element location.
|
||||||
|
const int out_x = out_x_origin + filter_x;
|
||||||
|
const int out_y = out_y_origin + filter_y;
|
||||||
|
// We cannot accumulate out of bounds.
|
||||||
|
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||||
|
(out_y < output_height)) {
|
||||||
|
const int8_t input_value = input_data[Offset(
|
||||||
|
input_shape, batch, in_y, in_x, in_channel)];
|
||||||
|
const int8_t filter_value =
|
||||||
|
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||||
|
filter_x, in_channel)];
|
||||||
|
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)] +=
|
||||||
|
(input_value + input_offset) * filter_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||||
|
int32_t acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)];
|
||||||
|
if (bias_data) {
|
||||||
|
acc += bias_data[out_channel];
|
||||||
|
}
|
||||||
|
acc = MultiplyByQuantizedMultiplier(
|
||||||
|
acc, output_multiplier[out_channel], output_shift[out_channel]);
|
||||||
|
acc += output_offset;
|
||||||
|
acc = std::max(acc, output_activation_min);
|
||||||
|
acc = std::min(acc, output_activation_max);
|
||||||
|
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||||
|
static_cast<int8_t>(acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// int16_t input (zero_point=0), int8_t filter, int64 accumulator
|
||||||
|
inline void TransposeConv(
|
||||||
|
const ConvParams& params, const int32_t* output_multiplier,
|
||||||
|
const int32_t* output_shift, const RuntimeShape& input_shape,
|
||||||
|
const int16_t* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||||
|
const std::int64_t* bias_data, const RuntimeShape& output_shape,
|
||||||
|
int16_t* output_data, const RuntimeShape& im2col_shape, int8_t* im2col_data,
|
||||||
|
std::int64_t* scratch_buffer) {
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
(void)im2col_data; // only used in optimized code.
|
||||||
|
(void)im2col_shape; // only used in optimized code.
|
||||||
|
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||||
|
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||||
|
if (bias_data) {
|
||||||
|
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||||
|
}
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int32_t output_activation_min = std::numeric_limits<int16_t>::min();
|
||||||
|
const int32_t output_activation_max = std::numeric_limits<int16_t>::max();
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
const int num_elements = output_shape.FlatSize();
|
||||||
|
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||||
|
// 'scatter' based trick as in float version.
|
||||||
|
memset(scratch_buffer, 0, num_elements * sizeof(std::int64_t));
|
||||||
|
|
||||||
|
// Loop through input elements one at a time.
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||||
|
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||||
|
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||||
|
// Loop through the output elements it will influence.
|
||||||
|
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||||
|
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth;
|
||||||
|
++out_channel) {
|
||||||
|
// Compute output element location.
|
||||||
|
const int out_x = out_x_origin + filter_x;
|
||||||
|
const int out_y = out_y_origin + filter_y;
|
||||||
|
// We cannot accumulate out of bounds.
|
||||||
|
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||||
|
(out_y < output_height)) {
|
||||||
|
const int32_t input_value = input_data[Offset(
|
||||||
|
input_shape, batch, in_y, in_x, in_channel)];
|
||||||
|
const int32_t filter_value =
|
||||||
|
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||||
|
filter_x, in_channel)];
|
||||||
|
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)] +=
|
||||||
|
input_value * filter_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||||
|
std::int64_t acc = scratch_buffer[Offset(output_shape, batch, out_y,
|
||||||
|
out_x, out_channel)];
|
||||||
|
if (bias_data) {
|
||||||
|
acc += bias_data[out_channel];
|
||||||
|
}
|
||||||
|
int32_t scaled_acc = MultiplyByQuantizedMultiplier(
|
||||||
|
acc, output_multiplier[out_channel], output_shift[out_channel]);
|
||||||
|
scaled_acc = std::max(scaled_acc, output_activation_min);
|
||||||
|
scaled_acc = std::min(scaled_acc, output_activation_max);
|
||||||
|
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||||
|
static_cast<int16_t>(scaled_acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEAKY_RELU_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEAKY_RELU_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
inline void LeakyRelu(const tflite::LeakyReluParams& params,
|
||||||
|
const RuntimeShape& input_shape, const float* input_data,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
for (int i = 0; i < flat_size; ++i) {
|
||||||
|
const float val = input_data[i];
|
||||||
|
// Note that alpha might be > 1 or < 0, so we don't use std::max here.
|
||||||
|
output_data[i] = val > 0 ? val : val * params.alpha;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void QuantizeLeakyRelu(const LeakyReluParams& params,
|
||||||
|
const RuntimeShape& input_shape,
|
||||||
|
const T* input_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
static const int32_t quantized_min = std::numeric_limits<T>::min();
|
||||||
|
static const int32_t quantized_max = std::numeric_limits<T>::max();
|
||||||
|
for (int i = 0; i < flat_size; ++i) {
|
||||||
|
const int32_t input_value = input_data[i] - params.input_offset;
|
||||||
|
int32_t unclamped_output;
|
||||||
|
if (input_value >= 0) {
|
||||||
|
unclamped_output = params.output_offset +
|
||||||
|
MultiplyByQuantizedMultiplier(
|
||||||
|
input_value, params.output_multiplier_identity,
|
||||||
|
params.output_shift_identity);
|
||||||
|
} else {
|
||||||
|
unclamped_output = params.output_offset +
|
||||||
|
MultiplyByQuantizedMultiplier(
|
||||||
|
input_value, params.output_multiplier_alpha,
|
||||||
|
params.output_shift_alpha);
|
||||||
|
}
|
||||||
|
const T clamped_output =
|
||||||
|
std::min(quantized_max, std::max(quantized_min, unclamped_output));
|
||||||
|
output_data[i] = static_cast<T>(clamped_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEAKY_RELU_H_
|
||||||
@@ -45,6 +45,7 @@ inline void Requantize(const input_type* input_data, int32_t size,
|
|||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
output_data[i] = input_data[i] ^ 0x80;
|
output_data[i] = input_data[i] ^ 0x80;
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static constexpr int32_t kMinOutput = std::numeric_limits<output_type>::min();
|
static constexpr int32_t kMinOutput = std::numeric_limits<output_type>::min();
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_BATCH_ND_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_BATCH_ND_H_
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
// TODO(b/135760455): Move this method anonymous namespace in a cc file.
|
||||||
|
inline RuntimeShape ExtendShapeSpaceToBatch(const RuntimeShape& shape) {
|
||||||
|
if (shape.DimensionsCount() == 4) {
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
RuntimeShape new_shape(4, 1);
|
||||||
|
new_shape.SetDim(0, shape.Dims(0));
|
||||||
|
new_shape.SetDim(1, shape.Dims(1));
|
||||||
|
new_shape.SetDim(3, shape.Dims(2));
|
||||||
|
return new_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void SpaceToBatchND(const SpaceToBatchParams& params,
|
||||||
|
const RuntimeShape& unextended_input1_shape,
|
||||||
|
const T* input1_data,
|
||||||
|
const RuntimeShape& unextended_input2_shape,
|
||||||
|
const int32_t* block_shape_data,
|
||||||
|
const RuntimeShape& unextended_input3_shape,
|
||||||
|
const int32_t* paddings_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
ruy::profiler::ScopeLabel label("SpaceToBatchND");
|
||||||
|
TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
|
||||||
|
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
|
||||||
|
unextended_output_shape.DimensionsCount());
|
||||||
|
|
||||||
|
// Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
|
||||||
|
const RuntimeShape input1_shape =
|
||||||
|
ExtendShapeSpaceToBatch(unextended_input1_shape);
|
||||||
|
const RuntimeShape output_shape =
|
||||||
|
ExtendShapeSpaceToBatch(unextended_output_shape);
|
||||||
|
|
||||||
|
const int depth = input1_shape.Dims(3);
|
||||||
|
const int input_width = input1_shape.Dims(2);
|
||||||
|
const int input_height = input1_shape.Dims(1);
|
||||||
|
const int input_batch_size = input1_shape.Dims(0);
|
||||||
|
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_batch_size = output_shape.Dims(0);
|
||||||
|
|
||||||
|
const int block_shape_height = block_shape_data[0];
|
||||||
|
const int block_shape_width =
|
||||||
|
unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
|
||||||
|
const int padding_top = paddings_data[0];
|
||||||
|
const int padding_left =
|
||||||
|
unextended_input1_shape.DimensionsCount() == 4 ? paddings_data[2] : 0;
|
||||||
|
|
||||||
|
// For uint8 quantized, the correct padding "zero value" is the output offset.
|
||||||
|
const int32_t pad_value = params.output_offset;
|
||||||
|
for (int out_b = 0; out_b < output_batch_size; ++out_b) {
|
||||||
|
int input_batch = out_b % input_batch_size;
|
||||||
|
int shift_w = (out_b / input_batch_size) % block_shape_width;
|
||||||
|
int shift_h = (out_b / input_batch_size) / block_shape_width;
|
||||||
|
for (int out_h = 0; out_h < output_height; ++out_h) {
|
||||||
|
for (int out_w = 0; out_w < output_width; ++out_w) {
|
||||||
|
T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
|
||||||
|
if (out_h * block_shape_height + shift_h < padding_top ||
|
||||||
|
out_h * block_shape_height + shift_h >=
|
||||||
|
padding_top + input_height ||
|
||||||
|
out_w * block_shape_width + shift_w < padding_left ||
|
||||||
|
out_w * block_shape_width + shift_w >= padding_left + input_width) {
|
||||||
|
// This may not execute correctly when pad_value != 0 and T != uint8.
|
||||||
|
memset(out, pad_value, depth * sizeof(T));
|
||||||
|
} else {
|
||||||
|
const T* in =
|
||||||
|
input1_data +
|
||||||
|
Offset(input1_shape, input_batch,
|
||||||
|
(out_h * block_shape_height + shift_h) - padding_top,
|
||||||
|
(out_w * block_shape_width + shift_w) - padding_left, 0);
|
||||||
|
memcpy(out, in, depth * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_BATCH_ND_H_
|
||||||
@@ -15,23 +15,28 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
||||||
|
|
||||||
|
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
namespace reference_ops {
|
namespace reference_ops {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
const RuntimeShape& unextended_input_shape,
|
const RuntimeShape& unextended_input_shape,
|
||||||
const T* input_data,
|
|
||||||
const RuntimeShape& unextended_output_shape,
|
const RuntimeShape& unextended_output_shape,
|
||||||
T* output_data) {
|
SequentialTensorWriter<T>* writer) {
|
||||||
using strided_slice::LoopCondition;
|
using strided_slice::LoopCondition;
|
||||||
using strided_slice::StartForAxis;
|
using strided_slice::StartForAxis;
|
||||||
using strided_slice::StopForAxis;
|
using strided_slice::StopForAxis;
|
||||||
|
|
||||||
|
ruy::profiler::ScopeLabel label("StridedSlice");
|
||||||
|
|
||||||
// Note that the output_shape is not used herein.
|
// Note that the output_shape is not used herein.
|
||||||
tflite::StridedSliceParams params_copy = op_params;
|
tflite::StridedSliceParams params_copy = op_params;
|
||||||
|
|
||||||
@@ -57,7 +62,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
|||||||
const int start_4 = StartForAxis(params_copy, input_shape, 4);
|
const int start_4 = StartForAxis(params_copy, input_shape, 4);
|
||||||
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
|
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
|
||||||
|
|
||||||
T* out_ptr = output_data;
|
|
||||||
for (int offset_0 = start_0 * input_shape.Dims(1),
|
for (int offset_0 = start_0 * input_shape.Dims(1),
|
||||||
end_0 = stop_0 * input_shape.Dims(1),
|
end_0 = stop_0 * input_shape.Dims(1),
|
||||||
step_0 = params_copy.strides[0] * input_shape.Dims(1);
|
step_0 = params_copy.strides[0] * input_shape.Dims(1);
|
||||||
@@ -81,13 +85,36 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
|||||||
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
|
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
|
||||||
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
|
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
|
||||||
offset_4 += params_copy.strides[4]) {
|
offset_4 += params_copy.strides[4]) {
|
||||||
*out_ptr++ = input_data[offset_4];
|
writer->Write(offset_4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||||
|
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||||
|
&writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
TfLiteTensor* output) {
|
||||||
|
SequentialTensorWriter<T> writer(input, output);
|
||||||
|
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||||
|
&writer);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace reference_ops
|
} // namespace reference_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|||||||
@@ -65,10 +65,6 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
|
|||||||
// dimensionality if the runtime code does a single loop over one dimension
|
// dimensionality if the runtime code does a single loop over one dimension
|
||||||
// that handles broadcasting as the base case. The code generator would then
|
// that handles broadcasting as the base case. The code generator would then
|
||||||
// generate max(D1, D2) nested for loops.
|
// generate max(D1, D2) nested for loops.
|
||||||
// TODO(b/151345101): BroadcastSub is intentionally duplicated from
|
|
||||||
// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
|
|
||||||
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
|
|
||||||
// reference_ops.h.
|
|
||||||
template <int N = 5>
|
template <int N = 5>
|
||||||
inline void BroadcastSubSlow(const ArithmeticParams& params,
|
inline void BroadcastSubSlow(const ArithmeticParams& params,
|
||||||
const RuntimeShape& input1_shape,
|
const RuntimeShape& input1_shape,
|
||||||
@@ -336,6 +332,50 @@ void BroadcastSubSlow(const ArithmeticParams& params,
|
|||||||
NDOpsHelper<N>(output_desc, sub_func);
|
NDOpsHelper<N>(output_desc, sub_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N = 5>
|
||||||
|
inline void BroadcastSub16POTSlow(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape,
|
||||||
|
const int16_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape,
|
||||||
|
const int16_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
int16_t* output_data) {
|
||||||
|
ruy::profiler::ScopeLabel label("BroadcastSub16POTSlow/int16_t");
|
||||||
|
NdArrayDesc<N> desc1;
|
||||||
|
NdArrayDesc<N> desc2;
|
||||||
|
NdArrayDesc<N> output_desc;
|
||||||
|
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
|
||||||
|
&desc2);
|
||||||
|
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc);
|
||||||
|
|
||||||
|
// In Tensorflow, the dimensions are canonically named (batch_number, row,
|
||||||
|
// col, channel), with extents (batches, height, width, depth), with the
|
||||||
|
// trailing dimension changing most rapidly (channels has the smallest stride,
|
||||||
|
// typically 1 element).
|
||||||
|
//
|
||||||
|
// In generated C code, we store arrays with the dimensions reversed. The
|
||||||
|
// first dimension has smallest stride.
|
||||||
|
//
|
||||||
|
// We name our variables by their Tensorflow convention, but generate C code
|
||||||
|
// nesting loops such that the innermost loop has the smallest stride for the
|
||||||
|
// best cache behavior.
|
||||||
|
auto sub_func = [&](int indexes[N]) {
|
||||||
|
const int32_t input1_val = input1_data[SubscriptToIndex(desc1, indexes)];
|
||||||
|
const int32_t input2_val = input2_data[SubscriptToIndex(desc2, indexes)];
|
||||||
|
const int32_t scaled_input1_val =
|
||||||
|
gemmlowp::RoundingDivideByPOT(input1_val, -params.input1_shift);
|
||||||
|
const int32_t scaled_input2_val =
|
||||||
|
gemmlowp::RoundingDivideByPOT(input2_val, -params.input2_shift);
|
||||||
|
const int32_t raw_output = scaled_input1_val - scaled_input2_val;
|
||||||
|
const int32_t clamped_output =
|
||||||
|
std::min(params.quantized_activation_max,
|
||||||
|
std::max(params.quantized_activation_min, raw_output));
|
||||||
|
output_data[SubscriptToIndex(output_desc, indexes)] =
|
||||||
|
static_cast<int16_t>(clamped_output);
|
||||||
|
};
|
||||||
|
NDOpsHelper<N>(output_desc, sub_func);
|
||||||
|
}
|
||||||
|
|
||||||
// Element-wise Sub that can often be used for inner loop of broadcast sub as
|
// Element-wise Sub that can often be used for inner loop of broadcast sub as
|
||||||
// well as the non-broadcast sub.
|
// well as the non-broadcast sub.
|
||||||
inline void SubElementwise(int size, const ArithmeticParams& params,
|
inline void SubElementwise(int size, const ArithmeticParams& params,
|
||||||
|
|||||||
@@ -0,0 +1,217 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
inline void TransposeConv(
|
||||||
|
const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const float* filter_data, const RuntimeShape& bias_shape,
|
||||||
|
const float* bias_data, const RuntimeShape& output_shape,
|
||||||
|
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
(void)im2col_data; // only used in optimized code.
|
||||||
|
(void)im2col_shape; // only used in optimized code.
|
||||||
|
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||||
|
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
if (bias_data) {
|
||||||
|
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Although transpose convolution simplifies to convolution with transposed
|
||||||
|
// weights for strides of 1, non-unitary striding complicates matters. To
|
||||||
|
// keep this reference implementation as clear as possible, we use a
|
||||||
|
// "scatter" access pattern, where we loop through all the input elements,
|
||||||
|
// computing their influence on the output, rather than looping through the
|
||||||
|
// output elements in the typical "gather" access pattern of a conv. We
|
||||||
|
// therefore must initialize the output array to zero.
|
||||||
|
const int num_elements = output_shape.FlatSize();
|
||||||
|
for (int i = 0; i < num_elements; i++) {
|
||||||
|
output_data[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop through input elements one at a time.
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||||
|
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||||
|
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||||
|
// Loop through the output elements it will influence
|
||||||
|
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||||
|
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth;
|
||||||
|
++out_channel) {
|
||||||
|
// Compute output element location
|
||||||
|
const int out_x = out_x_origin + filter_x;
|
||||||
|
const int out_y = out_y_origin + filter_y;
|
||||||
|
// We cannot accumulate out of bounds
|
||||||
|
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||||
|
(out_y < output_height)) {
|
||||||
|
float input_value = input_data[Offset(
|
||||||
|
input_shape, batch, in_y, in_x, in_channel)];
|
||||||
|
float filter_value =
|
||||||
|
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||||
|
filter_x, in_channel)];
|
||||||
|
output_data[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)] +=
|
||||||
|
input_value * filter_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bias_data) {
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||||
|
output_data[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)] += bias_data[out_channel];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TransposeConv(
|
||||||
|
const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
|
const uint8_t* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const uint8_t* filter_data, const RuntimeShape& bias_shape,
|
||||||
|
const int32_t* bias_data, const RuntimeShape& output_shape,
|
||||||
|
uint8_t* output_data, const RuntimeShape& im2col_shape,
|
||||||
|
uint8_t* im2col_data, int32_t* scratch_buffer) {
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
(void)im2col_data; // only used in optimized code.
|
||||||
|
(void)im2col_shape; // only used in optimized code.
|
||||||
|
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||||
|
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int32_t input_offset = params.input_offset;
|
||||||
|
const int32_t filter_offset = params.weights_offset;
|
||||||
|
const int32_t output_offset = params.output_offset;
|
||||||
|
const int32_t output_multiplier = params.output_multiplier;
|
||||||
|
const int output_shift = params.output_shift;
|
||||||
|
const int32_t output_activation_min = params.quantized_activation_min;
|
||||||
|
const int32_t output_activation_max = params.quantized_activation_max;
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
if (bias_data) {
|
||||||
|
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int num_elements = output_shape.FlatSize();
|
||||||
|
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||||
|
// 'scatter' based trick as in float version.
|
||||||
|
memset(scratch_buffer, 0, num_elements * sizeof(int32_t));
|
||||||
|
|
||||||
|
// Loop through input elements one at a time.
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||||
|
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||||
|
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||||
|
// Loop through the output elements it will influence.
|
||||||
|
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||||
|
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth;
|
||||||
|
++out_channel) {
|
||||||
|
// Compute output element location.
|
||||||
|
const int out_x = out_x_origin + filter_x;
|
||||||
|
const int out_y = out_y_origin + filter_y;
|
||||||
|
// We cannot accumulate out of bounds.
|
||||||
|
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||||
|
(out_y < output_height)) {
|
||||||
|
uint8_t input_value = input_data[Offset(
|
||||||
|
input_shape, batch, in_y, in_x, in_channel)];
|
||||||
|
uint8_t filter_value =
|
||||||
|
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||||
|
filter_x, in_channel)];
|
||||||
|
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)] +=
|
||||||
|
(input_value + input_offset) *
|
||||||
|
(filter_value + filter_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||||
|
int32_t acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)];
|
||||||
|
if (bias_data) {
|
||||||
|
acc += bias_data[out_channel];
|
||||||
|
}
|
||||||
|
int32_t scaled_acc = MultiplyByQuantizedMultiplier(
|
||||||
|
acc, output_multiplier, output_shift);
|
||||||
|
scaled_acc += output_offset;
|
||||||
|
scaled_acc = std::max(scaled_acc, output_activation_min);
|
||||||
|
scaled_acc = std::min(scaled_acc, output_activation_max);
|
||||||
|
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||||
|
static_cast<uint8_t>(scaled_acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TRANSPOSE_CONV_H_
|
||||||
@@ -140,7 +140,7 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
|
|||||||
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
|
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
|
||||||
// already been adjusted for negative indices.
|
// already been adjusted for negative indices.
|
||||||
if (shrink_axis) {
|
if (shrink_axis) {
|
||||||
stop = start_for_axis + 1;
|
return start_for_axis + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// end_mask override
|
// end_mask override
|
||||||
|
|||||||
@@ -43,6 +43,20 @@ struct PaddingValues {
|
|||||||
int16_t height_offset;
|
int16_t height_offset;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Padding3DValues {
|
||||||
|
int16_t width;
|
||||||
|
int16_t height;
|
||||||
|
int16_t depth;
|
||||||
|
// offset is used for calculating "remaining" padding, for example, `width`
|
||||||
|
// is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
|
||||||
|
// 1 + 1 = 2.
|
||||||
|
int16_t width_offset;
|
||||||
|
// Same as width_offset except it's over the height dimension.
|
||||||
|
int16_t height_offset;
|
||||||
|
// Same as width_offset except it's over the depth dimension.
|
||||||
|
int16_t depth_offset;
|
||||||
|
};
|
||||||
|
|
||||||
// This enumeration allows for non-default formats for the weights array
|
// This enumeration allows for non-default formats for the weights array
|
||||||
// of a fully-connected operator, allowing the use of special optimized
|
// of a fully-connected operator, allowing the use of special optimized
|
||||||
// runtime paths.
|
// runtime paths.
|
||||||
@@ -170,7 +184,11 @@ class RuntimeShape {
|
|||||||
// rolls out.
|
// rolls out.
|
||||||
RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
|
RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
|
||||||
if (size_ > kMaxSmallSize) {
|
if (size_ > kMaxSmallSize) {
|
||||||
|
#ifdef TF_LITE_STATIC_MEMORY
|
||||||
|
TFLITE_CHECK(false && "No shape resizing supported on this platform");
|
||||||
|
#else
|
||||||
dims_pointer_ = new int32_t[size_];
|
dims_pointer_ = new int32_t[size_];
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_);
|
std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_);
|
||||||
}
|
}
|
||||||
@@ -392,6 +410,20 @@ inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
|
|||||||
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
|
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
|
||||||
|
int i4) {
|
||||||
|
TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
|
||||||
|
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
|
||||||
|
TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
|
||||||
|
TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
|
||||||
|
TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
|
||||||
|
TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
|
||||||
|
TFLITE_DCHECK(i4 >= 0 && i4 < dims_data[4]);
|
||||||
|
return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
|
||||||
|
dims_data[4] +
|
||||||
|
i4;
|
||||||
|
}
|
||||||
|
|
||||||
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
|
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
|
||||||
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
|
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
|
||||||
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
|
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
|
||||||
@@ -840,6 +872,19 @@ struct ConvParams {
|
|||||||
float float_activation_max;
|
float float_activation_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Conv3DParams {
|
||||||
|
Padding3DValues padding_values;
|
||||||
|
int stride_width;
|
||||||
|
int stride_height;
|
||||||
|
int stride_depth;
|
||||||
|
int dilation_width;
|
||||||
|
int dilation_height;
|
||||||
|
int dilation_depth;
|
||||||
|
// float activation params.
|
||||||
|
float float_activation_min;
|
||||||
|
float float_activation_max;
|
||||||
|
};
|
||||||
|
|
||||||
struct DepthToSpaceParams {
|
struct DepthToSpaceParams {
|
||||||
int32_t block_size;
|
int32_t block_size;
|
||||||
};
|
};
|
||||||
@@ -907,6 +952,7 @@ struct FullyConnectedParams {
|
|||||||
|
|
||||||
struct GatherParams {
|
struct GatherParams {
|
||||||
int16_t axis;
|
int16_t axis;
|
||||||
|
int16_t batch_dims;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct L2NormalizationParams {
|
struct L2NormalizationParams {
|
||||||
@@ -1025,9 +1071,9 @@ struct ResizeNearestNeighborParams {
|
|||||||
|
|
||||||
struct SliceParams {
|
struct SliceParams {
|
||||||
int8_t begin_count;
|
int8_t begin_count;
|
||||||
int32_t begin[4];
|
int32_t begin[5];
|
||||||
int8_t size_count;
|
int8_t size_count;
|
||||||
int32_t size[4];
|
int32_t size[5];
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SoftmaxParams {
|
struct SoftmaxParams {
|
||||||
|
|||||||
@@ -21,12 +21,19 @@ limitations under the License.
|
|||||||
#include <complex>
|
#include <complex>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
|
#include <string>
|
||||||
|
#endif // TF_LITE_STATIC_MEMORY
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
#include "TargetConditionals.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@@ -283,8 +290,7 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
|||||||
double* multiplier) {
|
double* multiplier) {
|
||||||
const double input_product_scale = static_cast<double>(input->params.scale) *
|
const double input_product_scale = static_cast<double>(input->params.scale) *
|
||||||
static_cast<double>(filter->params.scale);
|
static_cast<double>(filter->params.scale);
|
||||||
// TODO(ahentz): The following conditions must be guaranteed by the training
|
// The following conditions must be guaranteed by the training pipeline.
|
||||||
// pipeline.
|
|
||||||
if (bias) {
|
if (bias) {
|
||||||
const double bias_scale = static_cast<double>(bias->params.scale);
|
const double bias_scale = static_cast<double>(bias->params.scale);
|
||||||
// Here we're making sure the input_product_scale & bias_scale are about the
|
// Here we're making sure the input_product_scale & bias_scale are about the
|
||||||
@@ -383,9 +389,25 @@ bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
|
|||||||
return TfLiteIntArrayEqual(input1->dims, input2->dims);
|
return TfLiteIntArrayEqual(input1->dims, input2->dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(petewarden): Having macros around this is ugly, look at other strategies
|
|
||||||
// before replicating this approach elsewhere.
|
|
||||||
#ifndef TF_LITE_STATIC_MEMORY
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
|
|
||||||
|
// TODO(b/172067338): Having this function be part of TF_LITE_STATIC_MEMORY
|
||||||
|
// build results in a 6KB size increase, even though the function is unsused for
|
||||||
|
// that build. What appears to be happening is that while the linker drops the
|
||||||
|
// unsused function, the string library that gets pulled in is not dropped,
|
||||||
|
// resulting in the increased binary size.
|
||||||
|
std::string GetShapeDebugString(const TfLiteIntArray* shape) {
|
||||||
|
std::string str;
|
||||||
|
for (int d = 0; d < shape->size; ++d) {
|
||||||
|
if (str.empty())
|
||||||
|
str = "[" + std::to_string(shape->data[d]);
|
||||||
|
else
|
||||||
|
str += ", " + std::to_string(shape->data[d]);
|
||||||
|
}
|
||||||
|
str += "]";
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
||||||
const TfLiteTensor* input1,
|
const TfLiteTensor* input1,
|
||||||
const TfLiteTensor* input2,
|
const TfLiteTensor* input2,
|
||||||
@@ -402,7 +424,13 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
|||||||
for (int i = 0; i < out_dims; ++i) {
|
for (int i = 0; i < out_dims; ++i) {
|
||||||
int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
|
int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
|
||||||
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
||||||
TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1);
|
if (!(d1 == d2 || d1 == 1 || d2 == 1)) {
|
||||||
|
context->ReportError(context,
|
||||||
|
"Given shapes, %s and %s, are not broadcastable.",
|
||||||
|
GetShapeDebugString(input1->dims).c_str(),
|
||||||
|
GetShapeDebugString(input2->dims).c_str());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
shape->data[out_dims - i - 1] = std::max(d1, d2);
|
shape->data[out_dims - i - 1] = std::max(d1, d2);
|
||||||
}
|
}
|
||||||
*output_shape = shape.release();
|
*output_shape = shape.release();
|
||||||
@@ -425,9 +453,15 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
|||||||
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
||||||
int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
|
int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
|
||||||
int max_value = std::max(std::max(d1, d2), d3);
|
int max_value = std::max(std::max(d1, d2), d3);
|
||||||
TF_LITE_ENSURE(context, d1 == 1 || d1 == max_value);
|
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
|
||||||
TF_LITE_ENSURE(context, d2 == 1 || d2 == max_value);
|
!(d3 == 1 || d3 == max_value)) {
|
||||||
TF_LITE_ENSURE(context, d3 == 1 || d3 == max_value);
|
context->ReportError(
|
||||||
|
context, "Given shapes, %s, %s and %s, are not broadcastable.",
|
||||||
|
GetShapeDebugString(input1->dims).c_str(),
|
||||||
|
GetShapeDebugString(input2->dims).c_str(),
|
||||||
|
GetShapeDebugString(input3->dims).c_str());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
shape->data[out_dims - i - 1] = max_value;
|
shape->data[out_dims - i - 1] = max_value;
|
||||||
}
|
}
|
||||||
*output_shape = shape.release();
|
*output_shape = shape.release();
|
||||||
@@ -458,9 +492,15 @@ int TfLiteTypeGetSize(TfLiteType type) {
|
|||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
|
TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
|
||||||
return 4;
|
return 4;
|
||||||
|
case kTfLiteUInt32:
|
||||||
|
TF_LITE_ASSERT_EQ(sizeof(uint32_t), 4);
|
||||||
|
return 4;
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
|
TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
|
||||||
return 8;
|
return 8;
|
||||||
|
case kTfLiteUInt64:
|
||||||
|
TF_LITE_ASSERT_EQ(sizeof(uint64_t), 8);
|
||||||
|
return 8;
|
||||||
case kTfLiteFloat64:
|
case kTfLiteFloat64:
|
||||||
TF_LITE_ASSERT_EQ(sizeof(double), 8);
|
TF_LITE_ASSERT_EQ(sizeof(double), 8);
|
||||||
return 8;
|
return 8;
|
||||||
@@ -475,4 +515,15 @@ int TfLiteTypeGetSize(TfLiteType type) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsMobilePlatform() {
|
||||||
|
#if defined(ANDROID) || defined(__ANDROID__)
|
||||||
|
return true;
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
#if TARGET_IPHONE_SIMULATOR || TARGET_OS_IPHONE
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -288,6 +288,9 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
|||||||
// Return the size of given type in bytes. Return 0 in in case of string.
|
// Return the size of given type in bytes. Return 0 in in case of string.
|
||||||
int TfLiteTypeGetSize(TfLiteType type);
|
int TfLiteTypeGetSize(TfLiteType type);
|
||||||
|
|
||||||
|
// Whether the current platform is mobile (Android or iOS).
|
||||||
|
bool IsMobilePlatform();
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ inline void InfiniteLoop() {
|
|||||||
|
|
||||||
#endif // TF_LITE_MCU_DEBUG_LOG
|
#endif // TF_LITE_MCU_DEBUG_LOG
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#if defined(NDEBUG) || defined(ARDUINO)
|
||||||
#define TFLITE_ASSERT_FALSE (static_cast<void>(0))
|
#define TFLITE_ASSERT_FALSE (static_cast<void>(0))
|
||||||
#else
|
#else
|
||||||
#define TFLITE_ASSERT_FALSE TFLITE_ABORT
|
#define TFLITE_ASSERT_FALSE TFLITE_ABORT
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_LITE_KERNELS_PADDING_H_
|
#define TENSORFLOW_LITE_KERNELS_PADDING_H_
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@@ -75,6 +76,36 @@ inline TfLitePaddingValues ComputePaddingHeightWidth(
|
|||||||
padding_values.width_offset = offset;
|
padding_values.width_offset = offset;
|
||||||
return padding_values;
|
return padding_values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline Padding3DValues ComputePadding3DValues(
|
||||||
|
int stride_height, int stride_width, int stride_depth,
|
||||||
|
int dilation_rate_height, int dilation_rate_width, int dilation_rate_depth,
|
||||||
|
int in_height, int in_width, int in_depth, int filter_height,
|
||||||
|
int filter_width, int filter_depth, TfLitePadding padding, int* out_height,
|
||||||
|
int* out_width, int* out_depth) {
|
||||||
|
*out_width = ComputeOutSize(padding, in_width, filter_width, stride_width,
|
||||||
|
dilation_rate_width);
|
||||||
|
*out_height = ComputeOutSize(padding, in_height, filter_height, stride_height,
|
||||||
|
dilation_rate_height);
|
||||||
|
*out_depth = ComputeOutSize(padding, in_depth, filter_depth, stride_depth,
|
||||||
|
dilation_rate_depth);
|
||||||
|
|
||||||
|
Padding3DValues padding_values;
|
||||||
|
int offset = 0;
|
||||||
|
padding_values.depth =
|
||||||
|
ComputePaddingWithOffset(stride_depth, dilation_rate_depth, in_depth,
|
||||||
|
filter_depth, *out_depth, &offset);
|
||||||
|
padding_values.depth_offset = offset;
|
||||||
|
padding_values.height =
|
||||||
|
ComputePaddingWithOffset(stride_height, dilation_rate_height, in_height,
|
||||||
|
filter_height, *out_height, &offset);
|
||||||
|
padding_values.height_offset = offset;
|
||||||
|
padding_values.width =
|
||||||
|
ComputePaddingWithOffset(stride_width, dilation_rate_width, in_width,
|
||||||
|
filter_width, *out_width, &offset);
|
||||||
|
padding_values.width_offset = offset;
|
||||||
|
return padding_values;
|
||||||
|
}
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_PADDING_H_
|
#endif // TENSORFLOW_LITE_KERNELS_PADDING_H_
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -15,35 +18,35 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
|
||||||
namespace micro {
|
|
||||||
namespace custom {
|
|
||||||
TfLiteRegistration* Register_ETHOSU();
|
|
||||||
const char* GetString_ETHOSU();
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace micro
|
|
||||||
} // namespace ops
|
|
||||||
|
|
||||||
AllOpsResolver::AllOpsResolver() {
|
AllOpsResolver::AllOpsResolver() {
|
||||||
// Please keep this list of Builtin Operators in alphabetical order.
|
// Please keep this list of Builtin Operators in alphabetical order.
|
||||||
AddAbs();
|
AddAbs();
|
||||||
AddAdd();
|
AddAdd();
|
||||||
|
AddAddN();
|
||||||
AddArgMax();
|
AddArgMax();
|
||||||
AddArgMin();
|
AddArgMin();
|
||||||
AddAveragePool2D();
|
AddAveragePool2D();
|
||||||
|
AddBatchToSpaceNd();
|
||||||
AddCeil();
|
AddCeil();
|
||||||
AddConcatenation();
|
AddConcatenation();
|
||||||
AddConv2D();
|
AddConv2D();
|
||||||
AddCos();
|
AddCos();
|
||||||
AddDepthwiseConv2D();
|
AddDepthwiseConv2D();
|
||||||
AddDequantize();
|
AddDequantize();
|
||||||
|
AddDetectionPostprocess();
|
||||||
|
AddDiv();
|
||||||
|
AddElu();
|
||||||
AddEqual();
|
AddEqual();
|
||||||
|
AddEthosU();
|
||||||
AddFloor();
|
AddFloor();
|
||||||
AddFullyConnected();
|
AddFullyConnected();
|
||||||
AddGreater();
|
AddGreater();
|
||||||
AddGreaterEqual();
|
AddGreaterEqual();
|
||||||
AddHardSwish();
|
AddHardSwish();
|
||||||
AddL2Normalization();
|
AddL2Normalization();
|
||||||
|
AddL2Pool2D();
|
||||||
|
AddLeakyRelu();
|
||||||
AddLess();
|
AddLess();
|
||||||
AddLessEqual();
|
AddLessEqual();
|
||||||
AddLog();
|
AddLog();
|
||||||
@@ -51,8 +54,8 @@ AllOpsResolver::AllOpsResolver() {
|
|||||||
AddLogicalNot();
|
AddLogicalNot();
|
||||||
AddLogicalOr();
|
AddLogicalOr();
|
||||||
AddLogistic();
|
AddLogistic();
|
||||||
AddMaximum();
|
|
||||||
AddMaxPool2D();
|
AddMaxPool2D();
|
||||||
|
AddMaximum();
|
||||||
AddMean();
|
AddMean();
|
||||||
AddMinimum();
|
AddMinimum();
|
||||||
AddMul();
|
AddMul();
|
||||||
@@ -73,22 +76,18 @@ AllOpsResolver::AllOpsResolver() {
|
|||||||
AddShape();
|
AddShape();
|
||||||
AddSin();
|
AddSin();
|
||||||
AddSoftmax();
|
AddSoftmax();
|
||||||
|
AddSpaceToBatchNd();
|
||||||
AddSplit();
|
AddSplit();
|
||||||
AddSplitV();
|
AddSplitV();
|
||||||
AddSqrt();
|
AddSqrt();
|
||||||
AddSquare();
|
AddSquare();
|
||||||
|
AddSqueeze();
|
||||||
AddStridedSlice();
|
AddStridedSlice();
|
||||||
AddSub();
|
AddSub();
|
||||||
AddSvdf();
|
AddSvdf();
|
||||||
AddTanh();
|
AddTanh();
|
||||||
|
AddTransposeConv();
|
||||||
AddUnpack();
|
AddUnpack();
|
||||||
|
|
||||||
// TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver.
|
|
||||||
TfLiteRegistration* registration =
|
|
||||||
tflite::ops::micro::custom::Register_ETHOSU();
|
|
||||||
if (registration) {
|
|
||||||
AddCustom(tflite::ops::micro::custom::GetString_ETHOSU(), registration);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
119
code/components/tfmicro/tensorflow/lite/micro/kernels/add_n.cc
Normal file
119
code/components/tfmicro/tensorflow/lite/micro/kernels/add_n.cc
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/add_n.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputTensor0 = 0;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
int num_inputs = NumInputs(node);
|
||||||
|
TF_LITE_ENSURE(context, num_inputs >= 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
|
const TfLiteTensor* input_tensor_first;
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, GetInputSafe(context, node, kInputTensor0, &input_tensor_first));
|
||||||
|
TfLiteTensor* output;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||||
|
|
||||||
|
// Check that all tensors have the same shape and type.
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_tensor_first->type);
|
||||||
|
for (int i = kInputTensor0 + 1; i < num_inputs; ++i) {
|
||||||
|
const TfLiteTensor* input;
|
||||||
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
|
||||||
|
TF_LITE_ENSURE(context, HaveSameShapes(input_tensor_first, input));
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input_tensor_first->type, input->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate scratch buffer space for pointer to each tensor's data
|
||||||
|
// and store the scratch buffer index in the node's user_data
|
||||||
|
if (output->type == kTfLiteFloat32) {
|
||||||
|
int scratch_index;
|
||||||
|
size_t scratch_size = sizeof(float*) * num_inputs;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->RequestScratchBufferInArena(
|
||||||
|
context, scratch_size, &scratch_index));
|
||||||
|
node->user_data =
|
||||||
|
reinterpret_cast<decltype(node->user_data)>(scratch_index);
|
||||||
|
} else {
|
||||||
|
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32, got %s.",
|
||||||
|
TfLiteTypeGetName(output->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
return CalculateOpData(context, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void EvalAddN(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
TfLiteEvalTensor* output) {
|
||||||
|
int num_inputs = NumInputs(node);
|
||||||
|
|
||||||
|
int scratch_index =
|
||||||
|
static_cast<int>(reinterpret_cast<intptr_t>(node->user_data));
|
||||||
|
void* scratch_buffer = context->GetScratchBuffer(context, scratch_index);
|
||||||
|
const T** all_inputs = static_cast<decltype(all_inputs)>(scratch_buffer);
|
||||||
|
for (int i = 0; i < num_inputs; i++) {
|
||||||
|
const TfLiteEvalTensor* next_input =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor0 + i);
|
||||||
|
all_inputs[i] = tflite::micro::GetTensorData<T>(next_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
reference_ops::AddN<T>(tflite::micro::GetTensorShape(output), num_inputs,
|
||||||
|
all_inputs, tflite::micro::GetTensorData<T>(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
if (output->type == kTfLiteFloat32) {
|
||||||
|
EvalAddN<float>(context, node, output);
|
||||||
|
} else {
|
||||||
|
TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32, got %s.",
|
||||||
|
TfLiteTypeGetName(output->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_ADD_N() {
|
||||||
|
return {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_utils.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kBlockShapeTensor = 1;
|
||||||
|
constexpr int kCropsTensor = 2;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
|
||||||
|
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
|
||||||
|
// The 4D array need to have exactly 2 spatial dimensions.
|
||||||
|
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
|
||||||
|
const int kInputOutputMinDimensionNum = 3;
|
||||||
|
const int kInputOutputMaxDimensionNum = 4;
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* input =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||||
|
const TfLiteEvalTensor* block_shape =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kBlockShapeTensor);
|
||||||
|
const TfLiteEvalTensor* crops =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kCropsTensor);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
switch (input->type) { // Already know in/out types are same.
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
reference_ops::BatchToSpaceND(
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<float>(input),
|
||||||
|
tflite::micro::GetTensorShape(block_shape),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(block_shape),
|
||||||
|
tflite::micro::GetTensorShape(crops),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(crops),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
reference_ops::BatchToSpaceND(
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(block_shape),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(block_shape),
|
||||||
|
tflite::micro::GetTensorShape(crops),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(crops),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace.
|
||||||
|
|
||||||
|
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
|
||||||
|
return {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FromT, typename ToT>
|
||||||
|
void copyCast(const FromT* in, ToT* out, int num_elements) {
|
||||||
|
std::transform(in, in + num_elements, out,
|
||||||
|
[](FromT a) { return static_cast<ToT>(a); });
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FromT>
|
||||||
|
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
||||||
|
TfLiteEvalTensor* out, int num_elements) {
|
||||||
|
switch (out->type) {
|
||||||
|
case kTfLiteInt8:
|
||||||
|
copyCast(in, out->data.int8, num_elements);
|
||||||
|
break;
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
copyCast(in, tflite::micro::GetTensorData<float>(out), num_elements);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Unsupported type.
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.",
|
||||||
|
TfLiteTypeGetName(out->type), out->type);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* input =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
int num_elements = MatchingFlatSize(tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorShape(output));
|
||||||
|
|
||||||
|
switch (input->type) {
|
||||||
|
case kTfLiteInt8:
|
||||||
|
return copyToTensor(context, input->data.int8, output, num_elements);
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
return copyToTensor(context, tflite::micro::GetTensorData<float>(input),
|
||||||
|
output, num_elements);
|
||||||
|
default:
|
||||||
|
// Unsupported type.
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Input type %s (%d) not supported.",
|
||||||
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_CAST() {
|
||||||
|
return {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
@@ -55,7 +57,7 @@ constexpr int kInputTensor = 0;
|
|||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
// TODO(b/149795762): Add this to TfLiteStatus enum.
|
// TODO(b/149795762): Add this to TfLiteStatus enum.
|
||||||
constexpr int kTfLiteAbort = -9;
|
constexpr TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(-9);
|
||||||
|
|
||||||
// These fields control the stride period of a strided streaming model. This op
|
// These fields control the stride period of a strided streaming model. This op
|
||||||
// returns kTfLiteAbort until cycles_until_run-- is zero. At this time,
|
// returns kTfLiteAbort until cycles_until_run-- is zero. At this time,
|
||||||
@@ -65,47 +67,64 @@ struct OpData {
|
|||||||
int cycles_max;
|
int cycles_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
// These constants represent constants specific to the music detect model.
|
|
||||||
// They exist until (b/132070898) is fixed.
|
|
||||||
constexpr int kMaxOpDataSize = 7;
|
|
||||||
int op_data_counter = 0;
|
|
||||||
OpData op_data_array[kMaxOpDataSize];
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
OpData* op_data = static_cast<OpData*>(
|
||||||
|
context->AllocatePersistentBuffer(context, sizeof(OpData)));
|
||||||
|
|
||||||
|
if (buffer != nullptr && length > 0) {
|
||||||
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
|
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||||
|
op_data->cycles_max = m["cycles_max"].AsInt32();
|
||||||
|
} else {
|
||||||
|
op_data->cycles_max = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return op_data;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, output->dims->data[0]);
|
TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[0]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
|
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, output->dims->data[2]);
|
TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[2]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
|
TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
|
||||||
|
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
// The circular buffer custom operator currently only supports int8_t.
|
// The circular buffer custom operator currently only supports int8.
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||||
|
|
||||||
// TODO(b/132070898): Use statically slotted OpData structures until a
|
if (op_data->cycles_max <= 0) {
|
||||||
// scratch memory API is ready.
|
// The last circular buffer layer simply accumulates outputs, and does not
|
||||||
TFLITE_DCHECK_LE(op_data_counter, kMaxOpDataSize);
|
// run periodically.
|
||||||
OpData* op_data = &op_data_array[op_data_counter++];
|
// TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
|
||||||
// The last circular buffer layer (length 5) simply accumulates outputs, and
|
static int cb_prepare_count = 0;
|
||||||
// does not run periodically.
|
cb_prepare_count++;
|
||||||
// TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
|
// These checks specifically work for the only two streaming models
|
||||||
if (output->dims->data[1] == 5) {
|
// supported on TFLM. They use the shape of the output tensor along with the
|
||||||
op_data->cycles_max = 1;
|
// layer number to determine if the circular buffer period should be 1 or 2.
|
||||||
} else {
|
|
||||||
op_data->cycles_max = 2;
|
// These models are outlined int the following documents:
|
||||||
|
// https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
|
||||||
|
// https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
|
||||||
|
if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
|
||||||
|
(cb_prepare_count == 5 && output->dims->data[2] == 2 &&
|
||||||
|
output->dims->data[3] == 96)) {
|
||||||
|
op_data->cycles_max = 1;
|
||||||
|
cb_prepare_count = 0;
|
||||||
|
} else {
|
||||||
|
op_data->cycles_max = 2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
op_data->cycles_until_run = op_data->cycles_max;
|
op_data->cycles_until_run = op_data->cycles_max;
|
||||||
node->user_data = op_data;
|
node->user_data = op_data;
|
||||||
@@ -127,10 +146,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteEvalTensor* output =
|
TfLiteEvalTensor* output =
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
int num_slots = output->dims->data[1];
|
int num_slots = output->dims->data[1];
|
||||||
int depth = output->dims->data[3];
|
int depth = output->dims->data[2] * output->dims->data[3];
|
||||||
|
|
||||||
if (input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteInt8) {
|
||||||
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
||||||
@@ -148,12 +168,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return static_cast<TfLiteStatus>(kTfLiteAbort);
|
return static_cast<TfLiteStatus>(kTfLiteAbort);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If prepare is ever called more than one time (for example, when testing the
|
|
||||||
// ambient model, the interpreter is created a few times), this op data
|
|
||||||
// counter needs to be reset so that future instances do not overrun this op
|
|
||||||
// data array.
|
|
||||||
op_data_counter = 0;
|
|
||||||
|
|
||||||
data->cycles_until_run = data->cycles_max;
|
data->cycles_until_run = data->cycles_max;
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@@ -162,8 +176,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace circular_buffer
|
} // namespace circular_buffer
|
||||||
|
|
||||||
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
||||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
|
||||||
/*free=*/circular_buffer::Free,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/circular_buffer::Prepare,
|
/*prepare=*/circular_buffer::Prepare,
|
||||||
/*invoke=*/circular_buffer::Eval,
|
/*invoke=*/circular_buffer::Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FLEXBUFFERS_GENERATED_DATA_H
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_FLEXBUFFERS_GENERATED_DATA_H
|
||||||
|
|
||||||
|
extern const int g_gen_data_size_circular_buffer_config;
|
||||||
|
extern const unsigned char g_gen_data_circular_buffer_config[];
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
#include "tensorflow/lite/micro/kernels/conv.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@@ -28,294 +29,60 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
|
||||||
constexpr int kFilterTensor = 1;
|
|
||||||
constexpr int kBiasTensor = 2;
|
|
||||||
constexpr int kOutputTensor = 0;
|
|
||||||
|
|
||||||
// Conv is quantized along dimension 0:
|
|
||||||
// https://www.tensorflow.org/lite/performance/quantization_spec
|
|
||||||
constexpr int kConvQuantizedDimension = 0;
|
|
||||||
|
|
||||||
// This file has 2 implementation of Conv.
|
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
TfLitePaddingValues padding;
|
|
||||||
|
|
||||||
// Cached tensor zero point values for quantized operations.
|
|
||||||
int32_t input_zero_point;
|
|
||||||
int32_t filter_zero_point;
|
|
||||||
int32_t output_zero_point;
|
|
||||||
|
|
||||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
|
||||||
// be represented as a fixed point multiplier plus a left shift.
|
|
||||||
int32_t output_multiplier;
|
|
||||||
int output_shift;
|
|
||||||
|
|
||||||
// Per channel output multiplier and shift.
|
|
||||||
int32_t* per_channel_output_multiplier;
|
|
||||||
int32_t* per_channel_output_shift;
|
|
||||||
|
|
||||||
// The range of the fused activation layer. For example for kNone and
|
|
||||||
// uint8_t these would be 0 and 255.
|
|
||||||
int32_t output_activation_min;
|
|
||||||
int32_t output_activation_max;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
|
|
||||||
switch (padding) {
|
|
||||||
case TfLitePadding::kTfLitePaddingSame:
|
|
||||||
return PaddingType::kSame;
|
|
||||||
case TfLitePadding::kTfLitePaddingValid:
|
|
||||||
return PaddingType::kValid;
|
|
||||||
case TfLitePadding::kTfLitePaddingUnknown:
|
|
||||||
default:
|
|
||||||
return PaddingType::kNone;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
const TfLiteConvParams* params, int width,
|
|
||||||
int height, int filter_width, int filter_height,
|
|
||||||
int out_width, int out_height,
|
|
||||||
const TfLiteType data_type, OpData* data) {
|
|
||||||
bool has_bias = node->inputs->size == 3;
|
|
||||||
// Check number of inputs/outputs
|
|
||||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
|
||||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
|
||||||
|
|
||||||
// Matching GetWindowedOutputSize in TensorFlow.
|
|
||||||
auto padding = params->padding;
|
|
||||||
data->padding = ComputePaddingHeightWidth(
|
|
||||||
params->stride_height, params->stride_width,
|
|
||||||
params->dilation_height_factor, params->dilation_width_factor, height,
|
|
||||||
width, filter_height, filter_width, padding, &out_height, &out_width);
|
|
||||||
|
|
||||||
// Note that quantized inference requires that all tensors have their
|
|
||||||
// parameters set. This is usually done during quantized training.
|
|
||||||
if (data_type != kTfLiteFloat32) {
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
|
||||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
|
||||||
TF_LITE_ENSURE(context, filter != nullptr);
|
|
||||||
const TfLiteTensor* bias =
|
|
||||||
GetOptionalInputTensor(context, node, kBiasTensor);
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
|
||||||
int output_channels = filter->dims->data[kConvQuantizedDimension];
|
|
||||||
|
|
||||||
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
|
|
||||||
context, input, filter, bias, output, params->activation,
|
|
||||||
&data->output_multiplier, &data->output_shift,
|
|
||||||
&data->output_activation_min, &data->output_activation_max,
|
|
||||||
data->per_channel_output_multiplier,
|
|
||||||
reinterpret_cast<int*>(data->per_channel_output_shift),
|
|
||||||
output_channels));
|
|
||||||
}
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
|
||||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
|
||||||
|
|
||||||
OpData* data = static_cast<OpData*>(node->user_data);
|
|
||||||
const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
|
|
||||||
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
|
||||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
|
||||||
TF_LITE_ENSURE(context, filter != nullptr);
|
|
||||||
|
|
||||||
int input_width = input->dims->data[2];
|
|
||||||
int input_height = input->dims->data[1];
|
|
||||||
int filter_width = filter->dims->data[2];
|
|
||||||
int filter_height = filter->dims->data[1];
|
|
||||||
int output_width = output->dims->data[2];
|
|
||||||
int output_height = output->dims->data[1];
|
|
||||||
|
|
||||||
// Dynimically allocate per-channel quantization parameters.
|
|
||||||
const int num_channels = filter->dims->data[kConvQuantizedDimension];
|
|
||||||
data->per_channel_output_multiplier =
|
|
||||||
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
|
||||||
context, num_channels * sizeof(int32_t)));
|
|
||||||
data->per_channel_output_shift =
|
|
||||||
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
|
||||||
context, num_channels * sizeof(int32_t)));
|
|
||||||
|
|
||||||
// All per-channel quantized tensors need valid zero point and scale arrays.
|
|
||||||
if (input->type == kTfLiteInt8) {
|
|
||||||
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
|
||||||
kTfLiteAffineQuantization);
|
|
||||||
|
|
||||||
const auto* affine_quantization =
|
|
||||||
static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
|
|
||||||
TF_LITE_ENSURE(context, affine_quantization);
|
|
||||||
TF_LITE_ENSURE(context, affine_quantization->scale);
|
|
||||||
TF_LITE_ENSURE(context, affine_quantization->zero_point);
|
|
||||||
|
|
||||||
TF_LITE_ENSURE(context,
|
|
||||||
affine_quantization->scale->size == 1 ||
|
|
||||||
affine_quantization->scale->size ==
|
|
||||||
filter->dims->data[kConvQuantizedDimension]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
|
||||||
affine_quantization->zero_point->size);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_LITE_ENSURE_STATUS(CalculateOpData(
|
|
||||||
context, node, params, input_width, input_height, filter_width,
|
|
||||||
filter_height, output_width, output_height, input->type, data));
|
|
||||||
|
|
||||||
data->input_zero_point = input->params.zero_point;
|
|
||||||
data->filter_zero_point = filter->params.zero_point;
|
|
||||||
data->output_zero_point = output->params.zero_point;
|
|
||||||
|
|
||||||
return kTfLiteOk;
|
|
||||||
} // namespace conv
|
|
||||||
|
|
||||||
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteConvParams* params, const OpData& data,
|
|
||||||
const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
|
|
||||||
TfLiteEvalTensor* im2col, TfLiteEvalTensor* hwcn_weights,
|
|
||||||
TfLiteEvalTensor* output) {
|
|
||||||
const int32_t input_offset = -data.input_zero_point;
|
|
||||||
const int32_t filter_offset = -data.filter_zero_point;
|
|
||||||
const int32_t output_offset = data.output_zero_point;
|
|
||||||
|
|
||||||
// TODO(b/154032858): Investigate removing extra copies.
|
|
||||||
ConvParams op_params;
|
|
||||||
op_params.padding_type = RuntimePaddingType(params->padding);
|
|
||||||
op_params.padding_values.width = data.padding.width;
|
|
||||||
op_params.padding_values.height = data.padding.height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.input_offset = input_offset;
|
|
||||||
op_params.weights_offset = filter_offset;
|
|
||||||
op_params.output_offset = output_offset;
|
|
||||||
op_params.output_multiplier = data.output_multiplier;
|
|
||||||
op_params.output_shift = -data.output_shift;
|
|
||||||
op_params.quantized_activation_min = data.output_activation_min;
|
|
||||||
op_params.quantized_activation_max = data.output_activation_max;
|
|
||||||
reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(output),
|
|
||||||
tflite::micro::GetTensorShape(im2col),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(im2col), nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteConvParams* params, const OpData& data,
|
|
||||||
const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias,
|
|
||||||
TfLiteEvalTensor* output,
|
|
||||||
TfLiteEvalTensor* im2col) {
|
|
||||||
// TODO(b/154032858): Investigate removing extra copies.
|
|
||||||
ConvParams op_params;
|
|
||||||
op_params.input_offset = -data.input_zero_point;
|
|
||||||
op_params.output_offset = data.output_zero_point;
|
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.padding_values.height = data.padding.height;
|
|
||||||
op_params.padding_values.width = data.padding.width;
|
|
||||||
op_params.quantized_activation_min = data.output_activation_min;
|
|
||||||
op_params.quantized_activation_max = data.output_activation_max;
|
|
||||||
|
|
||||||
reference_integer_ops::ConvPerChannel(
|
|
||||||
op_params, data.per_channel_output_multiplier,
|
|
||||||
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteConvParams* params, const OpData& data,
|
|
||||||
const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col,
|
|
||||||
TfLiteEvalTensor* hwcn_weights, TfLiteEvalTensor* output) {
|
|
||||||
float output_activation_min, output_activation_max;
|
|
||||||
CalculateActivationRange(params->activation, &output_activation_min,
|
|
||||||
&output_activation_max);
|
|
||||||
// TODO(b/154032858): Investigate removing extra copies.
|
|
||||||
ConvParams op_params;
|
|
||||||
op_params.padding_type = RuntimePaddingType(params->padding);
|
|
||||||
op_params.padding_values.width = data.padding.width;
|
|
||||||
op_params.padding_values.height = data.padding.height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.float_activation_min = output_activation_min;
|
|
||||||
op_params.float_activation_max = output_activation_max;
|
|
||||||
|
|
||||||
reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<float>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<float>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<float>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<float>(output),
|
|
||||||
tflite::micro::GetTensorShape(im2col),
|
|
||||||
tflite::micro::GetTensorData<float>(im2col));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
|
|
||||||
|
|
||||||
const TfLiteEvalTensor* input =
|
const TfLiteEvalTensor* input =
|
||||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
|
||||||
const TfLiteEvalTensor* filter =
|
const TfLiteEvalTensor* filter =
|
||||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
|
||||||
const TfLiteEvalTensor* bias =
|
const TfLiteEvalTensor* bias =
|
||||||
(NumInputs(node) == 3)
|
(NumInputs(node) == 3)
|
||||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
TfLiteEvalTensor* output =
|
TfLiteEvalTensor* output =
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
|
||||||
|
|
||||||
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
const auto& params =
|
||||||
|
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
const auto& data = *(static_cast<const OpDataConv*>(node->user_data));
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||||
"Hybrid models are not supported on TFLite Micro.");
|
"Hybrid models are not supported on TFLite Micro.");
|
||||||
|
|
||||||
switch (input->type) { // Already know in/out types are same.
|
switch (input->type) { // Already know in/out types are same.
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32: {
|
||||||
EvalFloat(context, node, params, data, input, filter, bias, nullptr,
|
tflite::reference_ops::Conv(
|
||||||
nullptr, output);
|
ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<float>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<float>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<float>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output),
|
||||||
|
tflite::micro::GetTensorShape(nullptr), nullptr);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
}
|
||||||
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
|
case kTfLiteInt8: {
|
||||||
output, nullptr);
|
reference_integer_ops::ConvPerChannel(
|
||||||
break;
|
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
|
||||||
case kTfLiteUInt8:
|
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||||
EvalQuantized(context, node, params, data, input, filter, bias, nullptr,
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
nullptr, output);
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
TfLiteTypeGetName(input->type), input->type);
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
@@ -329,7 +96,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteRegistration Register_CONV_2D() {
|
TfLiteRegistration Register_CONV_2D() {
|
||||||
return {/*init=*/Init,
|
return {/*init=*/Init,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/Prepare,
|
/*prepare=*/ConvPrepare,
|
||||||
/*invoke=*/Eval,
|
/*invoke=*/Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
/*builtin_code=*/0,
|
/*builtin_code=*/0,
|
||||||
|
|||||||
77
code/components/tfmicro/tensorflow/lite/micro/kernels/conv.h
Normal file
77
code/components/tfmicro/tensorflow/lite/micro/kernels/conv.h
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
struct OpDataConv {
|
||||||
|
TfLitePaddingValues padding;
|
||||||
|
|
||||||
|
// Cached tensor zero point values for quantized operations.
|
||||||
|
int32_t input_zero_point;
|
||||||
|
int32_t filter_zero_point;
|
||||||
|
int32_t output_zero_point;
|
||||||
|
|
||||||
|
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||||
|
// be represented as a fixed point multiplier plus a left shift.
|
||||||
|
int32_t output_multiplier;
|
||||||
|
int output_shift;
|
||||||
|
|
||||||
|
// Per channel output multiplier and shift.
|
||||||
|
int32_t* per_channel_output_multiplier;
|
||||||
|
int32_t* per_channel_output_shift;
|
||||||
|
|
||||||
|
// The range of the fused activation layer. For example for kNone and
|
||||||
|
// uint8_t these would be 0 and 255.
|
||||||
|
int32_t output_activation_min;
|
||||||
|
int32_t output_activation_max;
|
||||||
|
};
|
||||||
|
|
||||||
|
extern const int kConvInputTensor;
|
||||||
|
extern const int kConvWeightsTensor;
|
||||||
|
extern const int kConvBiasTensor;
|
||||||
|
extern const int kConvOutputTensor;
|
||||||
|
extern const int kConvQuantizedDimension;
|
||||||
|
|
||||||
|
// Returns a ConvParams struct with all the parameters needed for a
|
||||||
|
// float computation.
|
||||||
|
ConvParams ConvParamsFloat(const TfLiteConvParams& params,
|
||||||
|
const OpDataConv& data);
|
||||||
|
|
||||||
|
// Returns a ConvParams struct with all the parameters needed for a
|
||||||
|
// quantized computation.
|
||||||
|
ConvParams ConvParamsQuantized(const TfLiteConvParams& params,
|
||||||
|
const OpDataConv& data);
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
const TfLiteConvParams& params, int width,
|
||||||
|
int height, int filter_width,
|
||||||
|
int filter_height, int out_width,
|
||||||
|
int out_height, const TfLiteType data_type,
|
||||||
|
OpDataConv* data);
|
||||||
|
|
||||||
|
TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/conv.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
const int kConvInputTensor = 0;
|
||||||
|
const int kConvWeightsTensor = 1;
|
||||||
|
const int kConvBiasTensor = 2;
|
||||||
|
const int kConvOutputTensor = 0;
|
||||||
|
|
||||||
|
// Conv is quantized along dimension 0:
|
||||||
|
// https://www.tensorflow.org/lite/performance/quantization_spec
|
||||||
|
const int kConvQuantizedDimension = 0;
|
||||||
|
|
||||||
|
// Returns a ConvParams struct with all the parameters needed for a
|
||||||
|
// float computation.
|
||||||
|
ConvParams ConvParamsFloat(const TfLiteConvParams& params,
|
||||||
|
const OpDataConv& data) {
|
||||||
|
ConvParams op_params;
|
||||||
|
CalculateActivationRange(params.activation, &op_params.float_activation_min,
|
||||||
|
&op_params.float_activation_max);
|
||||||
|
op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding);
|
||||||
|
op_params.padding_values.width = data.padding.width;
|
||||||
|
op_params.padding_values.height = data.padding.height;
|
||||||
|
op_params.stride_width = params.stride_width;
|
||||||
|
op_params.stride_height = params.stride_height;
|
||||||
|
op_params.dilation_width_factor = params.dilation_width_factor;
|
||||||
|
op_params.dilation_height_factor = params.dilation_height_factor;
|
||||||
|
return op_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a ConvParams struct with all the parameters needed for a
|
||||||
|
// quantized computation.
|
||||||
|
ConvParams ConvParamsQuantized(const TfLiteConvParams& params,
|
||||||
|
const OpDataConv& data) {
|
||||||
|
ConvParams op_params;
|
||||||
|
op_params.input_offset = -data.input_zero_point;
|
||||||
|
op_params.weights_offset = -data.filter_zero_point;
|
||||||
|
op_params.output_offset = data.output_zero_point;
|
||||||
|
op_params.output_multiplier = data.output_multiplier;
|
||||||
|
op_params.output_shift = -data.output_shift;
|
||||||
|
op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding);
|
||||||
|
op_params.padding_values.height = data.padding.height;
|
||||||
|
op_params.padding_values.width = data.padding.width;
|
||||||
|
op_params.stride_height = params.stride_height;
|
||||||
|
op_params.stride_width = params.stride_width;
|
||||||
|
op_params.dilation_height_factor = params.dilation_height_factor;
|
||||||
|
op_params.dilation_width_factor = params.dilation_width_factor;
|
||||||
|
op_params.quantized_activation_min = data.output_activation_min;
|
||||||
|
op_params.quantized_activation_max = data.output_activation_max;
|
||||||
|
return op_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
const TfLiteConvParams& params, int width,
|
||||||
|
int height, int filter_width,
|
||||||
|
int filter_height, int out_width,
|
||||||
|
int out_height, const TfLiteType data_type,
|
||||||
|
OpDataConv* data) {
|
||||||
|
bool has_bias = node->inputs->size == 3;
|
||||||
|
// Check number of inputs/outputs
|
||||||
|
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||||
|
|
||||||
|
// Matching GetWindowedOutputSize in TensorFlow.
|
||||||
|
auto padding = params.padding;
|
||||||
|
data->padding = ComputePaddingHeightWidth(
|
||||||
|
params.stride_height, params.stride_width, params.dilation_height_factor,
|
||||||
|
params.dilation_width_factor, height, width, filter_height, filter_width,
|
||||||
|
padding, &out_height, &out_width);
|
||||||
|
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
|
||||||
|
TF_LITE_ENSURE(context, filter != nullptr);
|
||||||
|
const TfLiteTensor* bias =
|
||||||
|
GetOptionalInputTensor(context, node, kConvBiasTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
// Note that quantized inference requires that all tensors have their
|
||||||
|
// parameters set. This is usually done during quantized training.
|
||||||
|
if (data_type != kTfLiteFloat32) {
|
||||||
|
int output_channels = filter->dims->data[kConvQuantizedDimension];
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
|
||||||
|
context, input, filter, bias, output, params.activation,
|
||||||
|
&data->output_multiplier, &data->output_shift,
|
||||||
|
&data->output_activation_min, &data->output_activation_max,
|
||||||
|
data->per_channel_output_multiplier,
|
||||||
|
reinterpret_cast<int*>(data->per_channel_output_shift),
|
||||||
|
output_channels));
|
||||||
|
}
|
||||||
|
|
||||||
|
data->input_zero_point = input->params.zero_point;
|
||||||
|
data->filter_zero_point = filter->params.zero_point;
|
||||||
|
data->output_zero_point = output->params.zero_point;
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
|
||||||
|
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
|
||||||
|
const auto& params =
|
||||||
|
*(static_cast<const TfLiteConvParams*>(node->builtin_data));
|
||||||
|
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
|
||||||
|
TF_LITE_ENSURE(context, filter != nullptr);
|
||||||
|
|
||||||
|
const int input_width = input->dims->data[2];
|
||||||
|
const int input_height = input->dims->data[1];
|
||||||
|
const int filter_width = filter->dims->data[2];
|
||||||
|
const int filter_height = filter->dims->data[1];
|
||||||
|
const int output_width = output->dims->data[2];
|
||||||
|
const int output_height = output->dims->data[1];
|
||||||
|
|
||||||
|
// Dynamically allocate per-channel quantization parameters.
|
||||||
|
const int num_channels = filter->dims->data[kConvQuantizedDimension];
|
||||||
|
data->per_channel_output_multiplier =
|
||||||
|
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||||
|
context, num_channels * sizeof(int32_t)));
|
||||||
|
data->per_channel_output_shift =
|
||||||
|
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||||
|
context, num_channels * sizeof(int32_t)));
|
||||||
|
|
||||||
|
// All per-channel quantized tensors need valid zero point and scale arrays.
|
||||||
|
if (input->type == kTfLiteInt8) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
||||||
|
kTfLiteAffineQuantization);
|
||||||
|
|
||||||
|
const auto* affine_quantization =
|
||||||
|
static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
|
||||||
|
TFLITE_DCHECK(affine_quantization != nullptr);
|
||||||
|
TFLITE_DCHECK(affine_quantization->scale != nullptr);
|
||||||
|
TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context,
|
||||||
|
affine_quantization->scale->size == 1 ||
|
||||||
|
affine_quantization->scale->size ==
|
||||||
|
filter->dims->data[kConvQuantizedDimension]);
|
||||||
|
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
||||||
|
affine_quantization->zero_point->size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
|
||||||
|
context, node, params, input_width, input_height, filter_width,
|
||||||
|
filter_height, output_width, output_height, input->type, data));
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||||
|
#include "tensorflow/lite/micro/test_helpers.h"
|
||||||
|
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace testing {
|
||||||
|
|
||||||
|
TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size,
|
||||||
|
int output_length, TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration, float* output_data);
|
||||||
|
|
||||||
|
TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size,
|
||||||
|
int output_length, TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration, int8_t* output_data);
|
||||||
|
|
||||||
|
TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size,
|
||||||
|
int output_length, TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration, uint8_t* output_data);
|
||||||
|
|
||||||
|
TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||||
|
const float* expected_output_data,
|
||||||
|
int output_length,
|
||||||
|
TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration,
|
||||||
|
float* output_data, float tolerance = 1e-5);
|
||||||
|
|
||||||
|
TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||||
|
const int8_t* expected_output_data,
|
||||||
|
int output_length,
|
||||||
|
TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration,
|
||||||
|
int8_t* output_data, float tolerance = 1e-5);
|
||||||
|
|
||||||
|
TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||||
|
const uint8_t* expected_output_data,
|
||||||
|
int output_length,
|
||||||
|
TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration,
|
||||||
|
uint8_t* output_data, float tolerance = 1e-5);
|
||||||
|
|
||||||
|
TfLiteStatus TestConvFloat(const int* input_dims_data, const float* input_data,
|
||||||
|
const int* filter_dims_data,
|
||||||
|
const float* filter_data, const int* bias_dims_data,
|
||||||
|
const float* bias_data, const int* output_dims_data,
|
||||||
|
const float* expected_output_data,
|
||||||
|
TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration, float* output_data);
|
||||||
|
|
||||||
|
TfLiteStatus TestConvQuantizedPerLayer(
|
||||||
|
const int* input_dims_data, const float* input_data,
|
||||||
|
uint8_t* input_quantized, float input_scale, const int* filter_dims_data,
|
||||||
|
const float* filter_data, uint8_t* filter_quantized, float filter_scale,
|
||||||
|
const int* bias_dims_data, const float* bias_data, int32_t* bias_quantized,
|
||||||
|
const int* output_dims_data, const float* expected_output_data,
|
||||||
|
uint8_t* expected_output_quantized, float output_scale,
|
||||||
|
TfLiteConvParams* conv_params, TfLiteRegistration registration,
|
||||||
|
uint8_t* output_data);
|
||||||
|
|
||||||
|
TfLiteStatus TestConvQuantizedPerChannel(
|
||||||
|
const int* input_dims_data, const float* input_data,
|
||||||
|
int8_t* input_quantized, float input_scale, int input_zero_point,
|
||||||
|
const int* filter_dims_data, const float* filter_data,
|
||||||
|
int8_t* filter_data_quantized, const int* bias_dims_data,
|
||||||
|
const float* bias_data, int32_t* bias_data_quantized, float* bias_scales,
|
||||||
|
int* bias_zero_points, const int* output_dims_data,
|
||||||
|
const float* expected_output_data, int8_t* expected_output_data_quantized,
|
||||||
|
float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
|
||||||
|
TfLiteRegistration registration, int8_t* output_data);
|
||||||
|
|
||||||
|
} // namespace testing
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
|
||||||
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
|
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
@@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/padding.h"
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
@@ -29,279 +30,58 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
|
||||||
constexpr int kFilterTensor = 1;
|
|
||||||
constexpr int kBiasTensor = 2;
|
|
||||||
constexpr int kOutputTensor = 0;
|
|
||||||
|
|
||||||
// Depthwise conv is quantized along dimension 3:
|
|
||||||
// https://www.tensorflow.org/lite/performance/quantization_spec
|
|
||||||
constexpr int kDepthwiseConvQuantizedDimension = 3;
|
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
TfLitePaddingValues padding;
|
|
||||||
|
|
||||||
// Cached tensor zero point values for quantized operations.
|
|
||||||
int32_t input_zero_point;
|
|
||||||
int32_t filter_zero_point;
|
|
||||||
int32_t output_zero_point;
|
|
||||||
|
|
||||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
|
||||||
// be represented as a fixed point multiplier plus a left shift.
|
|
||||||
int32_t output_multiplier;
|
|
||||||
int output_shift;
|
|
||||||
|
|
||||||
// Per channel output multiplier and shift.
|
|
||||||
int32_t* per_channel_output_multiplier;
|
|
||||||
int32_t* per_channel_output_shift;
|
|
||||||
// The range of the fused activation layer. For example for kNone and
|
|
||||||
// uint8_t these would be 0 and 255.
|
|
||||||
int32_t output_activation_min;
|
|
||||||
int32_t output_activation_max;
|
|
||||||
};
|
|
||||||
|
|
||||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteDepthwiseConvParams* params, int width,
|
|
||||||
int height, int filter_width, int filter_height,
|
|
||||||
const TfLiteType data_type, OpData* data) {
|
|
||||||
bool has_bias = node->inputs->size == 3;
|
|
||||||
// Check number of inputs/outputs
|
|
||||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
|
||||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
|
||||||
|
|
||||||
int unused_output_height, unused_output_width;
|
|
||||||
data->padding = ComputePaddingHeightWidth(
|
|
||||||
params->stride_height, params->stride_width, 1, 1, height, width,
|
|
||||||
filter_height, filter_width, params->padding, &unused_output_height,
|
|
||||||
&unused_output_width);
|
|
||||||
|
|
||||||
// Note that quantized inference requires that all tensors have their
|
|
||||||
// parameters set. This is usually done during quantized training.
|
|
||||||
if (data_type != kTfLiteFloat32) {
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
|
||||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
|
||||||
TF_LITE_ENSURE(context, filter != nullptr);
|
|
||||||
const TfLiteTensor* bias =
|
|
||||||
GetOptionalInputTensor(context, node, kBiasTensor);
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
|
||||||
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
|
||||||
|
|
||||||
return tflite::PopulateConvolutionQuantizationParams(
|
|
||||||
context, input, filter, bias, output, params->activation,
|
|
||||||
&data->output_multiplier, &data->output_shift,
|
|
||||||
&data->output_activation_min, &data->output_activation_max,
|
|
||||||
data->per_channel_output_multiplier,
|
|
||||||
reinterpret_cast<int*>(data->per_channel_output_shift), num_channels);
|
|
||||||
}
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
|
||||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
|
||||||
|
|
||||||
auto* params =
|
|
||||||
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
|
|
||||||
OpData* data = static_cast<OpData*>(node->user_data);
|
|
||||||
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
|
||||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
|
||||||
TF_LITE_ENSURE(context, filter != nullptr);
|
|
||||||
|
|
||||||
const TfLiteType data_type = input->type;
|
|
||||||
int width = SizeOfDimension(input, 2);
|
|
||||||
int height = SizeOfDimension(input, 1);
|
|
||||||
int filter_width = SizeOfDimension(filter, 2);
|
|
||||||
int filter_height = SizeOfDimension(filter, 1);
|
|
||||||
|
|
||||||
// Per channel quantization is only needed for int8_t inference. For other
|
|
||||||
// quantized types, only a single scale and zero point is needed.
|
|
||||||
const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
|
||||||
// Dynimically allocate per-channel quantization parameters.
|
|
||||||
data->per_channel_output_multiplier =
|
|
||||||
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
|
|
||||||
context, num_channels * sizeof(int32_t)));
|
|
||||||
data->per_channel_output_shift =
|
|
||||||
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
|
|
||||||
context, num_channels * sizeof(int32_t)));
|
|
||||||
|
|
||||||
// All per-channel quantized tensors need valid zero point and scale arrays.
|
|
||||||
if (input->type == kTfLiteInt8) {
|
|
||||||
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
|
||||||
kTfLiteAffineQuantization);
|
|
||||||
|
|
||||||
const auto* affine_quantization =
|
|
||||||
reinterpret_cast<TfLiteAffineQuantization*>(
|
|
||||||
filter->quantization.params);
|
|
||||||
TF_LITE_ENSURE(context, affine_quantization);
|
|
||||||
TF_LITE_ENSURE(context, affine_quantization->scale);
|
|
||||||
TF_LITE_ENSURE(context, affine_quantization->zero_point);
|
|
||||||
TF_LITE_ENSURE(
|
|
||||||
context, affine_quantization->scale->size == 1 ||
|
|
||||||
affine_quantization->scale->size ==
|
|
||||||
filter->dims->data[kDepthwiseConvQuantizedDimension]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
|
||||||
affine_quantization->zero_point->size);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
|
|
||||||
filter_width, filter_height, data_type,
|
|
||||||
data));
|
|
||||||
|
|
||||||
data->input_zero_point = input->params.zero_point;
|
|
||||||
data->filter_zero_point = filter->params.zero_point;
|
|
||||||
data->output_zero_point = output->params.zero_point;
|
|
||||||
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteDepthwiseConvParams* params, const OpData& data,
|
|
||||||
const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
|
|
||||||
float output_activation_min, output_activation_max;
|
|
||||||
CalculateActivationRange(params->activation, &output_activation_min,
|
|
||||||
&output_activation_max);
|
|
||||||
|
|
||||||
tflite::DepthwiseParams op_params;
|
|
||||||
// Padding type is ignored, but still set.
|
|
||||||
op_params.padding_type = PaddingType::kSame;
|
|
||||||
op_params.padding_values.width = data.padding.width;
|
|
||||||
op_params.padding_values.height = data.padding.height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.depth_multiplier = params->depth_multiplier;
|
|
||||||
op_params.float_activation_min = output_activation_min;
|
|
||||||
op_params.float_activation_max = output_activation_max;
|
|
||||||
|
|
||||||
tflite::reference_ops::DepthwiseConv(
|
|
||||||
op_params, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<float>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<float>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<float>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<float>(output));
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteDepthwiseConvParams* params,
|
|
||||||
const OpData& data, const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias,
|
|
||||||
TfLiteEvalTensor* output) {
|
|
||||||
DepthwiseParams op_params;
|
|
||||||
op_params.padding_type = PaddingType::kSame;
|
|
||||||
op_params.padding_values.width = data.padding.width;
|
|
||||||
op_params.padding_values.height = data.padding.height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.depth_multiplier = params->depth_multiplier;
|
|
||||||
op_params.input_offset = -data.input_zero_point;
|
|
||||||
op_params.weights_offset = 0;
|
|
||||||
op_params.output_offset = data.output_zero_point;
|
|
||||||
// TODO(b/130439627): Use calculated value for clamping.
|
|
||||||
op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
|
|
||||||
op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
|
|
||||||
|
|
||||||
reference_integer_ops::DepthwiseConvPerChannel(
|
|
||||||
op_params, data.per_channel_output_multiplier,
|
|
||||||
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteDepthwiseConvParams* params, const OpData& data,
|
|
||||||
const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
|
|
||||||
TfLiteEvalTensor* output) {
|
|
||||||
const int32_t input_offset = -data.input_zero_point;
|
|
||||||
const int32_t filter_offset = -data.filter_zero_point;
|
|
||||||
const int32_t output_offset = data.output_zero_point;
|
|
||||||
|
|
||||||
tflite::DepthwiseParams op_params;
|
|
||||||
// Padding type is ignored, but still set.
|
|
||||||
op_params.padding_type = PaddingType::kSame;
|
|
||||||
op_params.padding_values.width = data.padding.width;
|
|
||||||
op_params.padding_values.height = data.padding.height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.depth_multiplier = params->depth_multiplier;
|
|
||||||
op_params.quantized_activation_min = data.output_activation_min;
|
|
||||||
op_params.quantized_activation_max = data.output_activation_max;
|
|
||||||
op_params.input_offset = input_offset;
|
|
||||||
op_params.weights_offset = filter_offset;
|
|
||||||
op_params.output_offset = output_offset;
|
|
||||||
op_params.output_multiplier = data.output_multiplier;
|
|
||||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
|
||||||
op_params.output_shift = -data.output_shift;
|
|
||||||
|
|
||||||
tflite::reference_ops::DepthwiseConv(
|
|
||||||
op_params, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(output));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
|
||||||
auto* params =
|
auto& params =
|
||||||
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
|
*(reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data));
|
||||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
const OpDataConv& data = *(static_cast<const OpDataConv*>(node->user_data));
|
||||||
|
|
||||||
TfLiteEvalTensor* output =
|
TfLiteEvalTensor* output =
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
tflite::micro::GetEvalOutput(context, node, kDepthwiseConvOutputTensor);
|
||||||
const TfLiteEvalTensor* input =
|
const TfLiteEvalTensor* input =
|
||||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
tflite::micro::GetEvalInput(context, node, kDepthwiseConvInputTensor);
|
||||||
const TfLiteEvalTensor* filter =
|
const TfLiteEvalTensor* filter =
|
||||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor);
|
||||||
const TfLiteEvalTensor* bias =
|
const TfLiteEvalTensor* bias =
|
||||||
(NumInputs(node) == 3)
|
(NumInputs(node) == 3)
|
||||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
// TODO(aselle): Consider whether float conv and quantized conv should be
|
|
||||||
// separate ops to avoid dispatch overhead here.
|
|
||||||
switch (input->type) { // Already know in/out types are same.
|
switch (input->type) { // Already know in/out types are same.
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32: {
|
||||||
EvalFloat(context, node, params, data, input, filter, bias, output);
|
tflite::reference_ops::DepthwiseConv(
|
||||||
|
DepthwiseConvParamsFloat(params, data),
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<float>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<float>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<float>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
}
|
||||||
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
|
case kTfLiteInt8: {
|
||||||
output);
|
reference_integer_ops::DepthwiseConvPerChannel(
|
||||||
break;
|
DepthwiseConvParamsQuantized(params, data),
|
||||||
case kTfLiteUInt8:
|
data.per_channel_output_multiplier, data.per_channel_output_shift,
|
||||||
EvalQuantized(context, node, params, data, input, filter, bias, output);
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
TfLiteTypeGetName(input->type), input->type);
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
@@ -315,7 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
|
||||||
return {/*init=*/Init,
|
return {/*init=*/Init,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/Prepare,
|
/*prepare=*/DepthwiseConvPrepare,
|
||||||
/*invoke=*/Eval,
|
/*invoke=*/Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
/*builtin_code=*/0,
|
/*builtin_code=*/0,
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/conv.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
extern const int kDepthwiseConvInputTensor;
|
||||||
|
extern const int kDepthwiseConvWeightsTensor;
|
||||||
|
extern const int kDepthwiseConvBiasTensor;
|
||||||
|
extern const int kDepthwiseConvOutputTensor;
|
||||||
|
extern const int kDepthwiseConvQuantizedDimension;
|
||||||
|
|
||||||
|
// Returns a DepthwiseParams struct with all the parameters needed for a
|
||||||
|
// float computation.
|
||||||
|
DepthwiseParams DepthwiseConvParamsFloat(
|
||||||
|
const TfLiteDepthwiseConvParams& params, const OpDataConv& data);
|
||||||
|
|
||||||
|
// Returns a DepthwiseParams struct with all the parameters needed for a
|
||||||
|
// quantized computation.
|
||||||
|
DepthwiseParams DepthwiseConvParamsQuantized(
|
||||||
|
const TfLiteDepthwiseConvParams& params, const OpDataConv& data);
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpDataDepthwiseConv(
|
||||||
|
TfLiteContext* context, TfLiteNode* node,
|
||||||
|
const TfLiteDepthwiseConvParams& params, int width, int height,
|
||||||
|
int filter_width, int filter_height, int out_width, int out_height,
|
||||||
|
const TfLiteType data_type, OpDataConv* data);
|
||||||
|
|
||||||
|
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
const int kDepthwiseConvInputTensor = 0;
|
||||||
|
const int kDepthwiseConvWeightsTensor = 1;
|
||||||
|
const int kDepthwiseConvBiasTensor = 2;
|
||||||
|
const int kDepthwiseConvOutputTensor = 0;
|
||||||
|
|
||||||
|
// DepthwiseConv is quantized along dimension 3:
|
||||||
|
// https://www.tensorflow.org/lite/performance/quantization_spec
|
||||||
|
const int kDepthwiseConvQuantizedDimension = 3;
|
||||||
|
|
||||||
|
// Returns a DepthwiseParams struct with all the parameters needed for a
|
||||||
|
// float computation.
|
||||||
|
DepthwiseParams DepthwiseConvParamsFloat(
|
||||||
|
const TfLiteDepthwiseConvParams& params, const OpDataConv& data) {
|
||||||
|
DepthwiseParams op_params;
|
||||||
|
CalculateActivationRange(params.activation, &op_params.float_activation_min,
|
||||||
|
&op_params.float_activation_max);
|
||||||
|
op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding);
|
||||||
|
op_params.padding_values.width = data.padding.width;
|
||||||
|
op_params.padding_values.height = data.padding.height;
|
||||||
|
op_params.stride_width = params.stride_width;
|
||||||
|
op_params.stride_height = params.stride_height;
|
||||||
|
op_params.dilation_width_factor = params.dilation_width_factor;
|
||||||
|
op_params.dilation_height_factor = params.dilation_height_factor;
|
||||||
|
op_params.depth_multiplier = params.depth_multiplier;
|
||||||
|
return op_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a DepthwiseParams struct with all the parameters needed for a
|
||||||
|
// quantized computation.
|
||||||
|
DepthwiseParams DepthwiseConvParamsQuantized(
|
||||||
|
const TfLiteDepthwiseConvParams& params, const OpDataConv& data) {
|
||||||
|
DepthwiseParams op_params;
|
||||||
|
op_params.input_offset = -data.input_zero_point;
|
||||||
|
op_params.weights_offset = -data.filter_zero_point;
|
||||||
|
op_params.output_offset = data.output_zero_point;
|
||||||
|
op_params.output_multiplier = data.output_multiplier;
|
||||||
|
op_params.output_shift = -data.output_shift;
|
||||||
|
op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding);
|
||||||
|
op_params.padding_values.height = data.padding.height;
|
||||||
|
op_params.padding_values.width = data.padding.width;
|
||||||
|
op_params.stride_height = params.stride_height;
|
||||||
|
op_params.stride_width = params.stride_width;
|
||||||
|
op_params.dilation_height_factor = params.dilation_height_factor;
|
||||||
|
op_params.dilation_width_factor = params.dilation_width_factor;
|
||||||
|
op_params.depth_multiplier = params.depth_multiplier;
|
||||||
|
op_params.quantized_activation_min = data.output_activation_min;
|
||||||
|
op_params.quantized_activation_max = data.output_activation_max;
|
||||||
|
return op_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpDataDepthwiseConv(
|
||||||
|
TfLiteContext* context, TfLiteNode* node,
|
||||||
|
const TfLiteDepthwiseConvParams& params, int width, int height,
|
||||||
|
int filter_width, int filter_height, int out_width, int out_height,
|
||||||
|
const TfLiteType data_type, OpDataConv* data) {
|
||||||
|
bool has_bias = node->inputs->size == 3;
|
||||||
|
// Check number of inputs/outputs
|
||||||
|
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||||
|
|
||||||
|
// Matching GetWindowedOutputSize in TensorFlow.
|
||||||
|
auto padding = params.padding;
|
||||||
|
data->padding = ComputePaddingHeightWidth(
|
||||||
|
params.stride_height, params.stride_width, params.dilation_height_factor,
|
||||||
|
params.dilation_width_factor, height, width, filter_height, filter_width,
|
||||||
|
padding, &out_height, &out_width);
|
||||||
|
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
|
||||||
|
TF_LITE_ENSURE(context, filter != nullptr);
|
||||||
|
const TfLiteTensor* bias =
|
||||||
|
GetOptionalInputTensor(context, node, kConvBiasTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
// Note that quantized inference requires that all tensors have their
|
||||||
|
// parameters set. This is usually done during quantized training.
|
||||||
|
if (data_type != kTfLiteFloat32) {
|
||||||
|
int output_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
|
||||||
|
context, input, filter, bias, output, params.activation,
|
||||||
|
&data->output_multiplier, &data->output_shift,
|
||||||
|
&data->output_activation_min, &data->output_activation_max,
|
||||||
|
data->per_channel_output_multiplier,
|
||||||
|
reinterpret_cast<int*>(data->per_channel_output_shift),
|
||||||
|
output_channels));
|
||||||
|
}
|
||||||
|
|
||||||
|
data->input_zero_point = input->params.zero_point;
|
||||||
|
data->filter_zero_point = filter->params.zero_point;
|
||||||
|
data->output_zero_point = output->params.zero_point;
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
|
||||||
|
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
|
||||||
|
const auto& params =
|
||||||
|
*(static_cast<const TfLiteDepthwiseConvParams*>(node->builtin_data));
|
||||||
|
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kDepthwiseConvOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
const TfLiteTensor* input =
|
||||||
|
GetInput(context, node, kDepthwiseConvInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* filter =
|
||||||
|
GetInput(context, node, kDepthwiseConvWeightsTensor);
|
||||||
|
TF_LITE_ENSURE(context, filter != nullptr);
|
||||||
|
|
||||||
|
const int input_width = input->dims->data[2];
|
||||||
|
const int input_height = input->dims->data[1];
|
||||||
|
const int filter_width = filter->dims->data[2];
|
||||||
|
const int filter_height = filter->dims->data[1];
|
||||||
|
const int output_width = output->dims->data[2];
|
||||||
|
const int output_height = output->dims->data[1];
|
||||||
|
|
||||||
|
// Dynamically allocate per-channel quantization parameters.
|
||||||
|
const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
||||||
|
data->per_channel_output_multiplier =
|
||||||
|
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||||
|
context, num_channels * sizeof(int32_t)));
|
||||||
|
data->per_channel_output_shift =
|
||||||
|
static_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||||
|
context, num_channels * sizeof(int32_t)));
|
||||||
|
|
||||||
|
// All per-channel quantized tensors need valid zero point and scale arrays.
|
||||||
|
if (input->type == kTfLiteInt8) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
||||||
|
kTfLiteAffineQuantization);
|
||||||
|
|
||||||
|
const auto* affine_quantization =
|
||||||
|
static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
|
||||||
|
TFLITE_DCHECK(affine_quantization != nullptr);
|
||||||
|
TFLITE_DCHECK(affine_quantization->scale != nullptr);
|
||||||
|
TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(
|
||||||
|
context, affine_quantization->scale->size == 1 ||
|
||||||
|
affine_quantization->scale->size ==
|
||||||
|
filter->dims->data[kDepthwiseConvQuantizedDimension]);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
||||||
|
affine_quantization->zero_point->size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
|
||||||
|
context, node, params, input_width, input_height, filter_width,
|
||||||
|
filter_height, output_width, output_height, input->type, data));
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -59,8 +59,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
|
TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
|
||||||
input->type == kTfLiteInt8 ||
|
input->type == kTfLiteInt8 ||
|
||||||
input->type == kTfLiteInt16);
|
input->type == kTfLiteInt16);
|
||||||
TF_LITE_ENSURE(
|
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
|
||||||
context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt32);
|
|
||||||
|
|
||||||
if (output->type == kTfLiteInt32) {
|
if (output->type == kTfLiteInt32) {
|
||||||
const double effective_output_scale =
|
const double effective_output_scale =
|
||||||
@@ -112,32 +111,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTypeGetName(output->type));
|
TfLiteTypeGetName(output->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} else if (output->type == kTfLiteInt32) {
|
|
||||||
int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorShape(output));
|
|
||||||
switch (input->type) {
|
|
||||||
case kTfLiteInt16: {
|
|
||||||
reference_ops::Requantize(
|
|
||||||
tflite::micro::GetTensorData<int16_t>(input), flat_size,
|
|
||||||
data->output_multiplier, data->output_shift,
|
|
||||||
data->quantization_params.zero_point, data->output_zero_point,
|
|
||||||
tflite::micro::GetTensorData<int32_t>(output));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kTfLiteInt8: {
|
|
||||||
reference_ops::Requantize(
|
|
||||||
tflite::micro::GetTensorData<int8_t>(input), flat_size,
|
|
||||||
data->output_multiplier, data->output_shift,
|
|
||||||
data->quantization_params.zero_point, data->output_zero_point,
|
|
||||||
tflite::micro::GetTensorData<int32_t>(output));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
|
||||||
TfLiteTypeGetName(input->type),
|
|
||||||
TfLiteTypeGetName(output->type));
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||||
TfLiteTypeGetName(input->type),
|
TfLiteTypeGetName(input->type),
|
||||||
|
|||||||
@@ -0,0 +1,805 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_utils.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This version of detection_postprocess is specific to TFLite Micro. It
|
||||||
|
* contains the following differences between the TFLite version:
|
||||||
|
*
|
||||||
|
* 1.) Temporaries (temporary tensors) - Micro use instead scratch buffer API.
|
||||||
|
* 2.) Output dimensions - the TFLite version does not support undefined out
|
||||||
|
* dimensions. So model must have static out dimensions.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Input tensors
|
||||||
|
constexpr int kInputTensorBoxEncodings = 0;
|
||||||
|
constexpr int kInputTensorClassPredictions = 1;
|
||||||
|
constexpr int kInputTensorAnchors = 2;
|
||||||
|
|
||||||
|
// Output tensors
|
||||||
|
constexpr int kOutputTensorDetectionBoxes = 0;
|
||||||
|
constexpr int kOutputTensorDetectionClasses = 1;
|
||||||
|
constexpr int kOutputTensorDetectionScores = 2;
|
||||||
|
constexpr int kOutputTensorNumDetections = 3;
|
||||||
|
|
||||||
|
constexpr int kNumCoordBox = 4;
|
||||||
|
constexpr int kBatchSize = 1;
|
||||||
|
|
||||||
|
constexpr int kNumDetectionsPerClass = 100;
|
||||||
|
|
||||||
|
// Object Detection model produces axis-aligned boxes in two formats:
|
||||||
|
// BoxCorner represents the lower left corner (xmin, ymin) and
|
||||||
|
// the upper right corner (xmax, ymax).
|
||||||
|
// CenterSize represents the center (xcenter, ycenter), height and width.
|
||||||
|
// BoxCornerEncoding and CenterSizeEncoding are related as follows:
|
||||||
|
// ycenter = y / y_scale * anchor.h + anchor.y;
|
||||||
|
// xcenter = x / x_scale * anchor.w + anchor.x;
|
||||||
|
// half_h = 0.5*exp(h/ h_scale)) * anchor.h;
|
||||||
|
// half_w = 0.5*exp(w / w_scale)) * anchor.w;
|
||||||
|
// ymin = ycenter - half_h
|
||||||
|
// ymax = ycenter + half_h
|
||||||
|
// xmin = xcenter - half_w
|
||||||
|
// xmax = xcenter + half_w
|
||||||
|
struct BoxCornerEncoding {
|
||||||
|
float ymin;
|
||||||
|
float xmin;
|
||||||
|
float ymax;
|
||||||
|
float xmax;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CenterSizeEncoding {
|
||||||
|
float y;
|
||||||
|
float x;
|
||||||
|
float h;
|
||||||
|
float w;
|
||||||
|
};
|
||||||
|
// We make sure that the memory allocations are contiguous with static_assert.
|
||||||
|
static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
|
||||||
|
"Size of BoxCornerEncoding is 4 float values");
|
||||||
|
static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
|
||||||
|
"Size of CenterSizeEncoding is 4 float values");
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int max_detections;
|
||||||
|
int max_classes_per_detection; // Fast Non-Max-Suppression
|
||||||
|
int detections_per_class; // Regular Non-Max-Suppression
|
||||||
|
float non_max_suppression_score_threshold;
|
||||||
|
float intersection_over_union_threshold;
|
||||||
|
int num_classes;
|
||||||
|
bool use_regular_non_max_suppression;
|
||||||
|
CenterSizeEncoding scale_values;
|
||||||
|
|
||||||
|
// Scratch buffers indexes
|
||||||
|
int active_candidate_idx;
|
||||||
|
int decoded_boxes_idx;
|
||||||
|
int scores_idx;
|
||||||
|
int score_buffer_idx;
|
||||||
|
int keep_scores_idx;
|
||||||
|
int scores_after_regular_non_max_suppression_idx;
|
||||||
|
int sorted_values_idx;
|
||||||
|
int keep_indices_idx;
|
||||||
|
int sorted_indices_idx;
|
||||||
|
int buffer_idx;
|
||||||
|
int selected_idx;
|
||||||
|
|
||||||
|
// Cached tensor scale and zero point values for quantized operations
|
||||||
|
TfLiteQuantizationParams input_box_encodings;
|
||||||
|
TfLiteQuantizationParams input_class_predictions;
|
||||||
|
TfLiteQuantizationParams input_anchors;
|
||||||
|
};
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
OpData* op_data = nullptr;
|
||||||
|
|
||||||
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
|
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||||
|
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
op_data = reinterpret_cast<OpData*>(
|
||||||
|
context->AllocatePersistentBuffer(context, sizeof(OpData)));
|
||||||
|
|
||||||
|
op_data->max_detections = m["max_detections"].AsInt32();
|
||||||
|
op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
|
||||||
|
if (m["detections_per_class"].IsNull())
|
||||||
|
op_data->detections_per_class = kNumDetectionsPerClass;
|
||||||
|
else
|
||||||
|
op_data->detections_per_class = m["detections_per_class"].AsInt32();
|
||||||
|
if (m["use_regular_nms"].IsNull())
|
||||||
|
op_data->use_regular_non_max_suppression = false;
|
||||||
|
else
|
||||||
|
op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
|
||||||
|
|
||||||
|
op_data->non_max_suppression_score_threshold =
|
||||||
|
m["nms_score_threshold"].AsFloat();
|
||||||
|
op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
|
||||||
|
op_data->num_classes = m["num_classes"].AsInt32();
|
||||||
|
op_data->scale_values.y = m["y_scale"].AsFloat();
|
||||||
|
op_data->scale_values.x = m["x_scale"].AsFloat();
|
||||||
|
op_data->scale_values.h = m["h_scale"].AsFloat();
|
||||||
|
op_data->scale_values.w = m["w_scale"].AsFloat();
|
||||||
|
|
||||||
|
return op_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(TfLiteContext* context, void* buffer) {}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
|
// Inputs: box_encodings, scores, anchors
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||||
|
const TfLiteTensor* input_box_encodings =
|
||||||
|
GetInput(context, node, kInputTensorBoxEncodings);
|
||||||
|
const TfLiteTensor* input_class_predictions =
|
||||||
|
GetInput(context, node, kInputTensorClassPredictions);
|
||||||
|
const TfLiteTensor* input_anchors =
|
||||||
|
GetInput(context, node, kInputTensorAnchors);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||||
|
const int num_boxes = input_box_encodings->dims->data[1];
|
||||||
|
const int num_classes = op_data->num_classes;
|
||||||
|
|
||||||
|
op_data->input_box_encodings.scale = input_box_encodings->params.scale;
|
||||||
|
op_data->input_box_encodings.zero_point =
|
||||||
|
input_box_encodings->params.zero_point;
|
||||||
|
op_data->input_class_predictions.scale =
|
||||||
|
input_class_predictions->params.scale;
|
||||||
|
op_data->input_class_predictions.zero_point =
|
||||||
|
input_class_predictions->params.zero_point;
|
||||||
|
op_data->input_anchors.scale = input_anchors->params.scale;
|
||||||
|
op_data->input_anchors.zero_point = input_anchors->params.zero_point;
|
||||||
|
|
||||||
|
// Scratch tensors
|
||||||
|
context->RequestScratchBufferInArena(context, num_boxes,
|
||||||
|
&op_data->active_candidate_idx);
|
||||||
|
context->RequestScratchBufferInArena(context,
|
||||||
|
num_boxes * kNumCoordBox * sizeof(float),
|
||||||
|
&op_data->decoded_boxes_idx);
|
||||||
|
context->RequestScratchBufferInArena(
|
||||||
|
context,
|
||||||
|
input_class_predictions->dims->data[1] *
|
||||||
|
input_class_predictions->dims->data[2] * sizeof(float),
|
||||||
|
&op_data->scores_idx);
|
||||||
|
|
||||||
|
// Additional buffers
|
||||||
|
context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
|
||||||
|
&op_data->score_buffer_idx);
|
||||||
|
context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
|
||||||
|
&op_data->keep_scores_idx);
|
||||||
|
context->RequestScratchBufferInArena(
|
||||||
|
context, op_data->max_detections * num_boxes * sizeof(float),
|
||||||
|
&op_data->scores_after_regular_non_max_suppression_idx);
|
||||||
|
context->RequestScratchBufferInArena(
|
||||||
|
context, op_data->max_detections * num_boxes * sizeof(float),
|
||||||
|
&op_data->sorted_values_idx);
|
||||||
|
context->RequestScratchBufferInArena(context, num_boxes * sizeof(int),
|
||||||
|
&op_data->keep_indices_idx);
|
||||||
|
context->RequestScratchBufferInArena(
|
||||||
|
context, op_data->max_detections * num_boxes * sizeof(int),
|
||||||
|
&op_data->sorted_indices_idx);
|
||||||
|
int buffer_size = std::max(num_classes, op_data->max_detections);
|
||||||
|
context->RequestScratchBufferInArena(
|
||||||
|
context, buffer_size * num_boxes * sizeof(int), &op_data->buffer_idx);
|
||||||
|
buffer_size = std::min(num_boxes, op_data->max_detections);
|
||||||
|
context->RequestScratchBufferInArena(
|
||||||
|
context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx);
|
||||||
|
|
||||||
|
// Outputs: detection_boxes, detection_scores, detection_classes,
|
||||||
|
// num_detections
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
class Dequantizer {
|
||||||
|
public:
|
||||||
|
Dequantizer(int zero_point, float scale)
|
||||||
|
: zero_point_(zero_point), scale_(scale) {}
|
||||||
|
float operator()(uint8_t x) {
|
||||||
|
return (static_cast<float>(x) - zero_point_) * scale_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int zero_point_;
|
||||||
|
float scale_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void DequantizeBoxEncodings(const TfLiteEvalTensor* input_box_encodings,
|
||||||
|
int idx, float quant_zero_point, float quant_scale,
|
||||||
|
int length_box_encoding,
|
||||||
|
CenterSizeEncoding* box_centersize) {
|
||||||
|
const uint8_t* boxes =
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input_box_encodings) +
|
||||||
|
length_box_encoding * idx;
|
||||||
|
Dequantizer dequantize(quant_zero_point, quant_scale);
|
||||||
|
// See definition of the KeyPointBoxCoder at
|
||||||
|
// https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
|
||||||
|
// The first four elements are the box coordinates, which is the same as the
|
||||||
|
// FastRnnBoxCoder at
|
||||||
|
// https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
|
||||||
|
box_centersize->y = dequantize(boxes[0]);
|
||||||
|
box_centersize->x = dequantize(boxes[1]);
|
||||||
|
box_centersize->h = dequantize(boxes[2]);
|
||||||
|
box_centersize->w = dequantize(boxes[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T ReInterpretTensor(const TfLiteEvalTensor* tensor) {
|
||||||
|
const float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
|
||||||
|
return reinterpret_cast<T>(tensor_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T ReInterpretTensor(TfLiteEvalTensor* tensor) {
|
||||||
|
float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
|
||||||
|
return reinterpret_cast<T>(tensor_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
OpData* op_data) {
|
||||||
|
// Parse input tensor boxencodings
|
||||||
|
const TfLiteEvalTensor* input_box_encodings =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
|
||||||
|
const int num_boxes = input_box_encodings->dims->data[1];
|
||||||
|
TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
|
||||||
|
const TfLiteEvalTensor* input_anchors =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorAnchors);
|
||||||
|
|
||||||
|
// Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
|
||||||
|
CenterSizeEncoding box_centersize;
|
||||||
|
CenterSizeEncoding scale_values = op_data->scale_values;
|
||||||
|
CenterSizeEncoding anchor;
|
||||||
|
for (int idx = 0; idx < num_boxes; ++idx) {
|
||||||
|
switch (input_box_encodings->type) {
|
||||||
|
// Quantized
|
||||||
|
case kTfLiteUInt8:
|
||||||
|
DequantizeBoxEncodings(
|
||||||
|
input_box_encodings, idx,
|
||||||
|
static_cast<float>(op_data->input_box_encodings.zero_point),
|
||||||
|
static_cast<float>(op_data->input_box_encodings.scale),
|
||||||
|
input_box_encodings->dims->data[2], &box_centersize);
|
||||||
|
DequantizeBoxEncodings(
|
||||||
|
input_anchors, idx,
|
||||||
|
static_cast<float>(op_data->input_anchors.zero_point),
|
||||||
|
static_cast<float>(op_data->input_anchors.scale), kNumCoordBox,
|
||||||
|
&anchor);
|
||||||
|
break;
|
||||||
|
// Float
|
||||||
|
case kTfLiteFloat32: {
|
||||||
|
// Please see DequantizeBoxEncodings function for the support detail.
|
||||||
|
const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
|
||||||
|
const float* boxes = &(tflite::micro::GetTensorData<float>(
|
||||||
|
input_box_encodings)[box_encoding_idx]);
|
||||||
|
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
|
||||||
|
anchor =
|
||||||
|
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Unsupported type.
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) /
|
||||||
|
static_cast<double>(scale_values.y) *
|
||||||
|
static_cast<double>(anchor.h) +
|
||||||
|
static_cast<double>(anchor.y));
|
||||||
|
|
||||||
|
float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) /
|
||||||
|
static_cast<double>(scale_values.x) *
|
||||||
|
static_cast<double>(anchor.w) +
|
||||||
|
static_cast<double>(anchor.x));
|
||||||
|
|
||||||
|
float half_h =
|
||||||
|
static_cast<float>(0.5 *
|
||||||
|
(std::exp(static_cast<double>(box_centersize.h) /
|
||||||
|
static_cast<double>(scale_values.h))) *
|
||||||
|
static_cast<double>(anchor.h));
|
||||||
|
float half_w =
|
||||||
|
static_cast<float>(0.5 *
|
||||||
|
(std::exp(static_cast<double>(box_centersize.w) /
|
||||||
|
static_cast<double>(scale_values.w))) *
|
||||||
|
static_cast<double>(anchor.w));
|
||||||
|
|
||||||
|
float* decoded_boxes = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||||
|
auto& box = reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[idx];
|
||||||
|
box.ymin = ycenter - half_h;
|
||||||
|
box.xmin = xcenter - half_w;
|
||||||
|
box.ymax = ycenter + half_h;
|
||||||
|
box.xmax = xcenter + half_w;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecreasingPartialArgSort(const float* values, int num_values,
|
||||||
|
int num_to_sort, int* indices) {
|
||||||
|
std::iota(indices, indices + num_values, 0);
|
||||||
|
std::partial_sort(
|
||||||
|
indices, indices + num_to_sort, indices + num_values,
|
||||||
|
[&values](const int i, const int j) { return values[i] > values[j]; });
|
||||||
|
}
|
||||||
|
|
||||||
|
int SelectDetectionsAboveScoreThreshold(const float* values, int size,
|
||||||
|
const float threshold,
|
||||||
|
float* keep_values, int* keep_indices) {
|
||||||
|
int counter = 0;
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
if (values[i] >= threshold) {
|
||||||
|
keep_values[counter] = values[i];
|
||||||
|
keep_indices[counter] = i;
|
||||||
|
counter++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return counter;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) {
|
||||||
|
for (int i = 0; i < num_boxes; ++i) {
|
||||||
|
// ymax>=ymin, xmax>=xmin
|
||||||
|
auto& box = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
|
||||||
|
if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
float ComputeIntersectionOverUnion(const float* decoded_boxes, const int i,
|
||||||
|
const int j) {
|
||||||
|
auto& box_i = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
|
||||||
|
auto& box_j = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[j];
|
||||||
|
const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
|
||||||
|
const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
|
||||||
|
if (area_i <= 0 || area_j <= 0) return 0.0;
|
||||||
|
const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
|
||||||
|
const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
|
||||||
|
const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
|
||||||
|
const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
|
||||||
|
const float intersection_area =
|
||||||
|
std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
|
||||||
|
std::max<float>(intersection_xmax - intersection_xmin, 0.0);
|
||||||
|
return intersection_area / (area_i + area_j - intersection_area);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
|
||||||
|
// before selecting the highest scoring boxes (max_detections in number)
|
||||||
|
// It assumes all boxes are good in beginning and sorts based on the scores.
|
||||||
|
// If lower-scoring box has too much overlap with a higher-scoring box,
|
||||||
|
// we get rid of the lower-scoring box.
|
||||||
|
// Complexity is O(N^2) pairwise comparison between boxes
|
||||||
|
TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
||||||
|
TfLiteContext* context, TfLiteNode* node, OpData* op_data,
|
||||||
|
const float* scores, int* selected, int* selected_size,
|
||||||
|
int max_detections) {
|
||||||
|
const TfLiteEvalTensor* input_box_encodings =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||||
|
const int num_boxes = input_box_encodings->dims->data[1];
|
||||||
|
const float non_max_suppression_score_threshold =
|
||||||
|
op_data->non_max_suppression_score_threshold;
|
||||||
|
const float intersection_over_union_threshold =
|
||||||
|
op_data->intersection_over_union_threshold;
|
||||||
|
// Maximum detections should be positive.
|
||||||
|
TF_LITE_ENSURE(context, (max_detections >= 0));
|
||||||
|
// intersection_over_union_threshold should be positive
|
||||||
|
// and should be less than 1.
|
||||||
|
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
|
||||||
|
(intersection_over_union_threshold <= 1.0f));
|
||||||
|
// Validate boxes
|
||||||
|
float* decoded_boxes = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
|
||||||
|
|
||||||
|
// threshold scores
|
||||||
|
int* keep_indices = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->keep_indices_idx));
|
||||||
|
float* keep_scores = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->keep_scores_idx));
|
||||||
|
int num_scores_kept = SelectDetectionsAboveScoreThreshold(
|
||||||
|
scores, num_boxes, non_max_suppression_score_threshold, keep_scores,
|
||||||
|
keep_indices);
|
||||||
|
int* sorted_indices = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->sorted_indices_idx));
|
||||||
|
|
||||||
|
DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept,
|
||||||
|
sorted_indices);
|
||||||
|
|
||||||
|
const int num_boxes_kept = num_scores_kept;
|
||||||
|
const int output_size = std::min(num_boxes_kept, max_detections);
|
||||||
|
*selected_size = 0;
|
||||||
|
|
||||||
|
int num_active_candidate = num_boxes_kept;
|
||||||
|
uint8_t* active_box_candidate = reinterpret_cast<uint8_t*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->active_candidate_idx));
|
||||||
|
|
||||||
|
for (int row = 0; row < num_boxes_kept; row++) {
|
||||||
|
active_box_candidate[row] = 1;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_boxes_kept; ++i) {
|
||||||
|
if (num_active_candidate == 0 || *selected_size >= output_size) break;
|
||||||
|
if (active_box_candidate[i] == 1) {
|
||||||
|
selected[(*selected_size)++] = keep_indices[sorted_indices[i]];
|
||||||
|
active_box_candidate[i] = 0;
|
||||||
|
num_active_candidate--;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int j = i + 1; j < num_boxes_kept; ++j) {
|
||||||
|
if (active_box_candidate[j] == 1) {
|
||||||
|
float intersection_over_union = ComputeIntersectionOverUnion(
|
||||||
|
decoded_boxes, keep_indices[sorted_indices[i]],
|
||||||
|
keep_indices[sorted_indices[j]]);
|
||||||
|
|
||||||
|
if (intersection_over_union > intersection_over_union_threshold) {
|
||||||
|
active_box_candidate[j] = 0;
|
||||||
|
num_active_candidate--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function implements a regular version of Non Maximal Suppression (NMS)
|
||||||
|
// for multiple classes where
|
||||||
|
// 1) we do NMS separately for each class across all anchors and
|
||||||
|
// 2) keep only the highest anchor scores across all classes
|
||||||
|
// 3) The worst runtime of the regular NMS is O(K*N^2)
|
||||||
|
// where N is the number of anchors and K the number of
|
||||||
|
// classes.
|
||||||
|
TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
||||||
|
TfLiteNode* node,
|
||||||
|
OpData* op_data,
|
||||||
|
const float* scores) {
|
||||||
|
const TfLiteEvalTensor* input_box_encodings =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||||
|
const TfLiteEvalTensor* input_class_predictions =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
|
||||||
|
TfLiteEvalTensor* detection_boxes =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
|
||||||
|
TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
|
||||||
|
context, node, kOutputTensorDetectionClasses);
|
||||||
|
TfLiteEvalTensor* detection_scores =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
|
||||||
|
TfLiteEvalTensor* num_detections =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
|
||||||
|
|
||||||
|
const int num_boxes = input_box_encodings->dims->data[1];
|
||||||
|
const int num_classes = op_data->num_classes;
|
||||||
|
const int num_detections_per_class = op_data->detections_per_class;
|
||||||
|
const int max_detections = op_data->max_detections;
|
||||||
|
const int num_classes_with_background =
|
||||||
|
input_class_predictions->dims->data[2];
|
||||||
|
// The row index offset is 1 if background class is included and 0 otherwise.
|
||||||
|
int label_offset = num_classes_with_background - num_classes;
|
||||||
|
TF_LITE_ENSURE(context, num_detections_per_class > 0);
|
||||||
|
|
||||||
|
// For each class, perform non-max suppression.
|
||||||
|
float* class_scores = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->score_buffer_idx));
|
||||||
|
int* box_indices_after_regular_non_max_suppression = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->buffer_idx));
|
||||||
|
float* scores_after_regular_non_max_suppression =
|
||||||
|
reinterpret_cast<float*>(context->GetScratchBuffer(
|
||||||
|
context, op_data->scores_after_regular_non_max_suppression_idx));
|
||||||
|
|
||||||
|
int size_of_sorted_indices = 0;
|
||||||
|
int* sorted_indices = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->sorted_indices_idx));
|
||||||
|
float* sorted_values = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->sorted_values_idx));
|
||||||
|
|
||||||
|
for (int col = 0; col < num_classes; col++) {
|
||||||
|
for (int row = 0; row < num_boxes; row++) {
|
||||||
|
// Get scores of boxes corresponding to all anchors for single class
|
||||||
|
class_scores[row] =
|
||||||
|
*(scores + row * num_classes_with_background + col + label_offset);
|
||||||
|
}
|
||||||
|
// Perform non-maximal suppression on single class
|
||||||
|
int selected_size = 0;
|
||||||
|
int* selected = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->selected_idx));
|
||||||
|
TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
|
||||||
|
context, node, op_data, class_scores, selected, &selected_size,
|
||||||
|
num_detections_per_class));
|
||||||
|
// Add selected indices from non-max suppression of boxes in this class
|
||||||
|
int output_index = size_of_sorted_indices;
|
||||||
|
for (int i = 0; i < selected_size; i++) {
|
||||||
|
int selected_index = selected[i];
|
||||||
|
|
||||||
|
box_indices_after_regular_non_max_suppression[output_index] =
|
||||||
|
(selected_index * num_classes_with_background + col + label_offset);
|
||||||
|
scores_after_regular_non_max_suppression[output_index] =
|
||||||
|
class_scores[selected_index];
|
||||||
|
output_index++;
|
||||||
|
}
|
||||||
|
// Sort the max scores among the selected indices
|
||||||
|
// Get the indices for top scores
|
||||||
|
int num_indices_to_sort = std::min(output_index, max_detections);
|
||||||
|
DecreasingPartialArgSort(scores_after_regular_non_max_suppression,
|
||||||
|
output_index, num_indices_to_sort, sorted_indices);
|
||||||
|
|
||||||
|
// Copy values to temporary vectors
|
||||||
|
for (int row = 0; row < num_indices_to_sort; row++) {
|
||||||
|
int temp = sorted_indices[row];
|
||||||
|
sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
|
||||||
|
sorted_values[row] = scores_after_regular_non_max_suppression[temp];
|
||||||
|
}
|
||||||
|
// Copy scores and indices from temporary vectors
|
||||||
|
for (int row = 0; row < num_indices_to_sort; row++) {
|
||||||
|
box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
|
||||||
|
scores_after_regular_non_max_suppression[row] = sorted_values[row];
|
||||||
|
}
|
||||||
|
size_of_sorted_indices = num_indices_to_sort;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate output tensors
|
||||||
|
for (int output_box_index = 0; output_box_index < max_detections;
|
||||||
|
output_box_index++) {
|
||||||
|
if (output_box_index < size_of_sorted_indices) {
|
||||||
|
const int anchor_index = floor(
|
||||||
|
box_indices_after_regular_non_max_suppression[output_box_index] /
|
||||||
|
num_classes_with_background);
|
||||||
|
const int class_index =
|
||||||
|
box_indices_after_regular_non_max_suppression[output_box_index] -
|
||||||
|
anchor_index * num_classes_with_background - label_offset;
|
||||||
|
const float selected_score =
|
||||||
|
scores_after_regular_non_max_suppression[output_box_index];
|
||||||
|
// detection_boxes
|
||||||
|
float* decoded_boxes = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||||
|
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
|
||||||
|
reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[anchor_index];
|
||||||
|
// detection_classes
|
||||||
|
tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
|
||||||
|
class_index;
|
||||||
|
// detection_scores
|
||||||
|
tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
|
||||||
|
selected_score;
|
||||||
|
} else {
|
||||||
|
ReInterpretTensor<BoxCornerEncoding*>(
|
||||||
|
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
// detection_classes
|
||||||
|
tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
|
||||||
|
0.0f;
|
||||||
|
// detection_scores
|
||||||
|
tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
|
||||||
|
0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tflite::micro::GetTensorData<float>(num_detections)[0] =
|
||||||
|
size_of_sorted_indices;
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function implements a fast version of Non Maximal Suppression for
|
||||||
|
// multiple classes where
|
||||||
|
// 1) we keep the top-k scores for each anchor and
|
||||||
|
// 2) during NMS, each anchor only uses the highest class score for sorting.
|
||||||
|
// 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
|
||||||
|
// instead of O(KN^2) where N is the number of anchors and K the number of
|
||||||
|
// classes.
|
||||||
|
TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
|
||||||
|
TfLiteNode* node,
|
||||||
|
OpData* op_data,
|
||||||
|
const float* scores) {
|
||||||
|
const TfLiteEvalTensor* input_box_encodings =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||||
|
const TfLiteEvalTensor* input_class_predictions =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
|
||||||
|
TfLiteEvalTensor* detection_boxes =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
|
||||||
|
|
||||||
|
TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
|
||||||
|
context, node, kOutputTensorDetectionClasses);
|
||||||
|
TfLiteEvalTensor* detection_scores =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
|
||||||
|
TfLiteEvalTensor* num_detections =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
|
||||||
|
|
||||||
|
const int num_boxes = input_box_encodings->dims->data[1];
|
||||||
|
const int num_classes = op_data->num_classes;
|
||||||
|
const int max_categories_per_anchor = op_data->max_classes_per_detection;
|
||||||
|
const int num_classes_with_background =
|
||||||
|
input_class_predictions->dims->data[2];
|
||||||
|
|
||||||
|
// The row index offset is 1 if background class is included and 0 otherwise.
|
||||||
|
int label_offset = num_classes_with_background - num_classes;
|
||||||
|
TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
|
||||||
|
const int num_categories_per_anchor =
|
||||||
|
std::min(max_categories_per_anchor, num_classes);
|
||||||
|
float* max_scores = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->score_buffer_idx));
|
||||||
|
int* sorted_class_indices = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->buffer_idx));
|
||||||
|
|
||||||
|
for (int row = 0; row < num_boxes; row++) {
|
||||||
|
const float* box_scores =
|
||||||
|
scores + row * num_classes_with_background + label_offset;
|
||||||
|
int* class_indices = sorted_class_indices + row * num_classes;
|
||||||
|
DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
|
||||||
|
class_indices);
|
||||||
|
max_scores[row] = box_scores[class_indices[0]];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform non-maximal suppression on max scores
|
||||||
|
int selected_size = 0;
|
||||||
|
int* selected = reinterpret_cast<int*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->selected_idx));
|
||||||
|
TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
|
||||||
|
context, node, op_data, max_scores, selected, &selected_size,
|
||||||
|
op_data->max_detections));
|
||||||
|
|
||||||
|
// Allocate output tensors
|
||||||
|
int output_box_index = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < selected_size; i++) {
|
||||||
|
int selected_index = selected[i];
|
||||||
|
|
||||||
|
const float* box_scores =
|
||||||
|
scores + selected_index * num_classes_with_background + label_offset;
|
||||||
|
const int* class_indices =
|
||||||
|
sorted_class_indices + selected_index * num_classes;
|
||||||
|
|
||||||
|
for (int col = 0; col < num_categories_per_anchor; ++col) {
|
||||||
|
int box_offset = num_categories_per_anchor * output_box_index + col;
|
||||||
|
|
||||||
|
// detection_boxes
|
||||||
|
float* decoded_boxes = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||||
|
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
|
||||||
|
reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[selected_index];
|
||||||
|
|
||||||
|
// detection_classes
|
||||||
|
tflite::micro::GetTensorData<float>(detection_classes)[box_offset] =
|
||||||
|
class_indices[col];
|
||||||
|
|
||||||
|
// detection_scores
|
||||||
|
tflite::micro::GetTensorData<float>(detection_scores)[box_offset] =
|
||||||
|
box_scores[class_indices[col]];
|
||||||
|
|
||||||
|
output_box_index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tflite::micro::GetTensorData<float>(num_detections)[0] = output_box_index;
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DequantizeClassPredictions(const TfLiteEvalTensor* input_class_predictions,
|
||||||
|
const int num_boxes,
|
||||||
|
const int num_classes_with_background,
|
||||||
|
float* scores, OpData* op_data) {
|
||||||
|
float quant_zero_point =
|
||||||
|
static_cast<float>(op_data->input_class_predictions.zero_point);
|
||||||
|
float quant_scale =
|
||||||
|
static_cast<float>(op_data->input_class_predictions.scale);
|
||||||
|
Dequantizer dequantize(quant_zero_point, quant_scale);
|
||||||
|
const uint8_t* scores_quant =
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input_class_predictions);
|
||||||
|
for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
|
||||||
|
scores[idx] = dequantize(scores_quant[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
|
||||||
|
TfLiteNode* node, OpData* op_data) {
|
||||||
|
// Get the input tensors
|
||||||
|
const TfLiteEvalTensor* input_box_encodings =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||||
|
const TfLiteEvalTensor* input_class_predictions =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
|
||||||
|
const int num_boxes = input_box_encodings->dims->data[1];
|
||||||
|
const int num_classes = op_data->num_classes;
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
|
||||||
|
kBatchSize);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
|
||||||
|
const int num_classes_with_background =
|
||||||
|
input_class_predictions->dims->data[2];
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
|
||||||
|
TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
|
||||||
|
|
||||||
|
const float* scores;
|
||||||
|
switch (input_class_predictions->type) {
|
||||||
|
case kTfLiteUInt8: {
|
||||||
|
float* temporary_scores = reinterpret_cast<float*>(
|
||||||
|
context->GetScratchBuffer(context, op_data->scores_idx));
|
||||||
|
DequantizeClassPredictions(input_class_predictions, num_boxes,
|
||||||
|
num_classes_with_background, temporary_scores,
|
||||||
|
op_data);
|
||||||
|
scores = temporary_scores;
|
||||||
|
} break;
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
scores = tflite::micro::GetTensorData<float>(input_class_predictions);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Unsupported type.
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_data->use_regular_non_max_suppression) {
|
||||||
|
TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
|
||||||
|
context, node, op_data, scores));
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
NonMaxSuppressionMultiClassFastHelper(context, node, op_data, scores));
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
||||||
|
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
|
// These two functions correspond to two blocks in the Object Detection model.
|
||||||
|
// In future, we would like to break the custom op in two blocks, which is
|
||||||
|
// currently not feasible because we would like to input quantized inputs
|
||||||
|
// and do all calculations in float. Mixed quantized/float calculations are
|
||||||
|
// currently not supported in TFLite.
|
||||||
|
|
||||||
|
// This fills in temporary decoded_boxes
|
||||||
|
// by transforming input_box_encodings and input_anchors from
|
||||||
|
// CenterSizeEncodings to BoxCornerEncoding
|
||||||
|
TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
|
||||||
|
|
||||||
|
// This fills in the output tensors
|
||||||
|
// by choosing effective set of decoded boxes
|
||||||
|
// based on Non Maximal Suppression, i.e. selecting
|
||||||
|
// highest scoring non-overlapping boxes.
|
||||||
|
TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
|
||||||
|
static TfLiteRegistration r = {/*init=*/Init,
|
||||||
|
/*free=*/Free,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FLEXBUFFERS_GENERATED_DATA_H
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_FLEXBUFFERS_GENERATED_DATA_H
|
||||||
|
|
||||||
|
extern const int g_gen_data_size_none_regular_nms;
|
||||||
|
extern const unsigned char g_gen_data_none_regular_nms[];
|
||||||
|
|
||||||
|
extern const int g_gen_data_size_regular_nms;
|
||||||
|
extern const unsigned char g_gen_data_regular_nms[];
|
||||||
|
|
||||||
|
#endif
|
||||||
206
code/components/tfmicro/tensorflow/lite/micro/kernels/div.cc
Normal file
206
code/components/tfmicro/tensorflow/lite/micro/kernels/div.cc
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/div.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputTensor1 = 0;
|
||||||
|
constexpr int kInputTensor2 = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
// Parameters used in the quantized paths where the output is 8bit
|
||||||
|
int32_t input1_zero_point;
|
||||||
|
int32_t input2_zero_point;
|
||||||
|
int32_t output_zero_point;
|
||||||
|
int32_t output_activation_min;
|
||||||
|
int32_t output_activation_max;
|
||||||
|
|
||||||
|
// Parameters used in all quantized paths
|
||||||
|
int32_t output_multiplier;
|
||||||
|
int output_shift;
|
||||||
|
};
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
TfLiteDivParams* params, OpData* data) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
|
const TfLiteTensor* input1;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||||
|
const TfLiteTensor* input2;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||||
|
TfLiteTensor* output;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, output->type);
|
||||||
|
|
||||||
|
if (output->type == kTfLiteInt8) {
|
||||||
|
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||||
|
context, params->activation, output, &data->output_activation_min,
|
||||||
|
&data->output_activation_max));
|
||||||
|
const double real_multiplier = static_cast<double>(
|
||||||
|
input1->params.scale / (input2->params.scale * output->params.scale));
|
||||||
|
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
|
||||||
|
&data->output_shift);
|
||||||
|
data->input1_zero_point = input1->params.zero_point;
|
||||||
|
data->input2_zero_point = input2->params.zero_point;
|
||||||
|
data->output_zero_point = output->params.zero_point;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
auto* params = static_cast<TfLiteDivParams*>(node->builtin_data);
|
||||||
|
auto* data = static_cast<OpData*>(node->user_data);
|
||||||
|
return CalculateOpData(context, node, params, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
|
||||||
|
const OpData* data, const TfLiteEvalTensor* input1,
|
||||||
|
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
|
||||||
|
tflite::ArithmeticParams op_params = {};
|
||||||
|
|
||||||
|
#define TF_LITE_DIV(type, opname, data_type) \
|
||||||
|
data_type output_activation_min, output_activation_max; \
|
||||||
|
CalculateActivationRange(params->activation, &output_activation_min, \
|
||||||
|
&output_activation_max); \
|
||||||
|
SetActivationParams(output_activation_min, output_activation_max, \
|
||||||
|
&op_params); \
|
||||||
|
type::opname(op_params, tflite::micro::GetTensorShape(input1), \
|
||||||
|
tflite::micro::GetTensorData<data_type>(input1), \
|
||||||
|
tflite::micro::GetTensorShape(input2), \
|
||||||
|
tflite::micro::GetTensorData<data_type>(input2), \
|
||||||
|
tflite::micro::GetTensorShape(output), \
|
||||||
|
tflite::micro::GetTensorData<data_type>(output))
|
||||||
|
|
||||||
|
bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
|
||||||
|
tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2), &op_params);
|
||||||
|
|
||||||
|
if (requires_broadcast) {
|
||||||
|
TF_LITE_DIV(reference_ops, BroadcastDivSlow, float);
|
||||||
|
} else {
|
||||||
|
TF_LITE_DIV(reference_ops, Div, float);
|
||||||
|
}
|
||||||
|
#undef TF_LITE_DIV
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
TfLiteDivParams* params, const OpData* data,
|
||||||
|
const TfLiteEvalTensor* input1,
|
||||||
|
const TfLiteEvalTensor* input2,
|
||||||
|
TfLiteEvalTensor* output) {
|
||||||
|
tflite::ArithmeticParams op_params = {};
|
||||||
|
|
||||||
|
#define TF_LITE_DIV(type, opname, dtype) \
|
||||||
|
type::opname(op_params, tflite::micro::GetTensorShape(input1), \
|
||||||
|
tflite::micro::GetTensorData<dtype>(input1), \
|
||||||
|
tflite::micro::GetTensorShape(input2), \
|
||||||
|
tflite::micro::GetTensorData<dtype>(input2), \
|
||||||
|
tflite::micro::GetTensorShape(output), \
|
||||||
|
tflite::micro::GetTensorData<dtype>(output))
|
||||||
|
|
||||||
|
if (input1->type == kTfLiteInt8 && input2->type == kTfLiteInt8 &&
|
||||||
|
output->type == kTfLiteInt8) {
|
||||||
|
SetActivationParams(data->output_activation_min,
|
||||||
|
data->output_activation_max, &op_params);
|
||||||
|
op_params.input1_offset = -data->input1_zero_point;
|
||||||
|
op_params.input2_offset = -data->input2_zero_point;
|
||||||
|
op_params.output_offset = data->output_zero_point;
|
||||||
|
op_params.output_multiplier = data->output_multiplier;
|
||||||
|
op_params.output_shift = data->output_shift;
|
||||||
|
|
||||||
|
bool requires_broadcast = reference_ops::ProcessBroadcastShapes(
|
||||||
|
tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2), &op_params);
|
||||||
|
|
||||||
|
if (requires_broadcast) {
|
||||||
|
TF_LITE_DIV(reference_ops, BroadcastDivSlow, int8_t);
|
||||||
|
} else {
|
||||||
|
TF_LITE_DIV(reference_ops, Div, int8_t);
|
||||||
|
}
|
||||||
|
#undef TF_LITE_DIV
|
||||||
|
} else {
|
||||||
|
TF_LITE_KERNEL_LOG(
|
||||||
|
context, "Unsupported combination of input and output types in DIV.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
auto* params = static_cast<TfLiteDivParams*>(node->builtin_data);
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
auto* data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
|
const TfLiteEvalTensor* input1 =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor1);
|
||||||
|
const TfLiteEvalTensor* input2 =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor2);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
if (output->type == kTfLiteFloat32) {
|
||||||
|
EvalDiv(context, node, params, data, input1, input2, output);
|
||||||
|
} else if (output->type == kTfLiteInt8) {
|
||||||
|
TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, data,
|
||||||
|
input1, input2, output));
|
||||||
|
} else {
|
||||||
|
TF_LITE_KERNEL_LOG(context,
|
||||||
|
"DIV only supports FLOAT32, quantized INT8 "
|
||||||
|
"now, got type %s (%d).",
|
||||||
|
TfLiteTypeGetName(output->type), output->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_DIV() {
|
||||||
|
return {/*init=*/Init,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
151
code/components/tfmicro/tensorflow/lite/micro/kernels/elu.cc
Normal file
151
code/components/tfmicro/tensorflow/lite/micro/kernels/elu.cc
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/elu.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Input/output tensor index.
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
// OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
|
||||||
|
// of the activation ops below.
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int8_t table[256];
|
||||||
|
};
|
||||||
|
|
||||||
|
using TransformFunc = float (*)(float);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void PopulateLookupTable(const TfLiteTensor* input, const TfLiteTensor* output,
|
||||||
|
const TransformFunc transform, OpData* data) {
|
||||||
|
if (sizeof(T) != 1) TF_LITE_FATAL("Lookup table valid only for 8bit");
|
||||||
|
|
||||||
|
const float inverse_scale = 1 / output->params.scale;
|
||||||
|
int32_t maxval = std::numeric_limits<T>::max();
|
||||||
|
int32_t minval = std::numeric_limits<T>::min();
|
||||||
|
for (int32_t val = minval; val <= maxval; ++val) {
|
||||||
|
const float dequantized =
|
||||||
|
input->params.scale * (val - input->params.zero_point);
|
||||||
|
const float transformed = transform(dequantized);
|
||||||
|
const float rescaled = TfLiteRound(transformed * inverse_scale);
|
||||||
|
const int32_t quantized =
|
||||||
|
static_cast<int32_t>(rescaled + output->params.zero_point);
|
||||||
|
data->table[static_cast<uint8_t>(static_cast<T>(val))] =
|
||||||
|
static_cast<T>(std::max(std::min(maxval, quantized), minval));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OLD-TODO(b/143696793): move this to optimized_ops.
|
||||||
|
void EvalUsingLookupTable(const OpData* data, const TfLiteEvalTensor* input,
|
||||||
|
TfLiteEvalTensor* output) {
|
||||||
|
const int size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorShape(output));
|
||||||
|
int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);
|
||||||
|
const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
output_data[i] = data->table[static_cast<uint8_t>(input_data[i])];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input;
|
||||||
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||||
|
TfLiteTensor* output;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
|
// Use LUT to handle quantized elu path.
|
||||||
|
if (input->type == kTfLiteInt8) {
|
||||||
|
OpData* data = static_cast<OpData*>(node->user_data);
|
||||||
|
TransformFunc transform = [](float value) {
|
||||||
|
return value < 0.0f ? std::exp(value) - 1.0f : value;
|
||||||
|
};
|
||||||
|
PopulateLookupTable<int8_t>(input, output, transform, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* EluInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||||
|
// Instead, we allocate a new object to carry information from Prepare() to
|
||||||
|
// Eval().
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
return CalculateOpData(context, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* input =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
switch (input->type) {
|
||||||
|
case kTfLiteFloat32: {
|
||||||
|
reference_ops::Elu(tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<float>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
const OpData* data = static_cast<OpData*>(node->user_data);
|
||||||
|
EvalUsingLookupTable(data, input, output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(
|
||||||
|
context, "ELU only supports float32 and int8 currently, got %s.",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_ELU() {
|
||||||
|
return {/*init=*/EluInit,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/EluPrepare,
|
||||||
|
/*invoke=*/EluEval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -19,14 +19,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
|
||||||
namespace micro {
|
|
||||||
namespace custom {
|
|
||||||
TfLiteRegistration* Register_ETHOSU() { return nullptr; }
|
TfLiteRegistration* Register_ETHOSU() { return nullptr; }
|
||||||
|
|
||||||
const char* GetString_ETHOSU() { return ""; }
|
const char* GetString_ETHOSU() { return ""; }
|
||||||
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace micro
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_ETHOSU_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_KERNELS_ETHOSU_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_ETHOSU();
|
||||||
|
|
||||||
|
const char* GetString_ETHOSU();
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_KERNELS_ETHOSU_H_
|
||||||
78
code/components/tfmicro/tensorflow/lite/micro/kernels/exp.cc
Normal file
78
code/components/tfmicro/tensorflow/lite/micro/kernels/exp.cc
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/exp.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->bytes, input->bytes);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->dims->size, input->dims->size);
|
||||||
|
for (int i = 0; i < output->dims->size; ++i) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* input =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorShape(output));
|
||||||
|
|
||||||
|
if (input->type == kTfLiteFloat32) {
|
||||||
|
reference_ops::Exp(tflite::micro::GetTensorData<float>(input),
|
||||||
|
static_cast<size_t>(flat_size),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
|
} else {
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) currently not supported by Exp.",
|
||||||
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_EXP() {
|
||||||
|
return {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -0,0 +1,152 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_utils.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kAxisTensor = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus ExpandTensorDim(TfLiteContext* context,
|
||||||
|
const TfLiteEvalTensor* input, int32_t axis,
|
||||||
|
TfLiteEvalTensor* output) {
|
||||||
|
const TfLiteIntArray* input_dims = input->dims;
|
||||||
|
TfLiteIntArray* output_dims = output->dims;
|
||||||
|
if (axis < 0) {
|
||||||
|
axis = input_dims->size + 1 + axis;
|
||||||
|
}
|
||||||
|
TF_LITE_ENSURE(context, (axis <= input_dims->size));
|
||||||
|
|
||||||
|
output_dims->size = input_dims->size + 1;
|
||||||
|
for (int i = 0; i < output_dims->size; ++i) {
|
||||||
|
if (i < axis) {
|
||||||
|
output_dims->data[i] = input_dims->data[i];
|
||||||
|
} else if (i == axis) {
|
||||||
|
output_dims->data[i] = 1;
|
||||||
|
} else {
|
||||||
|
output_dims->data[i] = input_dims->data[i - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
|
||||||
|
const TfLiteEvalTensor* axis,
|
||||||
|
int32_t* axis_value) {
|
||||||
|
const int axis_dims = (tflite::micro::GetTensorShape(axis)).DimensionsCount();
|
||||||
|
if (axis_dims > 1) {
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Axis has only one element for Expand_Dims.",
|
||||||
|
axis_dims);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kTfLiteInt32 == (axis->type)) {
|
||||||
|
const int32_t* axis_ptr = tflite::micro::GetTensorData<int32_t>(axis);
|
||||||
|
*axis_value = axis_ptr[0];
|
||||||
|
return kTfLiteOk;
|
||||||
|
} else {
|
||||||
|
TF_LITE_KERNEL_LOG(context,
|
||||||
|
"Axis type %s (%d) not supported by Expand_Dims.",
|
||||||
|
TfLiteTypeGetName(axis->type), axis->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input;
|
||||||
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||||
|
const TfLiteTensor* axis;
|
||||||
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
|
||||||
|
TfLiteTensor* output;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||||
|
output->type = input->type;
|
||||||
|
if (IsDynamicTensor(axis)) {
|
||||||
|
TF_LITE_KERNEL_LOG(context,
|
||||||
|
"DynamicTensor is not yet supported by Expand_Dims.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void memCopyN(T* out, const T* in, const int num_elements) {
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
out[i] = in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* input =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||||
|
const TfLiteEvalTensor* axis =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kAxisTensor);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
const int flat_size = ElementCount(*input->dims);
|
||||||
|
const int input_dims = input->dims->size;
|
||||||
|
|
||||||
|
int32_t axis_value;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetAxisValueFromTensor(context, axis, &axis_value));
|
||||||
|
if ((axis_value > static_cast<int32_t>(input_dims)) ||
|
||||||
|
(axis_value < static_cast<int32_t>(-(input_dims + 1)))) {
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Invalid Expand_Dims axis value (%d).",
|
||||||
|
axis_value);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
ExpandTensorDim(context, input, axis_value, output);
|
||||||
|
|
||||||
|
switch (input->type) {
|
||||||
|
case kTfLiteFloat32: {
|
||||||
|
memCopyN(tflite::micro::GetTensorData<float>(output),
|
||||||
|
tflite::micro::GetTensorData<float>(input), flat_size);
|
||||||
|
} break;
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
memCopyN(tflite::micro::GetTensorData<int8_t>(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input), flat_size);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(
|
||||||
|
context,
|
||||||
|
"Expand_Dims only currently supports int8 and float32, got %d.",
|
||||||
|
input->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_EXPAND_DIMS() {
|
||||||
|
return {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
131
code/components/tfmicro/tensorflow/lite/micro/kernels/fill.cc
Normal file
131
code/components/tfmicro/tensorflow/lite/micro/kernels/fill.cc
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/fill.h"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
TfLiteStatus EnsureEqImpl(TfLiteContext* context, const TfLiteIntArray* array,
|
||||||
|
const TfLiteTensor* tensor) {
|
||||||
|
for (int i = 0; i < array->size; ++i) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, array->data[i], GetTensorData<T>(tensor)[i]);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the equality of an int array and a tensor, which must be
|
||||||
|
// one-dimensional and of an integer type.
|
||||||
|
TfLiteStatus EnsureEq(TfLiteContext* context, const TfLiteIntArray* array,
|
||||||
|
const TfLiteTensor* tensor) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(tensor), 1);
|
||||||
|
const auto tensor_len = tensor->dims->data[0];
|
||||||
|
TF_LITE_ENSURE_EQ(context, array->size, tensor_len);
|
||||||
|
|
||||||
|
switch (tensor->type) {
|
||||||
|
case kTfLiteInt8:
|
||||||
|
return EnsureEqImpl<int8_t>(context, array, tensor);
|
||||||
|
case kTfLiteUInt8:
|
||||||
|
return EnsureEqImpl<uint8_t>(context, array, tensor);
|
||||||
|
case kTfLiteInt16:
|
||||||
|
return EnsureEqImpl<int16_t>(context, array, tensor);
|
||||||
|
case kTfLiteInt32:
|
||||||
|
return EnsureEqImpl<int32_t>(context, array, tensor);
|
||||||
|
case kTfLiteInt64:
|
||||||
|
return EnsureEqImpl<int64_t>(context, array, tensor);
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context,
|
||||||
|
"cannot compare int array to tensor of type %d.",
|
||||||
|
tensor->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int kDimsTensor = 0;
|
||||||
|
constexpr int kValueTensor = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
// Ensure inputs and outputs exist.
|
||||||
|
const TfLiteTensor* dims;
|
||||||
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
|
||||||
|
const TfLiteTensor* value;
|
||||||
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
|
||||||
|
TfLiteTensor* output;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||||
|
|
||||||
|
// The value tensor must be a scalar.
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
|
||||||
|
|
||||||
|
// The value type and output type must match.
|
||||||
|
TF_LITE_ENSURE_EQ(context, value->type, output->type);
|
||||||
|
|
||||||
|
// The dims tensor must match the output tensor shape. As a byproduct,
|
||||||
|
// ensures the dims tensor is of an integer type.
|
||||||
|
TF_LITE_ENSURE_OK(context, EnsureEq(context, output->dims, dims));
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void FillImpl(const TfLiteEvalTensor* value, TfLiteEvalTensor* output) {
|
||||||
|
reference_ops::Fill(
|
||||||
|
micro::GetTensorShape(value), micro::GetTensorData<T>(value),
|
||||||
|
micro::GetTensorShape(output), micro::GetTensorData<T>(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* value =
|
||||||
|
micro::GetEvalInput(context, node, kValueTensor);
|
||||||
|
TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
switch (value->type) {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
FillImpl<float>(value, output);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(
|
||||||
|
context, "Fill only currently supports float32 for input 1, got %d.",
|
||||||
|
TfLiteTypeGetName(value->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration Register_FILL() {
|
||||||
|
return {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr,
|
||||||
|
/*prepare=*/Prepare,
|
||||||
|
/*invoke=*/Eval,
|
||||||
|
/*profiling_string=*/nullptr,
|
||||||
|
/*builtin_code=*/0,
|
||||||
|
/*custom_name=*/nullptr,
|
||||||
|
/*version=*/0};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -28,176 +28,37 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
|
||||||
// be represented as a fixed point multiplier plus a left shift.
|
|
||||||
int32_t output_multiplier;
|
|
||||||
int output_shift;
|
|
||||||
// The range of the fused activation layer. For example for kNone and
|
|
||||||
// uint8_t these would be 0 and 255.
|
|
||||||
int32_t output_activation_min;
|
|
||||||
int32_t output_activation_max;
|
|
||||||
// The index of the temporary tensor where the quantized inputs are cached.
|
|
||||||
int input_quantized_index;
|
|
||||||
// Cached zero point values of tensors.
|
|
||||||
int32_t input_zero_point;
|
|
||||||
int32_t filter_zero_point;
|
|
||||||
int32_t output_zero_point;
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
|
||||||
constexpr int kWeightsTensor = 1;
|
|
||||||
constexpr int kBiasTensor = 2;
|
|
||||||
constexpr int kOutputTensor = 0;
|
|
||||||
|
|
||||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
|
||||||
TfLiteFusedActivation activation,
|
|
||||||
TfLiteType data_type, const TfLiteTensor* input,
|
|
||||||
const TfLiteTensor* filter,
|
|
||||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
|
||||||
OpData* data) {
|
|
||||||
TfLiteStatus status = kTfLiteOk;
|
|
||||||
if (data_type != kTfLiteFloat32) {
|
|
||||||
double real_multiplier = 0.0;
|
|
||||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
|
||||||
context, input, filter, bias, output, &real_multiplier));
|
|
||||||
int exponent;
|
|
||||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
|
|
||||||
data->output_shift = -exponent;
|
|
||||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
|
||||||
context, activation, output, &data->output_activation_min,
|
|
||||||
&data->output_activation_max));
|
|
||||||
|
|
||||||
data->input_zero_point = input->params.zero_point;
|
|
||||||
data->filter_zero_point = filter->params.zero_point;
|
|
||||||
data->output_zero_point = output->params.zero_point;
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
return context->AllocatePersistentBuffer(context,
|
||||||
|
sizeof(OpDataFullyConnected));
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
|
||||||
OpData* data = static_cast<OpData*>(node->user_data);
|
auto* data = static_cast<OpDataFullyConnected*>(node->user_data);
|
||||||
const auto params =
|
const auto params =
|
||||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input =
|
||||||
|
GetInput(context, node, kFullyConnectedInputTensor);
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
const TfLiteTensor* filter =
|
||||||
|
GetInput(context, node, kFullyConnectedWeightsTensor);
|
||||||
TF_LITE_ENSURE(context, filter != nullptr);
|
TF_LITE_ENSURE(context, filter != nullptr);
|
||||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
const TfLiteTensor* bias =
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||||
"Hybrid models are not supported on TFLite Micro.");
|
"Hybrid models are not supported on TFLite Micro.");
|
||||||
|
|
||||||
return CalculateOpData(context, params->activation, input->type, input,
|
return CalculateOpDataFullyConnected(context, params->activation, input->type,
|
||||||
filter, bias, output, data);
|
input, filter, bias, output, data);
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
const OpData& data,
|
|
||||||
const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias,
|
|
||||||
TfLiteEvalTensor* output) {
|
|
||||||
tflite::FullyConnectedParams op_params;
|
|
||||||
op_params.input_offset = -data.input_zero_point;
|
|
||||||
op_params.weights_offset = -data.filter_zero_point;
|
|
||||||
op_params.output_offset = data.output_zero_point;
|
|
||||||
op_params.output_multiplier = data.output_multiplier;
|
|
||||||
// TODO(b/138810107): Figure out whether output shift should be inverted
|
|
||||||
op_params.output_shift = -data.output_shift;
|
|
||||||
op_params.quantized_activation_min = data.output_activation_min;
|
|
||||||
op_params.quantized_activation_max = data.output_activation_max;
|
|
||||||
|
|
||||||
reference_integer_ops::FullyConnected(
|
|
||||||
op_params, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
const OpData& data, const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias,
|
|
||||||
TfLiteEvalTensor* output) {
|
|
||||||
const int32_t input_offset = -data.input_zero_point;
|
|
||||||
const int32_t filter_offset = -data.filter_zero_point;
|
|
||||||
const int32_t output_offset = data.output_zero_point;
|
|
||||||
|
|
||||||
tflite::FullyConnectedParams op_params;
|
|
||||||
op_params.input_offset = input_offset;
|
|
||||||
op_params.weights_offset = filter_offset;
|
|
||||||
op_params.output_offset = output_offset;
|
|
||||||
op_params.output_multiplier = data.output_multiplier;
|
|
||||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
|
||||||
op_params.output_shift = -data.output_shift;
|
|
||||||
op_params.quantized_activation_min = data.output_activation_min;
|
|
||||||
op_params.quantized_activation_max = data.output_activation_max;
|
|
||||||
|
|
||||||
#define TF_LITE_FULLY_CONNECTED(output_data_type) \
|
|
||||||
reference_ops::FullyConnected( \
|
|
||||||
op_params, tflite::micro::GetTensorShape(input), \
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(input), \
|
|
||||||
tflite::micro::GetTensorShape(filter), \
|
|
||||||
tflite::micro::GetTensorData<uint8_t>(filter), \
|
|
||||||
tflite::micro::GetTensorShape(bias), \
|
|
||||||
tflite::micro::GetTensorData<int32_t>(bias), \
|
|
||||||
tflite::micro::GetTensorShape(output), \
|
|
||||||
tflite::micro::GetTensorData<output_data_type>(output))
|
|
||||||
switch (output->type) {
|
|
||||||
case kTfLiteUInt8:
|
|
||||||
TF_LITE_FULLY_CONNECTED(uint8_t);
|
|
||||||
break;
|
|
||||||
case kTfLiteInt16:
|
|
||||||
TF_LITE_FULLY_CONNECTED(int16_t);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
|
||||||
TfLiteTypeGetName(output->type), output->type);
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteFusedActivation activation,
|
|
||||||
const TfLiteEvalTensor* input,
|
|
||||||
const TfLiteEvalTensor* filter,
|
|
||||||
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
|
|
||||||
float output_activation_min, output_activation_max;
|
|
||||||
CalculateActivationRange(activation, &output_activation_min,
|
|
||||||
&output_activation_max);
|
|
||||||
tflite::FullyConnectedParams op_params;
|
|
||||||
op_params.float_activation_min = output_activation_min;
|
|
||||||
op_params.float_activation_max = output_activation_max;
|
|
||||||
tflite::reference_ops::FullyConnected(
|
|
||||||
op_params, tflite::micro::GetTensorShape(input),
|
|
||||||
tflite::micro::GetTensorData<float>(input),
|
|
||||||
tflite::micro::GetTensorShape(filter),
|
|
||||||
tflite::micro::GetTensorData<float>(filter),
|
|
||||||
tflite::micro::GetTensorShape(bias),
|
|
||||||
tflite::micro::GetTensorData<float>(bias),
|
|
||||||
tflite::micro::GetTensorShape(output),
|
|
||||||
tflite::micro::GetTensorData<float>(output));
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@@ -206,33 +67,66 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteEvalTensor* input =
|
const TfLiteEvalTensor* input =
|
||||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
|
||||||
const TfLiteEvalTensor* filter =
|
const TfLiteEvalTensor* filter =
|
||||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
|
||||||
const TfLiteEvalTensor* bias =
|
const TfLiteEvalTensor* bias =
|
||||||
tflite::micro::GetEvalInput(context, node, kBiasTensor);
|
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
|
||||||
TfLiteEvalTensor* output =
|
TfLiteEvalTensor* output =
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
|
||||||
|
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
const auto& data =
|
||||||
|
*(static_cast<const OpDataFullyConnected*>(node->user_data));
|
||||||
|
|
||||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32: {
|
||||||
return EvalFloat(context, node, params->activation, input, filter, bias,
|
tflite::reference_ops::FullyConnected(
|
||||||
output);
|
FullyConnectedParamsFloat(params->activation),
|
||||||
case kTfLiteInt8:
|
tflite::micro::GetTensorShape(input),
|
||||||
return EvalQuantizedInt8(context, node, data, input, filter, bias,
|
tflite::micro::GetTensorData<float>(input),
|
||||||
output);
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<float>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<float>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case kTfLiteUInt8:
|
case kTfLiteInt8: {
|
||||||
return EvalQuantized(context, node, data, input, filter, bias, output);
|
tflite::reference_integer_ops::FullyConnected(
|
||||||
|
FullyConnectedParamsQuantized(data),
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
case kTfLiteUInt8: {
|
||||||
|
tflite::reference_ops::FullyConnected(
|
||||||
|
FullyConnectedParamsQuantized(data),
|
||||||
|
tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(filter),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(filter),
|
||||||
|
tflite::micro::GetTensorShape(bias),
|
||||||
|
tflite::micro::GetTensorData<int32_t>(bias),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
TfLiteTypeGetName(input->type), input->type);
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,10 +15,51 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||||
#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
|
struct OpDataFullyConnected {
|
||||||
|
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||||
|
// be represented as a fixed point multiplier plus a left shift.
|
||||||
|
int32_t output_multiplier;
|
||||||
|
int output_shift;
|
||||||
|
// The range of the fused activation layer. For example for kNone and
|
||||||
|
// uint8_t these would be 0 and 255.
|
||||||
|
int32_t output_activation_min;
|
||||||
|
int32_t output_activation_max;
|
||||||
|
// The index of the temporary tensor where the quantized inputs are cached.
|
||||||
|
int input_quantized_index;
|
||||||
|
// Cached zero point values of tensors.
|
||||||
|
int32_t input_zero_point;
|
||||||
|
int32_t filter_zero_point;
|
||||||
|
int32_t output_zero_point;
|
||||||
|
};
|
||||||
|
|
||||||
|
extern const int kFullyConnectedInputTensor;
|
||||||
|
extern const int kFullyConnectedWeightsTensor;
|
||||||
|
extern const int kFullyConnectedBiasTensor;
|
||||||
|
extern const int kFullyConnectedOutputTensor;
|
||||||
|
|
||||||
|
// Returns a FullyConnectedParams struct with all the parameters needed for a
|
||||||
|
// float computation.
|
||||||
|
FullyConnectedParams FullyConnectedParamsFloat(
|
||||||
|
TfLiteFusedActivation activation);
|
||||||
|
|
||||||
|
// Returns a FullyConnectedParams struct with all the parameters needed for a
|
||||||
|
// quantized computation.
|
||||||
|
FullyConnectedParams FullyConnectedParamsQuantized(
|
||||||
|
const OpDataFullyConnected& op_data);
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpDataFullyConnected(
|
||||||
|
TfLiteContext* context, TfLiteFusedActivation activation,
|
||||||
|
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||||
|
const TfLiteTensor* bias, TfLiteTensor* output, OpDataFullyConnected* data);
|
||||||
|
|
||||||
// This is the most generic TfLiteRegistration. The actual supported types may
|
// This is the most generic TfLiteRegistration. The actual supported types may
|
||||||
// still be target dependent. The only requirement is that every implementation
|
// still be target dependent. The only requirement is that every implementation
|
||||||
// (reference or optimized) must define this function.
|
// (reference or optimized) must define this function.
|
||||||
@@ -30,7 +71,7 @@ TfLiteRegistration Register_FULLY_CONNECTED();
|
|||||||
// part of the build. As a result, we use defined(ARDUINO) as proxy for the
|
// part of the build. As a result, we use defined(ARDUINO) as proxy for the
|
||||||
// CMSIS kernels for this one special case.
|
// CMSIS kernels for this one special case.
|
||||||
|
|
||||||
// Returns a TfLiteRegistration struct for cmsis-nn kernel variant that only
|
// Returns a TfLiteRegistration struct for cmsis_nn kernel variant that only
|
||||||
// supports int8.
|
// supports int8.
|
||||||
TfLiteRegistration Register_FULLY_CONNECTED_INT8();
|
TfLiteRegistration Register_FULLY_CONNECTED_INT8();
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
const int kFullyConnectedInputTensor = 0;
|
||||||
|
const int kFullyConnectedWeightsTensor = 1;
|
||||||
|
const int kFullyConnectedBiasTensor = 2;
|
||||||
|
const int kFullyConnectedOutputTensor = 0;
|
||||||
|
|
||||||
|
FullyConnectedParams FullyConnectedParamsQuantized(
|
||||||
|
const OpDataFullyConnected& op_data) {
|
||||||
|
FullyConnectedParams op_params;
|
||||||
|
op_params.input_offset = -op_data.input_zero_point;
|
||||||
|
op_params.weights_offset = -op_data.filter_zero_point;
|
||||||
|
op_params.output_offset = op_data.output_zero_point;
|
||||||
|
op_params.output_multiplier = op_data.output_multiplier;
|
||||||
|
op_params.output_shift = op_data.output_shift;
|
||||||
|
op_params.quantized_activation_min = op_data.output_activation_min;
|
||||||
|
op_params.quantized_activation_max = op_data.output_activation_max;
|
||||||
|
return op_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
FullyConnectedParams FullyConnectedParamsFloat(
|
||||||
|
TfLiteFusedActivation activation) {
|
||||||
|
FullyConnectedParams op_params;
|
||||||
|
CalculateActivationRange(activation, &op_params.float_activation_min,
|
||||||
|
&op_params.float_activation_max);
|
||||||
|
return op_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus CalculateOpDataFullyConnected(
|
||||||
|
TfLiteContext* context, TfLiteFusedActivation activation,
|
||||||
|
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||||
|
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||||
|
OpDataFullyConnected* data) {
|
||||||
|
if (data_type != kTfLiteFloat32) {
|
||||||
|
double real_multiplier = 0.0;
|
||||||
|
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||||
|
context, input, filter, bias, output, &real_multiplier));
|
||||||
|
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
|
||||||
|
&data->output_shift);
|
||||||
|
|
||||||
|
data->input_zero_point = input->params.zero_point;
|
||||||
|
data->filter_zero_point = filter->params.zero_point;
|
||||||
|
data->output_zero_point = output->params.zero_point;
|
||||||
|
|
||||||
|
return CalculateActivationRangeQuantized(context, activation, output,
|
||||||
|
&data->output_activation_min,
|
||||||
|
&data->output_activation_max);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
@@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace micro {
|
namespace micro {
|
||||||
|
|
||||||
@@ -30,12 +32,12 @@ uint8_t KernelRunner::kKernelRunnerBuffer_[];
|
|||||||
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||||
TfLiteTensor* tensors, int tensors_size,
|
TfLiteTensor* tensors, int tensors_size,
|
||||||
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
||||||
void* builtin_data, ErrorReporter* error_reporter)
|
void* builtin_data)
|
||||||
: allocator_(SimpleMemoryAllocator::Create(
|
: allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
|
||||||
error_reporter, kKernelRunnerBuffer_, kKernelRunnerBufferSize_)),
|
kKernelRunnerBuffer_,
|
||||||
|
kKernelRunnerBufferSize_)),
|
||||||
registration_(registration),
|
registration_(registration),
|
||||||
tensors_(tensors),
|
tensors_(tensors) {
|
||||||
error_reporter_(error_reporter) {
|
|
||||||
// Prepare TfLiteContext:
|
// Prepare TfLiteContext:
|
||||||
context_.impl_ = static_cast<void*>(this);
|
context_.impl_ = static_cast<void*>(this);
|
||||||
context_.ReportError = ReportOpError;
|
context_.ReportError = ReportOpError;
|
||||||
@@ -52,9 +54,10 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
|||||||
node_.builtin_data = builtin_data;
|
node_.builtin_data = builtin_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data) {
|
TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
|
||||||
|
size_t length) {
|
||||||
if (registration_.init) {
|
if (registration_.init) {
|
||||||
node_.user_data = registration_.init(&context_, init_data, /*length=*/0);
|
node_.user_data = registration_.init(&context_, init_data, length);
|
||||||
}
|
}
|
||||||
if (registration_.prepare) {
|
if (registration_.prepare) {
|
||||||
TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
|
TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
|
||||||
@@ -64,8 +67,7 @@ TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data) {
|
|||||||
|
|
||||||
TfLiteStatus KernelRunner::Invoke() {
|
TfLiteStatus KernelRunner::Invoke() {
|
||||||
if (registration_.invoke == nullptr) {
|
if (registration_.invoke == nullptr) {
|
||||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
MicroPrintf("TfLiteRegistration missing invoke function pointer!");
|
||||||
"TfLiteRegistration missing invoke function pointer!");
|
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return registration_.invoke(&context_, &node_);
|
return registration_.invoke(&context_, &node_);
|
||||||
@@ -118,10 +120,8 @@ TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
|
|||||||
TFLITE_DCHECK(runner != nullptr);
|
TFLITE_DCHECK(runner != nullptr);
|
||||||
|
|
||||||
if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
|
if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
|
||||||
TF_LITE_REPORT_ERROR(
|
MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
|
||||||
runner->error_reporter_,
|
kNumScratchBuffers_);
|
||||||
"Exceeded the maximum number of scratch tensors allowed (%d).",
|
|
||||||
kNumScratchBuffers_);
|
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,13 +151,9 @@ void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
|
|||||||
|
|
||||||
void KernelRunner::ReportOpError(struct TfLiteContext* context,
|
void KernelRunner::ReportOpError(struct TfLiteContext* context,
|
||||||
const char* format, ...) {
|
const char* format, ...) {
|
||||||
TFLITE_DCHECK(context != nullptr);
|
|
||||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
|
||||||
TFLITE_DCHECK(runner != nullptr);
|
|
||||||
|
|
||||||
va_list args;
|
va_list args;
|
||||||
va_start(args, format);
|
va_start(args, format);
|
||||||
TF_LITE_REPORT_ERROR(runner->error_reporter_, format, args);
|
GetMicroErrorReporter()->Report(format, args);
|
||||||
va_end(args);
|
va_end(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,23 +23,22 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace micro {
|
namespace micro {
|
||||||
|
|
||||||
// Helper class to perform a simulated kernel (i.e. TfLiteRegistration) lifecyle
|
// Helper class to perform a simulated kernel (i.e. TfLiteRegistration)
|
||||||
// (init, prepare, invoke). All internal allocations are handled by this class.
|
// lifecycle (init, prepare, invoke). All internal allocations are handled by
|
||||||
// Simply pass in the registration, list of required tensors, inputs array,
|
// this class. Simply pass in the registration, list of required tensors, inputs
|
||||||
// outputs array, and any pre-builtin data. Calling Invoke() will automatically
|
// array, outputs array, and any pre-builtin data. Calling Invoke() will
|
||||||
// walk the kernl and outputs will be ready on the the TfLiteTensor output
|
// automatically walk the kernel and outputs will be ready on the TfLiteTensor
|
||||||
// provided during construction.
|
// output provided during construction.
|
||||||
class KernelRunner {
|
class KernelRunner {
|
||||||
public:
|
public:
|
||||||
KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
|
KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
|
||||||
int tensors_size, TfLiteIntArray* inputs,
|
int tensors_size, TfLiteIntArray* inputs,
|
||||||
TfLiteIntArray* outputs, void* builtin_data,
|
TfLiteIntArray* outputs, void* builtin_data);
|
||||||
ErrorReporter* error_reporter);
|
|
||||||
|
|
||||||
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
|
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
|
||||||
// exceptions will be reported through the error_reporter and returned as a
|
// exceptions will be DebugLog'd and returned as a status code.
|
||||||
// status code here.
|
TfLiteStatus InitAndPrepare(const char* init_data = nullptr,
|
||||||
TfLiteStatus InitAndPrepare(const char* init_data = nullptr);
|
size_t length = 0);
|
||||||
|
|
||||||
// Calls init, prepare, and invoke on a given TfLiteRegistration pointer.
|
// Calls init, prepare, and invoke on a given TfLiteRegistration pointer.
|
||||||
// After successful invoke, results will be available in the output tensor as
|
// After successful invoke, results will be available in the output tensor as
|
||||||
@@ -60,7 +59,7 @@ class KernelRunner {
|
|||||||
...);
|
...);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr int kNumScratchBuffers_ = 5;
|
static constexpr int kNumScratchBuffers_ = 12;
|
||||||
|
|
||||||
static constexpr int kKernelRunnerBufferSize_ = 10000;
|
static constexpr int kKernelRunnerBufferSize_ = 10000;
|
||||||
static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_];
|
static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_];
|
||||||
@@ -68,7 +67,6 @@ class KernelRunner {
|
|||||||
SimpleMemoryAllocator* allocator_ = nullptr;
|
SimpleMemoryAllocator* allocator_ = nullptr;
|
||||||
const TfLiteRegistration& registration_;
|
const TfLiteRegistration& registration_;
|
||||||
TfLiteTensor* tensors_ = nullptr;
|
TfLiteTensor* tensors_ = nullptr;
|
||||||
ErrorReporter* error_reporter_ = nullptr;
|
|
||||||
|
|
||||||
TfLiteContext context_ = {};
|
TfLiteContext context_ = {};
|
||||||
TfLiteNode node_ = {};
|
TfLiteNode node_ = {};
|
||||||
|
|||||||
@@ -37,5 +37,17 @@ const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor) {
|
|||||||
return RuntimeShape(dims_size, dims_data);
|
return RuntimeShape(dims_size, dims_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PaddingType RuntimePaddingType(TfLitePadding padding) {
|
||||||
|
switch (padding) {
|
||||||
|
case TfLitePadding::kTfLitePaddingSame:
|
||||||
|
return PaddingType::kSame;
|
||||||
|
case TfLitePadding::kTfLitePaddingValid:
|
||||||
|
return PaddingType::kValid;
|
||||||
|
case TfLitePadding::kTfLitePaddingUnknown:
|
||||||
|
default:
|
||||||
|
return PaddingType::kNone;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace micro
|
} // namespace micro
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user