普通文本  |  900行  |  30.05 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "utils/intents/intent-generator.h"

#include <vector>

#include "actions/lua-utils.h"
#include "actions/types.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/hash/farmhash.h"
#include "utils/java/jni-base.h"
#include "utils/java/string_utils.h"
#include "utils/lua-utils.h"
#include "utils/strings/stringpiece.h"
#include "utils/strings/substitute.h"
#include "utils/utf8/unicodetext.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
#include "flatbuffers/reflection_generated.h"

#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#ifdef __cplusplus
}
#endif

namespace libtextclassifier3 {
namespace {

static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
static constexpr const char* kHashKey = "hash";
static constexpr const char* kUrlSchemaKey = "url_schema";
static constexpr const char* kUrlHostKey = "url_host";
static constexpr const char* kUrlEncodeKey = "urlencode";
static constexpr const char* kPackageNameKey = "package_name";
static constexpr const char* kDeviceLocaleKey = "device_locales";
static constexpr const char* kFormatKey = "format";

// An Android specific Lua environment with JNI backed callbacks.
class JniLuaEnvironment : public LuaEnvironment {
 public:
  JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
                    const jobject context,
                    const std::vector<Locale>& device_locales);
  // Environment setup.
  bool Initialize();

  // Runs an intent generator snippet.
  bool RunIntentGenerator(const std::string& generator_snippet,
                          std::vector<RemoteActionTemplate>* remote_actions);

 protected:
  virtual void SetupExternalHook();

  int HandleExternalCallback();
  int HandleAndroidCallback();
  int HandleUserRestrictionsCallback();
  int HandleUrlEncode();
  int HandleUrlSchema();
  int HandleHash();
  int HandleFormat();
  int HandleAndroidStringResources();
  int HandleUrlHost();

  // Checks and retrieves string resources from the model.
  bool LookupModelStringResource();

  // Reads and create a RemoteAction result from Lua.
  RemoteActionTemplate ReadRemoteActionTemplateResult();

  // Reads the extras from the Lua result.
  void ReadExtras(std::map<std::string, Variant>* extra);

  // Reads the intent categories array from a Lua result.
  void ReadCategories(std::vector<std::string>* category);

  // Retrieves user manager if not previously done.
  bool RetrieveUserManager();

  // Retrieves system resources if not previously done.
  bool RetrieveSystemResources();

  // Parse the url string by using Uri.parse from Java.
  ScopedLocalRef<jobject> ParseUri(StringPiece url) const;

  // Read remote action templates from lua generator.
  int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);

  const Resources& resources_;
  JNIEnv* jenv_;
  const JniCache* jni_cache_;
  const jobject context_;
  std::vector<Locale> device_locales_;

  ScopedGlobalRef<jobject> usermanager_;
  // Whether we previously attempted to retrieve the UserManager before.
  bool usermanager_retrieved_;

  ScopedGlobalRef<jobject> system_resources_;
  // Whether we previously attempted to retrieve the system resources.
  bool system_resources_resources_retrieved_;

  // Cached JNI references for Java strings `string` and `android`.
  ScopedGlobalRef<jstring> string_;
  ScopedGlobalRef<jstring> android_;
};

JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
                                     const JniCache* jni_cache,
                                     const jobject context,
                                     const std::vector<Locale>& device_locales)
    : resources_(resources),
      jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
      jni_cache_(jni_cache),
      context_(context),
      device_locales_(device_locales),
      usermanager_(/*object=*/nullptr,
                   /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
      usermanager_retrieved_(false),
      system_resources_(/*object=*/nullptr,
                        /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
      system_resources_resources_retrieved_(false),
      string_(/*object=*/nullptr,
              /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
      android_(/*object=*/nullptr,
               /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}

bool JniLuaEnvironment::Initialize() {
  string_ =
      MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
  android_ =
      MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
  if (string_ == nullptr || android_ == nullptr) {
    TC3_LOG(ERROR) << "Could not allocate constant strings references.";
    return false;
  }
  return (RunProtected([this] {
            LoadDefaultLibraries();
            SetupExternalHook();
            lua_setglobal(state_, "external");
            return LUA_OK;
          }) == LUA_OK);
}

void JniLuaEnvironment::SetupExternalHook() {
  // This exposes an `external` object with the following fields:
  //   * entity: the bundle with all information about a classification.
  //   * android: callbacks into specific android provided methods.
  //   * android.user_restrictions: callbacks to check user permissions.
  //   * android.R: callbacks to retrieve string resources.
  BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
      "external");

  // android
  BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
      "android");
  {
    // android.user_restrictions
    BindTable<JniLuaEnvironment,
              &JniLuaEnvironment::HandleUserRestrictionsCallback>(
        "user_restrictions");
    lua_setfield(state_, /*idx=*/-2, "user_restrictions");

    // android.R
    // Callback to access android string resources.
    BindTable<JniLuaEnvironment,
              &JniLuaEnvironment::HandleAndroidStringResources>("R");
    lua_setfield(state_, /*idx=*/-2, "R");
  }
  lua_setfield(state_, /*idx=*/-2, "android");
}

int JniLuaEnvironment::HandleExternalCallback() {
  const StringPiece key = ReadString(/*index=*/-1);
  if (key.Equals(kHashKey)) {
    Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
    return 1;
  } else if (key.Equals(kFormatKey)) {
    Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
    return 1;
  } else {
    TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
    lua_error(state_);
    return 0;
  }
}

int JniLuaEnvironment::HandleAndroidCallback() {
  const StringPiece key = ReadString(/*index=*/-1);
  if (key.Equals(kDeviceLocaleKey)) {
    // Provide the locale as table with the individual fields set.
    lua_newtable(state_);
    for (int i = 0; i < device_locales_.size(); i++) {
      // Adjust index to 1-based indexing for Lua.
      lua_pushinteger(state_, i + 1);
      lua_newtable(state_);
      PushString(device_locales_[i].Language());
      lua_setfield(state_, -2, "language");
      PushString(device_locales_[i].Region());
      lua_setfield(state_, -2, "region");
      PushString(device_locales_[i].Script());
      lua_setfield(state_, -2, "script");
      lua_settable(state_, /*idx=*/-3);
    }
    return 1;
  } else if (key.Equals(kPackageNameKey)) {
    if (context_ == nullptr) {
      TC3_LOG(ERROR) << "Context invalid.";
      lua_error(state_);
      return 0;
    }
    ScopedLocalRef<jstring> package_name_str(
        static_cast<jstring>(jenv_->CallObjectMethod(
            context_, jni_cache_->context_get_package_name)));
    if (jni_cache_->ExceptionCheckAndClear()) {
      TC3_LOG(ERROR) << "Error calling Context.getPackageName";
      lua_error(state_);
      return 0;
    }
    PushString(ToStlString(jenv_, package_name_str.get()));
    return 1;
  } else if (key.Equals(kUrlEncodeKey)) {
    Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
    return 1;
  } else if (key.Equals(kUrlHostKey)) {
    Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
    return 1;
  } else if (key.Equals(kUrlSchemaKey)) {
    Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
    return 1;
  } else {
    TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
    lua_error(state_);
    return 0;
  }
}

int JniLuaEnvironment::HandleUserRestrictionsCallback() {
  if (jni_cache_->usermanager_class == nullptr ||
      jni_cache_->usermanager_get_user_restrictions == nullptr) {
    // UserManager is only available for API level >= 17 and
    // getUserRestrictions only for API level >= 18, so we just return false
    // normally here.
    lua_pushboolean(state_, false);
    return 1;
  }

  // Get user manager if not previously retrieved.
  if (!RetrieveUserManager()) {
    TC3_LOG(ERROR) << "Error retrieving user manager.";
    lua_error(state_);
    return 0;
  }

  ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
      usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
  if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
    TC3_LOG(ERROR) << "Error calling getUserRestrictions";
    lua_error(state_);
    return 0;
  }

  const StringPiece key_str = ReadString(/*index=*/-1);
  if (key_str.empty()) {
    TC3_LOG(ERROR) << "Expected string, got null.";
    lua_error(state_);
    return 0;
  }

  ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
  if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
    TC3_LOG(ERROR) << "Expected string, got null.";
    lua_error(state_);
    return 0;
  }
  const bool permission = jenv_->CallBooleanMethod(
      bundle.get(), jni_cache_->bundle_get_boolean, key.get());
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error getting bundle value";
    lua_pushboolean(state_, false);
  } else {
    lua_pushboolean(state_, permission);
  }
  return 1;
}

