/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <memory> #include <vector> #include "utils/base/logging.h" #include "utils/sentencepiece/double_array_trie.h" #include "utils/sentencepiece/encoder.h" #include "utils/sentencepiece/normalizer.h" #include "utils/sentencepiece/sorted_strings_table.h" #include "utils/strings/stringpiece.h" #include "utils/tflite/encoder_common.h" #include "utils/tflite/text_encoder.h" #include "utils/tflite/text_encoder_config_generated.h" #include "flatbuffers/flatbuffers.h" #include "flatbuffers/flexbuffers.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { namespace { struct TextEncoderOp { std::unique_ptr<SentencePieceNormalizer> normalizer; std::unique_ptr<Encoder> encoder; std::unique_ptr<SentencePieceMatcher> matcher; }; // Input parameters for the op. // The conversation message as a (1, conversation length) string tensor. constexpr const int kInputTexts = 0; // The number of messages, the conversation length, int scalar. constexpr const int kInputNumInputs = 1; // Maximum output length of the encoding, int scalar. constexpr const int kInputMaxLength = 2; // Additional attributes to align to the sentence pieces, e.g. user ids per // message. constexpr const int kInputAttr = 3; // Output parameters for the op. // The text sentence piece encodings as ids, (1, max output length) int tensor. constexpr const int kOutputEncoded = 0; // Relative position of each sentence piece in the input text, // (1, max output length) int tensor. constexpr const int kOutputPosition = 1; // Output length after trimming to the maximum output length specified. // int scalar. constexpr const int kOutputLengths = 2; // Padded and sentence piece aligned provided attributes, e.g. user id per // sentence piece. constexpr const int kOutputAttr = 3; const char kTextEncoderConfigAttr[] = "text_encoder_config"; // Initializes text encoder object from serialized options: // The options are a flexbuffers attribute map that contain the op config // with the key `text_encoder_config` as `TextEncoderConfig`. void* Initialize(TfLiteContext* context, const char* buffer, size_t length) { const flexbuffers::Map& attr_map = flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length) .AsMap(); const flexbuffers::Blob serialized_config = attr_map[kTextEncoderConfigAttr].AsBlob(); const TextEncoderConfig* config = flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data()); std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp()); // Create normalizer from options. const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>( config->normalization_charsmap()->Data()); const int charsmap_trie_nodes_length = config->normalization_charsmap()->Length() / sizeof(TrieNode); encoder_op->normalizer.reset(new SentencePieceNormalizer( DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length), StringPiece(config->normalization_charsmap_values()->data(), config->normalization_charsmap_values()->size()), config->add_dummy_prefix(), config->remove_extra_whitespaces(), config->escape_whitespaces())); const int num_pieces = config->pieces_scores()->Length(); switch (config->matcher_type()) { case SentencePieceMatcherType_MAPPED_TRIE: { const TrieNode* pieces_trie_nodes = reinterpret_cast<const TrieNode*>(config->pieces()->Data()); const int pieces_trie_nodes_length = config->pieces()->Length() / sizeof(TrieNode); encoder_op->matcher.reset( new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length)); break; } case SentencePieceMatcherType_SORTED_STRING_TABLE: { encoder_op->matcher.reset(new SortedStringsTable( num_pieces, config->pieces_offsets()->data(), StringPiece(config->pieces()->data(), config->pieces()->Length()))); break; } default: { TC3_LOG(ERROR) << "Unknown sentence piece matcher type."; return nullptr; } } encoder_op->encoder.reset(new Encoder( encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(), config->start_code(), config->end_code(), config->encoding_offset(), config->unknown_code(), config->unknown_score())); return encoder_op.release(); } void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast<TextEncoderOp*>(buffer); } TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, int max_output_length) { TF_LITE_ENSURE_OK( context, ResizeOutputTensor(max_output_length, &context->tensors[node->outputs->data[kOutputEncoded]], context)); TF_LITE_ENSURE_OK( context, ResizeOutputTensor( max_output_length, &context->tensors[node->outputs->data[kOutputPosition]], context)); const int num_output_attrs = node->outputs->size - kOutputAttr; for (int i = 0; i < num_output_attrs; ++i) { TF_LITE_ENSURE_OK( context, ResizeOutputTensor( max_output_length, &context->tensors[node->outputs->data[kOutputAttr + i]], context)); } return kTfLiteOk; } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that the batch dimension is kBatchSize. const TfLiteTensor& input_text = context->tensors[node->inputs->data[kInputTexts]]; TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank); TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize); TfLiteTensor& output_lengths = context->tensors[node->outputs->data[kOutputLengths]]; TfLiteTensor& output_encoded = context->tensors[node->outputs->data[kOutputEncoded]]; TfLiteTensor& output_positions = context->tensors[node->outputs->data[kOutputPosition]]; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_lengths, CreateIntArray({kEncoderBatchSize}))); // Check that there are enough outputs for attributes. const int num_output_attrs = node->outputs->size - kOutputAttr; TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs); // Copy attribute types from input to output tensors. for (int i = 0; i < num_output_attrs; ++i) { TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]]; TfLiteTensor& output = context->tensors[node->outputs->data[kOutputAttr + i]]; output.type = input.type; } const TfLiteTensor& output_length = context->tensors[node->inputs->data[kInputMaxLength]]; if (tflite::IsConstantTensor(&output_length)) { return ResizeOutputTensors(context, node, output_length.data.i64[0]); } else { tflite::SetTensorToDynamic(&output_encoded); tflite::SetTensorToDynamic(&output_positions); for (int i = 0; i < num_output_attrs; ++i) { TfLiteTensor& output_attr = context->tensors[node->outputs->data[kOutputAttr + i]]; tflite::SetTensorToDynamic(&output_attr); } } return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (node->user_data == nullptr) { return kTfLiteError; } const TextEncoderOp* encoder_op = reinterpret_cast<TextEncoderOp*>(node->user_data); const TfLiteTensor& input_text = context->tensors[node->inputs->data[kInputTexts]]; const int num_strings = tflite::GetStringCount(&input_text); // Check that the number of strings matches the length parameter. const int num_strings_param = context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0]; TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param); TfLiteTensor& output_encoded = context->tensors[node->outputs->data[kOutputEncoded]]; if (tflite::IsDynamicTensor(&output_encoded)) { const TfLiteTensor& output_length = context->tensors[node->inputs->data[kInputMaxLength]]; TF_LITE_ENSURE_OK( context, ResizeOutputTensors(context, node, output_length.data.i64[0])); } TfLiteTensor& output_positions = context->tensors[node->outputs->data[kOutputPosition]]; std::vector<int> encoded_total; std::vector<int> encoded_offsets; std::vector<int> encoded_positions; encoded_offsets.reserve(num_strings); const int max_output_length = output_encoded.dims->data[1]; const int max_encoded_position = max_output_length; for (int i = 0; i < num_strings; ++i) { const auto& strref = tflite::GetString(&input_text, i); std::string normalized; TF_LITE_ENSURE(context, encoder_op->normalizer->Normalize( StringPiece(strref.str, strref.len), &normalized)); std::vector<int> encoded; TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded)); encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end()); encoded_offsets.push_back(encoded_total.size()); for (int i = 0; i < encoded.size(); i++) { encoded_positions.push_back(std::min(i, max_encoded_position - 1)); } } const int num_skip = CopyDataToTensorAndPadOrTruncate( max_output_length, encoded_total, /*padding_value=*/encoded_total.back(), &output_encoded); TfLiteTensor& output_lengths = context->tensors[node->outputs->data[kOutputLengths]]; output_lengths.data.i32[0] = encoded_total.size() - num_skip; CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions, /*padding_value=*/max_encoded_position, &output_positions); // Process attributes, all checks of sizes and types are done in Prepare. const int num_output_attrs = node->outputs->size - kOutputAttr; TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs); for (int i = 0; i < num_output_attrs; ++i) { TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate( context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets, num_skip, context, &context->tensors[node->outputs->data[kOutputAttr + i]]); if (attr_status != kTfLiteOk) { return attr_status; } } return kTfLiteOk; } } // namespace } // namespace libtextclassifier3 namespace tflite { namespace ops { namespace custom { TfLiteRegistration* Register_TEXT_ENCODER() { static TfLiteRegistration registration = { libtextclassifier3::Initialize, libtextclassifier3::Free, libtextclassifier3::Prepare, libtextclassifier3::Eval}; return ®istration; } } // namespace custom } // namespace ops } // namespace tflite