/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Contains classes that can execute different models/parts of a model. #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ #include <memory> #include "utils/base/logging.h" #include "utils/tensor-view.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { std::unique_ptr<tflite::OpResolver> BuildOpResolver(); std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( const tflite::Model*); std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( const flatbuffers::Vector<uint8_t>*); // Executor for the text selection prediction and classification models. class TfLiteModelExecutor { public: static std::unique_ptr<TfLiteModelExecutor> FromModelSpec( const tflite::Model* model_spec) { auto model = TfLiteModelFromModelSpec(model_spec); if (!model) { return nullptr; } return std::unique_ptr<TfLiteModelExecutor>( new TfLiteModelExecutor(std::move(model))); } static std::unique_ptr<TfLiteModelExecutor> FromBuffer( const flatbuffers::Vector<uint8_t>* model_spec_buffer) { auto model = TfLiteModelFromBuffer(model_spec_buffer); if (!model) { return nullptr; } return std::unique_ptr<TfLiteModelExecutor>( new TfLiteModelExecutor(std::move(model))); } // Creates an Interpreter for the model that serves as a scratch-pad for the // inference. The Interpreter is NOT thread-safe. std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; template <typename T> void SetInput(const int input_index, const TensorView<T>& input_data, tflite::Interpreter* interpreter) const { input_data.copy_to(interpreter->typed_input_tensor<T>(input_index), input_data.size()); } template <typename T> void SetInput(const int input_index, const std::vector<T>& input_data, tflite::Interpreter* interpreter) const { std::copy(input_data.begin(), input_data.end(), interpreter->typed_input_tensor<T>(input_index)); } template <typename T> void SetInput(const int input_index, const T input_value, tflite::Interpreter* interpreter) const { TfLiteTensor* input_tensor = interpreter->tensor(interpreter->inputs()[input_index]); switch (input_tensor->type) { case kTfLiteFloat32: *(input_tensor->data.f) = input_value; break; case kTfLiteInt32: *(input_tensor->data.i32) = input_value; break; case kTfLiteUInt8: *(input_tensor->data.uint8) = input_value; break; case kTfLiteInt64: *(input_tensor->data.i64) = input_value; break; case kTfLiteBool: *(input_tensor->data.b) = input_value; break; case kTfLiteInt16: *(input_tensor->data.i16) = input_value; break; case kTfLiteInt8: *(input_tensor->data.int8) = input_value; break; default: break; } } template <typename T> TensorView<T> OutputView(const int output_index, const tflite::Interpreter* interpreter) const { const TfLiteTensor* output_tensor = interpreter->tensor(interpreter->outputs()[output_index]); return TensorView<T>(interpreter->typed_output_tensor<T>(output_index), std::vector<int>(output_tensor->dims->data, output_tensor->dims->data + output_tensor->dims->size)); } template <typename T> std::vector<T> Output(const int output_index, const tflite::Interpreter* interpreter) const { TensorView<T> output_view = OutputView<T>(output_index, interpreter); return std::vector<T>(output_view.data(), output_view.data() + output_view.size()); } protected: explicit TfLiteModelExecutor( std::unique_ptr<const tflite::FlatBufferModel> model); std::unique_ptr<const tflite::FlatBufferModel> model_; std::unique_ptr<tflite::OpResolver> resolver_; }; template <> void TfLiteModelExecutor::SetInput(const int input_index, const std::vector<std::string>& input_data, tflite::Interpreter* interpreter) const; template <> std::vector<tflite::StringRef> TfLiteModelExecutor::Output( const int output_index, const tflite::Interpreter* interpreter) const; template <> std::vector<std::string> TfLiteModelExecutor::Output( const int output_index, const tflite::Interpreter* interpreter) const; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_