diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 90204ff3827..138f76dada6 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -40,6 +40,7 @@ PythonOpsResolver::PythonOpsResolver() { AddConv2D(); AddCos(); AddCumSum(); + AddDecode(); AddDelay(); AddDepthToSpace(); AddDepthwiseConv2D(); diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 1bb250885a8..264ccba383e 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -236,6 +236,10 @@ tflm_kernel_cc_library( "conv.cc", "conv_common.cc", "cumsum.cc", + "decode.cc", + "decode_state.cc", + "decode_state_lut.cc", + "decode_state_prune.cc", "depth_to_space.cc", "depthwise_conv.cc", "depthwise_conv_common.cc", @@ -327,6 +331,9 @@ tflm_kernel_cc_library( "batch_matmul.h", "circular_buffer.h", "conv.h", + "decode_state.h", + "decode_state_lut.h", + "decode_state_prune.h", "depthwise_conv.h", "dequantize.h", "ethosu.h", @@ -643,6 +650,21 @@ tflm_cc_test( ], ) +tflm_cc_test( + name = "decode_test", + srcs = [ + "decode_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_test( name = "decompress_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 11684278801..62e9324995e 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \ diff --git a/tensorflow/lite/micro/kernels/decode.cc b/tensorflow/lite/micro/kernels/decode.cc new file mode 100644 index 00000000000..92d7c121d89 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode.cc @@ -0,0 +1,155 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_log.h" + +namespace tflite { +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const size_t num_inputs = NumInputs(node); + const size_t num_outputs = NumOutputs(node); + TF_LITE_ENSURE(context, num_outputs > 0); + TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs * 2); + + MicroContext* const micro_context = GetMicroContext(context); + + node->user_data = micro_context->AllocatePersistentBuffer( + num_outputs * sizeof(DecodeState*)); + TF_LITE_ENSURE(context, node->user_data != nullptr); + DecodeState** const dsp_arr = + reinterpret_cast(node->user_data); + + TfLiteTensor* input = nullptr; + TfLiteTensor* ancillary = nullptr; + TfLiteTensor* output = nullptr; + TfLiteStatus status = kTfLiteOk; + + for (size_t i = 0; i < num_inputs; i += 2) { + input = micro_context->AllocateTempInputTensor(node, i); + if (input == nullptr) { + MicroPrintf("failed to allocate input tensor %u", i); + status = kTfLiteError; + break; + } + ancillary = micro_context->AllocateTempInputTensor(node, i + 1); + if (ancillary == nullptr) { + MicroPrintf("failed to allocate ancillary tensor %u", i + 1); + status = kTfLiteError; + break; + } + output = micro_context->AllocateTempOutputTensor(node, i / 2); + if (output == nullptr) { + MicroPrintf("failed to allocate output tensor %u", i / 2); + status = kTfLiteError; + break; + } + + TF_LITE_ENSURE(context, IsConstantTensor(input)); + TF_LITE_ENSURE(context, IsConstantTensor(ancillary)); + + if (DecodeState::Version(*ancillary) != 1) { + MicroPrintf("version %u != 1", DecodeState::Version(*ancillary)); + status = kTfLiteError; + break; + } + + DecodeState* dsp = nullptr; + switch (DecodeState::Type(*ancillary)) { + case DecodeState::kDcmTypeLUT: + dsp = DecodeState::CreateDecodeStateLUT( + context, micro_context->GetAlternateProfiler()); + break; + case DecodeState::kDcmTypePrune: + dsp = DecodeState::CreateDecodeStatePrune( + context, micro_context->GetAlternateProfiler()); + break; + case DecodeState::kDcmTypeCustom: + MicroPrintf("Custom decode type not yet supported"); + break; + default: + MicroPrintf("unsupported decode type %u", + DecodeState::Type(*ancillary)); + break; + } + + if (dsp != nullptr) { + status = dsp->Setup(*input, *ancillary, *output); + if (status != kTfLiteOk) { + break; + } + dsp_arr[i / 2] = dsp; + } else { + MicroPrintf("failed to allocate DecodeState[%u]", i / 2); + break; + } + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(ancillary); + micro_context->DeallocateTempTfLiteTensor(output); + input = nullptr; + ancillary = nullptr; + output = nullptr; + } + + if (input != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input); + } + if (ancillary != nullptr) { + micro_context->DeallocateTempTfLiteTensor(ancillary); + } + if (output != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output); + } + + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const size_t num_inputs = NumInputs(node); + DecodeState** const dsp_arr = + reinterpret_cast(node->user_data); + + for (size_t i = 0; i < num_inputs; i += 2) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, i); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteEvalTensor* ancillary = + tflite::micro::GetEvalInput(context, node, i + 1); + TF_LITE_ENSURE(context, ancillary != nullptr); + const TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, i / 2); + TF_LITE_ENSURE(context, output != nullptr); + + TfLiteStatus status = dsp_arr[i / 2]->Decode(*input, *ancillary, *output); + TF_LITE_ENSURE(context, status == kTfLiteOk); + } + + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_DECODE() { + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state.cc b/tensorflow/lite/micro/kernels/decode_state.cc new file mode 100644 index 00000000000..af0cc9eef44 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state.cc @@ -0,0 +1,50 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state.h" + +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" +#include "tensorflow/lite/micro/micro_context.h" + +namespace tflite { + +DecodeState* DecodeState::CreateDecodeStateLUT( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStateLUT)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler); + + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); + + return dsp; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state.h b/tensorflow/lite/micro/kernels/decode_state.h new file mode 100644 index 00000000000..4c61a0b1056 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state.h @@ -0,0 +1,90 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_ + +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_profiler_interface.h" + +namespace tflite { + +struct DecodeState { + DecodeState() = delete; + + DecodeState(const TfLiteContext* context, MicroProfilerInterface* profiler) + : context_(context), micro_profiler_(profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) = 0; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) = 0; + + static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context, + MicroProfilerInterface* profiler); + static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler); + + static uint8_t Type(const TfLiteTensor& ancillary) { + return GetTensorData(&ancillary)[kDcmDecodeTypeOffset]; + } + + static uint8_t Type(const TfLiteEvalTensor& ancillary) { + return micro::GetTensorData(&ancillary)[kDcmDecodeTypeOffset]; + } + + static uint8_t Version(const TfLiteTensor& ancillary) { + return GetTensorData(&ancillary)[kDcmVersionOffset]; + } + + static uint8_t Version(const TfLiteEvalTensor& ancillary) { + return micro::GetTensorData(&ancillary)[kDcmVersionOffset]; + } + + protected: + virtual ~DecodeState() = default; + + // Decode Common Metadata constants + public: + static constexpr uint8_t kDcmTypeLUT = 0; + static constexpr uint8_t kDcmTypePrune = 2; + static constexpr uint8_t kDcmTypeCustom = 127; + + static constexpr size_t kDcmSizeInBytes = 16; + + private: + static constexpr size_t kDcmDecodeTypeOffset = 0; + static constexpr size_t kDcmVersionOffset = 1; + + // DecodeState vars + protected: + const TfLiteContext* context_; + MicroProfilerInterface* micro_profiler_; + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_ diff --git a/tensorflow/lite/micro/kernels/decode_state_lut.cc b/tensorflow/lite/micro/kernels/decode_state_lut.cc new file mode 100644 index 00000000000..477c21d80a7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_lut.cc @@ -0,0 +1,630 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStateLUT::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + // resolve num_channels_ and use_alternate_axis_ + if (output.quantization.type == kTfLiteAffineQuantization && + output.quantization.params != nullptr) { + const TfLiteAffineQuantization* quantization = + reinterpret_cast(output.quantization.params); + num_channels_ = quantization->scale->size; + if ((quantization->quantized_dimension == output.dims->size - 1) && + num_channels_ > 1) { + use_alternate_axis_ = true; + } else if (quantization->quantized_dimension != 0) { + MicroPrintf("unsupported quantization axis %u", + quantization->quantized_dimension); + return kTfLiteError; + } + } + + compressed_indices_ = GetTensorData(&input); + count_indices_ = NumElements(&output); + elements_per_channel_ = + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; + value_table_ = &ancillary_data[kDcmSizeInBytes]; + value_table_channel_stride_ = ancillary_data[kDcmValueTableStrideOffset]; + compressed_bit_width_ = + ancillary_data[kDcmParamsOffset] & kDcmParamsBitWidthMask; + + return kTfLiteOk; +} + +TfLiteStatus DecodeStateLUT::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(buffer); + break; + case kTfLiteFloat32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + DecompressToBuffer(buffer); + break; + case kTfLiteInt16: + DecompressToBuffer(buffer); + break; + case kTfLiteInt32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +template +T* DecodeStateLUT::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && compressed_bit_width_ == 4 && + !use_alternate_axis_) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && compressed_bit_width_ == 3 && + !use_alternate_axis_) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && compressed_bit_width_ == 2 && + !use_alternate_axis_) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecodeStateLUT::DecompressToBuffer(void*); +template float* DecodeStateLUT::DecompressToBuffer(void*); +template int8_t* DecodeStateLUT::DecompressToBuffer(void*); +template int16_t* DecodeStateLUT::DecompressToBuffer(void*); +template int32_t* DecodeStateLUT::DecompressToBuffer(void*); +template int64_t* DecodeStateLUT::DecompressToBuffer(void*); + +void DecodeStateLUT::DecompressToBufferWidth4_16(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = value_table_channel_stride_; + const uint8_t* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint64_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint64_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 1]); + + while (count >= 16) { + count -= 16; + uint64_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 4) & 0x0F]); + value |= static_cast(value_table[index & 0x0F]) << 8; + value |= static_cast(value_table[(index >> 12) & 0x0F]) << 16; + value |= static_cast(value_table[(index >> 8) & 0x0F]) << 24; + value |= static_cast(value_table[(index >> 20) & 0x0F]) << 32; + value |= static_cast(value_table[(index >> 16) & 0x0F]) << 40; + value |= static_cast(value_table[(index >> 28) & 0x0F]) << 48; + value |= static_cast(value_table[(index >> 24) & 0x0F]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 36) & 0x0F]); + value2 |= static_cast(value_table[(index >> 32) & 0x0F]) << 8; + value2 |= static_cast(value_table[(index >> 44) & 0x0F]) + << 16; + value2 |= static_cast(value_table[(index >> 40) & 0x0F]) + << 24; + value2 |= static_cast(value_table[(index >> 52) & 0x0F]) + << 32; + value2 |= static_cast(value_table[(index >> 48) & 0x0F]) + << 40; + value2 |= static_cast(value_table[(index >> 60) & 0x0F]) + << 48; + value2 |= static_cast(value_table[(index >> 56) & 0x0F]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 1; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecodeStateLUT::DecompressToBufferWidth2_16(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = value_table_channel_stride_; + const uint8_t* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 2]); + + while (count >= 16) { + count -= 16; + uint32_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 6) & 0x03]); + value |= static_cast(value_table[(index >> 4) & 0x03]) << 8; + value |= static_cast(value_table[(index >> 2) & 0x03]) << 16; + value |= static_cast(value_table[index & 0x03]) << 24; + value |= static_cast(value_table[(index >> 14) & 0x03]) << 32; + value |= static_cast(value_table[(index >> 12) & 0x03]) << 40; + value |= static_cast(value_table[(index >> 10) & 0x03]) << 48; + value |= static_cast(value_table[(index >> 8) & 0x03]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 22) & 0x03]); + value2 |= static_cast(value_table[(index >> 20) & 0x03]) << 8; + value2 |= static_cast(value_table[(index >> 18) & 0x03]) + << 16; + value2 |= static_cast(value_table[(index >> 16) & 0x03]) + << 24; + value2 |= static_cast(value_table[(index >> 30) & 0x03]) + << 32; + value2 |= static_cast(value_table[(index >> 28) & 0x03]) + << 40; + value2 |= static_cast(value_table[(index >> 26) & 0x03]) + << 48; + value2 |= static_cast(value_table[(index >> 24) & 0x03]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 2; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecodeStateLUT::DecompressToBufferWidth3_32(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = value_table_channel_stride_; + const uint8_t* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x1F)) { + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 32 + if (count >= 32) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[(current_offset >> 5) * 12]); + + while (count >= 32) { + count -= 32; + uint32_t index0 = *indices++; + uint32_t index1 = *indices++; + uint32_t index2 = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index0 >> 5) & 0x07]); + value |= static_cast(value_table[(index0 >> 2) & 0x07]) << 8; + value |= + static_cast( + value_table[((index0 << 1) & 0b110) | ((index0 >> 15) & 0b1)]) + << 16; + value |= static_cast(value_table[(index0 >> 12) & 0x07]) + << 24; + value |= static_cast(value_table[(index0 >> 9) & 0x07]) << 32; + value |= + static_cast( + value_table[((index0 >> 6) & 0b100) | ((index0 >> 22) & 0b11)]) + << 40; + value |= static_cast(value_table[(index0 >> 19) & 0x07]) + << 48; + value |= static_cast(value_table[(index0 >> 16) & 0x07]) + << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index0 >> 29) & 0x07]); + value2 |= static_cast(value_table[(index0 >> 26) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index0 >> 23) & 0b110) | ((index1 >> 7) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index1 >> 4) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index1 >> 1) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index1 << 2) & 0b100) | ((index1 >> 14) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index1 >> 11) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index1 >> 8) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + value = static_cast(value_table[(index1 >> 21) & 0x07]); + value |= static_cast(value_table[(index1 >> 18) & 0x07]) << 8; + value |= + static_cast( + value_table[((index1 >> 15) & 0b110) | ((index1 >> 31) & 0b1)]) + << 16; + value |= static_cast(value_table[(index1 >> 28) & 0x07]) + << 24; + value |= static_cast(value_table[(index1 >> 25) & 0x07]) + << 32; + value |= + static_cast( + value_table[((index1 >> 22) & 0b100) | ((index2 >> 6) & 0b11)]) + << 40; + value |= static_cast(value_table[(index2 >> 3) & 0x07]) << 48; + value |= static_cast(value_table[(index2 >> 0) & 0x07]) << 56; + + *reinterpret_cast(buffer + 16) = value; + + value2 = static_cast(value_table[(index2 >> 13) & 0x07]); + value2 |= static_cast(value_table[(index2 >> 10) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index2 >> 7) & 0b110) | ((index2 >> 23) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index2 >> 20) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index2 >> 17) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index2 >> 14) & 0b100) | ((index2 >> 30) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index2 >> 27) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index2 >> 24) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 24) = value2; + + buffer += 32; + current_offset += 32; + } + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +// TODO(ddavis-2015): templating GetNextTableIndexWidth makes this method +// more than 2x faster, but with a large code size increase +template +void DecodeStateLUT::DecompressToBufferWidthAny(T* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + if (use_alternate_axis_) { + const size_t stride = value_table_channel_stride_; + size_t current_offset = 0; + size_t count = count_indices_; + + while (count > 0) { + const T* value_table = static_cast(value_table_); + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + value_table += stride; + } + count -= num_channels_; + } + } else { + const size_t stride = value_table_channel_stride_; + const T* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + while (count-- > 0) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + } + value_table += stride; + } + } +} + +template void DecodeStateLUT::DecompressToBufferWidthAny(bool*); +template void DecodeStateLUT::DecompressToBufferWidthAny(float*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int8_t*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int16_t*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int32_t*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int64_t*); + +inline size_t DecodeStateLUT::GetNextTableIndexWidth7( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 7; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 1; + case 1: + return ((indices[0] & 0b1) << 6) | (indices[1] >> 2); + case 2: + return ((indices[1] & 0b11) << 5) | (indices[2] >> 3); + case 3: + return ((indices[2] & 0b111) << 4) | (indices[3] >> 4); + case 4: + return ((indices[3] & 0x0F) << 3) | (indices[4] >> 5); + case 5: + return ((indices[4] & 0x1F) << 2) | (indices[5] >> 6); + case 6: + return ((indices[5] & 0x3F) << 1) | (indices[6] >> 7); + case 7: + return indices[6] & 0x7F; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth6( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 2) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b11) { + case 0: + return indices[0] >> 2; + case 1: + return ((indices[0] & 0b11) << 4) | (indices[1] >> 4); + case 2: + return ((indices[1] & 0x0F) << 2) | (indices[2] >> 6); + case 3: + return indices[2] & 0x3F; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth5( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 5; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 3; + case 1: + return ((indices[0] & 0b111) << 2) | (indices[1] >> 6); + case 2: + return (indices[1] >> 1) & 0x1F; + case 3: + return ((indices[1] & 0b1) << 4) | (indices[2] >> 4); + case 4: + return ((indices[2] & 0x0F) << 1) | (indices[3] >> 7); + case 5: + return (indices[3] >> 2) & 0x1F; + case 6: + return ((indices[3] & 0b11) << 3) | (indices[4] >> 5); + case 7: + return indices[4] & 0x1F; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth4( + const size_t current_offset) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 1] & 0x0F; + } else { + return compressed_indices_[current_offset >> 1] >> 4; + } +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth3( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 5; + case 1: + return (indices[0] >> 2) & 0b111; + case 2: + return ((indices[0] & 0b11) << 1) | (indices[1] >> 7); + case 3: + return (indices[1] >> 4) & 0b111; + case 4: + return (indices[1] >> 1) & 0b111; + case 5: + return ((indices[1] & 0b1) << 2) | (indices[2] >> 6); + case 6: + return (indices[2] >> 3) & 0b111; + case 7: + return indices[2] & 0b111; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth2( + const size_t current_offset) { + if (current_offset & 0b10) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 2] & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 2) & 0x03; + } + } else { + if (current_offset & 1) { + return (compressed_indices_[current_offset >> 2] >> 4) & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 6) & 0x03; + } + } +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth1( + const size_t current_offset) { + const size_t shift = ~current_offset & 0b111; + return (compressed_indices_[current_offset >> 3] >> shift) & 0b1; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_lut.h b/tensorflow/lite/micro/kernels/decode_state_lut.h new file mode 100644 index 00000000000..dbb64683960 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_lut.h @@ -0,0 +1,92 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_LUT_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_LUT_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +struct DecodeStateLUT : public DecodeState { + DecodeStateLUT() = delete; + + DecodeStateLUT(const TfLiteContext* context, MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + // LUT compression constants + static constexpr size_t kMaxBitWidth = 7; + static constexpr size_t kMaxValueTableChannelStride = 128; + + private: + // LUT Decode Common Metadata constants + static constexpr size_t kDcmVersionOffset = 4; + static constexpr size_t kDcmParamsOffset = 5; + static constexpr uint8_t kDcmParamsBitWidthMask = 0x07; + static constexpr size_t kDcmValueTableStrideOffset = 6; + + protected: + virtual ~DecodeStateLUT() = default; + + template + T* DecompressToBuffer(void* buffer); + + // optimized C++ for INT8, use_alt_axis == false + void DecompressToBufferWidth4_16(int8_t* buffer); + void DecompressToBufferWidth3_32(int8_t* buffer); + void DecompressToBufferWidth2_16(int8_t* buffer); + + // generic C++ for any bit width and value table type + template + void DecompressToBufferWidthAny(T* buffer); + + // Optimized C++ table index fetch + inline size_t GetNextTableIndexWidth7(const size_t current_offset); + inline size_t GetNextTableIndexWidth6(const size_t current_offset); + inline size_t GetNextTableIndexWidth5(const size_t current_offset); + inline size_t GetNextTableIndexWidth4(const size_t current_offset); + inline size_t GetNextTableIndexWidth3(const size_t current_offset); + inline size_t GetNextTableIndexWidth2(const size_t current_offset); + inline size_t GetNextTableIndexWidth1(const size_t current_offset); + + protected: + const uint8_t* compressed_indices_ = nullptr; + size_t count_indices_ = 0; + size_t num_channels_ = 1; + size_t elements_per_channel_ = 0; // computed from use_alternate_axis_ + const void* value_table_ = nullptr; // Pointer into FlatBuffer values + uint8_t value_table_channel_stride_ = 0; // elements per channel + uint8_t compressed_bit_width_ = 0; // 1 to 7 bits + bool use_alternate_axis_ = false; // shape channel axis: + // false = first, true = last + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_LUT_H_ diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.cc b/tensorflow/lite/micro/kernels/decode_state_prune.cc new file mode 100644 index 00000000000..f5ff7ac6a58 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.cc @@ -0,0 +1,199 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + // resolve num_channels_, use_alternate_axis_, and zero points + if (output.quantization.type == kTfLiteAffineQuantization && + output.quantization.params != nullptr) { + const TfLiteAffineQuantization* quantization = + reinterpret_cast(output.quantization.params); + num_channels_ = quantization->scale->size; + if ((quantization->quantized_dimension == output.dims->size - 1) && + num_channels_ > 1) { + use_alternate_axis_ = true; + } else if (quantization->quantized_dimension != 0) { + MicroPrintf("unsupported quantization axis %u", + quantization->quantized_dimension); + return kTfLiteError; + } + + if (output.type != kTfLiteInt8) { + // make sure all zero points are 0 (zero) + for (size_t i = 0; i < num_channels_; i++) { + TF_LITE_ENSURE(const_cast(context_), + quantization->zero_point->data[i] == 0); + } + } + + if (num_channels_ > 1 && output.type == kTfLiteInt8) { + // copy zero points + MicroContext* micro_context = GetMicroContext(context_); + const size_t bufsize = num_channels_ * sizeof(*zero_points_); + zero_points_ = static_cast( + micro_context->AllocatePersistentBuffer(bufsize)); + if (zero_points_ == nullptr) { + MicroPrintf("unable to allocate zero_points_"); + return kTfLiteError; + } + std::copy_n(quantization->zero_point->data, num_channels_, zero_points_); + } else { + single_zero_point_ = quantization->zero_point->data[0]; + } + } + + compressed_indices_ = GetTensorData(&input); + count_indices_ = NumElements(&output); + elements_per_channel_ = + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; + value_table_ = &ancillary_data[kDcmSizeInBytes]; + + return kTfLiteOk; +} + +TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(buffer); + break; + case kTfLiteFloat32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + if (num_channels_ > 1) { + DecompressToBufferPerChannelInt8(buffer); + } else { + DecompressToBuffer(buffer); + } + break; + case kTfLiteInt16: + DecompressToBuffer(buffer); + break; + case kTfLiteInt32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +template +void DecodeStatePrune::DecompressToBuffer(void* vp) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + T* buffer = static_cast(vp); + const T* value_table = static_cast(value_table_); + const size_t max_count = count_indices_; + const uint8_t* const indices = compressed_indices_; + + for (size_t index = 0; index < max_count; index++) { + size_t shift = ~index & 0b111; + size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = single_zero_point_; + } + } +} + +void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int8_t* buffer = static_cast(vp); + size_t current_offset = 0; + const uint8_t* const indices = compressed_indices_; + const int8_t* value_table = static_cast(value_table_); + + if (use_alternate_axis_) { + const size_t max_channels = num_channels_; + size_t count = count_indices_; + + while (count > 0) { + for (size_t channel = 0; channel < max_channels; channel++) { + const int8_t zp = zero_points_[channel]; + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + count -= max_channels; + } + } else { + const size_t max_count = elements_per_channel_; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + const int8_t zp = zero_points_[channel]; + + while (count-- > 0) { + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + } + } +} + +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.h b/tensorflow/lite/micro/kernels/decode_state_prune.h new file mode 100644 index 00000000000..de5ddd84249 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.h @@ -0,0 +1,69 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +struct DecodeStatePrune : public DecodeState { + DecodeStatePrune() = delete; + + DecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + private: + // Prune Decode Common Metadata constants + static constexpr size_t kDcmVersionOffset = 4; + + protected: + virtual ~DecodeStatePrune() = default; + + template + void DecompressToBuffer(void* buffer); + + void DecompressToBufferPerChannelInt8(void* buffer); + + protected: + const uint8_t* compressed_indices_ = nullptr; + size_t count_indices_ = 0; + size_t num_channels_ = 1; + size_t elements_per_channel_ = 0; // computed from use_alternate_axis_ + const void* value_table_ = nullptr; // original non-pruned values + int8_t* zero_points_ = nullptr; // quantized per-channel zero points + int8_t single_zero_point_ = 0; // single channel zero point + bool use_alternate_axis_ = false; // shape channel axis: + // false = first, true = last + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/kernels/decode_test.cc b/tensorflow/lite/micro/kernels/decode_test.cc new file mode 100644 index 00000000000..3408d71eadb --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_test.cc @@ -0,0 +1,796 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +struct TensorInDatum { + const void* const data; + const TfLiteIntArray& dims; +}; + +struct TensorOutDatum { + void* const data; + const TfLiteIntArray& dims; + const TfLiteType type; + const TfLiteFloatArray& scales; + const TfLiteIntArray& zero_points; + const int quantized_dimension; + + // initialized by CreatePerChannelQuantizedTensor + const TfLiteAffineQuantization affine_quantization; +}; + +template +struct AncillaryData { + AncillaryData() = delete; + AncillaryData(const uint8_t (&dcm)[tflite::DecodeState::kDcmSizeInBytes], + const T (&values)[N]) { + std::copy(std::begin(dcm), std::end(dcm), std::begin(dcm_)); + std::copy(std::begin(values), std::end(values), std::begin(value_table_)); + } + + private: + uint8_t dcm_[tflite::DecodeState::kDcmSizeInBytes]; + T value_table_[N > 0 ? N : 1]; // assure not zero length +}; + +// +// LUT test data +// +constexpr int kBitWidthLUT = 2; + +constexpr int8_t kAncillaryDataLUT0[] = {1, 2, 3, 4}; +constexpr int16_t kAncillaryDataLUT1[] = {5, 6, 7, 8}; + +constexpr uint8_t kDcmLUT0[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypeLUT, // type: LUT + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // LUT version: 1 + kBitWidthLUT, // Parameters: bit-width 2 + std::size(kAncillaryDataLUT0), // channel stride +}; + +constexpr uint8_t kDcmLUT1[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypeLUT, // type: LUT + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // LUT version: 1 + kBitWidthLUT, // Parameters: bit-width 2 + std::size(kAncillaryDataLUT1), // channel stride +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedLUT[] = {0x1B, 0xE4}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapeLUT[] = {3, 1, 2, 4}; +constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)}; + +constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1}; +constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5}; + +// +// Prune test data +// +constexpr int8_t kAncillaryDataPrune0[] = { + 1, 2, 3, 4, // 0 + 1, 2, 3, 4, // 1 + 1, 2, 3, 4, // 2 + 1, 2, 3, 4, // 3 + 1, 2, 3, 4 // 4 +}; +constexpr int16_t kAncillaryDataPrune1[] = { + 5, 6, 7, 8, // 0 + 5, 6, 7, 8, // 1 + 5, 6, 7, 8, // 2 + 5, 6, 7, 8, // 3 + 5, 6, 7, 8 // 4 +}; +constexpr float kAncillaryDataPrune2[] = { + 9.0f, 10.0f, 11.0f, 12.0f, // 0 + 9.0f, 10.0f, 11.0f, 12.0f, // 1 + 9.0f, 10.0f, 11.0f, 12.0f, // 2 + 9.0f, 10.0f, 11.0f, 12.0f, // 3 + 9.0f, 10.0f, 11.0f, 12.0f // 4 +}; +constexpr int8_t kAncillaryDataPrune3[] = { + 13, 14, 15, 16, // 0 + 13, 14, 15, 16, // 1 + 13, 14, 15, 16, // 2 + 13, 14, 15, 16, // 3 + 13, 14, 15, 16 // 4 +}; +constexpr int8_t kAncillaryDataPrune4[] = { + 17, 18, 19, 20, // 0 + 17, 18, 19, 20, // 1 + 17, 18, 19, 20, // 2 + 17, 18, 19, 20, // 3 + 17, 18, 19, 20 // 4 +}; + +constexpr uint8_t kDcmPrune[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypePrune, // type: Prune + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // Prune version: 1 +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedPrune[] = {0xA5, 0xA5, 0xA5, 0xA5, 0xA5}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapePrune[] = {3, 2, 5, 4}; +constexpr int kEncodedShapePrune[] = {1, sizeof(kEncodedPrune)}; + +// Quantization datum as TfLiteIntArray. +// Scales are modified by FloatArrayFromFloats. As globals they cannot be +// without causing a processor exception. +float kScalesPrune0[] = {2, 1.0f, 1.0f}; +constexpr int kZeroPointsPrune0[] = {2, -128, -64}; +float kScalesPrune1[] = {4, 1.0f, 1.0f, 1.0f, 1.0f}; +constexpr int kZeroPointsPrune1[] = {4, 0, 0, 0, 0}; +float kScalesPrune4[] = {4, 1.0f, 1.0f, 1.0f, 1.0f}; +constexpr int kZeroPointsPrune4[] = {4, -126, -62, -30, -14}; + +constexpr int8_t kExpectPrune0[] = { + 1, -128, 2, -128, -128, 3, -128, 4, 1, -128, // 0 + 2, -128, -128, 3, -128, 4, 1, -128, 2, -128, // 0 + -64, 3, -64, 4, 1, -64, 2, -64, -64, 3, // 1 + -64, 4, 1, -64, 2, -64, -64, 3, -64, 4 // 1 +}; +constexpr int16_t kExpectPrune1[] = { + 5, 0, 6, 0, // 0 + 0, 7, 0, 8, // 1 + 5, 0, 6, 0, // 2 + 0, 7, 0, 8, // 3 + 5, 0, 6, 0, // 4 + 0, 7, 0, 8, // 5 + 5, 0, 6, 0, // 6 + 0, 7, 0, 8, // 7 + 5, 0, 6, 0, // 8 + 0, 7, 0, 8 // 9 +}; +constexpr float kExpectPrune2[] = { + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 0 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 1 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 2 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 3 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f // 4 +}; +constexpr int8_t kExpectPrune3[] = { + 13, 0, 14, 0, 0, 15, 0, 16, // 0 + 13, 0, 14, 0, 0, 15, 0, 16, // 1 + 13, 0, 14, 0, 0, 15, 0, 16, // 2 + 13, 0, 14, 0, 0, 15, 0, 16, // 3 + 13, 0, 14, 0, 0, 15, 0, 16 // 4 +}; +constexpr int8_t kExpectPrune4[] = { + 17, -62, 18, -14, // 0 + -126, 19, -30, 20, // 1 + 17, -62, 18, -14, // 2 + -126, 19, -30, 20, // 3 + 17, -62, 18, -14, // 4 + -126, 19, -30, 20, // 5 + 17, -62, 18, -14, // 6 + -126, 19, -30, 20, // 7 + 17, -62, 18, -14, // 8 + -126, 19, -30, 20 // 9 +}; + +template +TfLiteStatus CheckOutput(const TfLiteTensor& output, + const void* const expected) { + const T* const expected_data = reinterpret_cast(expected); + const T* const output_data = tflite::GetTensorData(&output); + + constexpr float kTolerance = 1e-5; + const size_t kOutputCount = tflite::NumElements(&output); + for (size_t i = 0; i < kOutputCount; i++) { + TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance); + TF_LITE_MICRO_CHECK_FAIL(); + } + + return kTfLiteOk; +} + +template +TfLiteStatus ExecuteDecodeTest( + TfLiteTensor* tensors, const TFLMRegistration& registration, + const std::initializer_list& expected) { + int kInputArrayData[kNumInputs + 1] = {kNumInputs}; + for (size_t i = 0; i < kNumInputs; i++) { + kInputArrayData[i + 1] = i; + } + TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); + + int kOutputArrayData[kNumOutputs + 1] = {kNumOutputs}; + for (size_t i = 0; i < kNumOutputs; i++) { + kOutputArrayData[i + 1] = i + kNumInputs; + } + TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); + + micro::KernelRunner runner(registration, tensors, kNumInputs + kNumOutputs, + inputs_array, outputs_array, nullptr); + + if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) { + return kTfLiteError; + } + + const TfLiteTensor* const output_tensors = &tensors[kNumInputs]; + TfLiteStatus status = kTfLiteError; + for (size_t i = 0; i < kNumOutputs; i++) { + switch (output_tensors[i].type) { + case kTfLiteInt8: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + case kTfLiteInt16: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + case kTfLiteFloat32: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + default: + TF_LITE_MICRO_FAIL("unsupported tensor type in test"); + break; + } + } + + return status; +} + +template +void TestDecode(const std::initializer_list& encodes, + const std::initializer_list& ancillaries, + const std::initializer_list& outputs, + const std::initializer_list& expected, + const TFLMRegistration& registration, + const TfLiteStatus expected_status = kTfLiteOk) { + TfLiteTensor tensors[kNumInputs + kNumOutputs] = {}; + + for (size_t i = 0; i < kNumInputs; i += 2) { + const TensorInDatum& tid_encode = *encodes.begin()[i / 2]; + tensors[i] = CreateTensor(tid_encode.data, + const_cast(&tid_encode.dims), + false, kTfLiteUInt8); + // must be a const tensor + tensors[i].allocation_type = kTfLiteMmapRo; + const TensorInDatum& tid_ancillary = *ancillaries.begin()[i / 2]; + tensors[i + 1] = CreateTensor( + tid_ancillary.data, const_cast(&tid_ancillary.dims), + false, kTfLiteUInt8); + // must be a const tensor + tensors[i + 1].allocation_type = kTfLiteMmapRo; + } + for (size_t i = 0; i < kNumOutputs; i++) { + const TensorOutDatum& tod = *outputs.begin()[i]; + if (tod.scales.size == 0) { + tensors[i + kNumInputs] = CreateTensor( + tod.data, const_cast(&tod.dims), false, tod.type); + } else { + tensors[i + kNumInputs] = CreatePerChannelQuantizedTensor( + tod.data, const_cast(&tod.dims), + const_cast(&tod.scales), + const_cast(&tod.zero_points), + const_cast(&tod.affine_quantization), + tod.quantized_dimension, false, tod.type); + } + } + + TfLiteStatus s = ExecuteDecodeTest( + tensors, registration, expected); + TF_LITE_MICRO_EXPECT_EQ(s, expected_status); +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +using tflite::testing::AncillaryData; +using tflite::testing::kAncillaryDataLUT0; +using tflite::testing::kAncillaryDataLUT1; +using tflite::testing::kAncillaryDataPrune0; +using tflite::testing::kAncillaryDataPrune1; +using tflite::testing::kAncillaryDataPrune2; +using tflite::testing::kAncillaryDataPrune3; +using tflite::testing::kAncillaryDataPrune4; +using tflite::testing::kDcmLUT0; +using tflite::testing::kDcmLUT1; +using tflite::testing::kDcmPrune; +using tflite::testing::kEncodedLUT; +using tflite::testing::kEncodedPrune; +using tflite::testing::kEncodedShapeLUT; +using tflite::testing::kEncodedShapePrune; +using tflite::testing::kExpectLUT0; +using tflite::testing::kExpectLUT1; +using tflite::testing::kExpectPrune0; +using tflite::testing::kExpectPrune1; +using tflite::testing::kExpectPrune2; +using tflite::testing::kExpectPrune3; +using tflite::testing::kExpectPrune4; +using tflite::testing::kOutputShapeLUT; +using tflite::testing::kOutputShapePrune; +using tflite::testing::kScalesPrune0; +using tflite::testing::kScalesPrune1; +using tflite::testing::kScalesPrune4; +using tflite::testing::kZeroPointsPrune0; +using tflite::testing::kZeroPointsPrune1; +using tflite::testing::kZeroPointsPrune4; +using tflite::testing::TensorInDatum; +using tflite::testing::TensorOutDatum; + +TF_LITE_MICRO_TEST(DecodeSingleTensor) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectLUT0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmLUT0}, {kAncillaryDataLUT0}}; + + constexpr int kAncillaryShapeLUT[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeLUT); + static const TensorInDatum tid_encode = { + kEncodedLUT, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeLUT); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + output_data, + *output_dims, + kTfLiteInt8, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectLUT0}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodeTwoTensors) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data0[std::size(kExpectLUT0)] = {}; + alignas(16) int16_t output_data1[std::size(kExpectLUT1)] = {}; + alignas(16) const AncillaryData + kAncillaryData0 = {{kDcmLUT0}, {kAncillaryDataLUT0}}; + alignas(16) const AncillaryData + kAncillaryData1 = {{kDcmLUT1}, {kAncillaryDataLUT1}}; + + constexpr int kAncillaryShapeLUT0[] = {1, sizeof(kAncillaryData0)}; + constexpr int kAncillaryShapeLUT1[] = {1, sizeof(kAncillaryData1)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeLUT); + static const TensorInDatum tid_encode0 = { + kEncodedLUT, + *encoded_dims, + }; + static const TensorInDatum tid_encode1 = { + kEncodedLUT, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode0, &tid_encode1}; + + const TfLiteIntArray* const ancillary_dims0 = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT0); + static const TensorInDatum tid_ancillary0 = { + &kAncillaryData0, + *ancillary_dims0, + }; + const TfLiteIntArray* const ancillary_dims1 = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT1); + static const TensorInDatum tid_ancillary1 = { + &kAncillaryData1, + *ancillary_dims1, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary0, &tid_ancillary1}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeLUT); + constexpr float output_scales_data[] = {1, 1.0f}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {1, 0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod0 = { + output_data0, + *output_dims, + kTfLiteInt8, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static const TensorOutDatum tod1 = { + output_data1, + *output_dims, + kTfLiteInt16, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod0, &tod1}; + + const std::initializer_list expected = {kExpectLUT0, + kExpectLUT1}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneFloat) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) float output_data[std::size(kExpectPrune2)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune2}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr float kOutputScalesData[] = {0}; + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kOutputScalesData); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteFloat32, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune2}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune3)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune3}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr float kOutputScalesData[] = {0}; + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kOutputScalesData); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune3}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune0}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScalesPrune0); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune0); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune0}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedAltAxisInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune4)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune4}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScalesPrune4); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune4); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + *kOutputScales, + *kOutputZeroPoints, + (kOutputDims->size - 1), + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune4}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedAltAxisInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScalesPrune1); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune1); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + *kOutputScales, + *kOutputZeroPoints, + (kOutputDims->size - 1), + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + float kScales[] = {2, 1.0f, 1.0f}; + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScales); + const int kZeroPoints[] = {2, 0, -1}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPoints); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(), + kTfLiteError); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 3cc0a3a9432..2d5e54ba3fa 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -53,6 +53,7 @@ TFLMRegistration Register_CONCATENATION(); TFLMRegistration Register_CONV_2D(); TFLMRegistration Register_COS(); TFLMRegistration Register_CUMSUM(); +TFLMRegistration Register_DECODE(); TFLMRegistration Register_DEPTH_TO_SPACE(); TFLMRegistration Register_DEPTHWISE_CONV_2D(); TFLMRegistration Register_DEQUANTIZE(); diff --git a/tensorflow/lite/micro/kernels/xtensa/decode_state.cc b/tensorflow/lite/micro/kernels/xtensa/decode_state.cc new file mode 100644 index 00000000000..4feec409e15 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decode_state.cc @@ -0,0 +1,70 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state.h" + +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" +#include "tensorflow/lite/micro/micro_context.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h" +#endif // HIFI5 + +namespace tflite { + +DecodeState* DecodeState::CreateDecodeStateLUT( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStateLUT); +#else + constexpr size_t kBufferSize = sizeof(DecodeStateLUT); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStateLUT(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler); +#endif // HIFI5 + + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStatePrune); +#else + constexpr size_t kBufferSize = sizeof(DecodeStatePrune); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStatePrune(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); +#endif // HIFI5 + return dsp; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc new file mode 100644 index 00000000000..de5435f4b00 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc @@ -0,0 +1,609 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +void XtensaDecodeStateLUT::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = value_table_channel_stride_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + if (elements_per_channel_t_rem) { + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidth3_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = value_table_channel_stride_; + + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL); + ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL); + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 6); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 3); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidth2_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = value_table_channel_stride_; + + int elements_per_channel_t_by_5 = elements_per_channel_ >> 5; + int elements_per_channel_t_rem = elements_per_channel_ & 0x1F; + int elements_per_channel_t_rem_minus_16 = 0; + if (elements_per_channel_t_rem > 16) { + elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16; + } + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d0, d1, d2, d3, d4, d5; + ae_int8x8 q0, q1, q2, q3; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_5; j++) { + // AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits + // for decoding 32 weight values */ + + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 8); /* Loading 64 bits for decoding 32 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 2)); /* Loading 48 bits for decoding 16 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem_minus_16); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt8_Xtensa( + int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint8_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + uint32_t index_1, index_2; + uint32_t mask_bits = (1 << compressed_bit_width_) - 1; + + for (int i = 0; i < num_channels_t; i++) { + elements_per_channel_t = elements_per_channel_; + /* if output pointer is not 2 byte aligned */ + if ((unsigned int)p_out_tmp & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + elements_per_channel_t = elements_per_channel_t - 1; + } + for (int j = 0; j < (elements_per_channel_t >> 1); j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, 2 * bw); + index_1 = (index >> compressed_bit_width_) & mask_bits; + index_2 = (index)&mask_bits; + ae_int8x8 d_tmp1 = AE_L8_X((const ae_int8*)value_table, index_1); + ae_int8x8 d_tmp2 = AE_L8_X((const ae_int8*)value_table, index_2); + ae_int16x4 d_tmp = + AE_MOVINT16X4_FROMINT8X8(AE_SEL8X8I(d_tmp2, d_tmp1, 21)); + AE_S16_0_IP(d_tmp, (ae_int16*)p_out_tmp, 2); + } + if (elements_per_channel_t & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt16_Xtensa( + int16_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint16_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int16* __restrict p_out_tmp = (ae_int16*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint16_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt32_Xtensa( + int32_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint32_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int32* __restrict p_out_tmp = (ae_int32*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint32_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt64_Xtensa( + int64_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint64_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int64* __restrict p_out_tmp = (ae_int64*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint64_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBuffer(int8_t* buffer) { + if (compressed_bit_width_ == 4 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x01)) { + DecompressToBufferWidth4_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else if (compressed_bit_width_ == 3 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x07)) { + DecompressToBufferWidth3_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else if (compressed_bit_width_ == 2 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x03)) { + DecompressToBufferWidth2_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } +} + +TfLiteStatus XtensaDecodeStateLUT::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + TFLITE_DCHECK(compressed_bit_width_ <= kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(static_cast(buffer)); + break; + case kTfLiteFloat32: + DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt8: + DecompressToBuffer(static_cast(buffer)); + break; + case kTfLiteInt16: + DecompressToBufferWidthAnyInt16_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt32: + DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt64: + DecompressToBufferWidthAnyInt64_Xtensa(static_cast(buffer)); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h new file mode 100644 index 00000000000..b614887a4cc --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h @@ -0,0 +1,57 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" + +namespace tflite { + +struct XtensaDecodeStateLUT : public DecodeStateLUT { + XtensaDecodeStateLUT() = delete; + + XtensaDecodeStateLUT(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStateLUT(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStateLUT() = default; + + void DecompressToBuffer(int8_t* buffer); + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + void DecompressToBufferWidth3_Xtensa(int8_t* buffer); + void DecompressToBufferWidth2_Xtensa(int8_t* buffer); + + void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer); + void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer); + void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer); + void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc new file mode 100644 index 00000000000..c237ee3b44f --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc @@ -0,0 +1,443 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h" + +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus XtensaDecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBufferInt8_Xtensa(buffer); + break; + case kTfLiteFloat32: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + DecompressToBufferInt8_Xtensa(buffer); + break; + case kTfLiteInt16: + DecompressToBufferInt16_Xtensa(buffer); + break; + case kTfLiteInt32: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +void XtensaDecodeStatePrune::DecompressToBufferInt8_Xtensa(void* buffer) { + if (num_channels_ > 1) { + DecompressToBufferPerChannelInt8_Xtensa(buffer); + return; + } + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + int* __restrict p_mask32 = (int*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + const int count = count_indices_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 zero = single_zero_point_; + ae_int8x8 discarded; + + for (int i = 0; i < count >> 5; i++) { + // unpack elements + int mask = *p_mask32++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // merge into elements + AE_MOVT8X16_L(discarded, data0, zero, data0, mask); + AE_MOVT8X16_L(discarded, data1, zero, data1, mask >> 8); + AE_MOVT8X16_H(discarded, data2, zero, data2, mask); + AE_MOVT8X16_H(discarded, data3, zero, data3, mask >> 8); + + // move merged elements to output + AE_S8X8X2_IP(data0, data1, (ae_int8x16*)pCoeff, 16); + AE_S8X8X2_IP(data2, data3, (ae_int8x16*)pCoeff, 16); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + ae_valignx2 align2 = AE_ZALIGN128(); + int8_t* __restrict p_mask8 = reinterpret_cast(p_mask32); + + // unpack and merge into remaining elements + int mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + AE_MOVT8X16_L(discarded, data0, zero, data0, mask); + if (count_rem > 8) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 0); + data1 = AE_SHFL8X8(data1, shfl1); + AE_MOVT8X16_L(discarded, data1, zero, data1, mask); + } + if (count_rem > 16) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 0); + data2 = AE_SHFL8X8(data2, shfl2); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask); + } + if (count_rem > 24) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data3 = AE_SHFL8X8(data3, shfl3); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + } + + // move merged elements to output + if (count_rem <= 16) { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } +} + +void XtensaDecodeStatePrune::DecompressToBufferPerChannelInt8_Xtensa( + void* buffer) { + if (use_alternate_axis_) { + DecompressToBufferPerChannelAltAxisInt8_Xtensa(buffer); + return; + } + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + short* __restrict p_stream = (short*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_valignx2 align2 = AE_ZALIGN128(); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + const int count = elements_per_channel_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 discarded; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + for (size_t channel = 0; channel < num_channels_; channel++) { + ae_int8x8 zero = zero_points_[channel]; + uint32_t mask_low, mask_high; + + for (int i = 0; i < count >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // merge into elements + AE_MOVT8X16_H(discarded, data0, zero, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + + // move merged elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + AE_SA128POS_FP(align2, pCoeff); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + if (count_rem > 16) { + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LB_DB_IP((unsigned short*)p_stream, mask_low, count_rem - 16); + mask_low <<= 32 - count_rem; + } else { + AE_LB_DB_IP((unsigned short*)p_stream, mask_high, count_rem); + mask_high <<= 16 - count_rem; + mask_low = 0; + } + const int mask = (mask_high << 16) | mask_low; + + // unpack and merge into remaining elements + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + AE_MOVT8X16_H(discarded, data0, zero, data0, mask >> 8); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + data1 = AE_SHFL8X8(data1, shfl1); + AE_MOVT8X16_H(discarded, data1, zero, data1, mask); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + data2 = AE_SHFL8X8(data2, shfl2); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask >> 8); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data3 = AE_SHFL8X8(data3, shfl3); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + + // move merged elements to output + if (count_rem <= 16) { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } + } +} + +void XtensaDecodeStatePrune::DecompressToBufferPerChannelAltAxisInt8_Xtensa( + void* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + short* __restrict p_stream = (short*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_valignx2 align2 = AE_ZALIGN128(); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + int count = count_indices_ / num_channels_; + const int max_channels = num_channels_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 discarded; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + while (count-- > 0) { + ae_int8x8 zero0, zero1, zero2, zero3; + uint32_t mask_low, mask_high; + // p_zero is always 16 byte aligned due to copy during Setup(). + int8_t* __restrict p_zero = (int8_t*)zero_points_; + + for (int i = 0; i < max_channels >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // load values + AE_L8X8X2_IP(zero0, zero1, (ae_int8x16*)p_zero, 16); + AE_L8X8X2_IP(zero2, zero3, (ae_int8x16*)p_zero, 16); + + // merge into elements + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero2, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero3, data3, mask); + + // move merged elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + AE_SA128POS_FP(align2, pCoeff); + } + + const int count_rem = max_channels & 0x1F; + if (count_rem) { + if (count_rem > 16) { + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LB_DB_IP((unsigned short*)p_stream, mask_low, count_rem - 16); + mask_low <<= 32 - count_rem; + } else { + AE_LB_DB_IP((unsigned short*)p_stream, mask_high, count_rem); + mask_high <<= 16 - count_rem; + mask_low = 0; + } + const int mask = (mask_high << 16) | mask_low; + + // unpack remaining elements + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // load values, merge into elements and + // move merged elements to output + ae_valignx2 align_zero = AE_LA128_PP(p_zero); + if (count_rem <= 16) { + AE_LAV8X8X2_XP(zero0, zero1, align_zero, (ae_int8x16*)p_zero, + count_rem); + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_LAV8X8X2_XP(zero0, zero1, align_zero, (ae_int8x16*)p_zero, 16); + AE_LAV8X8X2_XP(zero2, zero3, align_zero, (ae_int8x16*)p_zero, + count_rem & 0xF); + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero2, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero3, data3, mask); + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } + } +} + +void XtensaDecodeStatePrune::DecompressToBufferInt16_Xtensa(void* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int16x8* p_weights = (ae_int16x8*)value_table_; + int* __restrict p_mask32 = (int*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_int16x4 data0, data1, data2, data3; + ae_int16x4 data4, data5, data6, data7; + ae_int16x4 shfl0, shfl1, shfl2, shfl3; + ae_int16x4 shfl4, shfl5, shfl6, shfl7; + const int count = count_indices_; + int16_t* __restrict pCoeff = static_cast(buffer); + + for (int i = 0; i < count >> 5; i++) { + // unpack elements and merge 0 (zero) elements + int mask = *p_mask32++; + AE_LAVUNSQZ16X4_XP(data0, shfl0, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data1, shfl1, align, p_weights, mask, 0); + AE_LAVUNSQZ16X4_XP(data2, shfl2, align, p_weights, mask, 3); + AE_LAVUNSQZ16X4_XP(data3, shfl3, align, p_weights, mask, 2); + AE_LAVUNSQZ16X4_XP(data4, shfl4, align, p_weights, mask, 5); + AE_LAVUNSQZ16X4_XP(data5, shfl5, align, p_weights, mask, 4); + AE_LAVUNSQZ16X4_XP(data6, shfl6, align, p_weights, mask, 7); + AE_LAVUNSQZ16X4_XP(data7, shfl7, align, p_weights, mask, 6); + data0 = AE_SHFL16X4(data0, shfl0); + data1 = AE_SHFL16X4(data1, shfl1); + data2 = AE_SHFL16X4(data2, shfl2); + data3 = AE_SHFL16X4(data3, shfl3); + data4 = AE_SHFL16X4(data4, shfl4); + data5 = AE_SHFL16X4(data5, shfl5); + data6 = AE_SHFL16X4(data6, shfl6); + data7 = AE_SHFL16X4(data7, shfl7); + + // move merged elements to output + AE_S16X4X2_IP(data0, data1, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data2, data3, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data4, data5, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data6, data7, (ae_int16x8*)pCoeff, 16); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + ae_valignx2 align2 = AE_ZALIGN128(); + int8_t* __restrict p_mask8 = reinterpret_cast(p_mask32); + + // unpack and merge into remaining elements + int mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data0, shfl0, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data1, shfl1, align, p_weights, mask, 0); + data0 = AE_SHFL16X4(data0, shfl0); + data1 = AE_SHFL16X4(data1, shfl1); + if (count_rem > 8) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data3, shfl3, align, p_weights, mask, 0); + data2 = AE_SHFL16X4(data2, shfl2); + data3 = AE_SHFL16X4(data3, shfl3); + } + if (count_rem > 16) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data4, shfl4, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data5, shfl5, align, p_weights, mask, 0); + data4 = AE_SHFL16X4(data4, shfl4); + data5 = AE_SHFL16X4(data5, shfl5); + } + if (count_rem > 24) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data6, shfl6, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data7, shfl7, align, p_weights, mask, 0); + data6 = AE_SHFL16X4(data6, shfl6); + data7 = AE_SHFL16X4(data7, shfl7); + } + + // move merged elements to output + if (count_rem <= 8) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, + count_rem << 1); + } else if (count_rem <= 16) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, + (count_rem - 8) << 1); + } else if (count_rem <= 24) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data4, data5, align2, (ae_int16x8*)pCoeff, + (count_rem - 16) << 1); + } else { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data4, data5, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data6, data7, align2, (ae_int16x8*)pCoeff, + (count_rem - 24) << 1); + } + AE_SA128POS_FP(align2, pCoeff); + } +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h new file mode 100644 index 00000000000..fb6935f3383 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h @@ -0,0 +1,51 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +namespace tflite { + +struct XtensaDecodeStatePrune : public DecodeStatePrune { + XtensaDecodeStatePrune() = delete; + + XtensaDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStatePrune(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStatePrune() = default; + + void DecompressToBufferInt8_Xtensa(void* buffer); + void DecompressToBufferPerChannelInt8_Xtensa(void* buffer); + void DecompressToBufferPerChannelAltAxisInt8_Xtensa(void* buffer); + void DecompressToBufferInt16_Xtensa(void* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 07f84692e44..e19aef80ccb 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -206,6 +206,11 @@ class MicroMutableOpResolver : public MicroOpResolver { ParseCumsum); } + TfLiteStatus AddDecode() { + const TFLMRegistration& registration = tflite::Register_DECODE(); + return AddCustom("TFLM_DECODE", ®istration); + } + TfLiteStatus AddDelay() { // TODO(b/286250473): change back name to "Delay" and remove namespace return AddCustom("SignalDelay", tflite::tflm_signal::Register_DELAY()); diff --git a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h index 344bcf3845c..a89a2806e92 100644 --- a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h +++ b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h @@ -23,7 +23,7 @@ limitations under the License. namespace tflite { -using TflmOpResolver = MicroMutableOpResolver<113>; +using TflmOpResolver = MicroMutableOpResolver<115>; inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddAbs()); @@ -45,6 +45,7 @@ inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddConv2D()); TF_LITE_ENSURE_STATUS(op_resolver.AddCos()); TF_LITE_ENSURE_STATUS(op_resolver.AddCumSum()); + TF_LITE_ENSURE_STATUS(op_resolver.AddDecode()); TF_LITE_ENSURE_STATUS(op_resolver.AddDelay()); TF_LITE_ENSURE_STATUS(op_resolver.AddDepthToSpace()); TF_LITE_ENSURE_STATUS(op_resolver.AddDepthwiseConv2D()); diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index a21765b3454..c3e1bbab3bf 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -386,6 +386,10 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_lut.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space.cc \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc index b05a0670248..4a8c1591445 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -124,6 +124,13 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa) MICROLITE_CC_KERNEL_SRCS += \ $(TENSORFLOW_ROOT)tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/kernels/kernel_util.cc + + # Additional kernel sources for DECODE operator support + ifeq ($(TARGET_ARCH), $(filter $(TARGET_ARCH), hifi5)) + MICROLITE_CC_KERNEL_SRCS += \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc + endif endif # override KERNEL_OPTIMIZATION_LEVEL to enable higher performance @@ -131,3 +138,11 @@ endif $(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@