普通文本  |  209行  |  6.77 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "lang_id/common/flatbuffers/model-utils.h"

#include <string.h>

#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/checksum.h"

namespace libtextclassifier3 {
namespace saft_fbs {

namespace {

// Returns true if we have clear evidence that |model| fails its checksum.
//
// E.g., if |model| has the crc32 field, and the value of that field does not
// match the checksum, then this function returns true.  If there is no crc32
// field, then we don't know what the original (at build time) checksum was, so
// we don't know anything clear and this function returns false.
bool ClearlyFailsChecksum(const Model &model) {
  if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
    SAFTM_LOG(WARNING)
        << "No CRC32, most likely an old model; skip CRC32 check";
    return false;
  }
  const mobile::uint32 expected_crc32 = model.crc32();
  const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
  if (actual_crc32 != expected_crc32) {
    SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
                     << " vs " << expected_crc32;
    return true;
  }
  SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
  return false;
}
}  // namespace

const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
  if ((data == nullptr) || (num_bytes == 0)) {
    SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
    return nullptr;
  }
  const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
  flatbuffers::Verifier verifier(start, num_bytes);
  if (!VerifyModelBuffer(verifier)) {
    SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
    return nullptr;
  }
  const Model *model = GetModel(start);
  if (model == nullptr) {
    return nullptr;
  }
  if (ClearlyFailsChecksum(*model)) {
    return nullptr;
  }
  return model;
}

const ModelInput *GetInputByName(const Model *model, const string &name) {
  if (model == nullptr) {
    SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
    return nullptr;
  }
  const auto *inputs = model->inputs();
  if (inputs == nullptr) {
    // We should always have a list of inputs; maybe an empty one, if no inputs,
    // but the list should be there.
    SAFTM_LOG(ERROR) << "null inputs";
    return nullptr;
  }
  for (const ModelInput *input : *inputs) {
    if (input != nullptr) {
      const flatbuffers::String *input_name = input->name();
      if (input_name && input_name->str() == name) {
        return input;
      }
    }
  }
  return nullptr;
}

mobile::StringPiece GetInputBytes(const ModelInput *input) {
  if ((input == nullptr) || (input->data() == nullptr)) {
    SAFTM_LOG(ERROR) << "ModelInput has no content";
    return mobile::StringPiece(nullptr, 0);
  }
  const flatbuffers::Vector<uint8_t> *input_data = input->data();
  if (input_data == nullptr) {
    SAFTM_LOG(ERROR) << "null input data";
    return mobile::StringPiece(nullptr, 0);
  }
  return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
                             input_data->size());
}

bool FillParameters(const Model &model, mobile::TaskContext *context) {
  if (context == nullptr) {
    SAFTM_LOG(ERROR) << "null context";
    return false;
  }
  const auto *parameters = model.parameters();
  if (parameters == nullptr) {
    // We should always have a list of parameters; maybe an empty one, if no
    // parameters, but the list should be there.
    SAFTM_LOG(ERROR) << "null list of parameters";
    return false;
  }
  for (const ModelParameter *p : *parameters) {
    if (p == nullptr) {
      SAFTM_LOG(ERROR) << "null parameter";
      return false;
    }
    if (p->name() == nullptr) {
      SAFTM_LOG(ERROR) << "null parameter name";
      return false;
    }
    const string name = p->name()->str();
    if (name.empty()) {
      SAFTM_LOG(ERROR) << "empty parameter name";
      return false;
    }
    if (p->value() == nullptr) {
      SAFTM_LOG(ERROR) << "null parameter name";
      return false;
    }
    context->SetParameter(name, p->value()->str());
  }
  return true;
}

namespace {
// Updates |*crc| with the information from |s|.  Auxiliary for
// ComputeCrc2Checksum.
//
// The bytes from |info| are also used to update the CRC32 checksum.  |info|
// should be a brief tag that indicates what |s| represents.  The idea is to add
// some structure to the information that goes into the CRC32 computation.
template <typename T>
void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
               mobile::StringPiece info) {
  crc->Update("|");
  crc->Update(info.data(), info.size());
  crc->Update(":");
  if (s == nullptr) {
    crc->Update("empty");
  } else {
    crc->Update(reinterpret_cast<const char *>(s->data()),
                s->size() * sizeof(T));
  }
}
}  // namespace

mobile::uint32 ComputeCrc2Checksum(const Model *model) {
  // Implementation note: originally, I (salcianu@) thought we can just compute
  // a CRC32 checksum of the model bytes.  Unfortunately, the expected checksum
  // is there too (and because we don't control the flatbuffer format, we can't
  // "arrange" for it to be placed at the head / tail of those bytes).  Instead,
  // we traverse |model| and feed into the CRC32 computation those parts we are
  // interested in (which excludes the crc32 field).
  //
  // Note: storing the checksum outside the Model would be too disruptive for
  // the way we currently ship our models.
  mobile::Crc32 crc;
  if (model == nullptr) {
    return crc.Get();
  }
  crc.Update("|Parameters:");
  const auto *parameters = model->parameters();
  if (parameters != nullptr) {
    for (const ModelParameter *p : *parameters) {
      if (p != nullptr) {
        UpdateCrc(&crc, p->name(), "name");
        UpdateCrc(&crc, p->value(), "value");
      }
    }
  }
  crc.Update("|Inputs:");
  const auto *inputs = model->inputs();
  if (inputs != nullptr) {
    for (const ModelInput *input : *inputs) {
      if (input != nullptr) {
        UpdateCrc(&crc, input->name(), "name");
        UpdateCrc(&crc, input->type(), "type");
        UpdateCrc(&crc, input->sub_type(), "sub-type");
        UpdateCrc(&crc, input->data(), "data");
      }
    }
  }
  return crc.Get();
}

}  // namespace saft_fbs
}  // namespace nlp_saft