/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // JNI wrapper for actions. #include "actions/actions_jni.h" #include <jni.h> #include <map> #include <type_traits> #include <vector> #include "actions/actions-suggestions.h" #include "annotator/annotator.h" #include "annotator/annotator_jni_common.h" #include "utils/base/integral_types.h" #include "utils/intents/intent-generator.h" #include "utils/intents/jni.h" #include "utils/java/jni-cache.h" #include "utils/java/scoped_local_ref.h" #include "utils/java/string_utils.h" #include "utils/memory/mmap.h" using libtextclassifier3::ActionsSuggestions; using libtextclassifier3::ActionsSuggestionsResponse; using libtextclassifier3::ActionSuggestion; using libtextclassifier3::ActionSuggestionOptions; using libtextclassifier3::Annotator; using libtextclassifier3::Conversation; using libtextclassifier3::IntentGenerator; using libtextclassifier3::ScopedLocalRef; using libtextclassifier3::ToStlString; // When using the Java's ICU, UniLib needs to be instantiated with a JavaVM // pointer from JNI. When using a standard ICU the pointer is not needed and the // objects are instantiated implicitly. #ifdef TC3_UNILIB_JAVAICU using libtextclassifier3::UniLib; #endif namespace libtextclassifier3 { namespace { // Cached state for model inference. // Keeps a jni cache, intent generator and model instance so that they don't // have to be recreated for each call. class ActionsSuggestionsJniContext { public: static ActionsSuggestionsJniContext* Create( const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache, std::unique_ptr<ActionsSuggestions> model) { if (jni_cache == nullptr || model == nullptr) { return nullptr; } std::unique_ptr<IntentGenerator> intent_generator = IntentGenerator::Create(model->model()->android_intent_options(), model->model()->resources(), jni_cache); std::unique_ptr<RemoteActionTemplatesHandler> template_handler = libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache); if (intent_generator == nullptr || template_handler == nullptr) { return nullptr; } return new ActionsSuggestionsJniContext(jni_cache, std::move(model), std::move(intent_generator), std::move(template_handler)); } std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const { return jni_cache_; } ActionsSuggestions* model() const { return model_.get(); } IntentGenerator* intent_generator() const { return intent_generator_.get(); } RemoteActionTemplatesHandler* template_handler() const { return template_handler_.get(); } private: ActionsSuggestionsJniContext( const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache, std::unique_ptr<ActionsSuggestions> model, std::unique_ptr<IntentGenerator> intent_generator, std::unique_ptr<RemoteActionTemplatesHandler> template_handler) : jni_cache_(jni_cache), model_(std::move(model)), intent_generator_(std::move(intent_generator)), template_handler_(std::move(template_handler)) {} std::shared_ptr<libtextclassifier3::JniCache> jni_cache_; std::unique_ptr<ActionsSuggestions> model_; std::unique_ptr<IntentGenerator> intent_generator_; std::unique_ptr<RemoteActionTemplatesHandler> template_handler_; }; ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env, jobject joptions) { ActionSuggestionOptions options = ActionSuggestionOptions::Default(); return options; } jobjectArray ActionSuggestionsToJObjectArray( JNIEnv* env, const ActionsSuggestionsJniContext* context, jobject app_context, const reflection::Schema* annotations_entity_data_schema, const std::vector<ActionSuggestion>& action_result, const Conversation& conversation, const jstring device_locales, const bool generate_intents) { const ScopedLocalRef<jclass> result_class( env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestion"), env); if (!result_class) { TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class."; return nullptr; } const jmethodID result_class_constructor = env->GetMethodID( result_class.get(), "<init>", "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V"); const jobjectArray results = env->NewObjectArray(action_result.size(), result_class.get(), nullptr); for (int i = 0; i < action_result.size(); i++) { jobject extras = nullptr; const reflection::Schema* actions_entity_data_schema = context->model()->entity_data_schema(); if (actions_entity_data_schema != nullptr && !action_result[i].serialized_entity_data.empty()) { extras = context->template_handler()->EntityDataAsNamedVariantArray( actions_entity_data_schema, action_result[i].serialized_entity_data); } jbyteArray serialized_entity_data = nullptr; if (!action_result[i].serialized_entity_data.empty()) { serialized_entity_data = env->NewByteArray(action_result[i].serialized_entity_data.size()); env->SetByteArrayRegion( serialized_entity_data, 0, action_result[i].serialized_entity_data.size(), reinterpret_cast<const jbyte*>( action_result[i].serialized_entity_data.data())); } jobject remote_action_templates_result = nullptr; if (generate_intents) { std::vector<RemoteActionTemplate> remote_action_templates; if (context->intent_generator()->GenerateIntents( device_locales, action_result[i], conversation, app_context, actions_entity_data_schema, annotations_entity_data_schema, &remote_action_templates)) { remote_action_templates_result = context->template_handler()->RemoteActionTemplatesToJObjectArray( remote_action_templates); } } ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString( action_result[i].response_text); ScopedLocalRef<jobject> result(env->NewObject( result_class.get(), result_class_constructor, reply.get(), env->NewStringUTF(action_result[i].type.c_str()), static_cast<jfloat>(action_result[i].score), extras, serialized_entity_data, remote_action_templates_result)); env->SetObjectArrayElement(results, i, result.get()); } return results; } ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) { if (!jmessage) { return {}; } const ScopedLocalRef<jclass> message_class( env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage"), env); const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>( env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText", "Ljava/lang/String;"); const std::pair<bool, int32> status_or_user_id = CallJniMethod0<int32>(env, jmessage, message_class.get(), &JNIEnv::CallIntMethod, "getUserId", "I"); const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>( env, jmessage, message_class.get(), &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", "J"); const std::pair<bool, jobject> status_or_reference_timezone = CallJniMethod0<jobject>(env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getReferenceTimezone", "Ljava/lang/String;"); const std::pair<bool, jobject> status_or_detected_text_language_tags = CallJniMethod0<jobject>( env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getDetectedTextLanguageTags", "Ljava/lang/String;"); if (!status_or_text.first || !status_or_user_id.first || !status_or_detected_text_language_tags.first || !status_or_reference_time.first || !status_or_reference_timezone.first) { return {}; } ConversationMessage message; message.text = ToStlString(env, static_cast<jstring>(status_or_text.second)); message.user_id = status_or_user_id.second; message.reference_time_ms_utc = status_or_reference_time.second; message.reference_timezone = ToStlString( env, static_cast<jstring>(status_or_reference_timezone.second)); message.detected_text_language_tags = ToStlString( env, static_cast<jstring>(status_or_detected_text_language_tags.second)); return message; } Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) { if (!jconversation) { return {}; } const ScopedLocalRef<jclass> conversation_class( env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$Conversation"), env); const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>( env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod, "getConversationMessages", "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;"); if (!status_or_messages.first) { return {}; } const jobjectArray jmessages = reinterpret_cast<jobjectArray>(status_or_messages.second); const int size = env->GetArrayLength(jmessages); std::vector<ConversationMessage> messages; for (int i = 0; i < size; i++) { jobject jmessage = env->GetObjectArrayElement(jmessages, i); ConversationMessage message = FromJavaConversationMessage(env, jmessage); messages.push_back(message); } Conversation conversation; conversation.messages = messages; return conversation; } jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { if (!mmap->handle().ok()) { return env->NewStringUTF(""); } const ActionsModel* model = libtextclassifier3::ViewActionsModel( mmap->handle().start(), mmap->handle().num_bytes()); if (!model || !model->locales()) { return env->NewStringUTF(""); } return env->NewStringUTF(model->locales()->c_str()); } jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { if (!mmap->handle().ok()) { return 0; } const ActionsModel* model = libtextclassifier3::ViewActionsModel( mmap->handle().start(), mmap->handle().num_bytes()); if (!model) { return 0; } return model->version(); } jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { if (!mmap->handle().ok()) { return env->NewStringUTF(""); } const ActionsModel* model = libtextclassifier3::ViewActionsModel( mmap->handle().start(), mmap->handle().num_bytes()); if (!model || !model->name()) { return env->NewStringUTF(""); } return env->NewStringUTF(model->name()->c_str()); } } // namespace } // namespace libtextclassifier3 using libtextclassifier3::ActionsSuggestionsJniContext; using libtextclassifier3::ActionSuggestionsToJObjectArray; using libtextclassifier3::FromJavaActionSuggestionOptions; using libtextclassifier3::FromJavaConversation; TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel) (JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) { std::shared_ptr<libtextclassifier3::JniCache> jni_cache = libtextclassifier3::JniCache::Create(env); std::string preconditions; if (serialized_preconditions != nullptr && !libtextclassifier3::JByteArrayToString(env, serialized_preconditions, &preconditions)) { TC3_LOG(ERROR) << "Could not convert serialized preconditions."; return 0; } #ifdef TC3_UNILIB_JAVAICU return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( jni_cache, ActionsSuggestions::FromFileDescriptor( fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions))); #else return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr, preconditions))); #endif // TC3_UNILIB_JAVAICU } TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath) (JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) { std::shared_ptr<libtextclassifier3::JniCache> jni_cache = libtextclassifier3::JniCache::Create(env); const std::string path_str = ToStlString(env, path); std::string preconditions; if (serialized_preconditions != nullptr && !libtextclassifier3::JByteArrayToString(env, serialized_preconditions, &preconditions)) { TC3_LOG(ERROR) << "Could not convert serialized preconditions."; return 0; } #ifdef TC3_UNILIB_JAVAICU return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( jni_cache, ActionsSuggestions::FromPath( path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions))); #else return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr, preconditions))); #endif // TC3_UNILIB_JAVAICU } TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions) (JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions, jlong annotatorPtr, jobject app_context, jstring device_locales, jboolean generate_intents) { if (!ptr) { return nullptr; } const Conversation conversation = FromJavaConversation(env, jconversation); const ActionSuggestionOptions options = FromJavaActionSuggestionOptions(env, joptions); const ActionsSuggestionsJniContext* context = reinterpret_cast<ActionsSuggestionsJniContext*>(ptr); const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr); const ActionsSuggestionsResponse response = context->model()->SuggestActions(conversation, annotator, options); const reflection::Schema* anntotations_entity_data_schema = annotator ? annotator->entity_data_schema() : nullptr; return ActionSuggestionsToJObjectArray( env, context, app_context, anntotations_entity_data_schema, response.actions, conversation, device_locales, generate_intents); } TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel) (JNIEnv* env, jobject clazz, jlong model_ptr) { const ActionsSuggestionsJniContext* context = reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr); delete context; } TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales) (JNIEnv* env, jobject clazz, jint fd) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd)); return libtextclassifier3::GetLocalesFromMmap(env, mmap.get()); } TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName) (JNIEnv* env, jobject clazz, jint fd) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd)); return libtextclassifier3::GetNameFromMmap(env, mmap.get()); } TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion) (JNIEnv* env, jobject clazz, jint fd) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd)); return libtextclassifier3::GetVersionFromMmap(env, mmap.get()); }