/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "actions/actions_model_generated.h"
#include "actions/feature-processor.h"
#include "actions/ngram-model.h"
#include "actions/ranker.h"
#include "actions/types.h"
#include "annotator/annotator.h"
#include "annotator/model-executor.h"
#include "annotator/types.h"
#include "utils/flatbuffers.h"
#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
// Options for suggesting actions.
struct ActionSuggestionOptions {
static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
};
// Class for predicting actions following a conversation.
class ActionsSuggestions {
public:
// Creates ActionsSuggestions from given data buffer with model.
static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
const std::string& triggering_preconditions_overlay = "");
// Creates ActionsSuggestions from model in the ScopedMmap object and takes
// ownership of it.
static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
const UniLib* unilib = nullptr,
const std::string& triggering_preconditions_overlay = "");
// Same as above, but also takes ownership of the unilib.
static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
std::unique_ptr<UniLib> unilib,
const std::string& triggering_preconditions_overlay);
// Creates ActionsSuggestions from model given as a file descriptor, offset
// and size in it. If offset and size are less than 0, will ignore them and
// will just use the fd.
static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
const int fd, const int offset, const int size,
const UniLib* unilib = nullptr,
const std::string& triggering_preconditions_overlay = "");
// Same as above, but also takes ownership of the unilib.
static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
const int fd, const int offset, const int size,
std::unique_ptr<UniLib> unilib,
const std::string& triggering_preconditions_overlay = "");
// Creates ActionsSuggestions from model given as a file descriptor.
static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
const int fd, const UniLib* unilib = nullptr,
const std::string& triggering_preconditions_overlay = "");
// Same as above, but also takes ownership of the unilib.
static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
const int fd, std::unique_ptr<UniLib> unilib,
const std::string& triggering_preconditions_overlay);
// Creates ActionsSuggestions from model given as a POSIX path.
static std::unique_ptr<ActionsSuggestions> FromPath(
const std::string& path, const UniLib* unilib = nullptr,
const std::string& triggering_preconditions_overlay = "");
// Same as above, but also takes ownership of unilib.
static std::unique_ptr<ActionsSuggestions> FromPath(
const std::string& path, std::unique_ptr<UniLib> unilib,
const std::string& triggering_preconditions_overlay);
ActionsSuggestionsResponse SuggestActions(
const Conversation& conversation,
const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
ActionsSuggestionsResponse SuggestActions(
const Conversation& conversation, const Annotator* annotator,
const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
const ActionsModel* model() const;
const reflection::Schema* entity_data_schema() const;
static const int kLocalUserId = 0;
// Should be in sync with those defined in Android.
// android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
static const std::string& kViewCalendarType;
static const std::string& kViewMapType;
static const std::string& kTrackFlightType;
static const std::string& kOpenUrlType;
static const std::string& kSendSmsType;
static const std::string& kCallPhoneType;
static const std::string& kSendEmailType;
static const std::string& kShareLocation;
protected:
// Exposed for testing.
bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
// Embeds the tokens per message separately. Each message is padded to the
// maximum length with the padding token.
bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
std::vector<float>* embeddings,
int* max_num_tokens_per_message) const;
// Concatenates the embedded message tokens - separated by start and end
// token between messages.
// If the total token count is greater than the maximum length, tokens at the
// start are dropped to fit into the limit.
// If the total token count is smaller than the minimum length, padding tokens
// are added to the end.
// Messages are assumed to be ordered by recency - most recent is last.
bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>> tokens,
std::vector<float>* embeddings,
int* total_token_count) const;
const ActionsModel* model_;
// Feature extractor and options.
std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
std::vector<float> embedded_padding_token_;
std::vector<float> embedded_start_token_;
std::vector<float> embedded_end_token_;
int token_embedding_size_;
private:
struct CompiledRule {
const RulesModel_::Rule* rule;
std::unique_ptr<UniLib::RegexPattern> pattern;
std::unique_ptr<UniLib::RegexPattern> output_pattern;
CompiledRule(const RulesModel_::Rule* rule,
std::unique_ptr<UniLib::RegexPattern> pattern,
std::unique_ptr<UniLib::RegexPattern> output_pattern)
: rule(rule),
pattern(std::move(pattern)),
output_pattern(std::move(output_pattern)) {}
};
// Checks that model contains all required fields, and initializes internal
// datastructures.
bool ValidateAndInitialize();
void SetOrCreateUnilib(const UniLib* unilib);
// Initializes regular expression rules.
bool InitializeRules(ZlibDecompressor* decompressor);
bool InitializeRules(ZlibDecompressor* decompressor, const RulesModel* rules,
std::vector<CompiledRule>* compiled_rules) const;
// Prepare preconditions.
// Takes values from flag provided data, but falls back to model provided
// values for parameters that are not explicitly provided.
bool InitializeTriggeringPreconditions();
// Tokenizes a conversation and produces the tokens per message.
std::vector<std::vector<Token>> Tokenize(
const std::vector<std::string>& context) const;
bool AllocateInput(const int conversation_length, const int max_tokens,
const int total_token_count,
tflite::Interpreter* interpreter) const;
bool SetupModelInput(const std::vector<std::string>& context,
const std::vector<int>& user_ids,
const std::vector<float>& time_diffs,
const int num_suggestions,
const float confidence_threshold,
const float diversification_distance,
const float empirical_probability_factor,
tflite::Interpreter* interpreter) const;
bool ReadModelOutput(tflite::Interpreter* interpreter,
const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response) const;
bool SuggestActionsFromModel(
const Conversation& conversation, const int num_messages,
const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response,
std::unique_ptr<tflite::Interpreter>* interpreter) const;
// Creates options for annotation of a message.
AnnotationOptions AnnotationOptionsForMessage(
const ConversationMessage& message) const;
void SuggestActionsFromAnnotations(
const Conversation& conversation, const ActionSuggestionOptions& options,
const Annotator* annotator, std::vector<ActionSuggestion>* actions) const;
void SuggestActionsFromAnnotation(
const int message_index, const ActionSuggestionAnnotation& annotation,
std::vector<ActionSuggestion>* actions) const;
// Deduplicates equivalent annotations - annotations that have the same type
// and same span text.
// Returns the indices of the deduplicated annotations.
std::vector<int> DeduplicateAnnotations(
const std::vector<ActionSuggestionAnnotation>& annotations) const;
bool SuggestActionsFromRules(const Conversation& conversation,
std::vector<ActionSuggestion>* actions) const;
bool SuggestActionsFromLua(
const Conversation& conversation,
const TfLiteModelExecutor* model_executor,
const tflite::Interpreter* interpreter,
const reflection::Schema* annotation_entity_data_schema,
std::vector<ActionSuggestion>* actions) const;
bool GatherActionsSuggestions(const Conversation& conversation,
const Annotator* annotator,
const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response) const;
// Checks whether the input triggers the low confidence checks.
bool IsLowConfidenceInput(const Conversation& conversation,
const int num_messages,
std::vector<int>* post_check_rules) const;
// Checks and filters suggestions triggering the low confidence post checks.
bool FilterConfidenceOutput(const std::vector<int>& post_check_rules,
std::vector<ActionSuggestion>* actions) const;
ActionSuggestion SuggestionFromSpec(
const ActionSuggestionSpec* action, const std::string& default_type = "",
const std::string& default_response_text = "",
const std::string& default_serialized_entity_data = "",
const float default_score = 0.0f,
const float default_priority_score = 0.0f) const;
bool FillAnnotationFromMatchGroup(
const UniLib::RegexMatcher* matcher,
const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
const int message_index, ActionSuggestionAnnotation* annotation) const;
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
// Tensorflow Lite models.
std::unique_ptr<const TfLiteModelExecutor> model_executor_;
// Rules.
std::vector<CompiledRule> rules_, low_confidence_rules_;
std::unique_ptr<UniLib> owned_unilib_;
const UniLib* unilib_;
// Locales supported by the model.
std::vector<Locale> locales_;
// Annotation entities used by the model.
std::unordered_set<std::string> annotation_entity_types_;
// Builder for creating extra data.
const reflection::Schema* entity_data_schema_;
std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
std::unique_ptr<ActionsSuggestionsRanker> ranker_;
std::string lua_bytecode_;
// Triggering preconditions. These parameters can be backed by the model and
// (partially) be provided by flags.
TriggeringPreconditionsT preconditions_;
std::string triggering_preconditions_overlay_buffer_;
const TriggeringPreconditions* triggering_preconditions_overlay_;
// Low confidence input ngram classifier.
std::unique_ptr<const NGramModel> ngram_model_;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
const ActionsModel* ViewActionsModel(const void* buffer, int size);
// Opens model from given path and runs a function, passing the loaded Model
// flatbuffer as an argument.
//
// This is mainly useful if we don't want to pay the cost for the model
// initialization because we'll be only reading some flatbuffer values from the
// file.
template <typename ReturnType, typename Func>
ReturnType VisitActionsModel(const std::string& path, Func function) {
ScopedMmap mmap(path);
if (!mmap.handle().ok()) {
function(/*model=*/nullptr);
}
const ActionsModel* model =
ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
return function(model);
}
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_