/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Contains classes that can execute different models/parts of a model.
#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
#define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
#include <memory>
#include "utils/base/logging.h"
#include "utils/tensor-view.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
std::unique_ptr<tflite::OpResolver> BuildOpResolver();
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
const tflite::Model*);
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
const flatbuffers::Vector<uint8_t>*);
// Executor for the text selection prediction and classification models.
class TfLiteModelExecutor {
public:
static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
const tflite::Model* model_spec) {
auto model = TfLiteModelFromModelSpec(model_spec);
if (!model) {
return nullptr;
}
return std::unique_ptr<TfLiteModelExecutor>(
new TfLiteModelExecutor(std::move(model)));
}
static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
auto model = TfLiteModelFromBuffer(model_spec_buffer);
if (!model) {
return nullptr;
}
return std::unique_ptr<TfLiteModelExecutor>(
new TfLiteModelExecutor(std::move(model)));
}
// Creates an Interpreter for the model that serves as a scratch-pad for the
// inference. The Interpreter is NOT thread-safe.
std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
template <typename T>
void SetInput(const int input_index, const TensorView<T>& input_data,
tflite::Interpreter* interpreter) const {
input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
input_data.size());
}
template <typename T>
void SetInput(const int input_index, const std::vector<T>& input_data,
tflite::Interpreter* interpreter) const {
std::copy(input_data.begin(), input_data.end(),
interpreter->typed_input_tensor<T>(input_index));
}
template <typename T>
void SetInput(const int input_index, const T input_value,
tflite::Interpreter* interpreter) const {
TfLiteTensor* input_tensor =
interpreter->tensor(interpreter->inputs()[input_index]);
switch (input_tensor->type) {
case kTfLiteFloat32:
*(input_tensor->data.f) = input_value;
break;
case kTfLiteInt32:
*(input_tensor->data.i32) = input_value;
break;
case kTfLiteUInt8:
*(input_tensor->data.uint8) = input_value;
break;
case kTfLiteInt64:
*(input_tensor->data.i64) = input_value;
break;
case kTfLiteBool:
*(input_tensor->data.b) = input_value;
break;
case kTfLiteInt16:
*(input_tensor->data.i16) = input_value;
break;
case kTfLiteInt8:
*(input_tensor->data.int8) = input_value;
break;
default:
break;
}
}
template <typename T>
TensorView<T> OutputView(const int output_index,
const tflite::Interpreter* interpreter) const {
const TfLiteTensor* output_tensor =
interpreter->tensor(interpreter->outputs()[output_index]);
return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
std::vector<int>(output_tensor->dims->data,
output_tensor->dims->data +
output_tensor->dims->size));
}
template <typename T>
std::vector<T> Output(const int output_index,
const tflite::Interpreter* interpreter) const {
TensorView<T> output_view = OutputView<T>(output_index, interpreter);
return std::vector<T>(output_view.data(),
output_view.data() + output_view.size());
}
protected:
explicit TfLiteModelExecutor(
std::unique_ptr<const tflite::FlatBufferModel> model);
std::unique_ptr<const tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::OpResolver> resolver_;
};
template <>
void TfLiteModelExecutor::SetInput(const int input_index,
const std::vector<std::string>& input_data,
tflite::Interpreter* interpreter) const;
template <>
std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
const int output_index, const tflite::Interpreter* interpreter) const;
template <>
std::vector<std::string> TfLiteModelExecutor::Output(
const int output_index, const tflite::Interpreter* interpreter) const;
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_