int JniLuaEnvironment::HandleUrlEncode() {
  const StringPiece input = ReadString(/*index=*/1);
  if (input.empty()) {
    TC3_LOG(ERROR) << "Expected string, got null.";
    lua_error(state_);
    return 0;
  }

  // Call Java URL encoder.
  ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
  if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
    TC3_LOG(ERROR) << "Expected string, got null.";
    lua_error(state_);
    return 0;
  }
  ScopedLocalRef<jstring> encoded_str(
      static_cast<jstring>(jenv_->CallStaticObjectMethod(
          jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
          input_str.get(), jni_cache_->string_utf8.get())));
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
    lua_error(state_);
    return 0;
  }
  PushString(ToStlString(jenv_, encoded_str.get()));
  return 1;
}

ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
  if (url.empty()) {
    return nullptr;
  }

  // Call to Java URI parser.
  ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
  if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
    TC3_LOG(ERROR) << "Expected string, got null";
    return nullptr;
  }

  // Try to parse uri and get scheme.
  ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
      jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
  if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
    TC3_LOG(ERROR) << "Error calling Uri.parse";
  }
  return uri;
}

int JniLuaEnvironment::HandleUrlSchema() {
  StringPiece url = ReadString(/*index=*/1);

  ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
  if (parsed_uri == nullptr) {
    lua_error(state_);
    return 0;
  }

  ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
      jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error calling Uri.getScheme";
    lua_error(state_);
    return 0;
  }
  if (scheme_str == nullptr) {
    lua_pushnil(state_);
  } else {
    PushString(ToStlString(jenv_, scheme_str.get()));
  }
  return 1;
}

int JniLuaEnvironment::HandleUrlHost() {
  StringPiece url = ReadString(/*index=*/-1);

  ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
  if (parsed_uri == nullptr) {
    lua_error(state_);
    return 0;
  }

  ScopedLocalRef<jstring> host_str(static_cast<jstring>(
      jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error calling Uri.getHost";
    lua_error(state_);
    return 0;
  }
  if (host_str == nullptr) {
    lua_pushnil(state_);
  } else {
    PushString(ToStlString(jenv_, host_str.get()));
  }
  return 1;
}

int JniLuaEnvironment::HandleHash() {
  const StringPiece input = ReadString(/*index=*/-1);
  lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
  return 1;
}

int JniLuaEnvironment::HandleFormat() {
  const int num_args = lua_gettop(state_);
  std::vector<StringPiece> args(num_args - 1);
  for (int i = 0; i < num_args - 1; i++) {
    args[i] = ReadString(/*index=*/i + 2);
  }
  PushString(strings::Substitute(ReadString(/*index=*/1), args));
  return 1;
}

bool JniLuaEnvironment::LookupModelStringResource() {
  // Handle only lookup by name.
  if (lua_type(state_, 2) != LUA_TSTRING) {
    return false;
  }

  const StringPiece resource_name = ReadString(/*index=*/-1);
  std::string resource_content;
  if (!resources_.GetResourceContent(device_locales_, resource_name,
                                     &resource_content)) {
    // Resource cannot be provided by the model.
    return false;
  }

  PushString(resource_content);
  return true;
}

int JniLuaEnvironment::HandleAndroidStringResources() {
  // Check whether the requested resource can be served from the model data.
  if (LookupModelStringResource()) {
    return 1;
  }

  // Get system resources if not previously retrieved.
  if (!RetrieveSystemResources()) {
    TC3_LOG(ERROR) << "Error retrieving system resources.";
    lua_error(state_);
    return 0;
  }

  int resource_id;
  switch (lua_type(state_, -1)) {
    case LUA_TNUMBER:
      resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
      break;
    case LUA_TSTRING: {
      const StringPiece resource_name_str = ReadString(/*index=*/-1);
      if (resource_name_str.empty()) {
        TC3_LOG(ERROR) << "No resource name provided.";
        lua_error(state_);
        return 0;
      }
      ScopedLocalRef<jstring> resource_name =
          jni_cache_->ConvertToJavaString(resource_name_str);
      if (resource_name == nullptr) {
        TC3_LOG(ERROR) << "Invalid resource name.";
        lua_error(state_);
        return 0;
      }
      resource_id = jenv_->CallIntMethod(
          system_resources_.get(), jni_cache_->resources_get_identifier,
          resource_name.get(), string_.get(), android_.get());
      if (jni_cache_->ExceptionCheckAndClear()) {
        TC3_LOG(ERROR) << "Error calling getIdentifier.";
        lua_error(state_);
        return 0;
      }
      break;
    }
    default:
      TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
      lua_error(state_);
      return 0;
  }
  if (resource_id == 0) {
    TC3_LOG(ERROR) << "Resource not found.";
    lua_pushnil(state_);
    return 1;
  }
  ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
      jenv_->CallObjectMethod(system_resources_.get(),
                              jni_cache_->resources_get_string, resource_id)));
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error calling getString.";
    lua_error(state_);
    return 0;
  }
  if (resource_str == nullptr) {
    lua_pushnil(state_);
  } else {
    PushString(ToStlString(jenv_, resource_str.get()));
  }
  return 1;
}

