/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_ #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_ #include <string> #include "lang_id/common/fel/task-context.h" #include "lang_id/common/lite_base/float16.h" #include "lang_id/common/lite_base/logging.h" namespace libtextclassifier3 { enum class QuantizationType { NONE = 0, // Quantization to 8 bit unsigned ints. UINT8, // Quantization to 4 bit unsigned ints. UINT4, // Quantization to 16 bit floats, the type defined in // lang_id/common/float16.h FLOAT16, // NOTE: for backward compatibility, if you add a new value to this enum, add // it *at the end*, such that you do not change the integer values of the // existing enum values. }; // Converts "UINT8" -> QuantizationType::UINT8, and so on. QuantizationType ParseQuantizationType(const string &s); // API for accessing parameters for a feed-forward neural network with // embeddings. // // // In fact, we provide two APIs: a high-level (and highly-recommented) API, with // methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a // low-level API, using C-style names (e.g., softmax_num_cols()). // // Note: the API below is meant to allow the inference code (the class // libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need // for transposing any matrix (which would require extra overhead on mobile // devices). Hence, as indicated by the comments for the API methods, some of // the matrices below are the transposes of the corresponding matrices from the // original proto. class EmbeddingNetworkParams { public: virtual ~EmbeddingNetworkParams() {} // Returns true if these params are valid. False otherwise (e.g., if the // underlying data is corrupted). If is_valid() returns false, clients should // not call any other method on that instance of EmbeddingNetworkParams. If // is_valid() returns true, then calls to the API methods below should not // crash *if they are called with index parameters in bounds*. E.g., if // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i) // should not crash. virtual bool is_valid() const = 0; // **** High-level API. // Simple representation of a matrix. This small struct that doesn't own any // resource intentionally supports copy / assign, to simplify our APIs. struct Matrix { // Number of rows. int rows = 0; // Number of columns. int cols = 0; QuantizationType quant_type = QuantizationType::NONE; // Pointer to matrix elements, in row-major order // (https://en.wikipedia.org/wiki/Row-major_order) Not owned. const void *elements = nullptr; // Quantization scales: one scale for each row. const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr; }; // Returns i-th embedding matrix. Crashes on out of bounds indices. // // This is the transpose of the corresponding matrix from the original proto. Matrix GetEmbeddingMatrix(int i) const { CheckIndex(i, embeddings_size(), "embedding matrix"); Matrix matrix; matrix.rows = embeddings_num_rows(i); matrix.cols = embeddings_num_cols(i); matrix.elements = embeddings_weights(i); matrix.quant_type = embeddings_quant_type(i); matrix.quant_scales = embeddings_quant_scales(i); return matrix; } // Returns weight matrix for i-th hidden layer. Crashes on out of bounds // indices. // // This is the transpose of the corresponding matrix from the original proto. Matrix GetHiddenLayerMatrix(int i) const { CheckIndex(i, hidden_size(), "hidden layer"); Matrix matrix; matrix.rows = hidden_num_rows(i); matrix.cols = hidden_num_cols(i); // Quantization not supported here. matrix.quant_type = hidden_weights_quant_type(i); matrix.elements = hidden_weights(i); return matrix; } // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it // to be a row/column vector (i.e., num rows or num cols is 1). However, we // don't CHECK for that: we just provide access to underlying data. Crashes // on out of bounds indices. Matrix GetHiddenLayerBias(int i) const { CheckIndex(i, hidden_bias_size(), "hidden layer bias"); Matrix matrix; matrix.rows = hidden_bias_num_rows(i); matrix.cols = hidden_bias_num_cols(i); // Quantization not supported here. matrix.quant_type = QuantizationType::NONE; matrix.elements = hidden_bias_weights(i); return matrix; } // Returns true if a softmax layer exists. bool HasSoftmax() const { return softmax_size() == 1; } // Returns weight matrix for the softmax layer. Note: should be called only // if HasSoftmax() is true. // // This is the transpose of the corresponding matrix from the original proto. Matrix GetSoftmaxMatrix() const { SAFTM_CHECK(HasSoftmax()) << "No softmax layer."; Matrix matrix; matrix.rows = softmax_num_rows(0); matrix.cols = softmax_num_cols(0); // Quantization not supported here. matrix.quant_type = softmax_weights_quant_type(0); matrix.elements = softmax_weights(0); return matrix; } // Returns bias for the softmax layer. Technically a Matrix, but we expect it // to be a row/column vector (i.e., num rows or num cols is 1). However, we // don't CHECK for that: we just provide access to underlying data. Matrix GetSoftmaxBias() const { SAFTM_CHECK(HasSoftmax()) << "No softmax layer."; Matrix matrix; matrix.rows = softmax_bias_num_rows(0); matrix.cols = softmax_bias_num_cols(0); // Quantization not supported here. matrix.quant_type = QuantizationType::NONE; matrix.elements = softmax_bias_weights(0); return matrix; } // Updates the EmbeddingNetwork-related parameters from task_context. Returns // true on success, false on error. virtual bool UpdateTaskContextParameters( mobile::TaskContext *task_context) = 0; // **** Low-level API. // // * Most low-level API methods are documented by giving an equivalent // function call on proto, the original proto (of type // EmbeddingNetworkProto) which was used to generate the C++ code. // // * To simplify our generation code, optional proto fields of message type // are treated as repeated fields with 0 or 1 instances. As such, we have // *_size() methods for such optional fields: they return 0 or 1. // // * "transpose(M)" denotes the transpose of a matrix M. // ** Access methods for repeated MatrixParams embeddings. // // Returns proto.embeddings_size(). virtual int embeddings_size() const = 0; // Returns number of rows of transpose(proto.embeddings(i)). virtual int embeddings_num_rows(int i) const = 0; // Returns number of columns of transpose(proto.embeddings(i)). virtual int embeddings_num_cols(int i) const = 0; // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major // order. NOTE: for unquantized embeddings, this returns a pointer to float; // for quantized embeddings, this returns a pointer to uint8. virtual const void *embeddings_weights(int i) const = 0; virtual QuantizationType embeddings_quant_type(int i) const { return QuantizationType::NONE; } virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales( int i) const { return nullptr; } // ** Access methods for repeated MatrixParams hidden. // // Returns embedding_network_proto.hidden_size(). virtual int hidden_size() const = 0; // Returns embedding_network_proto.hidden(i).rows(). virtual int hidden_num_rows(int i) const = 0; // Returns embedding_network_proto.hidden(i).rows(). virtual int hidden_num_cols(int i) const = 0; // Returns quantization mode for the weights of the i-th hidden layer. virtual QuantizationType hidden_weights_quant_type(int i) const { return QuantizationType::NONE; } // Returns pointer to beginning of array of floats with all values from // embedding_network_proto.hidden(i). virtual const void *hidden_weights(int i) const = 0; // ** Access methods for repeated MatrixParams hidden_bias. // // Returns proto.hidden_bias_size(). virtual int hidden_bias_size() const = 0; // Returns number of rows of proto.hidden_bias(i). virtual int hidden_bias_num_rows(int i) const = 0; // Returns number of columns of proto.hidden_bias(i). virtual int hidden_bias_num_cols(int i) const = 0; // Returns pointer to elements of proto.hidden_bias(i), in row-major order. virtual const void *hidden_bias_weights(int i) const = 0; // ** Access methods for optional MatrixParams softmax. // // Returns 1 if proto has optional field softmax, 0 otherwise. virtual int softmax_size() const = 0; // Returns number of rows of transpose(proto.softmax()). virtual int softmax_num_rows(int i) const = 0; // Returns number of columns of transpose(proto.softmax()). virtual int softmax_num_cols(int i) const = 0; // Returns quantization mode for the softmax weights. virtual QuantizationType softmax_weights_quant_type(int i) const { return QuantizationType::NONE; } // Returns pointer to elements of transpose(proto.softmax()), in row-major // order. virtual const void *softmax_weights(int i) const = 0; // ** Access methods for optional MatrixParams softmax_bias. // // Returns 1 if proto has optional field softmax_bias, 0 otherwise. virtual int softmax_bias_size() const = 0; // Returns number of rows of proto.softmax_bias(). virtual int softmax_bias_num_rows(int i) const = 0; // Returns number of columns of proto.softmax_bias(). virtual int softmax_bias_num_cols(int i) const = 0; // Returns pointer to elements of proto.softmax_bias(), in row-major order. virtual const void *softmax_bias_weights(int i) const = 0; // ** Access methods for repeated int32 embedding_num_features. // // Returns proto.embedding_num_features_size(). virtual int embedding_num_features_size() const = 0; // Returns proto.embedding_num_features(i). virtual int embedding_num_features(int i) const = 0; // ** Access methods for is_precomputed // // Returns proto.has_is_precomputed(). virtual bool has_is_precomputed() const = 0; // Returns proto.is_precomputed(). virtual bool is_precomputed() const = 0; protected: void CheckIndex(int index, int size, const string &description) const { SAFTM_CHECK_GE(index, 0) << "Out-of-range index for " << description << ": " << index; SAFTM_CHECK_LT(index, size) << "Out-of-range index for " << description << ": " << index; } }; // class EmbeddingNetworkParams } // namespace nlp_saft #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_