/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/intents/intent-generator.h" #include <vector> #include "actions/lua-utils.h" #include "actions/types.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/hash/farmhash.h" #include "utils/java/jni-base.h" #include "utils/java/string_utils.h" #include "utils/lua-utils.h" #include "utils/strings/stringpiece.h" #include "utils/strings/substitute.h" #include "utils/utf8/unicodetext.h" #include "utils/variant.h" #include "utils/zlib/zlib.h" #include "flatbuffers/reflection_generated.h" #ifdef __cplusplus extern "C" { #endif #include "lauxlib.h" #include "lua.h" #ifdef __cplusplus } #endif namespace libtextclassifier3 { namespace { static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc"; static constexpr const char* kHashKey = "hash"; static constexpr const char* kUrlSchemaKey = "url_schema"; static constexpr const char* kUrlHostKey = "url_host"; static constexpr const char* kUrlEncodeKey = "urlencode"; static constexpr const char* kPackageNameKey = "package_name"; static constexpr const char* kDeviceLocaleKey = "device_locales"; static constexpr const char* kFormatKey = "format"; // An Android specific Lua environment with JNI backed callbacks. class JniLuaEnvironment : public LuaEnvironment { public: JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache, const jobject context, const std::vector<Locale>& device_locales); // Environment setup. bool Initialize(); // Runs an intent generator snippet. bool RunIntentGenerator(const std::string& generator_snippet, std::vector<RemoteActionTemplate>* remote_actions); protected: virtual void SetupExternalHook(); int HandleExternalCallback(); int HandleAndroidCallback(); int HandleUserRestrictionsCallback(); int HandleUrlEncode(); int HandleUrlSchema(); int HandleHash(); int HandleFormat(); int HandleAndroidStringResources(); int HandleUrlHost(); // Checks and retrieves string resources from the model. bool LookupModelStringResource(); // Reads and create a RemoteAction result from Lua. RemoteActionTemplate ReadRemoteActionTemplateResult(); // Reads the extras from the Lua result. void ReadExtras(std::map<std::string, Variant>* extra); // Reads the intent categories array from a Lua result. void ReadCategories(std::vector<std::string>* category); // Retrieves user manager if not previously done. bool RetrieveUserManager(); // Retrieves system resources if not previously done. bool RetrieveSystemResources(); // Parse the url string by using Uri.parse from Java. ScopedLocalRef<jobject> ParseUri(StringPiece url) const; // Read remote action templates from lua generator. int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result); const Resources& resources_; JNIEnv* jenv_; const JniCache* jni_cache_; const jobject context_; std::vector<Locale> device_locales_; ScopedGlobalRef<jobject> usermanager_; // Whether we previously attempted to retrieve the UserManager before. bool usermanager_retrieved_; ScopedGlobalRef<jobject> system_resources_; // Whether we previously attempted to retrieve the system resources. bool system_resources_resources_retrieved_; // Cached JNI references for Java strings `string` and `android`. ScopedGlobalRef<jstring> string_; ScopedGlobalRef<jstring> android_; }; JniLuaEnvironment::JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache, const jobject context, const std::vector<Locale>& device_locales) : resources_(resources), jenv_(jni_cache ? jni_cache->GetEnv() : nullptr), jni_cache_(jni_cache), context_(context), device_locales_(device_locales), usermanager_(/*object=*/nullptr, /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)), usermanager_retrieved_(false), system_resources_(/*object=*/nullptr, /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)), system_resources_resources_retrieved_(false), string_(/*object=*/nullptr, /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)), android_(/*object=*/nullptr, /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {} bool JniLuaEnvironment::Initialize() { string_ = MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm); android_ = MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm); if (string_ == nullptr || android_ == nullptr) { TC3_LOG(ERROR) << "Could not allocate constant strings references."; return false; } return (RunProtected([this] { LoadDefaultLibraries(); SetupExternalHook(); lua_setglobal(state_, "external"); return LUA_OK; }) == LUA_OK); } void JniLuaEnvironment::SetupExternalHook() { // This exposes an `external` object with the following fields: // * entity: the bundle with all information about a classification. // * android: callbacks into specific android provided methods. // * android.user_restrictions: callbacks to check user permissions. // * android.R: callbacks to retrieve string resources. BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>( "external"); // android BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>( "android"); { // android.user_restrictions BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleUserRestrictionsCallback>( "user_restrictions"); lua_setfield(state_, /*idx=*/-2, "user_restrictions"); // android.R // Callback to access android string resources. BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidStringResources>("R"); lua_setfield(state_, /*idx=*/-2, "R"); } lua_setfield(state_, /*idx=*/-2, "android"); } int JniLuaEnvironment::HandleExternalCallback() { const StringPiece key = ReadString(/*index=*/-1); if (key.Equals(kHashKey)) { Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>(); return 1; } else if (key.Equals(kFormatKey)) { Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>(); return 1; } else { TC3_LOG(ERROR) << "Undefined external access " << key.ToString(); lua_error(state_); return 0; } } int JniLuaEnvironment::HandleAndroidCallback() { const StringPiece key = ReadString(/*index=*/-1); if (key.Equals(kDeviceLocaleKey)) { // Provide the locale as table with the individual fields set. lua_newtable(state_); for (int i = 0; i < device_locales_.size(); i++) { // Adjust index to 1-based indexing for Lua. lua_pushinteger(state_, i + 1); lua_newtable(state_); PushString(device_locales_[i].Language()); lua_setfield(state_, -2, "language"); PushString(device_locales_[i].Region()); lua_setfield(state_, -2, "region"); PushString(device_locales_[i].Script()); lua_setfield(state_, -2, "script"); lua_settable(state_, /*idx=*/-3); } return 1; } else if (key.Equals(kPackageNameKey)) { if (context_ == nullptr) { TC3_LOG(ERROR) << "Context invalid."; lua_error(state_); return 0; } ScopedLocalRef<jstring> package_name_str( static_cast<jstring>(jenv_->CallObjectMethod( context_, jni_cache_->context_get_package_name))); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling Context.getPackageName"; lua_error(state_); return 0; } PushString(ToStlString(jenv_, package_name_str.get())); return 1; } else if (key.Equals(kUrlEncodeKey)) { Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>(); return 1; } else if (key.Equals(kUrlHostKey)) { Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>(); return 1; } else if (key.Equals(kUrlSchemaKey)) { Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>(); return 1; } else { TC3_LOG(ERROR) << "Undefined android reference " << key.ToString(); lua_error(state_); return 0; } } int JniLuaEnvironment::HandleUserRestrictionsCallback() { if (jni_cache_->usermanager_class == nullptr || jni_cache_->usermanager_get_user_restrictions == nullptr) { // UserManager is only available for API level >= 17 and // getUserRestrictions only for API level >= 18, so we just return false // normally here. lua_pushboolean(state_, false); return 1; } // Get user manager if not previously retrieved. if (!RetrieveUserManager()) { TC3_LOG(ERROR) << "Error retrieving user manager."; lua_error(state_); return 0; } ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod( usermanager_.get(), jni_cache_->usermanager_get_user_restrictions)); if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) { TC3_LOG(ERROR) << "Error calling getUserRestrictions"; lua_error(state_); return 0; } const StringPiece key_str = ReadString(/*index=*/-1); if (key_str.empty()) { TC3_LOG(ERROR) << "Expected string, got null."; lua_error(state_); return 0; } ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str); if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) { TC3_LOG(ERROR) << "Expected string, got null."; lua_error(state_); return 0; } const bool permission = jenv_->CallBooleanMethod( bundle.get(), jni_cache_->bundle_get_boolean, key.get()); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error getting bundle value"; lua_pushboolean(state_, false); } else { lua_pushboolean(state_, permission); } return 1; } int JniLuaEnvironment::HandleUrlEncode() { const StringPiece input = ReadString(/*index=*/1); if (input.empty()) { TC3_LOG(ERROR) << "Expected string, got null."; lua_error(state_); return 0; } // Call Java URL encoder. ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input); if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) { TC3_LOG(ERROR) << "Expected string, got null."; lua_error(state_); return 0; } ScopedLocalRef<jstring> encoded_str( static_cast<jstring>(jenv_->CallStaticObjectMethod( jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode, input_str.get(), jni_cache_->string_utf8.get()))); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling UrlEncoder.encode"; lua_error(state_); return 0; } PushString(ToStlString(jenv_, encoded_str.get())); return 1; } ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const { if (url.empty()) { return nullptr; } // Call to Java URI parser. ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url); if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) { TC3_LOG(ERROR) << "Expected string, got null"; return nullptr; } // Try to parse uri and get scheme. ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod( jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get())); if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) { TC3_LOG(ERROR) << "Error calling Uri.parse"; } return uri; } int JniLuaEnvironment::HandleUrlSchema() { StringPiece url = ReadString(/*index=*/1); ScopedLocalRef<jobject> parsed_uri = ParseUri(url); if (parsed_uri == nullptr) { lua_error(state_); return 0; } ScopedLocalRef<jstring> scheme_str(static_cast<jstring>( jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme))); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling Uri.getScheme"; lua_error(state_); return 0; } if (scheme_str == nullptr) { lua_pushnil(state_); } else { PushString(ToStlString(jenv_, scheme_str.get())); } return 1; } int JniLuaEnvironment::HandleUrlHost() { StringPiece url = ReadString(/*index=*/-1); ScopedLocalRef<jobject> parsed_uri = ParseUri(url); if (parsed_uri == nullptr) { lua_error(state_); return 0; } ScopedLocalRef<jstring> host_str(static_cast<jstring>( jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host))); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling Uri.getHost"; lua_error(state_); return 0; } if (host_str == nullptr) { lua_pushnil(state_); } else { PushString(ToStlString(jenv_, host_str.get())); } return 1; } int JniLuaEnvironment::HandleHash() { const StringPiece input = ReadString(/*index=*/-1); lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length())); return 1; } int JniLuaEnvironment::HandleFormat() { const int num_args = lua_gettop(state_); std::vector<StringPiece> args(num_args - 1); for (int i = 0; i < num_args - 1; i++) { args[i] = ReadString(/*index=*/i + 2); } PushString(strings::Substitute(ReadString(/*index=*/1), args)); return 1; } bool JniLuaEnvironment::LookupModelStringResource() { // Handle only lookup by name. if (lua_type(state_, 2) != LUA_TSTRING) { return false; } const StringPiece resource_name = ReadString(/*index=*/-1); std::string resource_content; if (!resources_.GetResourceContent(device_locales_, resource_name, &resource_content)) { // Resource cannot be provided by the model. return false; } PushString(resource_content); return true; } int JniLuaEnvironment::HandleAndroidStringResources() { // Check whether the requested resource can be served from the model data. if (LookupModelStringResource()) { return 1; } // Get system resources if not previously retrieved. if (!RetrieveSystemResources()) { TC3_LOG(ERROR) << "Error retrieving system resources."; lua_error(state_); return 0; } int resource_id; switch (lua_type(state_, -1)) { case LUA_TNUMBER: resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1)); break; case LUA_TSTRING: { const StringPiece resource_name_str = ReadString(/*index=*/-1); if (resource_name_str.empty()) { TC3_LOG(ERROR) << "No resource name provided."; lua_error(state_); return 0; } ScopedLocalRef<jstring> resource_name = jni_cache_->ConvertToJavaString(resource_name_str); if (resource_name == nullptr) { TC3_LOG(ERROR) << "Invalid resource name."; lua_error(state_); return 0; } resource_id = jenv_->CallIntMethod( system_resources_.get(), jni_cache_->resources_get_identifier, resource_name.get(), string_.get(), android_.get()); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling getIdentifier."; lua_error(state_); return 0; } break; } default: TC3_LOG(ERROR) << "Unexpected type for resource lookup."; lua_error(state_); return 0; } if (resource_id == 0) { TC3_LOG(ERROR) << "Resource not found."; lua_pushnil(state_); return 1; } ScopedLocalRef<jstring> resource_str(static_cast<jstring>( jenv_->CallObjectMethod(system_resources_.get(), jni_cache_->resources_get_string, resource_id))); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling getString."; lua_error(state_); return 0; } if (resource_str == nullptr) { lua_pushnil(state_); } else { PushString(ToStlString(jenv_, resource_str.get())); } return 1; } bool JniLuaEnvironment::RetrieveSystemResources() { if (system_resources_resources_retrieved_) { return (system_resources_ != nullptr); } system_resources_resources_retrieved_ = true; jobject system_resources_ref = jenv_->CallStaticObjectMethod( jni_cache_->resources_class.get(), jni_cache_->resources_get_system); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling getSystem."; return false; } system_resources_ = MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm); return (system_resources_ != nullptr); } bool JniLuaEnvironment::RetrieveUserManager() { if (context_ == nullptr) { return false; } if (usermanager_retrieved_) { return (usermanager_ != nullptr); } usermanager_retrieved_ = true; ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user")); jobject usermanager_ref = jenv_->CallObjectMethod( context_, jni_cache_->context_get_system_service, service.get()); if (jni_cache_->ExceptionCheckAndClear()) { TC3_LOG(ERROR) << "Error calling getSystemService."; return false; } usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm); return (usermanager_ != nullptr); } RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() { RemoteActionTemplate result; // Read intent template. lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { const StringPiece key = ReadString(/*index=*/-2); if (key.Equals("title_without_entity")) { result.title_without_entity = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("title_with_entity")) { result.title_with_entity = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("description")) { result.description = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("description_with_app_name")) { result.description_with_app_name = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("action")) { result.action = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("data")) { result.data = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("type")) { result.type = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("flags")) { result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1)); } else if (key.Equals("package_name")) { result.package_name = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("request_code")) { result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1)); } else if (key.Equals("category")) { ReadCategories(&result.category); } else if (key.Equals("extra")) { ReadExtras(&result.extra); } else { TC3_LOG(INFO) << "Unknown entry: " << key.ToString(); } lua_pop(state_, 1); } lua_pop(state_, 1); return result; } void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) { // Read category array. if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected categories table, got: " << lua_type(state_, /*idx=*/-1); lua_pop(state_, 1); return; } lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { category->push_back(ReadString(/*index=*/-1).ToString()); lua_pop(state_, 1); } } void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) { if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected extras table, got: " << lua_type(state_, /*idx=*/-1); lua_pop(state_, 1); return; } lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { // Each entry is a table specifying name and value. // The value is specified via a type specific field as Lua doesn't allow // to easily distinguish between different number types. if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected a table for an extra, got: " << lua_type(state_, /*idx=*/-1); lua_pop(state_, 1); return; } std::string name; Variant value; lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { const StringPiece key = ReadString(/*index=*/-2); if (key.Equals("name")) { name = ReadString(/*index=*/-1).ToString(); } else if (key.Equals("int_value")) { value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1))); } else if (key.Equals("long_value")) { value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1))); } else if (key.Equals("float_value")) { value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1))); } else if (key.Equals("bool_value")) { value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1))); } else if (key.Equals("string_value")) { value = Variant(ReadString(/*index=*/-1).ToString()); } else { TC3_LOG(INFO) << "Unknown extra field: " << key.ToString(); } lua_pop(state_, 1); } if (!name.empty()) { (*extra)[name] = value; } else { TC3_LOG(ERROR) << "Unnamed extra entry. Skipping."; } lua_pop(state_, 1); } } int JniLuaEnvironment::ReadRemoteActionTemplates( std::vector<RemoteActionTemplate>* result) { // Read result. if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1); lua_error(state_); return LUA_ERRRUN; } // Read remote action templates array. lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected intent table, got: " << lua_type(state_, /*idx=*/-1); lua_pop(state_, 1); continue; } result->push_back(ReadRemoteActionTemplateResult()); } lua_pop(state_, /*n=*/1); return LUA_OK; } bool JniLuaEnvironment::RunIntentGenerator( const std::string& generator_snippet, std::vector<RemoteActionTemplate>* remote_actions) { int status; status = luaL_loadbuffer(state_, generator_snippet.data(), generator_snippet.size(), /*name=*/nullptr); if (status != LUA_OK) { TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status; return false; } status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0); if (status != LUA_OK) { TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status; return false; } if (RunProtected( [this, remote_actions] { return ReadRemoteActionTemplates(remote_actions); }, /*num_args=*/1) != LUA_OK) { TC3_LOG(ERROR) << "Could not read results."; return false; } // Check that we correctly cleaned-up the state. const int stack_size = lua_gettop(state_); if (stack_size > 0) { TC3_LOG(ERROR) << "Unexpected stack size."; lua_settop(state_, 0); return false; } return true; } // Lua environment for classfication result intent generation. class AnnotatorJniEnvironment : public JniLuaEnvironment { public: AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache, const jobject context, const std::vector<Locale>& device_locales, const std::string& entity_text, const ClassificationResult& classification, const int64 reference_time_ms_utc, const reflection::Schema* entity_data_schema) : JniLuaEnvironment(resources, jni_cache, context, device_locales), entity_text_(entity_text), classification_(classification), reference_time_ms_utc_(reference_time_ms_utc), entity_data_schema_(entity_data_schema) {} protected: void SetupExternalHook() override { JniLuaEnvironment::SetupExternalHook(); lua_pushinteger(state_, reference_time_ms_utc_); lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey); PushAnnotation(classification_, entity_text_, entity_data_schema_, this); lua_setfield(state_, /*idx=*/-2, "entity"); } const std::string& entity_text_; const ClassificationResult& classification_; const int64 reference_time_ms_utc_; // Reflection schema data. const reflection::Schema* const entity_data_schema_; }; // Lua environment for actions intent generation. class ActionsJniLuaEnvironment : public JniLuaEnvironment { public: ActionsJniLuaEnvironment( const Resources& resources, const JniCache* jni_cache, const jobject context, const std::vector<Locale>& device_locales, const Conversation& conversation, const ActionSuggestion& action, const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema) : JniLuaEnvironment(resources, jni_cache, context, device_locales), conversation_(conversation), action_(action), annotation_iterator_(annotations_entity_data_schema, this), conversation_iterator_(annotations_entity_data_schema, this), entity_data_schema_(actions_entity_data_schema) {} protected: void SetupExternalHook() override { JniLuaEnvironment::SetupExternalHook(); conversation_iterator_.NewIterator("conversation", &conversation_.messages, state_); lua_setfield(state_, /*idx=*/-2, "conversation"); PushAction(action_, entity_data_schema_, annotation_iterator_, this); lua_setfield(state_, /*idx=*/-2, "entity"); } const Conversation& conversation_; const ActionSuggestion& action_; const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_; const ConversationIterator conversation_iterator_; const reflection::Schema* entity_data_schema_; }; } // namespace std::unique_ptr<IntentGenerator> IntentGenerator::Create( const IntentFactoryModel* options, const ResourcePool* resources, const std::shared_ptr<JniCache>& jni_cache) { std::unique_ptr<IntentGenerator> intent_generator( new IntentGenerator(options, resources, jni_cache)); if (options == nullptr || options->generator() == nullptr) { TC3_LOG(ERROR) << "No intent generator options."; return nullptr; } std::unique_ptr<ZlibDecompressor> zlib_decompressor = ZlibDecompressor::Instance(); if (!zlib_decompressor) { TC3_LOG(ERROR) << "Cannot initialize decompressor."; return nullptr; } for (const IntentFactoryModel_::IntentGenerator* generator : *options->generator()) { std::string lua_template_generator; if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer( generator->lua_template_generator(), generator->compressed_lua_template_generator(), &lua_template_generator)) { TC3_LOG(ERROR) << "Could not decompress generator template."; return nullptr; } std::string lua_code = lua_template_generator; if (options->precompile_generators()) { if (!Compile(lua_template_generator, &lua_code)) { TC3_LOG(ERROR) << "Could not precompile generator template."; return nullptr; } } intent_generator->generators_[generator->type()->str()] = lua_code; } return intent_generator; } std::vector<Locale> IntentGenerator::ParseDeviceLocales( const jstring device_locales) const { if (device_locales == nullptr) { TC3_LOG(ERROR) << "No locales provided."; return {}; } ScopedStringChars locales_str = GetScopedStringChars(jni_cache_->GetEnv(), device_locales); if (locales_str == nullptr) { TC3_LOG(ERROR) << "Cannot retrieve provided locales."; return {}; } std::vector<Locale> locales; if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()), &locales)) { TC3_LOG(ERROR) << "Cannot parse locales."; return {}; } return locales; } bool IntentGenerator::GenerateIntents( const jstring device_locales, const ClassificationResult& classification, const int64 reference_time_ms_utc, const std::string& text, const CodepointSpan selection_indices, const jobject context, const reflection::Schema* annotations_entity_data_schema, std::vector<RemoteActionTemplate>* remote_actions) const { if (options_ == nullptr) { return false; } // Retrieve generator for specified entity. auto it = generators_.find(classification.collection); if (it == generators_.end()) { return true; } const std::string entity_text = UTF8ToUnicodeText(text, /*do_copy=*/false) .UTF8Substring(selection_indices.first, selection_indices.second); std::unique_ptr<AnnotatorJniEnvironment> interpreter( new AnnotatorJniEnvironment( resources_, jni_cache_.get(), context, ParseDeviceLocales(device_locales), entity_text, classification, reference_time_ms_utc, annotations_entity_data_schema)); if (!interpreter->Initialize()) { TC3_LOG(ERROR) << "Could not create Lua interpreter."; return false; } return interpreter->RunIntentGenerator(it->second, remote_actions); } bool IntentGenerator::GenerateIntents( const jstring device_locales, const ActionSuggestion& action, const Conversation& conversation, const jobject context, const reflection::Schema* annotations_entity_data_schema, const reflection::Schema* actions_entity_data_schema, std::vector<RemoteActionTemplate>* remote_actions) const { if (options_ == nullptr) { return false; } // Retrieve generator for specified action. auto it = generators_.find(action.type); if (it == generators_.end()) { return true; } std::unique_ptr<ActionsJniLuaEnvironment> interpreter( new ActionsJniLuaEnvironment( resources_, jni_cache_.get(), context, ParseDeviceLocales(device_locales), conversation, action, actions_entity_data_schema, annotations_entity_data_schema)); if (!interpreter->Initialize()) { TC3_LOG(ERROR) << "Could not create Lua interpreter."; return false; } return interpreter->RunIntentGenerator(it->second, remote_actions); } } // namespace libtextclassifier3