bool JniLuaEnvironment::RetrieveSystemResources() {
  if (system_resources_resources_retrieved_) {
    return (system_resources_ != nullptr);
  }
  system_resources_resources_retrieved_ = true;
  jobject system_resources_ref = jenv_->CallStaticObjectMethod(
      jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error calling getSystem.";
    return false;
  }
  system_resources_ =
      MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
  return (system_resources_ != nullptr);
}

bool JniLuaEnvironment::RetrieveUserManager() {
  if (context_ == nullptr) {
    return false;
  }
  if (usermanager_retrieved_) {
    return (usermanager_ != nullptr);
  }
  usermanager_retrieved_ = true;
  ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
  jobject usermanager_ref = jenv_->CallObjectMethod(
      context_, jni_cache_->context_get_system_service, service.get());
  if (jni_cache_->ExceptionCheckAndClear()) {
    TC3_LOG(ERROR) << "Error calling getSystemService.";
    return false;
  }
  usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
  return (usermanager_ != nullptr);
}

RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
  RemoteActionTemplate result;
  // Read intent template.
  lua_pushnil(state_);
  while (lua_next(state_, /*idx=*/-2)) {
    const StringPiece key = ReadString(/*index=*/-2);
    if (key.Equals("title_without_entity")) {
      result.title_without_entity = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("title_with_entity")) {
      result.title_with_entity = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("description")) {
      result.description = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("description_with_app_name")) {
      result.description_with_app_name = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("action")) {
      result.action = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("data")) {
      result.data = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("type")) {
      result.type = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("flags")) {
      result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
    } else if (key.Equals("package_name")) {
      result.package_name = ReadString(/*index=*/-1).ToString();
    } else if (key.Equals("request_code")) {
      result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
    } else if (key.Equals("category")) {
      ReadCategories(&result.category);
    } else if (key.Equals("extra")) {
      ReadExtras(&result.extra);
    } else {
      TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
    }
    lua_pop(state_, 1);
  }
  lua_pop(state_, 1);
  return result;
}

void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
  // Read category array.
  if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    TC3_LOG(ERROR) << "Expected categories table, got: "
                   << lua_type(state_, /*idx=*/-1);
    lua_pop(state_, 1);
    return;
  }
  lua_pushnil(state_);
  while (lua_next(state_, /*idx=*/-2)) {
    category->push_back(ReadString(/*index=*/-1).ToString());
    lua_pop(state_, 1);
  }
}

void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
  if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    TC3_LOG(ERROR) << "Expected extras table, got: "
                   << lua_type(state_, /*idx=*/-1);
    lua_pop(state_, 1);
    return;
  }
  lua_pushnil(state_);
  while (lua_next(state_, /*idx=*/-2)) {
    // Each entry is a table specifying name and value.
    // The value is specified via a type specific field as Lua doesn't allow
    // to easily distinguish between different number types.
    if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
      TC3_LOG(ERROR) << "Expected a table for an extra, got: "
                     << lua_type(state_, /*idx=*/-1);
      lua_pop(state_, 1);
      return;
    }
    std::string name;
    Variant value;

    lua_pushnil(state_);
    while (lua_next(state_, /*idx=*/-2)) {
      const StringPiece key = ReadString(/*index=*/-2);
      if (key.Equals("name")) {
        name = ReadString(/*index=*/-1).ToString();
      } else if (key.Equals("int_value")) {
        value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
      } else if (key.Equals("long_value")) {
        value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
      } else if (key.Equals("float_value")) {
        value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
      } else if (key.Equals("bool_value")) {
        value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
      } else if (key.Equals("string_value")) {
        value = Variant(ReadString(/*index=*/-1).ToString());
      } else {
        TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
      }
      lua_pop(state_, 1);
    }
    if (!name.empty()) {
      (*extra)[name] = value;
    } else {
      TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
    }
    lua_pop(state_, 1);
  }
}

