/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Contains classes that can execute different models/parts of a model.

#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
#define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_

#include <memory>

#include "utils/base/logging.h"
#include "utils/tensor-view.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/string_util.h"

namespace libtextclassifier3 {

std::unique_ptr<tflite::OpResolver> BuildOpResolver();
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
    const tflite::Model*);
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
    const flatbuffers::Vector<uint8_t>*);

// Executor for the text selection prediction and classification models.
class TfLiteModelExecutor {
 public:
  static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
      const tflite::Model* model_spec) {
    auto model = TfLiteModelFromModelSpec(model_spec);
    if (!model) {
      return nullptr;
    }
    return std::unique_ptr<TfLiteModelExecutor>(
        new TfLiteModelExecutor(std::move(model)));
  }

  static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
      const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
    auto model = TfLiteModelFromBuffer(model_spec_buffer);
    if (!model) {
      return nullptr;
    }
    return std::unique_ptr<TfLiteModelExecutor>(
        new TfLiteModelExecutor(std::move(model)));
  }

  // Creates an Interpreter for the model that serves as a scratch-pad for the
  // inference. The Interpreter is NOT thread-safe.
  std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;

  template <typename T>
  void SetInput(const int input_index, const TensorView<T>& input_data,
                tflite::Interpreter* interpreter) const {
    input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
                       input_data.size());
  }

  template <typename T>
  void SetInput(const int input_index, const std::vector<T>& input_data,
                tflite::Interpreter* interpreter) const {
    std::copy(input_data.begin(), input_data.end(),
              interpreter->typed_input_tensor<T>(input_index));
  }

  template <typename T>
  void SetInput(const int input_index, const T input_value,
                tflite::Interpreter* interpreter) const {
    TfLiteTensor* input_tensor =
        interpreter->tensor(interpreter->inputs()[input_index]);
    switch (input_tensor->type) {
      case kTfLiteFloat32:
        *(input_tensor->data.f) = input_value;
        break;
      case kTfLiteInt32:
        *(input_tensor->data.i32) = input_value;
        break;
      case kTfLiteUInt8:
        *(input_tensor->data.uint8) = input_value;
        break;
      case kTfLiteInt64:
        *(input_tensor->data.i64) = input_value;
        break;
      case kTfLiteBool:
        *(input_tensor->data.b) = input_value;
        break;
      case kTfLiteInt16:
        *(input_tensor->data.i16) = input_value;
        break;
      case kTfLiteInt8:
        *(input_tensor->data.int8) = input_value;
        break;
      default:
        break;
    }
  }

  template <typename T>
  TensorView<T> OutputView(const int output_index,
                           const tflite::Interpreter* interpreter) const {
    const TfLiteTensor* output_tensor =
        interpreter->tensor(interpreter->outputs()[output_index]);
    return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
                         std::vector<int>(output_tensor->dims->data,
                                          output_tensor->dims->data +
                                              output_tensor->dims->size));
  }

  template <typename T>
  std::vector<T> Output(const int output_index,
                        const tflite::Interpreter* interpreter) const {
    TensorView<T> output_view = OutputView<T>(output_index, interpreter);
    return std::vector<T>(output_view.data(),
                          output_view.data() + output_view.size());
  }

 protected:
  explicit TfLiteModelExecutor(
      std::unique_ptr<const tflite::FlatBufferModel> model);

  std::unique_ptr<const tflite::FlatBufferModel> model_;
  std::unique_ptr<tflite::OpResolver> resolver_;
};

template <>
void TfLiteModelExecutor::SetInput(const int input_index,
                                   const std::vector<std::string>& input_data,
                                   tflite::Interpreter* interpreter) const;

template <>
std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
    const int output_index, const tflite::Interpreter* interpreter) const;

template <>
std::vector<std::string> TfLiteModelExecutor::Output(
    const int output_index, const tflite::Interpreter* interpreter) const;

}  // namespace libtextclassifier3

#endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_