C++程序  |  246行  |  8.52 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_

#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "common/embedding-network-package.pb.h"
#include "common/embedding-network-params.h"
#include "common/embedding-network.pb.h"
#include "common/float16.h"
#include "common/little-endian-data.h"
#include "common/task-context.h"
#include "common/task-spec.pb.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"

namespace libtextclassifier {
namespace nlp_core {

// A wrapper class that owns and exposes an EmbeddingNetworkProto message via
// the EmbeddingNetworkParams interface.
//
// The EmbeddingNetworkParams interface encapsulates the weight matrices of the
// embeddings, hidden and softmax layers as transposed versions of their
// counterparts in the original EmbeddingNetworkProto. The matrices in the proto
// passed to this class' constructor must likewise already have been transposed.
// See embedding-network-params.h for details.
class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams {
 public:
  // Constructor that takes ownership of the provided proto. See class-comment
  // for the requirements that certain weight matrices must satisfy.
  explicit EmbeddingNetworkParamsFromProto(
      std::unique_ptr<EmbeddingNetworkProto> proto)
      : proto_(std::move(proto)) {
    valid_ = true;

    // Initialize these vectors to have the required number of elements
    // regardless of quantization status. This is to support the unlikely case
    // where only some embeddings are quantized, along with the fact that
    // EmbeddingNetworkParams interface accesses them by index.
    embeddings_quant_scales_.resize(proto_->embeddings_size());
    embeddings_quant_weights_.resize(proto_->embeddings_size());
    for (int i = 0; i < proto_->embeddings_size(); ++i) {
      MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i);
      if (!embedding->is_quantized()) {
        continue;
      }

      bool success = FillVectorFromDataBytesInLittleEndian(
          embedding->bytes_for_quantized_values(),
          embedding->rows() * embedding->cols(),
          &(embeddings_quant_weights_[i]));
      if (!success) {
        TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i;
        valid_ = false;
      }

      // The repeated field bytes_for_quantized_values uses a lot of memory.
      // Since it's no longer necessary (and we own the proto), we clear it.
      embedding->clear_bytes_for_quantized_values();

      success = FillVectorFromDataBytesInLittleEndian(
          embedding->bytes_for_col_scales(),
          embedding->rows(),
          &(embeddings_quant_scales_[i]));
      if (!success) {
        TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i;
        valid_ = false;
      }

      // See comments for clear_bytes_for_quantized_values().
      embedding->clear_bytes_for_col_scales();
    }
  }

  const TaskSpec *GetTaskSpec() override {
    if (!proto_) {
      return nullptr;
    }
    auto extension_id = task_spec_in_embedding_network_proto;
    if (proto_->HasExtension(extension_id)) {
      return &(proto_->GetExtension(extension_id));
    } else {
      TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto";
      return nullptr;
    }
  }

  // Returns true if these params are valid.  False otherwise (e.g., if the
  // original proto data was corrupted).
  bool is_valid() { return valid_; }

 protected:
  int embeddings_size() const override { return proto_->embeddings_size(); }

  int embeddings_num_rows(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    return proto_->embeddings(i).rows();
  }

  int embeddings_num_cols(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    return proto_->embeddings(i).cols();
  }

  const void *embeddings_weights(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    if (proto_->embeddings(i).is_quantized()) {
      return static_cast<const void *>(embeddings_quant_weights_.at(i).data());
    } else {
      return static_cast<const void *>(proto_->embeddings(i).value().data());
    }
  }

  QuantizationType embeddings_quant_type(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8
                                                : QuantizationType::NONE;
  }

  const float16 *embeddings_quant_scales(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    return proto_->embeddings(i).is_quantized()
               ? embeddings_quant_scales_.at(i).data()
               : nullptr;
  }

  int hidden_size() const override { return proto_->hidden_size(); }

  int hidden_num_rows(int i) const override {
    TC_DCHECK(InRange(i, hidden_size()));
    return proto_->hidden(i).rows();
  }

  int hidden_num_cols(int i) const override {
    TC_DCHECK(InRange(i, hidden_size()));
    return proto_->hidden(i).cols();
  }

  const void *hidden_weights(int i) const override {
    TC_DCHECK(InRange(i, hidden_size()));
    return proto_->hidden(i).value().data();
  }

  int hidden_bias_size() const override { return proto_->hidden_bias_size(); }

  int hidden_bias_num_rows(int i) const override {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    return proto_->hidden_bias(i).rows();
  }

  int hidden_bias_num_cols(int i) const override {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    return proto_->hidden_bias(i).cols();
  }

  const void *hidden_bias_weights(int i) const override {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    return proto_->hidden_bias(i).value().data();
  }

  int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; }

  int softmax_num_rows(int i) const override {
    TC_DCHECK(InRange(i, softmax_size()));
    return proto_->has_softmax() ? proto_->softmax().rows() : 0;
  }

  int softmax_num_cols(int i) const override {
    TC_DCHECK(InRange(i, softmax_size()));
    return proto_->has_softmax() ? proto_->softmax().cols() : 0;
  }

  const void *softmax_weights(int i) const override {
    TC_DCHECK(InRange(i, softmax_size()));
    return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr;
  }

  int softmax_bias_size() const override {
    return proto_->has_softmax_bias() ? 1 : 0;
  }

  int softmax_bias_num_rows(int i) const override {
    TC_DCHECK(InRange(i, softmax_bias_size()));
    return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0;
  }

  int softmax_bias_num_cols(int i) const override {
    TC_DCHECK(InRange(i, softmax_bias_size()));
    return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0;
  }

  const void *softmax_bias_weights(int i) const override {
    TC_DCHECK(InRange(i, softmax_bias_size()));
    return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data()
                                      : nullptr;
  }

  int embedding_num_features_size() const override {
    return proto_->embedding_num_features_size();
  }

  int embedding_num_features(int i) const override {
    TC_DCHECK(InRange(i, embedding_num_features_size()));
    return proto_->embedding_num_features(i);
  }

 private:
  std::unique_ptr<EmbeddingNetworkProto> proto_;

  // True if these params are valid.  May be false if the original proto was
  // corrupted.  We prefer to set this to false to CHECK-failing.
  bool valid_;

  // When the embeddings are quantized, these members are used to store their
  // numeric values using the types expected by the rest of the class. Due to
  // technical reasons, the proto stores this info using larger types (i.e.,
  // more bits).
  std::vector<std::vector<float16>> embeddings_quant_scales_;
  std::vector<std::vector<uint8>> embeddings_quant_weights_;
};

}  // namespace nlp_core
}  // namespace libtextclassifier

#endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_