/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_

#include <algorithm>
#include <memory>
#include <string>
#include <utility>

#include "lang_id/common/embedding-network-params.h"
#include "lang_id/common/flatbuffers/embedding-network_generated.h"
#include "lang_id/common/lite_base/float16.h"
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_strings/stringpiece.h"

namespace libtextclassifier3 {
namespace mobile {

// EmbeddingNetworkParams implementation backed by a flatbuffer.
//
// For info on our flatbuffer schema, see embedding-network.fbs.
class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams {
 public:
  // Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the
  // flatbuffer from |bytes|.
  //
  // IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime
  // of this EmbeddingNetworkParamsFromFlatbuffer instance.  To avoid overhead,
  // this constructor does not copy |bytes|.
  //
  // IMPORTANT #2: immediately after this constructor returns, we suggest you
  // call is_valid() on the newly-constructed object and do not call any other
  // method if the answer is negative (false).
  explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes);

  bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override {
    // This class does not provide access to the overall TaskContext.  It
    // provides only parameters for the Neurosis neural network.
    SAFTM_LOG(DFATAL) << "Not supported";
    return false;
  }

  bool is_valid() const override { return valid_; }

  int embeddings_size() const override { return SafeGetNumInputChunks(); }

  int embeddings_num_rows(int i) const override {
    const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
    return SafeGetNumRows(matrix);
  }

  int embeddings_num_cols(int i) const override {
    const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
    return SafeGetNumCols(matrix);
  }

  const void *embeddings_weights(int i) const override {
    const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
    return SafeGetValuesOfMatrix(matrix);
  }

  QuantizationType embeddings_quant_type(int i) const override {
    const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
    return SafeGetQuantizationType(matrix);
  }

  const float16 *embeddings_quant_scales(int i) const override {
    const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
    return SafeGetScales(matrix);
  }

  int hidden_size() const override {
    // -1 because last layer is always the softmax layer.
    return std::max(SafeGetNumLayers() - 1, 0);
  }

  int hidden_num_rows(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
    return SafeGetNumRows(weights);
  }

  int hidden_num_cols(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
    return SafeGetNumCols(weights);
  }

  QuantizationType hidden_weights_quant_type(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
    return SafeGetQuantizationType(weights);
  }

  const void *hidden_weights(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
    return SafeGetValuesOfMatrix(weights);
  }

  int hidden_bias_size() const override { return hidden_size(); }

  int hidden_bias_num_rows(int i) const override {
    const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
    return SafeGetNumRows(bias);
  }

  int hidden_bias_num_cols(int i) const override {
    const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
    return SafeGetNumCols(bias);
  }

  const void *hidden_bias_weights(int i) const override {
    const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
    return SafeGetValues(bias);
  }

  int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; }

  int softmax_num_rows(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
    return SafeGetNumRows(weights);
  }

  int softmax_num_cols(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
    return SafeGetNumCols(weights);
  }

  QuantizationType softmax_weights_quant_type(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
    return SafeGetQuantizationType(weights);
  }

  const void *softmax_weights(int i) const override {
    const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
    return SafeGetValuesOfMatrix(weights);
  }

  int softmax_bias_size() const override { return softmax_size(); }

  int softmax_bias_num_rows(int i) const override {
    const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
    return SafeGetNumRows(bias);
  }

  int softmax_bias_num_cols(int i) const override {
    const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
    return SafeGetNumCols(bias);
  }

  const void *softmax_bias_weights(int i) const override {
    const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
    return SafeGetValues(bias);
  }

  int embedding_num_features_size() const override {
    return SafeGetNumInputChunks();
  }

  int embedding_num_features(int i) const override {
    if (!InRangeIndex(i, embedding_num_features_size(),
                      "embedding num features")) {
      return 0;
    }
    const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
    if (input_chunk == nullptr) {
      return 0;
    }
    return input_chunk->num_features();
  }

  bool has_is_precomputed() const override { return false; }
  bool is_precomputed() const override { return false; }