int JniLuaEnvironment::ReadRemoteActionTemplates(
    std::vector<RemoteActionTemplate>* result) {
  // Read result.
  if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
    lua_error(state_);
    return LUA_ERRRUN;
  }

  // Read remote action templates array.
  lua_pushnil(state_);
  while (lua_next(state_, /*idx=*/-2)) {
    if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
      TC3_LOG(ERROR) << "Expected intent table, got: "
                     << lua_type(state_, /*idx=*/-1);
      lua_pop(state_, 1);
      continue;
    }
    result->push_back(ReadRemoteActionTemplateResult());
  }
  lua_pop(state_, /*n=*/1);
  return LUA_OK;
}

bool JniLuaEnvironment::RunIntentGenerator(
    const std::string& generator_snippet,
    std::vector<RemoteActionTemplate>* remote_actions) {
  int status;
  status = luaL_loadbuffer(state_, generator_snippet.data(),
                           generator_snippet.size(),
                           /*name=*/nullptr);
  if (status != LUA_OK) {
    TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
    return false;
  }
  status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
  if (status != LUA_OK) {
    TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
    return false;
  }
  if (RunProtected(
          [this, remote_actions] {
            return ReadRemoteActionTemplates(remote_actions);
          },
          /*num_args=*/1) != LUA_OK) {
    TC3_LOG(ERROR) << "Could not read results.";
    return false;
  }
  // Check that we correctly cleaned-up the state.
  const int stack_size = lua_gettop(state_);
  if (stack_size > 0) {
    TC3_LOG(ERROR) << "Unexpected stack size.";
    lua_settop(state_, 0);
    return false;
  }
  return true;
}

// Lua environment for classfication result intent generation.
class AnnotatorJniEnvironment : public JniLuaEnvironment {
 public:
  AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
                          const jobject context,
                          const std::vector<Locale>& device_locales,
                          const std::string& entity_text,
                          const ClassificationResult& classification,
                          const int64 reference_time_ms_utc,
                          const reflection::Schema* entity_data_schema)
      : JniLuaEnvironment(resources, jni_cache, context, device_locales),
        entity_text_(entity_text),
        classification_(classification),
        reference_time_ms_utc_(reference_time_ms_utc),
        entity_data_schema_(entity_data_schema) {}

 protected:
  void SetupExternalHook() override {
    JniLuaEnvironment::SetupExternalHook();
    lua_pushinteger(state_, reference_time_ms_utc_);
    lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);

    PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
    lua_setfield(state_, /*idx=*/-2, "entity");
  }

  const std::string& entity_text_;
  const ClassificationResult& classification_;
  const int64 reference_time_ms_utc_;

  // Reflection schema data.
  const reflection::Schema* const entity_data_schema_;
};

// Lua environment for actions intent generation.
class ActionsJniLuaEnvironment : public JniLuaEnvironment {
 public:
  ActionsJniLuaEnvironment(
      const Resources& resources, const JniCache* jni_cache,
      const jobject context, const std::vector<Locale>& device_locales,
      const Conversation& conversation, const ActionSuggestion& action,
      const reflection::Schema* actions_entity_data_schema,
      const reflection::Schema* annotations_entity_data_schema)
      : JniLuaEnvironment(resources, jni_cache, context, device_locales),
        conversation_(conversation),
        action_(action),
        annotation_iterator_(annotations_entity_data_schema, this),
        conversation_iterator_(annotations_entity_data_schema, this),
        entity_data_schema_(actions_entity_data_schema) {}

 protected:
  void SetupExternalHook() override {
    JniLuaEnvironment::SetupExternalHook();
    conversation_iterator_.NewIterator("conversation", &conversation_.messages,
                                       state_);
    lua_setfield(state_, /*idx=*/-2, "conversation");

    PushAction(action_, entity_data_schema_, annotation_iterator_, this);
    lua_setfield(state_, /*idx=*/-2, "entity");
  }

