C++程序  |  326行  |  11 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_

#include <algorithm>
#include <string>

#include "common/float16.h"
#include "common/task-context.h"
#include "common/task-spec.pb.h"
#include "util/base/logging.h"

namespace libtextclassifier {
namespace nlp_core {

enum class QuantizationType { NONE = 0, UINT8 };

// API for accessing parameters for a feed-forward neural network with
// embeddings.
//
// Note: this API is closely related to embedding-network.proto.  The reason we
// have a separate API is that the proto may not be the only way of packaging
// these parameters.
class EmbeddingNetworkParams {
 public:
  virtual ~EmbeddingNetworkParams() {}

  // **** High-level API.

  // Simple representation of a matrix.  This small struct that doesn't own any
  // resource intentionally supports copy / assign, to simplify our APIs.
  struct Matrix {
    // Number of rows.
    int rows;

    // Number of columns.
    int cols;

    QuantizationType quant_type;

    // Pointer to matrix elements, in row-major order
    // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
    const void *elements;

    // Quantization scales: one scale for each row.
    const float16 *quant_scales;
  };

  // Returns number of embedding spaces.
  int GetNumEmbeddingSpaces() const {
    if (embeddings_size() != embedding_num_features_size()) {
      TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size()
                    << " != " << embedding_num_features_size();
    }
    return std::max(0,
                    std::min(embeddings_size(), embedding_num_features_size()));
  }

  // Returns embedding matrix for the i-th embedding space.
  //
  // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
  // otherwise.
  Matrix GetEmbeddingMatrix(int i) const {
    TC_DCHECK(InRange(i, embeddings_size()));
    Matrix matrix;
    matrix.rows = embeddings_num_rows(i);
    matrix.cols = embeddings_num_cols(i);
    matrix.elements = embeddings_weights(i);
    matrix.quant_type = embeddings_quant_type(i);
    matrix.quant_scales = embeddings_quant_scales(i);
    return matrix;
  }

  // Returns number of features in i-th embedding space.
  //
  // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
  // otherwise.
  int GetNumFeaturesInEmbeddingSpace(int i) const {
    TC_DCHECK(InRange(i, embedding_num_features_size()));
    return std::max(0, embedding_num_features(i));
  }

  // Returns number of hidden layers in the neural network.  Each such layer has
  // weight matrix and a bias vector (a matrix with one column).
  int GetNumHiddenLayers() const {
    if (hidden_size() != hidden_bias_size()) {
      TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size()
                    << " != " << hidden_bias_size();
    }
    return std::max(0, std::min(hidden_size(), hidden_bias_size()));
  }

  // Returns weight matrix for i-th hidden layer.
  //
  // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
  // otherwise.
  Matrix GetHiddenLayerMatrix(int i) const {
    TC_DCHECK(InRange(i, hidden_size()));
    Matrix matrix;
    matrix.rows = hidden_num_rows(i);
    matrix.cols = hidden_num_cols(i);

    // Quantization not supported here.
    matrix.quant_type = QuantizationType::NONE;
    matrix.elements = hidden_weights(i);
    return matrix;
  }

  // Returns bias matrix for i-th hidden layer.  Technically a Matrix, but we
  // expect it to be a vector (i.e., num cols is 1).
  //
  // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
  // otherwise.
  Matrix GetHiddenLayerBias(int i) const {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    Matrix matrix;
    matrix.rows = hidden_bias_num_rows(i);
    matrix.cols = hidden_bias_num_cols(i);

    // Quantization not supported here.
    matrix.quant_type = QuantizationType::NONE;
    matrix.elements = hidden_bias_weights(i);
    return matrix;
  }

  // Returns true if a softmax layer exists.
  bool HasSoftmaxLayer() const {
    if (softmax_size() != softmax_bias_size()) {
      TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size()
                    << " != " << softmax_bias_size();
    }
    return (softmax_size() == 1) && (softmax_bias_size() == 1);
  }

  // Returns weight matrix for the softmax layer.
  //
  // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
  // behavior otherwise.
  Matrix GetSoftmaxMatrix() const {
    TC_DCHECK(softmax_size() == 1);
    Matrix matrix;
    matrix.rows = softmax_num_rows(0);
    matrix.cols = softmax_num_cols(0);

    // Quantization not supported here.
    matrix.quant_type = QuantizationType::NONE;
    matrix.elements = softmax_weights(0);
    return matrix;
  }

  // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
  // to be a row/column vector (i.e., num cols is 1).
  //
  // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
  // behavior otherwise.
  Matrix GetSoftmaxBias() const {
    TC_DCHECK(softmax_bias_size() == 1);
    Matrix matrix;
    matrix.rows = softmax_bias_num_rows(0);
    matrix.cols = softmax_bias_num_cols(0);

    // Quantization not supported here.
    matrix.quant_type = QuantizationType::NONE;
    matrix.elements = softmax_bias_weights(0);
    return matrix;
  }

  // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
  // true on success, false on error.
  virtual bool UpdateTaskContextParameters(TaskContext *task_context) {
    const TaskSpec *task_spec = GetTaskSpec();
    if (task_spec == nullptr) {
      TC_LOG(ERROR) << "Unable to get TaskSpec";
      return false;
    }
    for (const TaskSpec::Parameter &parameter : task_spec->parameter()) {
      task_context->SetParameter(parameter.name(), parameter.value());
    }
    return true;
  }

  // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related
  // parameters.  Returns nullptr in case of problems.  Ownership with the
  // returned pointer is *not* transfered to the caller.
  virtual const TaskSpec *GetTaskSpec() {
    TC_LOG(ERROR) << "Not implemented";
    return nullptr;
  }

 protected:
  // **** Low-level API.
  //
  // * Most low-level API methods are documented by giving an equivalent
  //   function call on proto, the original proto (of type
  //   EmbeddingNetworkProto) which was used to generate the C++ code.
  //
  // * To simplify our generation code, optional proto fields of message type
  //   are treated as repeated fields with 0 or 1 instances.  As such, we have
  //   *_size() methods for such optional fields: they return 0 or 1.
  //
  // * "transpose(M)" denotes the transpose of a matrix M.
  //
  // * Behavior is undefined when trying to retrieve a piece of data that does
  //   not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2.

  // ** Access methods for repeated MatrixParams embeddings.
  //
  // Returns proto.embeddings_size().
  virtual int embeddings_size() const = 0;

  // Returns number of rows of transpose(proto.embeddings(i)).
  virtual int embeddings_num_rows(int i) const = 0;

  // Returns number of columns of transpose(proto.embeddings(i)).
  virtual int embeddings_num_cols(int i) const = 0;

  // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
  // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
  // for quantized embeddings, this returns a pointer to uint8.
  virtual const void *embeddings_weights(int i) const = 0;

  virtual QuantizationType embeddings_quant_type(int i) const {
    return QuantizationType::NONE;
  }

  virtual const float16 *embeddings_quant_scales(int i) const {
    return nullptr;
  }

  // ** Access methods for repeated MatrixParams hidden.
  //
  // Returns embedding_network_proto.hidden_size().
  virtual int hidden_size() const = 0;

  // Returns embedding_network_proto.hidden(i).rows().
  virtual int hidden_num_rows(int i) const = 0;

  // Returns embedding_network_proto.hidden(i).rows().
  virtual int hidden_num_cols(int i) const = 0;

  // Returns pointer to beginning of array of floats with all values from
  // embedding_network_proto.hidden(i).
  virtual const void *hidden_weights(int i) const = 0;

  // ** Access methods for repeated MatrixParams hidden_bias.
  //
  // Returns proto.hidden_bias_size().
  virtual int hidden_bias_size() const = 0;

  // Returns number of rows of proto.hidden_bias(i).
  virtual int hidden_bias_num_rows(int i) const = 0;

  // Returns number of columns of proto.hidden_bias(i).
  virtual int hidden_bias_num_cols(int i) const = 0;

  // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
  virtual const void *hidden_bias_weights(int i) const = 0;

  // ** Access methods for optional MatrixParams softmax.
  //
  // Returns 1 if proto has optional field softmax, 0 otherwise.
  virtual int softmax_size() const = 0;

  // Returns number of rows of transpose(proto.softmax()).
  virtual int softmax_num_rows(int i) const = 0;

  // Returns number of columns of transpose(proto.softmax()).
  virtual int softmax_num_cols(int i) const = 0;

  // Returns pointer to elements of transpose(proto.softmax()), in row-major
  // order.
  virtual const void *softmax_weights(int i) const = 0;

  // ** Access methods for optional MatrixParams softmax_bias.
  //
  // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
  virtual int softmax_bias_size() const = 0;

  // Returns number of rows of proto.softmax_bias().
  virtual int softmax_bias_num_rows(int i) const = 0;

  // Returns number of columns of proto.softmax_bias().
  virtual int softmax_bias_num_cols(int i) const = 0;

  // Returns pointer to elements of proto.softmax_bias(), in row-major order.
  virtual const void *softmax_bias_weights(int i) const = 0;

  // ** Access methods for repeated int32 embedding_num_features.
  //
  // Returns proto.embedding_num_features_size().
  virtual int embedding_num_features_size() const = 0;

  // Returns proto.embedding_num_features(i).
  virtual int embedding_num_features(int i) const = 0;

  // Returns true if and only if index is in range [0, size).  Log an error
  // message otherwise.
  static bool InRange(int index, int size) {
    if ((index < 0) || (index >= size)) {
      TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")";
      return false;
    }
    return true;
  }
};  // class EmbeddingNetworkParams

}  // namespace nlp_core
}  // namespace libtextclassifier

#endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_