/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// JNI wrapper for actions.

#include "actions/actions_jni.h"

#include <jni.h>
#include <map>
#include <type_traits>
#include <vector>

#include "actions/actions-suggestions.h"
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
#include "utils/base/integral_types.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/jni.h"
#include "utils/java/jni-cache.h"
#include "utils/java/scoped_local_ref.h"
#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"

using libtextclassifier3::ActionsSuggestions;
using libtextclassifier3::ActionsSuggestionsResponse;
using libtextclassifier3::ActionSuggestion;
using libtextclassifier3::ActionSuggestionOptions;
using libtextclassifier3::Annotator;
using libtextclassifier3::Conversation;
using libtextclassifier3::IntentGenerator;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::ToStlString;

// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
// pointer from JNI. When using a standard ICU the pointer is not needed and the
// objects are instantiated implicitly.
#ifdef TC3_UNILIB_JAVAICU
using libtextclassifier3::UniLib;
#endif

namespace libtextclassifier3 {

namespace {

// Cached state for model inference.
// Keeps a jni cache, intent generator and model instance so that they don't
// have to be recreated for each call.
class ActionsSuggestionsJniContext {
 public:
  static ActionsSuggestionsJniContext* Create(
      const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
      std::unique_ptr<ActionsSuggestions> model) {
    if (jni_cache == nullptr || model == nullptr) {
      return nullptr;
    }
    std::unique_ptr<IntentGenerator> intent_generator =
        IntentGenerator::Create(model->model()->android_intent_options(),
                                model->model()->resources(), jni_cache);
    std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
        libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);

    if (intent_generator == nullptr || template_handler == nullptr) {
      return nullptr;
    }

    return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
                                            std::move(intent_generator),
                                            std::move(template_handler));
  }

  std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
    return jni_cache_;
  }

  ActionsSuggestions* model() const { return model_.get(); }

  IntentGenerator* intent_generator() const { return intent_generator_.get(); }

  RemoteActionTemplatesHandler* template_handler() const {
    return template_handler_.get();
  }

 private:
  ActionsSuggestionsJniContext(
      const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
      std::unique_ptr<ActionsSuggestions> model,
      std::unique_ptr<IntentGenerator> intent_generator,
      std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
      : jni_cache_(jni_cache),
        model_(std::move(model)),
        intent_generator_(std::move(intent_generator)),
        template_handler_(std::move(template_handler)) {}

  std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
  std::unique_ptr<ActionsSuggestions> model_;
  std::unique_ptr<IntentGenerator> intent_generator_;
  std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
};

ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
                                                        jobject joptions) {
  ActionSuggestionOptions options = ActionSuggestionOptions::Default();
  return options;
}

jobjectArray ActionSuggestionsToJObjectArray(
    JNIEnv* env, const ActionsSuggestionsJniContext* context,
    jobject app_context,
    const reflection::Schema* annotations_entity_data_schema,
    const std::vector<ActionSuggestion>& action_result,
    const Conversation& conversation, const jstring device_locales,
    const bool generate_intents) {
  const ScopedLocalRef<jclass> result_class(
      env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
                     "$ActionSuggestion"),
      env);
  if (!result_class) {
    TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
    return nullptr;
  }

  const jmethodID result_class_constructor = env->GetMethodID(
      result_class.get(), "<init>",
      "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
          TC3_NAMED_VARIANT_CLASS_NAME_STR
      ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
  const jobjectArray results =
      env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
  for (int i = 0; i < action_result.size(); i++) {
    jobject extras = nullptr;

    const reflection::Schema* actions_entity_data_schema =
        context->model()->entity_data_schema();
    if (actions_entity_data_schema != nullptr &&
        !action_result[i].serialized_entity_data.empty()) {
      extras = context->template_handler()->EntityDataAsNamedVariantArray(
          actions_entity_data_schema, action_result[i].serialized_entity_data);
    }

    jbyteArray serialized_entity_data = nullptr;
    if (!action_result[i].serialized_entity_data.empty()) {
      serialized_entity_data =
          env->NewByteArray(action_result[i].serialized_entity_data.size());
      env->SetByteArrayRegion(
          serialized_entity_data, 0,
          action_result[i].serialized_entity_data.size(),
          reinterpret_cast<const jbyte*>(
              action_result[i].serialized_entity_data.data()));
    }

    jobject remote_action_templates_result = nullptr;
    if (generate_intents) {
      std::vector<RemoteActionTemplate> remote_action_templates;
      if (context->intent_generator()->GenerateIntents(
              device_locales, action_result[i], conversation, app_context,
              actions_entity_data_schema, annotations_entity_data_schema,
              &remote_action_templates)) {
        remote_action_templates_result =
            context->template_handler()->RemoteActionTemplatesToJObjectArray(
                remote_action_templates);
      }
    }

    ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
        action_result[i].response_text);

    ScopedLocalRef<jobject> result(env->NewObject(
        result_class.get(), result_class_constructor, reply.get(),
        env->NewStringUTF(action_result[i].type.c_str()),
        static_cast<jfloat>(action_result[i].score), extras,
        serialized_entity_data, remote_action_templates_result));
    env->SetObjectArrayElement(results, i, result.get());
  }
  return results;
}

ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
  if (!jmessage) {
    return {};
  }

  const ScopedLocalRef<jclass> message_class(
      env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
                     "$ConversationMessage"),
      env);
  const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
      env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
      "Ljava/lang/String;");
  const std::pair<bool, int32> status_or_user_id =
      CallJniMethod0<int32>(env, jmessage, message_class.get(),
                            &JNIEnv::CallIntMethod, "getUserId", "I");
  const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
      env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
      "getReferenceTimeMsUtc", "J");
  const std::pair<bool, jobject> status_or_reference_timezone =
      CallJniMethod0<jobject>(env, jmessage, message_class.get(),
                              &JNIEnv::CallObjectMethod, "getReferenceTimezone",
                              "Ljava/lang/String;");
  const std::pair<bool, jobject> status_or_detected_text_language_tags =
      CallJniMethod0<jobject>(
          env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
          "getDetectedTextLanguageTags", "Ljava/lang/String;");
  if (!status_or_text.first || !status_or_user_id.first ||
      !status_or_detected_text_language_tags.first ||
      !status_or_reference_time.first || !status_or_reference_timezone.first) {
    return {};
  }

  ConversationMessage message;
  message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
  message.user_id = status_or_user_id.second;
  message.reference_time_ms_utc = status_or_reference_time.second;
  message.reference_timezone = ToStlString(
      env, static_cast<jstring>(status_or_reference_timezone.second));
  message.detected_text_language_tags = ToStlString(
      env, static_cast<jstring>(status_or_detected_text_language_tags.second));
  return message;
}

Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
  if (!jconversation) {
    return {};
  }

  const ScopedLocalRef<jclass> conversation_class(
      env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
                     "$Conversation"),
      env);

  const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
      env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
      "getConversationMessages",
      "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");

  if (!status_or_messages.first) {
    return {};
  }

  const jobjectArray jmessages =
      reinterpret_cast<jobjectArray>(status_or_messages.second);

  const int size = env->GetArrayLength(jmessages);

  std::vector<ConversationMessage> messages;
  for (int i = 0; i < size; i++) {
    jobject jmessage = env->GetObjectArrayElement(jmessages, i);
    ConversationMessage message = FromJavaConversationMessage(env, jmessage);
    messages.push_back(message);
  }
  Conversation conversation;
  conversation.messages = messages;
  return conversation;
}

jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
  if (!mmap->handle().ok()) {
    return env->NewStringUTF("");
  }
  const ActionsModel* model = libtextclassifier3::ViewActionsModel(
      mmap->handle().start(), mmap->handle().num_bytes());
  if (!model || !model->locales()) {
    return env->NewStringUTF("");
  }
  return env->NewStringUTF(model->locales()->c_str());
}

jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
  if (!mmap->handle().ok()) {
    return 0;
  }
  const ActionsModel* model = libtextclassifier3::ViewActionsModel(
      mmap->handle().start(), mmap->handle().num_bytes());
  if (!model) {
    return 0;
  }
  return model->version();
}

jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
  if (!mmap->handle().ok()) {
    return env->NewStringUTF("");
  }
  const ActionsModel* model = libtextclassifier3::ViewActionsModel(
      mmap->handle().start(), mmap->handle().num_bytes());
  if (!model || !model->name()) {
    return env->NewStringUTF("");
  }
  return env->NewStringUTF(model->name()->c_str());
}
}  // namespace
}  // namespace libtextclassifier3

using libtextclassifier3::ActionsSuggestionsJniContext;
using libtextclassifier3::ActionSuggestionsToJObjectArray;
using libtextclassifier3::FromJavaActionSuggestionOptions;
using libtextclassifier3::FromJavaConversation;

TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
  std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
      libtextclassifier3::JniCache::Create(env);
  std::string preconditions;
  if (serialized_preconditions != nullptr &&
      !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
                                              &preconditions)) {
    TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
    return 0;
  }
#ifdef TC3_UNILIB_JAVAICU
  return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
      jni_cache,
      ActionsSuggestions::FromFileDescriptor(
          fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
#else
  return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
      jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
                                                        preconditions)));
#endif  // TC3_UNILIB_JAVAICU
}

TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
  std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
      libtextclassifier3::JniCache::Create(env);
  const std::string path_str = ToStlString(env, path);
  std::string preconditions;
  if (serialized_preconditions != nullptr &&
      !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
                                              &preconditions)) {
    TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
    return 0;
  }
#ifdef TC3_UNILIB_JAVAICU
  return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
      jni_cache, ActionsSuggestions::FromPath(
                     path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
                     preconditions)));
#else
  return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
      jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
                                              preconditions)));
#endif  // TC3_UNILIB_JAVAICU
}

TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
 jlong annotatorPtr, jobject app_context, jstring device_locales,
 jboolean generate_intents) {
  if (!ptr) {
    return nullptr;
  }
  const Conversation conversation = FromJavaConversation(env, jconversation);
  const ActionSuggestionOptions options =
      FromJavaActionSuggestionOptions(env, joptions);
  const ActionsSuggestionsJniContext* context =
      reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
  const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);

  const ActionsSuggestionsResponse response =
      context->model()->SuggestActions(conversation, annotator, options);

  const reflection::Schema* anntotations_entity_data_schema =
      annotator ? annotator->entity_data_schema() : nullptr;
  return ActionSuggestionsToJObjectArray(
      env, context, app_context, anntotations_entity_data_schema,
      response.actions, conversation, device_locales, generate_intents);
}

TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
(JNIEnv* env, jobject clazz, jlong model_ptr) {
  const ActionsSuggestionsJniContext* context =
      reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
  delete context;
}

TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
(JNIEnv* env, jobject clazz, jint fd) {
  const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
      new libtextclassifier3::ScopedMmap(fd));
  return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
}

TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd) {
  const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
      new libtextclassifier3::ScopedMmap(fd));
  return libtextclassifier3::GetNameFromMmap(env, mmap.get());
}

TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd) {
  const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
      new libtextclassifier3::ScopedMmap(fd));
  return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
}