/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "text-classifier.h" #include <algorithm> #include <cctype> #include <cmath> #include <iterator> #include <numeric> #include "util/base/logging.h" #include "util/math/softmax.h" #include "util/utf8/unicodetext.h" namespace libtextclassifier2 { const std::string& TextClassifier::kOtherCollection = *[]() { return new std::string("other"); }(); const std::string& TextClassifier::kPhoneCollection = *[]() { return new std::string("phone"); }(); const std::string& TextClassifier::kAddressCollection = *[]() { return new std::string("address"); }(); const std::string& TextClassifier::kDateCollection = *[]() { return new std::string("date"); }(); namespace { const Model* LoadAndVerifyModel(const void* addr, int size) { const Model* model = GetModel(addr); flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); if (model->Verify(verifier)) { return model; } else { return nullptr; } } } // namespace tflite::Interpreter* InterpreterManager::SelectionInterpreter() { if (!selection_interpreter_) { TC_CHECK(selection_executor_); selection_interpreter_ = selection_executor_->CreateInterpreter(); if (!selection_interpreter_) { TC_LOG(ERROR) << "Could not build TFLite interpreter."; } } return selection_interpreter_.get(); } tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { if (!classification_interpreter_) { TC_CHECK(classification_executor_); classification_interpreter_ = classification_executor_->CreateInterpreter(); if (!classification_interpreter_) { TC_LOG(ERROR) << "Could not build TFLite interpreter."; } } return classification_interpreter_.get(); } std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer( const char* buffer, int size, const UniLib* unilib) { const Model* model = LoadAndVerifyModel(buffer, size); if (model == nullptr) { return nullptr; } auto classifier = std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib)); if (!classifier->IsInitialized()) { return nullptr; } return classifier; } std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap( std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) { if (!(*mmap)->handle().ok()) { TC_VLOG(1) << "Mmap failed."; return nullptr; } const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), (*mmap)->handle().num_bytes()); if (!model) { TC_LOG(ERROR) << "Model verification failed."; return nullptr; } auto classifier = std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib)); if (!classifier->IsInitialized()) { return nullptr; } return classifier; } std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( int fd, int offset, int size, const UniLib* unilib) { std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); return FromScopedMmap(&mmap, unilib); } std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( int fd, const UniLib* unilib) { std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); return FromScopedMmap(&mmap, unilib); } std::unique_ptr<TextClassifier> TextClassifier::FromPath( const std::string& path, const UniLib* unilib) { std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); return FromScopedMmap(&mmap, unilib); } void TextClassifier::ValidateAndInitialize() { initialized_ = false; if (model_ == nullptr) { TC_LOG(ERROR) << "No model specified."; return; } const bool model_enabled_for_annotation = (model_->triggering_options() != nullptr && (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)); const bool model_enabled_for_classification = (model_->triggering_options() != nullptr && (model_->triggering_options()->enabled_modes() & ModeFlag_CLASSIFICATION)); const bool model_enabled_for_selection = (model_->triggering_options() != nullptr && (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)); // Annotation requires the selection model. if (model_enabled_for_annotation || model_enabled_for_selection) { if (!model_->selection_options()) { TC_LOG(ERROR) << "No selection options."; return; } if (!model_->selection_feature_options()) { TC_LOG(ERROR) << "No selection feature options."; return; } if (!model_->selection_feature_options()->bounds_sensitive_features()) { TC_LOG(ERROR) << "No selection bounds sensitive feature options."; return; } if (!model_->selection_model()) { TC_LOG(ERROR) << "No selection model."; return; } selection_executor_ = ModelExecutor::Instance(model_->selection_model()); if (!selection_executor_) { TC_LOG(ERROR) << "Could not initialize selection executor."; return; } selection_feature_processor_.reset( new FeatureProcessor(model_->selection_feature_options(), unilib_)); } // Annotation requires the classification model for conflict resolution and // scoring. // Selection requires the classification model for conflict resolution. if (model_enabled_for_annotation || model_enabled_for_classification || model_enabled_for_selection) { if (!model_->classification_options()) { TC_LOG(ERROR) << "No classification options."; return; } if (!model_->classification_feature_options()) { TC_LOG(ERROR) << "No classification feature options."; return; } if (!model_->classification_feature_options() ->bounds_sensitive_features()) { TC_LOG(ERROR) << "No classification bounds sensitive feature options."; return; } if (!model_->classification_model()) { TC_LOG(ERROR) << "No clf model."; return; } classification_executor_ = ModelExecutor::Instance(model_->classification_model()); if (!classification_executor_) { TC_LOG(ERROR) << "Could not initialize classification executor."; return; } classification_feature_processor_.reset(new FeatureProcessor( model_->classification_feature_options(), unilib_)); } // The embeddings need to be specified if the model is to be used for // classification or selection. if (model_enabled_for_annotation || model_enabled_for_classification || model_enabled_for_selection) { if (!model_->embedding_model()) { TC_LOG(ERROR) << "No embedding model."; return; } // Check that the embedding size of the selection and classification model // matches, as they are using the same embeddings. if (model_enabled_for_selection && (model_->selection_feature_options()->embedding_size() != model_->classification_feature_options()->embedding_size() || model_->selection_feature_options()->embedding_quantization_bits() != model_->classification_feature_options() ->embedding_quantization_bits())) { TC_LOG(ERROR) << "Mismatching embedding size/quantization."; return; } embedding_executor_ = TFLiteEmbeddingExecutor::Instance( model_->embedding_model(), model_->classification_feature_options()->embedding_size(), model_->classification_feature_options() ->embedding_quantization_bits()); if (!embedding_executor_) { TC_LOG(ERROR) << "Could not initialize embedding executor."; return; } } std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); if (model_->regex_model()) { if (!InitializeRegexModel(decompressor.get())) { TC_LOG(ERROR) << "Could not initialize regex model."; return; } } if (model_->datetime_model()) { datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(), *unilib_, decompressor.get()); if (!datetime_parser_) { TC_LOG(ERROR) << "Could not initialize datetime parser."; return; } } if (model_->output_options()) { if (model_->output_options()->filtered_collections_annotation()) { for (const auto collection : *model_->output_options()->filtered_collections_annotation()) { filtered_collections_annotation_.insert(collection->str()); } } if (model_->output_options()->filtered_collections_classification()) { for (const auto collection : *model_->output_options()->filtered_collections_classification()) { filtered_collections_classification_.insert(collection->str()); } } if (model_->output_options()->filtered_collections_selection()) { for (const auto collection : *model_->output_options()->filtered_collections_selection()) { filtered_collections_selection_.insert(collection->str()); } } } initialized_ = true; } bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { if (!model_->regex_model()->patterns()) { return true; } // Initialize pattern recognizers. int regex_pattern_id = 0; for (const auto& regex_pattern : *model_->regex_model()->patterns()) { std::unique_ptr<UniLib::RegexPattern> compiled_pattern = UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(), regex_pattern->compressed_pattern(), decompressor); if (!compiled_pattern) { TC_LOG(INFO) << "Failed to load regex pattern"; return false; } if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) { annotation_regex_patterns_.push_back(regex_pattern_id); } if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) { classification_regex_patterns_.push_back(regex_pattern_id); } if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { selection_regex_patterns_.push_back(regex_pattern_id); } regex_patterns_.push_back({regex_pattern->collection_name()->str(), regex_pattern->target_classification_score(), regex_pattern->priority_score(), std::move(compiled_pattern)}); if (regex_pattern->use_approximate_matching()) { regex_approximate_match_pattern_ids_.insert(regex_pattern_id); } ++regex_pattern_id; } return true; } namespace { int CountDigits(const std::string& str, CodepointSpan selection_indices) { int count = 0; int i = 0; const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { if (i >= selection_indices.first && i < selection_indices.second && isdigit(*it)) { ++count; } } return count; } std::string ExtractSelection(const std::string& context, CodepointSpan selection_indices) { const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); auto selection_begin = context_unicode.begin(); std::advance(selection_begin, selection_indices.first); auto selection_end = context_unicode.begin(); std::advance(selection_end, selection_indices.second); return UnicodeText::UTF8Substring(selection_begin, selection_end); } } // namespace namespace internal { // Helper function, which if the initial 'span' contains only white-spaces, // moves the selection to a single-codepoint selection on a left or right side // of this space. CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, const UnicodeText& context_unicode, const UniLib& unilib) { TC_CHECK(ValidNonEmptySpan(span)); UnicodeText::const_iterator it; // Check that the current selection is all whitespaces. it = context_unicode.begin(); std::advance(it, span.first); for (int i = 0; i < (span.second - span.first); ++i, ++it) { if (!unilib.IsWhitespace(*it)) { return span; } } CodepointSpan result; // Try moving left. result = span; it = context_unicode.begin(); std::advance(it, span.first); while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) { --result.first; --it; } result.second = result.first + 1; if (!unilib.IsWhitespace(*it)) { return result; } // If moving left didn't find a non-whitespace character, just return the // original span. return span; } } // namespace internal bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const { return !span.classification.empty() && filtered_collections_annotation_.find( span.classification[0].collection) != filtered_collections_annotation_.end(); } bool TextClassifier::FilteredForClassification( const ClassificationResult& classification) const { return filtered_collections_classification_.find(classification.collection) != filtered_collections_classification_.end(); } bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const { return !span.classification.empty() && filtered_collections_selection_.find( span.classification[0].collection) != filtered_collections_selection_.end(); } CodepointSpan TextClassifier::SuggestSelection( const std::string& context, CodepointSpan click_indices, const SelectionOptions& options) const { CodepointSpan original_click_indices = click_indices; if (!initialized_) { TC_LOG(ERROR) << "Not initialized"; return original_click_indices; } if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { return original_click_indices; } const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); if (!context_unicode.is_valid()) { return original_click_indices; } const int context_codepoint_size = context_unicode.size_codepoints(); if (click_indices.first < 0 || click_indices.second < 0 || click_indices.first >= context_codepoint_size || click_indices.second > context_codepoint_size || click_indices.first >= click_indices.second) { TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " << click_indices.first << " " << click_indices.second; return original_click_indices; } if (model_->snap_whitespace_selections()) { // We want to expand a purely white-space selection to a multi-selection it // would've been part of. But with this feature disabled we would do a no- // op, because no token is found. Therefore, we need to modify the // 'click_indices' a bit to include a part of the token, so that the click- // finding logic finds the clicked token correctly. This modification is // done by the following function. Note, that it's enough to check the left // side of the current selection, because if the white-space is a part of a // multi-selection, neccessarily both tokens - on the left and the right // sides need to be selected. Thus snapping only to the left is sufficient // (there's a check at the bottom that makes sure that if we snap to the // left token but the result does not contain the initial white-space, // returns the original indices). click_indices = internal::SnapLeftIfWhitespaceSelection( click_indices, context_unicode, *unilib_); } std::vector<AnnotatedSpan> candidates; InterpreterManager interpreter_manager(selection_executor_.get(), classification_executor_.get()); std::vector<Token> tokens; if (!ModelSuggestSelection(context_unicode, click_indices, &interpreter_manager, &tokens, &candidates)) { TC_LOG(ERROR) << "Model suggest selection failed."; return original_click_indices; } if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) { TC_LOG(ERROR) << "Regex suggest selection failed."; return original_click_indices; } if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", options.locales, ModeFlag_SELECTION, &candidates)) { TC_LOG(ERROR) << "Datetime suggest selection failed."; return original_click_indices; } // Sort candidates according to their position in the input, so that the next // code can assume that any connected component of overlapping spans forms a // contiguous block. std::sort(candidates.begin(), candidates.end(), [](const AnnotatedSpan& a, const AnnotatedSpan& b) { return a.span.first < b.span.first; }); std::vector<int> candidate_indices; if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, &candidate_indices)) { TC_LOG(ERROR) << "Couldn't resolve conflicts."; return original_click_indices; } for (const int i : candidate_indices) { if (SpansOverlap(candidates[i].span, click_indices) && SpansOverlap(candidates[i].span, original_click_indices)) { // Run model classification if not present but requested and there's a // classification collection filter specified. if (candidates[i].classification.empty() && model_->selection_options()->always_classify_suggested_selection() && !filtered_collections_selection_.empty()) { if (!ModelClassifyText( context, candidates[i].span, &interpreter_manager, /*embedding_cache=*/nullptr, &candidates[i].classification)) { return original_click_indices; } } // Ignore if span classification is filtered. if (FilteredForSelection(candidates[i])) { return original_click_indices; } return candidates[i].span; } } return original_click_indices; } namespace { // Helper function that returns the index of the first candidate that // transitively does not overlap with the candidate on 'start_index'. If the end // of 'candidates' is reached, it returns the index that points right behind the // array. int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, int start_index) { int first_non_overlapping = start_index + 1; CodepointSpan conflicting_span = candidates[start_index].span; while ( first_non_overlapping < candidates.size() && SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) { // Grow the span to include the current one. conflicting_span.second = std::max( conflicting_span.second, candidates[first_non_overlapping].span.second); ++first_non_overlapping; } return first_non_overlapping; } } // namespace bool TextClassifier::ResolveConflicts( const std::vector<AnnotatedSpan>& candidates, const std::string& context, const std::vector<Token>& cached_tokens, InterpreterManager* interpreter_manager, std::vector<int>* result) const { result->clear(); result->reserve(candidates.size()); for (int i = 0; i < candidates.size();) { int first_non_overlapping = FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i); const bool conflict_found = first_non_overlapping != (i + 1); if (conflict_found) { std::vector<int> candidate_indices; if (!ResolveConflict(context, cached_tokens, candidates, i, first_non_overlapping, interpreter_manager, &candidate_indices)) { return false; } result->insert(result->end(), candidate_indices.begin(), candidate_indices.end()); } else { result->push_back(i); } // Skip over the whole conflicting group/go to next candidate. i = first_non_overlapping; } return true; } namespace { inline bool ClassifiedAsOther( const std::vector<ClassificationResult>& classification) { return !classification.empty() && classification[0].collection == TextClassifier::kOtherCollection; } float GetPriorityScore( const std::vector<ClassificationResult>& classification) { if (!ClassifiedAsOther(classification)) { return classification[0].priority_score; } else { return -1.0; } } } // namespace bool TextClassifier::ResolveConflict( const std::string& context, const std::vector<Token>& cached_tokens, const std::vector<AnnotatedSpan>& candidates, int start_index, int end_index, InterpreterManager* interpreter_manager, std::vector<int>* chosen_indices) const { std::vector<int> conflicting_indices; std::unordered_map<int, float> scores; for (int i = start_index; i < end_index; ++i) { conflicting_indices.push_back(i); if (!candidates[i].classification.empty()) { scores[i] = GetPriorityScore(candidates[i].classification); continue; } // OPTIMIZATION: So that we don't have to classify all the ML model // spans apriori, we wait until we get here, when they conflict with // something and we need the actual classification scores. So if the // candidate conflicts and comes from the model, we need to run a // classification to determine its priority: std::vector<ClassificationResult> classification; if (!ModelClassifyText(context, cached_tokens, candidates[i].span, interpreter_manager, /*embedding_cache=*/nullptr, &classification)) { return false; } if (!classification.empty()) { scores[i] = GetPriorityScore(classification); } } std::sort(conflicting_indices.begin(), conflicting_indices.end(), [&scores](int i, int j) { return scores[i] > scores[j]; }); // Keeps the candidates sorted by their position in the text (their left span // index) for fast retrieval down. std::set<int, std::function<bool(int, int)>> chosen_indices_set( [&candidates](int a, int b) { return candidates[a].span.first < candidates[b].span.first; }); // Greedily place the candidates if they don't conflict with the already // placed ones. for (int i = 0; i < conflicting_indices.size(); ++i) { const int considered_candidate = conflicting_indices[i]; if (!DoesCandidateConflict(considered_candidate, candidates, chosen_indices_set)) { chosen_indices_set.insert(considered_candidate); } } *chosen_indices = std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end()); return true; } bool TextClassifier::ModelSuggestSelection( const UnicodeText& context_unicode, CodepointSpan click_indices, InterpreterManager* interpreter_manager, std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const { if (model_->triggering_options() == nullptr || !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) { return true; } int click_pos; *tokens = selection_feature_processor_->Tokenize(context_unicode); selection_feature_processor_->RetokenizeAndFindClick( context_unicode, click_indices, selection_feature_processor_->GetOptions()->only_use_line_with_click(), tokens, &click_pos); if (click_pos == kInvalidIndex) { TC_VLOG(1) << "Could not calculate the click position."; return false; } const int symmetry_context_size = model_->selection_options()->symmetry_context_size(); const FeatureProcessorOptions_::BoundsSensitiveFeatures* bounds_sensitive_features = selection_feature_processor_->GetOptions() ->bounds_sensitive_features(); // The symmetry context span is the clicked token with symmetry_context_size // tokens on either side. const TokenSpan symmetry_context_span = IntersectTokenSpans( ExpandTokenSpan(SingleTokenSpan(click_pos), /*num_tokens_left=*/symmetry_context_size, /*num_tokens_right=*/symmetry_context_size), {0, tokens->size()}); // Compute the extraction span based on the model type. TokenSpan extraction_span; if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { // The extraction span is the symmetry context span expanded to include // max_selection_span tokens on either side, which is how far a selection // can stretch from the click, plus a relevant number of tokens outside of // the bounds of the selection. const int max_selection_span = selection_feature_processor_->GetOptions()->max_selection_span(); extraction_span = ExpandTokenSpan(symmetry_context_span, /*num_tokens_left=*/max_selection_span + bounds_sensitive_features->num_tokens_before(), /*num_tokens_right=*/max_selection_span + bounds_sensitive_features->num_tokens_after()); } else { // The extraction span is the symmetry context span expanded to include // context_size tokens on either side. const int context_size = selection_feature_processor_->GetOptions()->context_size(); extraction_span = ExpandTokenSpan(symmetry_context_span, /*num_tokens_left=*/context_size, /*num_tokens_right=*/context_size); } extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); if (!selection_feature_processor_->HasEnoughSupportedCodepoints( *tokens, extraction_span)) { return true; } std::unique_ptr<CachedFeatures> cached_features; if (!selection_feature_processor_->ExtractFeatures( *tokens, extraction_span, /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, embedding_executor_.get(), /*embedding_cache=*/nullptr, selection_feature_processor_->EmbeddingSize() + selection_feature_processor_->DenseFeaturesCount(), &cached_features)) { TC_LOG(ERROR) << "Could not extract features."; return false; } // Produce selection model candidates. std::vector<TokenSpan> chunks; if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, interpreter_manager->SelectionInterpreter(), *cached_features, &chunks)) { TC_LOG(ERROR) << "Could not chunk."; return false; } for (const TokenSpan& chunk : chunks) { AnnotatedSpan candidate; candidate.span = selection_feature_processor_->StripBoundaryCodepoints( context_unicode, TokenSpanToCodepointSpan(*tokens, chunk)); if (model_->selection_options()->strip_unpaired_brackets()) { candidate.span = StripUnpairedBrackets(context_unicode, candidate.span, *unilib_); } // Only output non-empty spans. if (candidate.span.first != candidate.span.second) { result->push_back(candidate); } } return true; } bool TextClassifier::ModelClassifyText( const std::string& context, CodepointSpan selection_indices, InterpreterManager* interpreter_manager, FeatureProcessor::EmbeddingCache* embedding_cache, std::vector<ClassificationResult>* classification_results) const { if (model_->triggering_options() == nullptr || !(model_->triggering_options()->enabled_modes() & ModeFlag_CLASSIFICATION)) { return true; } return ModelClassifyText(context, {}, selection_indices, interpreter_manager, embedding_cache, classification_results); } namespace internal { std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, CodepointSpan selection_indices, TokenSpan tokens_around_selection_to_copy) { const auto first_selection_token = std::upper_bound( cached_tokens.begin(), cached_tokens.end(), selection_indices.first, [](int selection_start, const Token& token) { return selection_start < token.end; }); const auto last_selection_token = std::lower_bound( cached_tokens.begin(), cached_tokens.end(), selection_indices.second, [](const Token& token, int selection_end) { return token.start < selection_end; }); const int64 first_token = std::max( static_cast<int64>(0), static_cast<int64>((first_selection_token - cached_tokens.begin()) - tokens_around_selection_to_copy.first)); const int64 last_token = std::min( static_cast<int64>(cached_tokens.size()), static_cast<int64>((last_selection_token - cached_tokens.begin()) + tokens_around_selection_to_copy.second)); std::vector<Token> tokens; tokens.reserve(last_token - first_token); for (int i = first_token; i < last_token; ++i) { tokens.push_back(cached_tokens[i]); } return tokens; } } // namespace internal TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const { const FeatureProcessorOptions_::BoundsSensitiveFeatures* bounds_sensitive_features = classification_feature_processor_->GetOptions() ->bounds_sensitive_features(); if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { // The extraction span is the selection span expanded to include a relevant // number of tokens outside of the bounds of the selection. return {bounds_sensitive_features->num_tokens_before(), bounds_sensitive_features->num_tokens_after()}; } else { // The extraction span is the clicked token with context_size tokens on // either side. const int context_size = selection_feature_processor_->GetOptions()->context_size(); return {context_size, context_size}; } } bool TextClassifier::ModelClassifyText( const std::string& context, const std::vector<Token>& cached_tokens, CodepointSpan selection_indices, InterpreterManager* interpreter_manager, FeatureProcessor::EmbeddingCache* embedding_cache, std::vector<ClassificationResult>* classification_results) const { std::vector<Token> tokens; if (cached_tokens.empty()) { tokens = classification_feature_processor_->Tokenize(context); } else { tokens = internal::CopyCachedTokens(cached_tokens, selection_indices, ClassifyTextUpperBoundNeededTokens()); } int click_pos; classification_feature_processor_->RetokenizeAndFindClick( context, selection_indices, classification_feature_processor_->GetOptions() ->only_use_line_with_click(), &tokens, &click_pos); const TokenSpan selection_token_span = CodepointSpanToTokenSpan(tokens, selection_indices); const int selection_num_tokens = TokenSpanSize(selection_token_span); if (model_->classification_options()->max_num_tokens() > 0 && model_->classification_options()->max_num_tokens() < selection_num_tokens) { *classification_results = {{kOtherCollection, 1.0}}; return true; } const FeatureProcessorOptions_::BoundsSensitiveFeatures* bounds_sensitive_features = classification_feature_processor_->GetOptions() ->bounds_sensitive_features(); if (selection_token_span.first == kInvalidIndex || selection_token_span.second == kInvalidIndex) { TC_LOG(ERROR) << "Could not determine span."; return false; } // Compute the extraction span based on the model type. TokenSpan extraction_span; if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { // The extraction span is the selection span expanded to include a relevant // number of tokens outside of the bounds of the selection. extraction_span = ExpandTokenSpan( selection_token_span, /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(), /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); } else { if (click_pos == kInvalidIndex) { TC_LOG(ERROR) << "Couldn't choose a click position."; return false; } // The extraction span is the clicked token with context_size tokens on // either side. const int context_size = classification_feature_processor_->GetOptions()->context_size(); extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos), /*num_tokens_left=*/context_size, /*num_tokens_right=*/context_size); } extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()}); if (!classification_feature_processor_->HasEnoughSupportedCodepoints( tokens, extraction_span)) { *classification_results = {{kOtherCollection, 1.0}}; return true; } std::unique_ptr<CachedFeatures> cached_features; if (!classification_feature_processor_->ExtractFeatures( tokens, extraction_span, selection_indices, embedding_executor_.get(), embedding_cache, classification_feature_processor_->EmbeddingSize() + classification_feature_processor_->DenseFeaturesCount(), &cached_features)) { TC_LOG(ERROR) << "Could not extract features."; return false; } std::vector<float> features; features.reserve(cached_features->OutputFeaturesSize()); if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span, &features); } else { cached_features->AppendClickContextFeaturesForClick(click_pos, &features); } TensorView<float> logits = classification_executor_->ComputeLogits( TensorView<float>(features.data(), {1, static_cast<int>(features.size())}), interpreter_manager->ClassificationInterpreter()); if (!logits.is_valid()) { TC_LOG(ERROR) << "Couldn't compute logits."; return false; } if (logits.dims() != 2 || logits.dim(0) != 1 || logits.dim(1) != classification_feature_processor_->NumCollections()) { TC_LOG(ERROR) << "Mismatching output"; return false; } const std::vector<float> scores = ComputeSoftmax(logits.data(), logits.dim(1)); classification_results->resize(scores.size()); for (int i = 0; i < scores.size(); i++) { (*classification_results)[i] = { classification_feature_processor_->LabelToCollection(i), scores[i]}; } std::sort(classification_results->begin(), classification_results->end(), [](const ClassificationResult& a, const ClassificationResult& b) { return a.score > b.score; }); // Phone class sanity check. if (!classification_results->empty() && classification_results->begin()->collection == kPhoneCollection) { const int digit_count = CountDigits(context, selection_indices); if (digit_count < model_->classification_options()->phone_min_num_digits() || digit_count > model_->classification_options()->phone_max_num_digits()) { *classification_results = {{kOtherCollection, 1.0}}; } } // Address class sanity check. if (!classification_results->empty() && classification_results->begin()->collection == kAddressCollection) { if (selection_num_tokens < model_->classification_options()->address_min_num_tokens()) { *classification_results = {{kOtherCollection, 1.0}}; } } return true; } bool TextClassifier::RegexClassifyText( const std::string& context, CodepointSpan selection_indices, ClassificationResult* classification_result) const { const std::string selection_text = ExtractSelection(context, selection_indices); const UnicodeText selection_text_unicode( UTF8ToUnicodeText(selection_text, /*do_copy=*/false)); // Check whether any of the regular expressions match. for (const int pattern_id : classification_regex_patterns_) { const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; const std::unique_ptr<UniLib::RegexMatcher> matcher = regex_pattern.pattern->Matcher(selection_text_unicode); int status = UniLib::RegexMatcher::kNoError; bool matches; if (regex_approximate_match_pattern_ids_.find(pattern_id) != regex_approximate_match_pattern_ids_.end()) { matches = matcher->ApproximatelyMatches(&status); } else { matches = matcher->Matches(&status); } if (status != UniLib::RegexMatcher::kNoError) { return false; } if (matches) { *classification_result = {regex_pattern.collection_name, regex_pattern.target_classification_score, regex_pattern.priority_score}; return true; } if (status != UniLib::RegexMatcher::kNoError) { TC_LOG(ERROR) << "Cound't match regex: " << pattern_id; } } return false; } bool TextClassifier::DatetimeClassifyText( const std::string& context, CodepointSpan selection_indices, const ClassificationOptions& options, ClassificationResult* classification_result) const { if (!datetime_parser_) { return false; } const std::string selection_text = ExtractSelection(context, selection_indices); std::vector<DatetimeParseResultSpan> datetime_spans; if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, options.reference_timezone, options.locales, ModeFlag_CLASSIFICATION, /*anchor_start_end=*/true, &datetime_spans)) { TC_LOG(ERROR) << "Error during parsing datetime."; return false; } for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { // Only consider the result valid if the selection and extracted datetime // spans exactly match. if (std::make_pair(datetime_span.span.first + selection_indices.first, datetime_span.span.second + selection_indices.first) == selection_indices) { *classification_result = {kDateCollection, datetime_span.target_classification_score}; classification_result->datetime_parse_result = datetime_span.data; return true; } } return false; } std::vector<ClassificationResult> TextClassifier::ClassifyText( const std::string& context, CodepointSpan selection_indices, const ClassificationOptions& options) const { if (!initialized_) { TC_LOG(ERROR) << "Not initialized"; return {}; } if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { return {}; } if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { return {}; } if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: " << std::get<0>(selection_indices) << " " << std::get<1>(selection_indices); return {}; } // Try the regular expression models. ClassificationResult regex_result; if (RegexClassifyText(context, selection_indices, ®ex_result)) { if (!FilteredForClassification(regex_result)) { return {regex_result}; } else { return {{kOtherCollection, 1.0}}; } } // Try the date model. ClassificationResult datetime_result; if (DatetimeClassifyText(context, selection_indices, options, &datetime_result)) { if (!FilteredForClassification(datetime_result)) { return {datetime_result}; } else { return {{kOtherCollection, 1.0}}; } } // Fallback to the model. std::vector<ClassificationResult> model_result; InterpreterManager interpreter_manager(selection_executor_.get(), classification_executor_.get()); if (ModelClassifyText(context, selection_indices, &interpreter_manager, /*embedding_cache=*/nullptr, &model_result) && !model_result.empty()) { if (!FilteredForClassification(model_result[0])) { return model_result; } else { return {{kOtherCollection, 1.0}}; } } // No classifications. return {}; } bool TextClassifier::ModelAnnotate(const std::string& context, InterpreterManager* interpreter_manager, std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const { if (model_->triggering_options() == nullptr || !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { return true; } const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); std::vector<UnicodeTextRange> lines; if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) { lines.push_back({context_unicode.begin(), context_unicode.end()}); } else { lines = selection_feature_processor_->SplitContext(context_unicode); } const float min_annotate_confidence = (model_->triggering_options() != nullptr ? model_->triggering_options()->min_annotate_confidence() : 0.f); FeatureProcessor::EmbeddingCache embedding_cache; for (const UnicodeTextRange& line : lines) { const std::string line_str = UnicodeText::UTF8Substring(line.first, line.second); *tokens = selection_feature_processor_->Tokenize(line_str); selection_feature_processor_->RetokenizeAndFindClick( line_str, {0, std::distance(line.first, line.second)}, selection_feature_processor_->GetOptions()->only_use_line_with_click(), tokens, /*click_pos=*/nullptr); const TokenSpan full_line_span = {0, tokens->size()}; // TODO(zilka): Add support for greater granularity of this check. if (!selection_feature_processor_->HasEnoughSupportedCodepoints( *tokens, full_line_span)) { continue; } std::unique_ptr<CachedFeatures> cached_features; if (!selection_feature_processor_->ExtractFeatures( *tokens, full_line_span, /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, embedding_executor_.get(), /*embedding_cache=*/nullptr, selection_feature_processor_->EmbeddingSize() + selection_feature_processor_->DenseFeaturesCount(), &cached_features)) { TC_LOG(ERROR) << "Could not extract features."; return false; } std::vector<TokenSpan> local_chunks; if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, interpreter_manager->SelectionInterpreter(), *cached_features, &local_chunks)) { TC_LOG(ERROR) << "Could not chunk."; return false; } const int offset = std::distance(context_unicode.begin(), line.first); for (const TokenSpan& chunk : local_chunks) { const CodepointSpan codepoint_span = selection_feature_processor_->StripBoundaryCodepoints( line_str, TokenSpanToCodepointSpan(*tokens, chunk)); // Skip empty spans. if (codepoint_span.first != codepoint_span.second) { std::vector<ClassificationResult> classification; if (!ModelClassifyText(line_str, *tokens, codepoint_span, interpreter_manager, &embedding_cache, &classification)) { TC_LOG(ERROR) << "Could not classify text: " << (codepoint_span.first + offset) << " " << (codepoint_span.second + offset); return false; } // Do not include the span if it's classified as "other". if (!classification.empty() && !ClassifiedAsOther(classification) && classification[0].score >= min_annotate_confidence) { AnnotatedSpan result_span; result_span.span = {codepoint_span.first + offset, codepoint_span.second + offset}; result_span.classification = std::move(classification); result->push_back(std::move(result_span)); } } } } return true; } const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests() const { return selection_feature_processor_.get(); } const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests() const { return classification_feature_processor_.get(); } const DatetimeParser* TextClassifier::DatetimeParserForTests() const { return datetime_parser_.get(); } std::vector<AnnotatedSpan> TextClassifier::Annotate( const std::string& context, const AnnotationOptions& options) const { std::vector<AnnotatedSpan> candidates; if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { return {}; } if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { return {}; } InterpreterManager interpreter_manager(selection_executor_.get(), classification_executor_.get()); // Annotate with the selection model. std::vector<Token> tokens; if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) { TC_LOG(ERROR) << "Couldn't run ModelAnnotate."; return {}; } // Annotate with the regular expression models. if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), annotation_regex_patterns_, &candidates)) { TC_LOG(ERROR) << "Couldn't run RegexChunk."; return {}; } // Annotate with the datetime model. if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), options.reference_time_ms_utc, options.reference_timezone, options.locales, ModeFlag_ANNOTATION, &candidates)) { TC_LOG(ERROR) << "Couldn't run RegexChunk."; return {}; } // Sort candidates according to their position in the input, so that the next // code can assume that any connected component of overlapping spans forms a // contiguous block. std::sort(candidates.begin(), candidates.end(), [](const AnnotatedSpan& a, const AnnotatedSpan& b) { return a.span.first < b.span.first; }); std::vector<int> candidate_indices; if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, &candidate_indices)) { TC_LOG(ERROR) << "Couldn't resolve conflicts."; return {}; } std::vector<AnnotatedSpan> result; result.reserve(candidate_indices.size()); for (const int i : candidate_indices) { if (!candidates[i].classification.empty() && !ClassifiedAsOther(candidates[i].classification) && !FilteredForAnnotation(candidates[i])) { result.push_back(std::move(candidates[i])); } } return result; } bool TextClassifier::RegexChunk(const UnicodeText& context_unicode, const std::vector<int>& rules, std::vector<AnnotatedSpan>* result) const { for (int pattern_id : rules) { const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; const auto matcher = regex_pattern.pattern->Matcher(context_unicode); if (!matcher) { TC_LOG(ERROR) << "Could not get regex matcher for pattern: " << pattern_id; return false; } int status = UniLib::RegexMatcher::kNoError; while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { result->emplace_back(); // Selection/annotation regular expressions need to specify a capturing // group specifying the selection. result->back().span = {matcher->Start(1, &status), matcher->End(1, &status)}; result->back().classification = { {regex_pattern.collection_name, regex_pattern.target_classification_score, regex_pattern.priority_score}}; } } return true; } bool TextClassifier::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, tflite::Interpreter* selection_interpreter, const CachedFeatures& cached_features, std::vector<TokenSpan>* chunks) const { const int max_selection_span = selection_feature_processor_->GetOptions()->max_selection_span(); // The inference span is the span of interest expanded to include // max_selection_span tokens on either side, which is how far a selection can // stretch from the click. const TokenSpan inference_span = IntersectTokenSpans( ExpandTokenSpan(span_of_interest, /*num_tokens_left=*/max_selection_span, /*num_tokens_right=*/max_selection_span), {0, num_tokens}); std::vector<ScoredChunk> scored_chunks; if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() && selection_feature_processor_->GetOptions() ->bounds_sensitive_features() ->enabled()) { if (!ModelBoundsSensitiveScoreChunks( num_tokens, span_of_interest, inference_span, cached_features, selection_interpreter, &scored_chunks)) { return false; } } else { if (!ModelClickContextScoreChunks(num_tokens, span_of_interest, cached_features, selection_interpreter, &scored_chunks)) { return false; } } std::sort(scored_chunks.rbegin(), scored_chunks.rend(), [](const ScoredChunk& lhs, const ScoredChunk& rhs) { return lhs.score < rhs.score; }); // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick // them greedily as long as they do not overlap with any previously picked // chunks. std::vector<bool> token_used(TokenSpanSize(inference_span)); chunks->clear(); for (const ScoredChunk& scored_chunk : scored_chunks) { bool feasible = true; for (int i = scored_chunk.token_span.first; i < scored_chunk.token_span.second; ++i) { if (token_used[i - inference_span.first]) { feasible = false; break; } } if (!feasible) { continue; } for (int i = scored_chunk.token_span.first; i < scored_chunk.token_span.second; ++i) { token_used[i - inference_span.first] = true; } chunks->push_back(scored_chunk.token_span); } std::sort(chunks->begin(), chunks->end()); return true; } namespace { // Updates the value at the given key in the map to maximum of the current value // and the given value, or simply inserts the value if the key is not yet there. template <typename Map> void UpdateMax(Map* map, typename Map::key_type key, typename Map::mapped_type value) { const auto it = map->find(key); if (it != map->end()) { it->second = std::max(it->second, value); } else { (*map)[key] = value; } } } // namespace bool TextClassifier::ModelClickContextScoreChunks( int num_tokens, const TokenSpan& span_of_interest, const CachedFeatures& cached_features, tflite::Interpreter* selection_interpreter, std::vector<ScoredChunk>* scored_chunks) const { const int max_batch_size = model_->selection_options()->batch_size(); std::vector<float> all_features; std::map<TokenSpan, float> chunk_scores; for (int batch_start = span_of_interest.first; batch_start < span_of_interest.second; batch_start += max_batch_size) { const int batch_end = std::min(batch_start + max_batch_size, span_of_interest.second); // Prepare features for the whole batch. all_features.clear(); all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { cached_features.AppendClickContextFeaturesForClick(click_pos, &all_features); } // Run batched inference. const int batch_size = batch_end - batch_start; const int features_size = cached_features.OutputFeaturesSize(); TensorView<float> logits = selection_executor_->ComputeLogits( TensorView<float>(all_features.data(), {batch_size, features_size}), selection_interpreter); if (!logits.is_valid()) { TC_LOG(ERROR) << "Couldn't compute logits."; return false; } if (logits.dims() != 2 || logits.dim(0) != batch_size || logits.dim(1) != selection_feature_processor_->GetSelectionLabelCount()) { TC_LOG(ERROR) << "Mismatching output."; return false; } // Save results. for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { const std::vector<float> scores = ComputeSoftmax( logits.data() + logits.dim(1) * (click_pos - batch_start), logits.dim(1)); for (int j = 0; j < selection_feature_processor_->GetSelectionLabelCount(); ++j) { TokenSpan relative_token_span; if (!selection_feature_processor_->LabelToTokenSpan( j, &relative_token_span)) { TC_LOG(ERROR) << "Couldn't map the label to a token span."; return false; } const TokenSpan candidate_span = ExpandTokenSpan( SingleTokenSpan(click_pos), relative_token_span.first, relative_token_span.second); if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) { UpdateMax(&chunk_scores, candidate_span, scores[j]); } } } } scored_chunks->clear(); scored_chunks->reserve(chunk_scores.size()); for (const auto& entry : chunk_scores) { scored_chunks->push_back(ScoredChunk{entry.first, entry.second}); } return true; } bool TextClassifier::ModelBoundsSensitiveScoreChunks( int num_tokens, const TokenSpan& span_of_interest, const TokenSpan& inference_span, const CachedFeatures& cached_features, tflite::Interpreter* selection_interpreter, std::vector<ScoredChunk>* scored_chunks) const { const int max_selection_span = selection_feature_processor_->GetOptions()->max_selection_span(); const int max_chunk_length = selection_feature_processor_->GetOptions() ->selection_reduced_output_space() ? max_selection_span + 1 : 2 * max_selection_span + 1; const bool score_single_token_spans_as_zero = selection_feature_processor_->GetOptions() ->bounds_sensitive_features() ->score_single_token_spans_as_zero(); scored_chunks->clear(); if (score_single_token_spans_as_zero) { scored_chunks->reserve(TokenSpanSize(span_of_interest)); } // Prepare all chunk candidates into one batch: // - Are contained in the inference span // - Have a non-empty intersection with the span of interest // - Are at least one token long // - Are not longer than the maximum chunk length std::vector<TokenSpan> candidate_spans; for (int start = inference_span.first; start < span_of_interest.second; ++start) { const int leftmost_end_index = std::max(start, span_of_interest.first) + 1; for (int end = leftmost_end_index; end <= inference_span.second && end - start <= max_chunk_length; ++end) { const TokenSpan candidate_span = {start, end}; if (score_single_token_spans_as_zero && TokenSpanSize(candidate_span) == 1) { // Do not include the single token span in the batch, add a zero score // for it directly to the output. scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f}); } else { candidate_spans.push_back(candidate_span); } } } const int max_batch_size = model_->selection_options()->batch_size(); std::vector<float> all_features; scored_chunks->reserve(scored_chunks->size() + candidate_spans.size()); for (int batch_start = 0; batch_start < candidate_spans.size(); batch_start += max_batch_size) { const int batch_end = std::min(batch_start + max_batch_size, static_cast<int>(candidate_spans.size())); // Prepare features for the whole batch. all_features.clear(); all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); for (int i = batch_start; i < batch_end; ++i) { cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i], &all_features); } // Run batched inference. const int batch_size = batch_end - batch_start; const int features_size = cached_features.OutputFeaturesSize(); TensorView<float> logits = selection_executor_->ComputeLogits( TensorView<float>(all_features.data(), {batch_size, features_size}), selection_interpreter); if (!logits.is_valid()) { TC_LOG(ERROR) << "Couldn't compute logits."; return false; } if (logits.dims() != 2 || logits.dim(0) != batch_size || logits.dim(1) != 1) { TC_LOG(ERROR) << "Mismatching output."; return false; } // Save results. for (int i = batch_start; i < batch_end; ++i) { scored_chunks->push_back( ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]}); } } return true; } bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode, int64 reference_time_ms_utc, const std::string& reference_timezone, const std::string& locales, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { if (!datetime_parser_) { return true; } std::vector<DatetimeParseResultSpan> datetime_spans; if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, reference_timezone, locales, mode, /*anchor_start_end=*/false, &datetime_spans)) { return false; } for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { AnnotatedSpan annotated_span; annotated_span.span = datetime_span.span; annotated_span.classification = {{kDateCollection, datetime_span.target_classification_score, datetime_span.priority_score}}; annotated_span.classification[0].datetime_parse_result = datetime_span.data; result->push_back(std::move(annotated_span)); } return true; } const Model* ViewModel(const void* buffer, int size) { if (!buffer) { return nullptr; } return LoadAndVerifyModel(buffer, size); } } // namespace libtextclassifier2