 private:
  // Returns true if and only if index is in [0, limit).  info should be a
  // pointer to a zero-terminated array of chars (ideally a literal string,
  // e.g. "layer") indicating what the index refers to; info is used to make log
  // messages more informative.
  static bool InRangeIndex(int index, int limit, const char *info);

  // Returns network_->input_chunks()->size(), if all dereferences are safe
  // (i.e., no nullptr); otherwise, returns 0.
  int SafeGetNumInputChunks() const;

  // Returns network_->input_chunks()->Get(i), if all dereferences are safe
  // (i.e., no nullptr) otherwise, returns nullptr.
  const saft_fbs::InputChunk *SafeGetInputChunk(int i) const;

  // Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences
  // are safe (i.e., no nullptr); otherwise, returns nullptr.
  const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const;

  // Returns network_->layers()->size(), if all dereferences are safe (i.e., no
  // nullptr); otherwise, returns 0.
  int SafeGetNumLayers() const;

  // Returns network_->layers()->Get(i), if all dereferences are safe
  // (i.e., no nullptr); otherwise, returns nullptr.
  const saft_fbs::NeuralLayer *SafeGetLayer(int i) const;

  // Returns network_->layers()->Get(i)->weights(), if all dereferences are safe
  // (i.e., no nullptr); otherwise, returns nullptr.
  const saft_fbs::Matrix *SafeGetLayerWeights(int i) const;

  // Returns network_->layers()->Get(i)->bias(), if all dereferences are safe
  // (i.e., no nullptr); otherwise, returns nullptr.
  const saft_fbs::Matrix *SafeGetLayerBias(int i) const;

  static int SafeGetNumRows(const saft_fbs::Matrix *matrix) {
    return (matrix == nullptr) ? 0 : matrix->rows();
  }

  static int SafeGetNumCols(const saft_fbs::Matrix *matrix) {
    return (matrix == nullptr) ? 0 : matrix->cols();
  }

  // Returns matrix->values()->data() if all dereferences are safe (i.e., no
  // nullptr); otherwise, returns nullptr.
  static const float *SafeGetValues(const saft_fbs::Matrix *matrix);

  // Returns matrix->quantized_values()->data() if all dereferences are safe
  // (i.e., no nullptr); otherwise, returns nullptr.
  static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix);

  // Returns matrix->scales()->data() if all dereferences are safe (i.e., no
  // nullptr); otherwise, returns nullptr.
  static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix);

  // Returns network_->layers()->Get(last_index) with last_index =
  // SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and
  // there exists at least one layer; otherwise, returns nullptr.
  const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const;

  const saft_fbs::Matrix *SafeGetSoftmaxWeights() const {
    const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
    return (layer == nullptr) ? nullptr : layer->weights();
  }

  const saft_fbs::Matrix *SafeGetSoftmaxBias() const {
    const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
    return (layer == nullptr) ? nullptr : layer->bias();
  }

  // Returns the quantization type for |matrix|.  Returns NONE in case of
  // problems (e.g., matrix is nullptr or unknown quantization type).
  QuantizationType SafeGetQuantizationType(
      const saft_fbs::Matrix *matrix) const;

  // Returns a pointer to the values (float, uint8, or float16, depending on
  // quantization) from |matrix|, in row-major order.  Returns nullptr in case
  // of a problem.
  const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const;

  // Performs some validity checks.  E.g., check that dimensions of the network
  // layers match.  Also checks that all pointers we return are inside the
  // |bytes| passed to the constructor, such that client that reads from those
  // pointers will not run into troubles.
  bool ValidityChecking(StringPiece bytes) const;

  // True if these params are valid.  May be false if the original proto was
  // corrupted.  We prefer to set this to false to CHECK-failing.
  bool valid_ = false;

  // EmbeddingNetwork flatbuffer from the bytes passed as parameter to the
  // constructor; see constructor doc.
  const saft_fbs::EmbeddingNetwork *network_ = nullptr;
};

}  // namespace mobile
}  // namespace nlp_saft

#endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_