  const Conversation& conversation_;
  const ActionSuggestion& action_;
  const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
  const ConversationIterator conversation_iterator_;
  const reflection::Schema* entity_data_schema_;
};

}  // namespace

std::unique_ptr<IntentGenerator> IntentGenerator::Create(
    const IntentFactoryModel* options, const ResourcePool* resources,
    const std::shared_ptr<JniCache>& jni_cache) {
  std::unique_ptr<IntentGenerator> intent_generator(
      new IntentGenerator(options, resources, jni_cache));

  if (options == nullptr || options->generator() == nullptr) {
    TC3_LOG(ERROR) << "No intent generator options.";
    return nullptr;
  }

  std::unique_ptr<ZlibDecompressor> zlib_decompressor =
      ZlibDecompressor::Instance();
  if (!zlib_decompressor) {
    TC3_LOG(ERROR) << "Cannot initialize decompressor.";
    return nullptr;
  }

  for (const IntentFactoryModel_::IntentGenerator* generator :
       *options->generator()) {
    std::string lua_template_generator;
    if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
            generator->lua_template_generator(),
            generator->compressed_lua_template_generator(),
            &lua_template_generator)) {
      TC3_LOG(ERROR) << "Could not decompress generator template.";
      return nullptr;
    }

    std::string lua_code = lua_template_generator;
    if (options->precompile_generators()) {
      if (!Compile(lua_template_generator, &lua_code)) {
        TC3_LOG(ERROR) << "Could not precompile generator template.";
        return nullptr;
      }
    }

    intent_generator->generators_[generator->type()->str()] = lua_code;
  }

  return intent_generator;
}

std::vector<Locale> IntentGenerator::ParseDeviceLocales(
    const jstring device_locales) const {
  if (device_locales == nullptr) {
    TC3_LOG(ERROR) << "No locales provided.";
    return {};
  }
  ScopedStringChars locales_str =
      GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
  if (locales_str == nullptr) {
    TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
    return {};
  }
  std::vector<Locale> locales;
  if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
                    &locales)) {
    TC3_LOG(ERROR) << "Cannot parse locales.";
    return {};
  }
  return locales;
}

bool IntentGenerator::GenerateIntents(
    const jstring device_locales, const ClassificationResult& classification,
    const int64 reference_time_ms_utc, const std::string& text,
    const CodepointSpan selection_indices, const jobject context,
    const reflection::Schema* annotations_entity_data_schema,
    std::vector<RemoteActionTemplate>* remote_actions) const {
  if (options_ == nullptr) {
    return false;
  }

  // Retrieve generator for specified entity.
  auto it = generators_.find(classification.collection);
  if (it == generators_.end()) {
    return true;
  }

  const std::string entity_text =
      UTF8ToUnicodeText(text, /*do_copy=*/false)
          .UTF8Substring(selection_indices.first, selection_indices.second);

  std::unique_ptr<AnnotatorJniEnvironment> interpreter(
      new AnnotatorJniEnvironment(
          resources_, jni_cache_.get(), context,
          ParseDeviceLocales(device_locales), entity_text, classification,
          reference_time_ms_utc, annotations_entity_data_schema));

  if (!interpreter->Initialize()) {
    TC3_LOG(ERROR) << "Could not create Lua interpreter.";
    return false;
  }

  return interpreter->RunIntentGenerator(it->second, remote_actions);
}

bool IntentGenerator::GenerateIntents(
    const jstring device_locales, const ActionSuggestion& action,
    const Conversation& conversation, const jobject context,
    const reflection::Schema* annotations_entity_data_schema,
    const reflection::Schema* actions_entity_data_schema,
    std::vector<RemoteActionTemplate>* remote_actions) const {
  if (options_ == nullptr) {
    return false;
  }

  // Retrieve generator for specified action.
  auto it = generators_.find(action.type);
  if (it == generators_.end()) {
    return true;
  }

  std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
      new ActionsJniLuaEnvironment(
          resources_, jni_cache_.get(), context,
          ParseDeviceLocales(device_locales), conversation, action,
          actions_entity_data_schema, annotations_entity_data_schema));

  if (!interpreter->Initialize()) {
    TC3_LOG(ERROR) << "Could not create Lua interpreter.";
    return false;
  }

  return interpreter->RunIntentGenerator(it->second, remote_actions);
}

}  // namespace libtextclassifier3