mirror of
https://github.com/jomjol/AI-on-the-edge-device.git
synced 2025-12-10 21:46:55 +03:00
Rolling 20220526
This commit is contained in:
@@ -60,9 +60,8 @@ void VectorBatchVectorAdd(const T* vector, int v_size, int n_batch,
|
||||
|
||||
// Cwise product of two vectors.
|
||||
template <typename T>
|
||||
inline void VectorVectorCwiseProduct(const T* __restrict__ vector1,
|
||||
const T* __restrict__ vector2, int v_size,
|
||||
T* __restrict__ result) {
|
||||
inline void VectorVectorCwiseProduct(const T* vector1, const T* vector2,
|
||||
int v_size, T* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
*result++ = *vector1++ * *vector2++;
|
||||
}
|
||||
@@ -117,6 +116,367 @@ void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch,
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if all entries of vector are zero for float.
|
||||
bool IsZeroVector(const float* vector, int v_size);
|
||||
|
||||
// Checks if all entries of vector are zero for int8.
|
||||
bool IsZeroVector(const int8_t* vector, int v_size);
|
||||
|
||||
// Quantizes a buffer of floating point values using a symmetric quantization
|
||||
// (i.e. linear quantization without an offset) to 8-bit signed integers.
|
||||
// It also outputs the range (min, max) of the floating point buffer, and the
|
||||
// scaling factor used to quantize the values.
|
||||
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min_value,
|
||||
float* max_value, float* scaling_factor);
|
||||
|
||||
// Quantizes a buffer of floating point values using a symmetric quantization
|
||||
// (i.e. linear quantization without an offset) to 8-bit signed integers.
|
||||
// It uses the range (min, max) provided to the function to calculate the
|
||||
// appropriate scaling factor to quantize the values.
|
||||
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float min_value,
|
||||
float max_value, float* scaling_factor);
|
||||
|
||||
void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* scaling_factor,
|
||||
int32_t* offset);
|
||||
|
||||
// Helper function to quantize floats.
|
||||
// float_data_ptr input float vectors
|
||||
// n_batch number of input vectors
|
||||
// n_data size of a single input vector
|
||||
// quantized_data_ptr (out) vector with quantized data
|
||||
// scaling_factors (out) scaling factors (one per vector)
|
||||
// zero_points (out) zero points (one per vector)
|
||||
// do_asymmetric controls if the quantization should be asymmetric.
|
||||
inline void BatchQuantizeFloats(const float* float_data_ptr, int n_batch,
|
||||
int n_data, int8_t* quantized_data_ptr,
|
||||
float* scaling_factors, int32_t* zero_points,
|
||||
bool do_asymmetric) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_data;
|
||||
if (do_asymmetric) {
|
||||
tensor_utils::AsymmetricQuantizeFloats(
|
||||
float_data_ptr + offset, n_data, quantized_data_ptr + offset,
|
||||
&scaling_factors[b], &zero_points[b]);
|
||||
} else {
|
||||
float unused_min, unused_max;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
float_data_ptr + offset, n_data, quantized_data_ptr + offset,
|
||||
&unused_min, &unused_max, &scaling_factors[b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
|
||||
// dimension composed by input vectors independent from each other). The result
|
||||
// of the multiplication is accumulated to the passed result buffer.
|
||||
// More specifically, for a matrix M of shape [n, i] and a batched-vector
|
||||
// of shape [i, batch] it will first compute the product of shape [n, batch].
|
||||
// This product will be accumulated to the result buffer.
|
||||
void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
|
||||
int m_cols, const float* vector,
|
||||
int n_batch, float* result);
|
||||
|
||||
// Same as the function above, but the matrix is a sparse tensor with block
|
||||
// pattern 1x4.
|
||||
// This function assumes that m_cols is a multiple of the block size (4 in this
|
||||
// case) so that there's no incomplete block.
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
|
||||
|
||||
// Same as the function above, but the matrix is stored in block compressed
|
||||
// sparse row format with block pattern 1x16 which consists of two arrays:
|
||||
// 1. A matrix array stores non-zero blocks of the matrix in row major.
|
||||
// 2. A ledger array stores nrows groups, one group per row. Each group starts
|
||||
// with an integer representing the number of non-zero blocks for the
|
||||
// corresponding row and follows with column indexes of the first element
|
||||
// of each non-zero block.
|
||||
// This function assumes that
|
||||
// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
|
||||
// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
float* __restrict__ result);
|
||||
|
||||
// Same as the function above, but for values quantized using symmetric
|
||||
// quantization (e.g. by calling SymmetricQuantizeFloats).
|
||||
// The passed scaling factors is a buffer of the quantization scaling factors
|
||||
// that will be used to dequentize the products into the final result buffer.
|
||||
// These scaling factors are the multiplication of the matrix scaling factor
|
||||
// by the vector's scaling factor, one per batch (i.e. this allows quantizing
|
||||
// each batch in the batch-vector matrix independently).
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result);
|
||||
|
||||
// Same as the function above except that vector values
|
||||
// are quantized with asymmetric quantization per-batch and the matrix
|
||||
// is quantized per row.
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
||||
const int32_t* __restrict__ input_offset);
|
||||
|
||||
// Same as the function above, but the matrix is a sparse tensor with block
|
||||
// pattern 1x16.
|
||||
// This function assumes that m_cols is a multiple of the block size (16 in this
|
||||
// case) so that there's no incomplete block. Also, it assumes all offsets of
|
||||
// input, output and filter are zero.
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
int8_t* __restrict__ result);
|
||||
|
||||
// Same as the function above, but the matrix is stored in block compressed
|
||||
// sparse row format with block pattern 1x16 which consists of two arrays:
|
||||
// 1. A matrix array stores non-zero blocks of the matrix in row major.
|
||||
// 2. A ledger array stores nrows groups, one group per row. Each group starts
|
||||
// with an integer representing the number of non-zero blocks for the
|
||||
// corresponding row followed by column index of the first element of
|
||||
// each non-zero block.
|
||||
// This function assumes that
|
||||
// 1. m_cols is a multiple of 16 so that all blocks are full blocks.
|
||||
// 2. m_cols < 254 * 16 so that block index can be represented by uint8.
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
const int m_rows, const int m_cols, const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result);
|
||||
|
||||
// Same as the above 8, 8, 8 integer matmul except for the presence of zero
|
||||
// point and non-accumulative.
|
||||
// TODO(b/148688698): remove this function by folding zero point calculation in
|
||||
// prepare() function.
|
||||
void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
|
||||
const int8_t* input_to_gate_weights,
|
||||
int32_t input_to_gate_effective_scale_a,
|
||||
int32_t input_to_gate_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_cell,
|
||||
int8_t* gate_output, int8_t gate_output_zp);
|
||||
|
||||
// Same as above but has 16 bit and 8 bit input and 8 bit output.
|
||||
// Used in projection when hidden is 16bit.
|
||||
void MatrixBatchVectorMultiply(const int16_t* hidden,
|
||||
const int8_t* hidden_to_output_weights,
|
||||
int32_t proj_effective_scale_a,
|
||||
int32_t proj_effective_scale_b,
|
||||
const int32_t* gate_bias, int32_t n_batch,
|
||||
int32_t n_hidden, int32_t n_output,
|
||||
int32_t output_zp, int8_t* proj_output);
|
||||
|
||||
// Apply Layer Normalization (https://arxiv.org/abs/1607.06450) to a Quantized
|
||||
// vector.
|
||||
// Parameters:
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - layer_norm_weights: the quantized layer normalization weights.
|
||||
// - bias: the bias for the layer normalization.
|
||||
// - layer_norm_scale_a: multiplier for scale factor.
|
||||
// - layer_norm_scale_b: shift for scale factor.
|
||||
// - variance_limit: the guard to make sure the inverse does not overflow.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 16 bit output
|
||||
void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
|
||||
const int32_t* bias, int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, int32_t variance_limit,
|
||||
int n_batch, int n_input, int16_t* output);
|
||||
|
||||
// Same as above but the internal calculation is done in float.
|
||||
void ApplyLayerNormFloat(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
|
||||
const int32_t* bias, int n_batch, int n_input,
|
||||
int16_t* output);
|
||||
|
||||
// Apply Sigmoid to a quantized vector.
|
||||
// Parameters:
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 16 bit output
|
||||
// The input is in Q3.12 format and the output is in Q0.15 format.
|
||||
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int16_t* output);
|
||||
|
||||
// Same as above but the internal calcualtion is float.
|
||||
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int16_t* output);
|
||||
|
||||
// Apply Tanh to a quantized vector.
|
||||
// Parameters:
|
||||
// - integer_bits: the integer bits of the input.
|
||||
// Currently supports 0, 1, 2, 3, 4, 5, 6.
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 16 bit output
|
||||
// The input is in Qm.15-m format and the output is in Q0.15 format.
|
||||
void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output);
|
||||
|
||||
// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
|
||||
// - Input has 2^(integer_bits) as scale.
|
||||
// - Output has Q0.15 as scale.
|
||||
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int32_t integer_bits, int16_t* output);
|
||||
|
||||
// Element-wise multiplication of two quantized vectors.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - shift: the shift needed to produce the output.
|
||||
// - output: the 16 bit output of size n_batch * n_input.
|
||||
// Output does not need to be initialized.
|
||||
void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int shift, int16_t* output);
|
||||
|
||||
// Element-wise multiplication of two quantized vectors.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - shift: the shift needed to produce the output.
|
||||
// - output: the 8 bit output of size n_batch * n_input.
|
||||
// Output does not need to be initialized.
|
||||
void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int shift, int8_t* output);
|
||||
|
||||
// Element-wise multiplication of two quantized vectors with rescaling.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - multiplier: the multiplier part of scale.
|
||||
// - shift: the shift part of scale.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 8 bit output of size n_batch * n_input.
|
||||
// - output_zp: the zero point of output.
|
||||
// Output does not need to be initialized.
|
||||
// Multiplier ("m") and shift ("s") are connected to scale ("s") with s = m *
|
||||
// 2^(s - 31).
|
||||
void CwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int32_t multiplier, int32_t shift, int32_t n_batch,
|
||||
int32_t n_input, int32_t output_zp, int8_t* output);
|
||||
|
||||
// Element-wise saturating addition of two quantized vectors without rescaling.
|
||||
// Parameters:
|
||||
// - input_1: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - input_2: batch vector of size n_batch * n_input; 16 bit.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
// - output: the 8 bit output of size n_batch * n_input.
|
||||
// Output does not need to be initialized.
|
||||
void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int16_t* output);
|
||||
|
||||
// Element-wise in-place clipping of a vector. Overloaded for float, int16_t,
|
||||
// int8_t. Parameters:
|
||||
// - vector: vector of size v_size.
|
||||
// - v_size: the size of the vector.
|
||||
// - clipping_value: the value used for clipping.
|
||||
void CwiseClipping(float* vector, const int v_size, const float clipping_value);
|
||||
void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value);
|
||||
void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value);
|
||||
|
||||
// Dot product of two vectors.
|
||||
float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size);
|
||||
|
||||
// Dot product of two batch vectors of size n_batch * v_size:
|
||||
// vector1 = [x_1_1, x_1_2, ..., x_1_vsize,
|
||||
// x_2_1, x_2_2, ..., x_2_vsize,
|
||||
// ...
|
||||
// x_nbatch_1,..., x_nbatch_vsize]
|
||||
// vector2 = [y_1_1, y_1_2, ..., y_1_vsize,
|
||||
// y_2_1, y_2_2, ..., y_2_vsize,
|
||||
// ...
|
||||
// y_nbatch_1,..., y_nbatch_vsize]
|
||||
// Then result will be a vector of n_batch size starting from 'result':
|
||||
// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize,
|
||||
// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
|
||||
// ...
|
||||
// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
|
||||
template <typename T>
|
||||
inline void BatchVectorBatchVectorDotProduct(const T* vector1, const T* vector2,
|
||||
int v_size, int n_batch,
|
||||
T* result) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
result[b] = VectorVectorDotProduct(vector1, vector2, v_size);
|
||||
vector1 += v_size;
|
||||
vector2 += v_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above but input is 16bit and output is 32bit.
|
||||
void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
|
||||
const int16_t* vector2, int v_size,
|
||||
int n_batch, int32_t* result);
|
||||
|
||||
// Same as above, but inputs are 16bit integer and output is 16bit integer.
|
||||
void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
|
||||
const int16_t* batch_vector,
|
||||
int n_batch, int32_t multiplier,
|
||||
int shift, int16_t* result);
|
||||
|
||||
// Compute "1.0f - elements of vector" (used in CIFG).
|
||||
void Sub1Vector(const float* vector, int v_size, float* result);
|
||||
|
||||
// Compute "1.0f - elements of vector" (used in CIFG) for int16 input.
|
||||
// "vector" has range [0, 32767] because it is the output of sigmoid function.
|
||||
void Sub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
||||
|
||||
// Multiply all elements of vector with a scalar.
|
||||
void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result);
|
||||
|
||||
// Reduce-sum on a float input vector:
|
||||
// input_vector: float pointer to input vector.
|
||||
// output_vector: float pointer to vector.
|
||||
// output_size: output vector size.
|
||||
// reduction_size: number of consecutive elements from input vector which are
|
||||
// added to get one element of output.
|
||||
void ReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size);
|
||||
|
||||
// Same as above but input/output is 32 bit integer.
|
||||
void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
|
||||
int output_size, int reduction_size);
|
||||
|
||||
// Same as above but input is 8 bit integer.
|
||||
void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
||||
int output_size, int reduction_size);
|
||||
|
||||
// Layer norm for each batch.
|
||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||
int v_size, int n_batch);
|
||||
|
||||
// Saturate Add with rescale on both inputs.
|
||||
void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||
const int8_t* recurrent, int8_t recurrent_zp,
|
||||
int32_t input_effective_scale_a,
|
||||
int32_t input_effective_scale_b,
|
||||
int32_t recurrent_effective_scale_a,
|
||||
int32_t recurrent_effective_scale_b, int32_t n_batch,
|
||||
int32_t n_cell, int16_t* output);
|
||||
|
||||
} // namespace tensor_utils
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils_common.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_ARGS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_ARGS_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename T>
|
||||
void BroadcastArgs(const RuntimeShape& input1_shape, const T* input1_data,
|
||||
const RuntimeShape& input2_shape, const T* input2_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
// Gets data at the backward index i of the shape tensor. Returns 1 if the
|
||||
// index is out of range.
|
||||
auto get_shape_data = [](const RuntimeShape& shape, const T* data,
|
||||
int backward_idx) -> T {
|
||||
int forward_idx = shape.FlatSize() - 1 - backward_idx;
|
||||
if (forward_idx < 0) return 1;
|
||||
return data[forward_idx];
|
||||
};
|
||||
|
||||
int output_num_elements = output_shape.FlatSize();
|
||||
for (int i = 0; i < output_num_elements; ++i) {
|
||||
int backward_i = output_num_elements - 1 - i;
|
||||
int shape1_i = get_shape_data(input1_shape, input1_data, i);
|
||||
int shape2_i = get_shape_data(input2_shape, input2_data, i);
|
||||
if (shape1_i == 1) {
|
||||
output_data[backward_i] = shape2_i;
|
||||
} else if (shape2_i == 1) {
|
||||
output_data[backward_i] = shape1_i;
|
||||
} else {
|
||||
TFLITE_CHECK_EQ(shape1_i, shape2_i);
|
||||
output_data[backward_i] = shape1_i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_ARGS_H_
|
||||
@@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
template <int N>
|
||||
void BroadcastImpl(const NdArrayDesc<N>& input_desc, const char* input_data,
|
||||
const NdArrayDesc<N>& output_desc, char* output_data,
|
||||
int indexes[N], int dim, const int last_broadcasting_dim,
|
||||
const int type_size) {
|
||||
// Copy data from input to output.
|
||||
if (dim == last_broadcasting_dim) {
|
||||
int copy_size = output_desc.strides[dim] * type_size;
|
||||
const char* data_src =
|
||||
input_data + SubscriptToIndex(input_desc, indexes) * type_size;
|
||||
char* data_dst =
|
||||
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
|
||||
for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
|
||||
memcpy(data_dst, data_src, copy_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Recursive call to find the next broadcasting.
|
||||
for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim];
|
||||
++indexes[dim]) {
|
||||
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes,
|
||||
dim + 1, last_broadcasting_dim, type_size);
|
||||
}
|
||||
|
||||
// Duplicate data in output tensor.
|
||||
indexes[dim] = 0;
|
||||
if (input_desc.extents[dim] != output_desc.extents[dim]) {
|
||||
int copy_size = output_desc.strides[dim] * type_size;
|
||||
char* data_src =
|
||||
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
|
||||
char* data_dst = data_src + copy_size;
|
||||
for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
|
||||
memcpy(data_dst, data_src, copy_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
|
||||
const char* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
char* output_data, TfLiteType data_type) {
|
||||
NdArrayDesc<N> input_desc;
|
||||
NdArrayDesc<N> output_desc;
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
|
||||
&input_desc);
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
|
||||
&output_desc);
|
||||
|
||||
// Get the last dimension has broadcasting. At this dimension, the data is
|
||||
// copied from input tensor to output tensor.
|
||||
int last_broadcast_dim = -1;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
if (input_desc.extents[i] != output_desc.extents[i]) {
|
||||
last_broadcast_dim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If non-broadcasting, just copy data from input to output tensor.
|
||||
if (last_broadcast_dim == -1) {
|
||||
memcpy(output_data, input_data,
|
||||
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
|
||||
return;
|
||||
}
|
||||
|
||||
// Broadcasting using memcpy.
|
||||
int indexes[N] = {0};
|
||||
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
|
||||
last_broadcast_dim, TfLiteTypeGetSize(data_type));
|
||||
}
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
|
||||
@@ -43,7 +43,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
(void)im2col_data; // only used in optimized code.
|
||||
(void)im2col_shape; // only used in optimized code.
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
@@ -52,14 +52,20 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int filter_input_depth = filter_shape.Dims(3);
|
||||
const int groups = input_depth / filter_input_depth;
|
||||
TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
|
||||
const int filters_per_group = output_depth / groups;
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
const int in_y_origin = (out_y * stride_height) - pad_height;
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
auto group = out_channel / filters_per_group;
|
||||
float total = 0.f;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int in_y = in_y_origin + dilation_height_factor * filter_y;
|
||||
@@ -74,10 +80,11 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
if (!is_point_inside_image) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
float input_value = input_data[Offset(input_shape, batch, in_y,
|
||||
in_x, in_channel)];
|
||||
for (int in_channel = 0; in_channel < filter_input_depth;
|
||||
++in_channel) {
|
||||
float input_value =
|
||||
input_data[Offset(input_shape, batch, in_y, in_x,
|
||||
in_channel + group * filter_input_depth)];
|
||||
float filter_value = filter_data[Offset(
|
||||
filter_shape, out_channel, filter_y, filter_x, in_channel)];
|
||||
total += (input_value * filter_value);
|
||||
@@ -126,7 +133,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
@@ -135,6 +142,10 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int filter_input_depth = filter_shape.Dims(3);
|
||||
const int groups = input_depth / filter_input_depth;
|
||||
TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
|
||||
const int filters_per_group = output_depth / groups;
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
@@ -143,6 +154,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
auto group = out_channel / filters_per_group;
|
||||
int32_t acc = 0;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int in_y = in_y_origin + dilation_height_factor * filter_y;
|
||||
@@ -158,9 +170,11 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
int32_t input_val = input_data[Offset(input_shape, batch, in_y,
|
||||
in_x, in_channel)];
|
||||
for (int in_channel = 0; in_channel < filter_input_depth;
|
||||
++in_channel) {
|
||||
int32_t input_val =
|
||||
input_data[Offset(input_shape, batch, in_y, in_x,
|
||||
in_channel + group * filter_input_depth)];
|
||||
int32_t filter_val = filter_data[Offset(
|
||||
filter_shape, out_channel, filter_y, filter_x, in_channel)];
|
||||
acc +=
|
||||
@@ -206,7 +220,7 @@ inline void HybridConvPerChannel(
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
@@ -215,18 +229,24 @@ inline void HybridConvPerChannel(
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int filter_input_depth = filter_shape.Dims(3);
|
||||
const int groups = input_depth / filter_input_depth;
|
||||
TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
|
||||
const int filters_per_group = output_depth / groups;
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
auto group = out_channel / filters_per_group;
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
const int in_y_origin = (out_y * stride_height) - pad_height;
|
||||
int32_t acc = 0;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
for (int in_channel = 0; in_channel < filter_input_depth;
|
||||
++in_channel) {
|
||||
const int in_x = in_x_origin + dilation_width_factor * filter_x;
|
||||
const int in_y =
|
||||
in_y_origin + dilation_height_factor * filter_y;
|
||||
@@ -235,7 +255,8 @@ inline void HybridConvPerChannel(
|
||||
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
|
||||
(in_y < input_height)) {
|
||||
int32_t input_val = input_data[Offset(
|
||||
input_shape, batch, in_y, in_x, in_channel)];
|
||||
input_shape, batch, in_y, in_x,
|
||||
in_channel + group * filter_input_depth)];
|
||||
int32_t filter_val =
|
||||
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||
filter_x, in_channel)];
|
||||
|
||||
@@ -48,7 +48,7 @@ inline void ConvPerChannel(
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
@@ -59,6 +59,10 @@ inline void ConvPerChannel(
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int filter_input_depth = filter_shape.Dims(3);
|
||||
const int groups = input_depth / filter_input_depth;
|
||||
TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
|
||||
const int filters_per_group = output_depth / groups;
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
@@ -67,6 +71,7 @@ inline void ConvPerChannel(
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
auto group = out_channel / filters_per_group;
|
||||
int32_t acc = 0;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int in_y = in_y_origin + dilation_height_factor * filter_y;
|
||||
@@ -82,9 +87,11 @@ inline void ConvPerChannel(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
int32_t input_val = input_data[Offset(input_shape, batch, in_y,
|
||||
in_x, in_channel)];
|
||||
for (int in_channel = 0; in_channel < filter_input_depth;
|
||||
++in_channel) {
|
||||
int32_t input_val =
|
||||
input_data[Offset(input_shape, batch, in_y, in_x,
|
||||
in_channel + group * filter_input_depth)];
|
||||
int32_t filter_val = filter_data[Offset(
|
||||
filter_shape, out_channel, filter_y, filter_x, in_channel)];
|
||||
// Accumulate with 32 bits accumulator.
|
||||
@@ -126,12 +133,13 @@ inline void ConvPerChannel(
|
||||
|
||||
// Fixed-point per-channel-quantization convolution reference kernel.
|
||||
// 16-bit data and 8-bit filter
|
||||
template <typename AccumScalar>
|
||||
inline void ConvPerChannel(
|
||||
const ConvParams& params, const int32_t* output_multiplier,
|
||||
const int32_t* output_shift, const RuntimeShape& input_shape,
|
||||
const int16_t* input_data, const RuntimeShape& filter_shape,
|
||||
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||
const std::int64_t* bias_data, const RuntimeShape& output_shape,
|
||||
const AccumScalar* bias_data, const RuntimeShape& output_shape,
|
||||
int16_t* output_data) {
|
||||
// Get parameters.
|
||||
const int stride_width = params.stride_width;
|
||||
@@ -151,7 +159,7 @@ inline void ConvPerChannel(
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
@@ -162,6 +170,10 @@ inline void ConvPerChannel(
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int filter_input_depth = filter_shape.Dims(3);
|
||||
const int groups = input_depth / filter_input_depth;
|
||||
TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
|
||||
const int filters_per_group = output_depth / groups;
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
@@ -170,7 +182,8 @@ inline void ConvPerChannel(
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
std::int64_t acc = 0;
|
||||
auto group = out_channel / filters_per_group;
|
||||
AccumScalar acc = 0;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int in_y = in_y_origin + dilation_height_factor * filter_y;
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
@@ -185,9 +198,11 @@ inline void ConvPerChannel(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
int32_t input_val = input_data[Offset(input_shape, batch, in_y,
|
||||
in_x, in_channel)];
|
||||
for (int in_channel = 0; in_channel < filter_input_depth;
|
||||
++in_channel) {
|
||||
int32_t input_val =
|
||||
input_data[Offset(input_shape, batch, in_y, in_x,
|
||||
in_channel + group * filter_input_depth)];
|
||||
int32_t filter_val = filter_data[Offset(
|
||||
filter_shape, out_channel, filter_y, filter_x, in_channel)];
|
||||
// Accumulate with 64 bits accumulator.
|
||||
|
||||
@@ -34,12 +34,13 @@ inline void FullyConnected(
|
||||
const int32_t output_activation_min = params.quantized_activation_min;
|
||||
const int32_t output_activation_max = params.quantized_activation_max;
|
||||
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
|
||||
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
const int filter_dim_count = filter_shape.DimensionsCount();
|
||||
const int batches = output_shape.Dims(0);
|
||||
const int output_depth = output_shape.Dims(1);
|
||||
const int output_dim_count = output_shape.DimensionsCount();
|
||||
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
|
||||
const int output_depth = output_shape.Dims(output_dim_count - 1);
|
||||
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
|
||||
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
@@ -62,11 +63,12 @@ inline void FullyConnected(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AccumScalar>
|
||||
inline void FullyConnected(
|
||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||
const int16_t* input_data, const RuntimeShape& filter_shape,
|
||||
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||
const int64_t* bias_data, const RuntimeShape& output_shape,
|
||||
const AccumScalar* bias_data, const RuntimeShape& output_shape,
|
||||
int16_t* output_data) {
|
||||
const int32_t filter_offset = params.weights_offset;
|
||||
const int32_t output_multiplier = params.output_multiplier;
|
||||
@@ -85,7 +87,7 @@ inline void FullyConnected(
|
||||
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int out_c = 0; out_c < output_depth; ++out_c) {
|
||||
int64_t acc = 0;
|
||||
AccumScalar acc = 0;
|
||||
for (int d = 0; d < accum_depth; ++d) {
|
||||
int32_t input_val = input_data[b * accum_depth + d];
|
||||
int32_t filter_val = filter_data[out_c * accum_depth + d];
|
||||
|
||||
@@ -119,15 +119,16 @@ inline void TransposeConv(
|
||||
}
|
||||
}
|
||||
|
||||
// int16_t input (zero_point=0), int8_t filter, int64 accumulator
|
||||
// int16_t input (zero_point=0), int8_t filter, int32 or int64 accumulator
|
||||
template <typename Scalar>
|
||||
inline void TransposeConv(
|
||||
const ConvParams& params, const int32_t* output_multiplier,
|
||||
const int32_t* output_shift, const RuntimeShape& input_shape,
|
||||
const int16_t* input_data, const RuntimeShape& filter_shape,
|
||||
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||
const std::int64_t* bias_data, const RuntimeShape& output_shape,
|
||||
const Scalar* bias_data, const RuntimeShape& output_shape,
|
||||
int16_t* output_data, const RuntimeShape& im2col_shape, int8_t* im2col_data,
|
||||
std::int64_t* scratch_buffer) {
|
||||
Scalar* scratch_buffer) {
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int pad_width = params.padding_values.width;
|
||||
@@ -157,7 +158,7 @@ inline void TransposeConv(
|
||||
const int num_elements = output_shape.FlatSize();
|
||||
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||
// 'scatter' based trick as in float version.
|
||||
memset(scratch_buffer, 0, num_elements * sizeof(std::int64_t));
|
||||
memset(scratch_buffer, 0, num_elements * sizeof(Scalar));
|
||||
|
||||
// Loop through input elements one at a time.
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
@@ -198,8 +199,8 @@ inline void TransposeConv(
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
std::int64_t acc = scratch_buffer[Offset(output_shape, batch, out_y,
|
||||
out_x, out_channel)];
|
||||
Scalar acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||
out_channel)];
|
||||
if (bias_data) {
|
||||
acc += bias_data[out_channel];
|
||||
}
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/concatenation.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
inline void LstmCell(
|
||||
const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
|
||||
const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
|
||||
const float* prev_activ_data, const RuntimeShape& weights_shape,
|
||||
const float* weights_data, const RuntimeShape& unextended_bias_shape,
|
||||
const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
|
||||
const float* prev_state_data,
|
||||
const RuntimeShape& unextended_output_state_shape, float* output_state_data,
|
||||
const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
|
||||
const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
|
||||
const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
const RuntimeShape prev_activ_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
|
||||
const RuntimeShape bias_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_bias_shape);
|
||||
const RuntimeShape prev_state_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
|
||||
const RuntimeShape output_state_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
|
||||
const RuntimeShape output_activ_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
|
||||
const RuntimeShape concat_temp_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
|
||||
const RuntimeShape activ_temp_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
|
||||
TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
|
||||
|
||||
const int weights_dim_count = weights_shape.DimensionsCount();
|
||||
const int batches =
|
||||
MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
|
||||
output_state_shape, 0, output_activ_shape, 0);
|
||||
const int height =
|
||||
MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
|
||||
output_state_shape, 1, output_activ_shape, 1);
|
||||
const int width =
|
||||
MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
|
||||
output_state_shape, 2, output_activ_shape, 2);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int prev_activ_depth = prev_activ_shape.Dims(3);
|
||||
const int total_input_depth = prev_activ_depth + input_depth;
|
||||
TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
|
||||
total_input_depth);
|
||||
TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
|
||||
const int intern_activ_depth =
|
||||
MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
|
||||
TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
|
||||
intern_activ_depth * total_input_depth);
|
||||
TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
|
||||
const int output_depth =
|
||||
MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
|
||||
3, output_activ_shape, 3);
|
||||
TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
|
||||
|
||||
// Concatenate prev_activ and input data together
|
||||
float const* concat_input_arrays_data[2] = {input_data, prev_activ_data};
|
||||
const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
|
||||
&prev_activ_shape};
|
||||
tflite::ConcatenationParams concat_params;
|
||||
concat_params.axis = 3;
|
||||
concat_params.inputs_count = 2;
|
||||
Concatenation(concat_params, concat_input_arrays_shapes,
|
||||
concat_input_arrays_data, concat_temp_shape, concat_temp_data);
|
||||
|
||||
// Fully connected
|
||||
tflite::FullyConnectedParams fc_params;
|
||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||
FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
|
||||
weights_data, bias_shape, bias_data, activ_temp_shape,
|
||||
activ_temp_data);
|
||||
|
||||
// Memory state update (the LSTM "guts")
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int w = 0; w < width; ++w) {
|
||||
for (int h = 0; h < height; ++h) {
|
||||
for (int c = 0; c < output_depth; ++c) {
|
||||
const float input_gate =
|
||||
1.f /
|
||||
(1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
|
||||
0 * output_depth + c)]));
|
||||
const float new_input = std::tanh(activ_temp_data[Offset(
|
||||
activ_temp_shape, b, h, w, 1 * output_depth + c)]);
|
||||
const float forget_gate =
|
||||
1.f /
|
||||
(1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
|
||||
2 * output_depth + c)]));
|
||||
const float output_gate =
|
||||
1.f /
|
||||
(1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
|
||||
3 * output_depth + c)]));
|
||||
const float new_state =
|
||||
input_gate * new_input +
|
||||
forget_gate *
|
||||
prev_state_data[Offset(prev_state_shape, b, h, w, c)];
|
||||
output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
|
||||
output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
|
||||
output_gate * std::tanh(new_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Quantized LSTM cell implementation.
|
||||
// The quantization of the input, output arrays is as follows:
|
||||
// - The input activations are quantized as uint8 on the interval
|
||||
// [-1, 127/128].
|
||||
// The rationale for that is that is the natural interval for output
|
||||
// activations (see next point) and these need to be concatenated together.
|
||||
// We could accommodate different ranges by re-scaling, but we empirically
|
||||
// found that setting the input activations range to be [-1, 127/128] in the
|
||||
// first place, removing the need for re-scaling, greatly improves accuracy.
|
||||
// - The output activations are quantized as uint8 on the interval
|
||||
// [-1, 127/128].
|
||||
// The rationale for that is that the definition of a LSTM cell makes them
|
||||
// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
|
||||
// makes for simpler, more accurate fixed-point arithmetic.
|
||||
// - The output-at-previous-timestep state array is obviously quantized as
|
||||
// the output activations.
|
||||
// - The internal LSTM memory (not the output-at-previous-timestep, the other
|
||||
// internal state array) is int16-quantized and may use any power-of-two,
|
||||
// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
|
||||
// StateIntegerBits below, see the below discussion of that template
|
||||
// parameter ("The StateIntegerBits template parameter").
|
||||
// - The output of the internal fully-connected node is int16-quantized
|
||||
// on the interval [-8, 8 * 32767/32768], the rationale for which is
|
||||
// explained just below ("Why [-8, 8] for fully-connected output?").
|
||||
//
|
||||
//
|
||||
// === The StateIntegerBits template parameter ===
|
||||
//
|
||||
// The StateIntegerBits template parameter controls the fixed-point format used
|
||||
// to represent the internal memory of the LSTM cell (not the
|
||||
// output-at-previous-timestep, the other internal state array). It's currently
|
||||
// a template parameter so that the model can control that. The most typical
|
||||
// value for StateIntegerBits is 4. Other plausible values are anywhere between
|
||||
// 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
|
||||
// and drop that template parameter. The reason why it can't be a runtime
|
||||
// parameter is that this controls the fixed-point format used, i.e. we need to
|
||||
// generate actually different code based on it. In particular, we generate code
|
||||
// for a fixed-point tanh() implementation for that format, which internally
|
||||
// uses a fixed-point exp() implementation, which internally uses a
|
||||
// barrel-shifter with a number of steps that depends on StateIntegerBits.
|
||||
// Another consequence of that is that a higher value of StateIntegerBits
|
||||
// results in a more expensive implementation (more barrel shifter steps
|
||||
// needed).
|
||||
//
|
||||
//
|
||||
// === Why [-8, 8] for fully-connected output? ===
|
||||
//
|
||||
// This array is only fed to Logistic and Tanh functions, for which
|
||||
// the quantized implementation will want to use fixed-point arithmetic,
|
||||
// requiring a power-of-two representation interval. Thus, we should right
|
||||
// away quantize this array to a power-of-two interval; otherwise,
|
||||
// implementation will need to rescale that, losing any benefit that a tighter
|
||||
// representation interval might otherwise yield, while introducing some
|
||||
// numerical error and computational overhead.
|
||||
//
|
||||
// Now, Logistic and Tanh
|
||||
// are nearly constant (nearly equal to their horizontal asymptotes)
|
||||
// outside of a small bounded interval around 0:
|
||||
//
|
||||
// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
|
||||
// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
|
||||
// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
|
||||
//
|
||||
// From this, we see that clamping to [-4, 4] would be too inaccurate
|
||||
// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
|
||||
// while clamping to [-16, 16] would make no difference even in float32.
|
||||
// However, for a fixed-point implementation in 16-bit integers, using 5
|
||||
// integer bits to represent the [-16, 16] range would leave only 11
|
||||
// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
|
||||
// representable values. Notice that is higher than the
|
||||
// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
|
||||
// Using [-8, 8] thus seems like the better compromise overall, enjoying
|
||||
// an increment of 2.4e-4 between representable values and a worst-case
|
||||
// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
|
||||
// [-16, 16].
|
||||
//
|
||||
// Moreover, all other things being equal, it is nice to choose the narrower
|
||||
// representation range, as that makes the implementation of fixed-point
|
||||
// math functions a little cheaper (each integer bit requires an additional
|
||||
// barrel-shifter atep in the implementation of exp(-x)). That is further
|
||||
// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
|
||||
// sense for 32-bit float or 32-bit fixed-point quantization, but we are
|
||||
// aiming for 16-bit fixed-point quantization of these internal nodes here.
|
||||
//
|
||||
template <int StateIntegerBits>
|
||||
inline void LstmCell(const LstmCellParams& params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const uint8_t* input_data_uint8,
|
||||
const RuntimeShape& unextended_prev_activ_shape,
|
||||
const uint8_t* prev_activ_data_uint8,
|
||||
const RuntimeShape& weights_shape,
|
||||
const uint8_t* weights_data_uint8,
|
||||
const RuntimeShape& unextended_bias_shape,
|
||||
const int32_t* bias_data_int32,
|
||||
const RuntimeShape& unextended_prev_state_shape,
|
||||
const int16_t* prev_state_data_int16,
|
||||
const RuntimeShape& unextended_output_state_shape,
|
||||
int16_t* output_state_data_int16,
|
||||
const RuntimeShape& unextended_output_activ_shape,
|
||||
uint8_t* output_activ_data_uint8,
|
||||
const RuntimeShape& unextended_concat_temp_shape,
|
||||
uint8_t* concat_temp_data_uint8,
|
||||
const RuntimeShape& unextended_activ_temp_shape,
|
||||
int16_t* activ_temp_data_int16, void* gemmlowp_context) {
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
int32_t weights_zero_point = params.weights_zero_point;
|
||||
int32_t accum_multiplier = params.accum_multiplier;
|
||||
int accum_shift = params.accum_shift;
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
const RuntimeShape prev_activ_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
|
||||
const RuntimeShape bias_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_bias_shape);
|
||||
const RuntimeShape prev_state_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
|
||||
const RuntimeShape output_state_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
|
||||
const RuntimeShape output_activ_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
|
||||
const RuntimeShape concat_temp_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
|
||||
const RuntimeShape activ_temp_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
|
||||
TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
|
||||
|
||||
// Gather dimensions information, and perform consistency checks.
|
||||
const int weights_dim_count = weights_shape.DimensionsCount();
|
||||
const int outer_size = MatchingFlatSizeSkipDim(
|
||||
input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
|
||||
output_activ_shape);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int prev_activ_depth = prev_activ_shape.Dims(3);
|
||||
const int total_input_depth = prev_activ_depth + input_depth;
|
||||
TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
|
||||
total_input_depth);
|
||||
const int intern_activ_depth =
|
||||
MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
|
||||
TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
|
||||
intern_activ_depth * total_input_depth);
|
||||
TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
|
||||
TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
|
||||
const int output_depth =
|
||||
MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
|
||||
3, output_activ_shape, 3);
|
||||
TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
|
||||
const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
|
||||
const int fc_output_depth =
|
||||
MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
|
||||
const int fc_accum_depth = total_input_depth;
|
||||
TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
|
||||
|
||||
// Depth-concatenate prev_activ and input data together.
|
||||
uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
|
||||
prev_activ_data_uint8};
|
||||
const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
|
||||
&prev_activ_shape};
|
||||
tflite::ConcatenationParams concat_params;
|
||||
concat_params.axis = 3;
|
||||
concat_params.inputs_count = 2;
|
||||
Concatenation(concat_params, concat_input_arrays_shapes,
|
||||
concat_input_arrays_data, concat_temp_shape,
|
||||
concat_temp_data_uint8);
|
||||
|
||||
// Implementation of the fully connected node inside the LSTM cell.
|
||||
// The operands are 8-bit integers, the accumulators are internally 32bit
|
||||
// integers, and the output is 16-bit fixed-point with 3 integer bits so
|
||||
// the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
|
||||
// is explained in the function comment above.
|
||||
for (int b = 0; b < fc_batches; ++b) {
|
||||
for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
|
||||
// Internal accumulation.
|
||||
// Initialize accumulator with the bias-value.
|
||||
int32_t accum = bias_data_int32[out_c];
|
||||
// Accumulation loop.
|
||||
for (int d = 0; d < fc_accum_depth; ++d) {
|
||||
int16_t input_val =
|
||||
concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
|
||||
int16_t weights_val =
|
||||
weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
|
||||
accum += input_val * weights_val;
|
||||
}
|
||||
// Down-scale the final int32 accumulator to the scale used by our
|
||||
// (16-bit, using 3 integer bits) fixed-point format. The quantized
|
||||
// multiplier and shift here have been pre-computed offline
|
||||
// (e.g. by toco).
|
||||
accum =
|
||||
MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
|
||||
// Saturate, cast to int16, and store to the temporary activations array.
|
||||
accum = std::max(-32768, std::min(32767, accum));
|
||||
activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
|
||||
}
|
||||
}
|
||||
|
||||
// Rest of the LSTM cell: tanh and logistic math functions, and some adds
|
||||
// and muls, all done in 16-bit fixed-point.
|
||||
for (int b = 0; b < outer_size; ++b) {
|
||||
for (int c = 0; c < output_depth; ++c) {
|
||||
// Define the fixed-point data types that we will use here. All use
|
||||
// int16 as the underlying integer type i.e. all are 16-bit fixed-point.
|
||||
// They only differ by the number of integral vs. fractional bits,
|
||||
// determining the range of values that they can represent.
|
||||
//
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
// whose range is in [-1, 1].
|
||||
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
||||
// F3 uses 3 integer bits, range [-8, 8].
|
||||
// This is the range of the previous fully-connected node's output,
|
||||
// which is our input here.
|
||||
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
|
||||
// FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
|
||||
// 2^StateIntegerBits]. It's used to represent the internal state, whose
|
||||
// number of integer bits is currently dictated by the model. See comment
|
||||
// on the StateIntegerBits template parameter above.
|
||||
using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
|
||||
// Implementation of input gate, using fixed-point logistic function.
|
||||
F3 input_gate_input = F3::FromRaw(
|
||||
activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
|
||||
F0 input_gate_output = gemmlowp::logistic(input_gate_input);
|
||||
// Implementation of input modulation gate, using fixed-point tanh
|
||||
// function.
|
||||
F3 input_modulation_gate_input = F3::FromRaw(
|
||||
activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
|
||||
F0 input_modulation_gate_output =
|
||||
gemmlowp::tanh(input_modulation_gate_input);
|
||||
// Implementation of forget gate, using fixed-point logistic function.
|
||||
F3 forget_gate_input = F3::FromRaw(
|
||||
activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
|
||||
F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
|
||||
// Implementation of output gate, using fixed-point logistic function.
|
||||
F3 output_gate_input = F3::FromRaw(
|
||||
activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
|
||||
F0 output_gate_output = gemmlowp::logistic(output_gate_input);
|
||||
// Implementation of internal multiplication nodes, still in fixed-point.
|
||||
F0 input_times_input_modulation =
|
||||
input_gate_output * input_modulation_gate_output;
|
||||
FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
|
||||
FS prev_state_times_forget_state = forget_gate_output * prev_state;
|
||||
// Implementation of internal addition node, saturating.
|
||||
FS new_state = gemmlowp::SaturatingAdd(
|
||||
gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
|
||||
prev_state_times_forget_state);
|
||||
// Implementation of last internal Tanh node, still in fixed-point.
|
||||
// Since a Tanh fixed-point implementation is specialized for a given
|
||||
// number or integer bits, and each specialization can have a substantial
|
||||
// code size, and we already used above a Tanh on an input with 3 integer
|
||||
// bits, and per the table in the above function comment there is no
|
||||
// significant accuracy to be lost by clamping to [-8, +8] for a
|
||||
// 3-integer-bits representation, let us just do that. This helps people
|
||||
// porting this to targets where code footprint must be minimized.
|
||||
F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
|
||||
F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
|
||||
// Store the new internal state back to memory, as 16-bit integers.
|
||||
// Note: here we store the original value with StateIntegerBits, not
|
||||
// the rescaled 3-integer-bits value fed to tanh.
|
||||
output_state_data_int16[b * output_depth + c] = new_state.raw();
|
||||
// Down-scale the output activations to 8-bit integers, saturating,
|
||||
// and store back to memory.
|
||||
int16_t rescaled_output_activ =
|
||||
gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
|
||||
int16_t clamped_output_activ = std::max<int16_t>(
|
||||
-128, std::min<int16_t>(127, rescaled_output_activ));
|
||||
output_activ_data_uint8[b * output_depth + c] =
|
||||
128 + clamped_output_activ;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
|
||||
@@ -227,6 +227,41 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
int8_t* __restrict__ result) {
|
||||
const int kBlockSize = 16;
|
||||
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
const int8_t* matrix_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; ++row) {
|
||||
int32_t dot_prod = 0;
|
||||
const int8_t* vector_in_batch = vector + batch * m_cols;
|
||||
for (int i = segments[row]; i < segments[row + 1]; ++i) {
|
||||
const int block_start_index = indices[i] * kBlockSize;
|
||||
const int8_t* vector_block_in_batch_ptr =
|
||||
vector_in_batch + block_start_index;
|
||||
for (int c = 0; c < kBlockSize; c++) {
|
||||
dot_prod += *matrix_ptr * *vector_block_in_batch_ptr++;
|
||||
dot_prod += *matrix_ptr++ * input_offset;
|
||||
}
|
||||
}
|
||||
const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0;
|
||||
dot_prod = MultiplyByQuantizedMultiplier(dot_prod + bias_value,
|
||||
output_multiplier, output_shift);
|
||||
dot_prod += output_offset;
|
||||
result[batch * m_rows + row] =
|
||||
static_cast<int8_t>(ActivationFunctionWithMinMax(
|
||||
dot_prod, output_activation_min, output_activation_max));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define __restrict__ __restrict
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
namespace tensor_utils {
|
||||
|
||||
// Check if all entries of a vector are zero for float.
|
||||
bool IsZeroVector(const float* vector, int v_size) {
|
||||
return PortableIsZeroVector(vector, v_size);
|
||||
}
|
||||
|
||||
// Check if all entries of a vector are zero for int8_t.
|
||||
bool IsZeroVector(const int8_t* vector, int v_size) {
|
||||
return PortableIsZeroVector(vector, v_size);
|
||||
}
|
||||
|
||||
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min, float* max,
|
||||
float* scaling_factor) {
|
||||
PortableSymmetricQuantizeFloats(values, size, quantized_values, min, max,
|
||||
scaling_factor);
|
||||
}
|
||||
|
||||
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float min_value,
|
||||
float max_value, float* scaling_factor) {
|
||||
PortableSymmetricQuantizeFloats(values, size, quantized_values, min_value,
|
||||
max_value, scaling_factor);
|
||||
}
|
||||
|
||||
void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* scaling_factor,
|
||||
int32_t* offset) {
|
||||
PortableAsymmetricQuantizeFloats(values, size, quantized_values,
|
||||
scaling_factor, offset);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
|
||||
int m_cols, const float* vector,
|
||||
int n_batch, float* result) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
n_batch, result);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
|
||||
const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vector,
|
||||
const float* scaling_factors,
|
||||
int n_batch,
|
||||
float* __restrict__ result) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
scaling_factors, n_batch, result);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
|
||||
context);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
|
||||
const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vector,
|
||||
const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch,
|
||||
float* __restrict__ result,
|
||||
CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
scaling_factors, n_batch, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
matrix, segments, indices, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, ledger, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
|
||||
int8_t* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch,
|
||||
input_offset, output_multiplier, output_shift, output_offset,
|
||||
output_activation_min, output_activation_max, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||
const int m_cols, const int8_t* __restrict__ vectors,
|
||||
const float* scaling_factors, int n_batch, float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch,
|
||||
result);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int16_t* output, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
|
||||
n_output, output_zp, scratch, output, context);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
|
||||
int32_t* scratch, int8_t* output, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
|
||||
n_output, output_zp, scratch, output, context);
|
||||
}
|
||||
|
||||
void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
|
||||
int32_t n_row, int32_t n_col,
|
||||
int32_t* output) {
|
||||
PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
|
||||
const int8_t* input_to_gate_weights,
|
||||
int32_t input_to_gate_effective_scale_a,
|
||||
int32_t input_to_gate_effective_scale_b,
|
||||
int32_t n_batch, int32_t n_input, int32_t n_cell,
|
||||
int8_t* gate_output, int8_t gate_output_zp) {
|
||||
PortableMatrixBatchVectorMultiply(
|
||||
input, input_zeropoint, input_to_gate_weights,
|
||||
input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
|
||||
n_input, n_cell, gate_output, gate_output_zp);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiply(const int16_t* hidden,
|
||||
const int8_t* hidden_to_output_weights,
|
||||
int32_t proj_effective_scale_a,
|
||||
int32_t proj_effective_scale_b,
|
||||
const int32_t* gate_bias, int32_t n_batch,
|
||||
int32_t n_hidden, int32_t n_output,
|
||||
int32_t output_zp, int8_t* proj_output) {
|
||||
PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
|
||||
proj_effective_scale_a,
|
||||
proj_effective_scale_b, gate_bias, n_batch,
|
||||
n_hidden, n_output, output_zp, proj_output);
|
||||
}
|
||||
|
||||
void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
|
||||
const int32_t* bias, int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, int32_t variance_limit,
|
||||
int n_batch, int n_input, int16_t* output) {
|
||||
PortableApplyLayerNorm(input, layer_norm_weights, bias, layer_norm_scale_a,
|
||||
layer_norm_scale_b, variance_limit, n_batch, n_input,
|
||||
output);
|
||||
}
|
||||
|
||||
void ApplyLayerNormFloat(const int16_t* input,
|
||||
const int16_t* layer_norm_weights,
|
||||
int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
|
||||
const int32_t* bias, int n_batch, int n_input,
|
||||
int16_t* output) {
|
||||
PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
|
||||
layer_norm_scale_b, bias, n_batch, n_input,
|
||||
output);
|
||||
}
|
||||
|
||||
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int16_t* output) {
|
||||
PortableApplySigmoid(input, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int16_t* output) {
|
||||
PortableApplySigmoidFloat(input, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
|
||||
int32_t n_input, int16_t* output) {
|
||||
PortableApplyTanh(integer_bits, input, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
|
||||
int32_t integer_bits, int16_t* output) {
|
||||
PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
|
||||
}
|
||||
|
||||
void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int shift, int16_t* output) {
|
||||
PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
|
||||
}
|
||||
|
||||
void CwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
int32_t multiplier, int32_t shift, int32_t n_batch,
|
||||
int32_t n_input, int32_t output_zp, int8_t* output) {
|
||||
PortableCwiseMul(input_1, input_2, multiplier, shift, n_batch, n_input,
|
||||
output_zp, output);
|
||||
}
|
||||
|
||||
void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int16_t* output) {
|
||||
PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void CwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
|
||||
const int16_t* batch_vector,
|
||||
int n_batch, int32_t multiplier,
|
||||
int shift, int16_t* result) {
|
||||
PortableVectorBatchVectorCwiseProductAccumulate(
|
||||
vector, v_size, batch_vector, n_batch, multiplier, shift, result);
|
||||
}
|
||||
|
||||
float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size) {
|
||||
return PortableVectorVectorDotProduct(vector1, vector2, v_size);
|
||||
}
|
||||
|
||||
void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
|
||||
const int16_t* vector2, int v_size,
|
||||
int n_batch, int32_t* result) {
|
||||
PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
|
||||
result);
|
||||
}
|
||||
|
||||
void Sub1Vector(const float* vector, int v_size, float* result) {
|
||||
PortableSub1Vector(vector, v_size, result);
|
||||
}
|
||||
|
||||
void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
|
||||
PortableSub1Vector(vector, v_size, result);
|
||||
}
|
||||
|
||||
// Multiply all elements of vector with a scalar.
|
||||
void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result) {
|
||||
PortableVectorScalarMultiply(vector, v_size, scale, result);
|
||||
}
|
||||
|
||||
void ReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||
int v_size, int n_batch) {
|
||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||
}
|
||||
|
||||
void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||
const int8_t* recurrent, int8_t recurrent_zp,
|
||||
int32_t input_effective_scale_a,
|
||||
int32_t input_effective_scale_b,
|
||||
int32_t recurrent_effective_scale_a,
|
||||
int32_t recurrent_effective_scale_b, int32_t n_batch,
|
||||
int32_t n_cell, int16_t* output) {
|
||||
PortableTwoGateSaturatingAdd(
|
||||
input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
|
||||
input_effective_scale_b, recurrent_effective_scale_a,
|
||||
recurrent_effective_scale_b, n_batch, n_cell, output);
|
||||
}
|
||||
|
||||
} // namespace tensor_utils
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
|
||||
@@ -87,6 +87,15 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
float* __restrict__ result);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16(
|
||||
const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
|
||||
int n_batch, const int32_t input_offset, const int32_t output_multiplier,
|
||||
const int32_t output_shift, const int32_t output_offset,
|
||||
const int32_t output_activation_min, const int32_t output_activation_max,
|
||||
int8_t* __restrict__ result);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||
const int m_cols, const int8_t* __restrict__ vectors,
|
||||
|
||||
@@ -273,6 +273,9 @@ void BroadcastQuantSubSlow(const ArithmeticParams& params,
|
||||
const T* input2_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("BroadcastQuantSubSlow/T");
|
||||
TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N);
|
||||
TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N);
|
||||
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
|
||||
NdArrayDesc<N> desc1;
|
||||
NdArrayDesc<N> desc2;
|
||||
NdArrayDesc<N> output_desc;
|
||||
|
||||
Reference in New Issue
Block a user