/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_

#include <functional>
#include <vector>

#include "utils/flatbuffers.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"

#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif

namespace libtextclassifier3 {

static constexpr const char *kLengthKey = "__len";
static constexpr const char *kPairsKey = "__pairs";
static constexpr const char *kIndexKey = "__index";

// Casts to the lua user data type.
template <typename T>
void *AsUserData(const T *value) {
  return static_cast<void *>(const_cast<T *>(value));
}
template <typename T>
void *AsUserData(const T value) {
  return reinterpret_cast<void *>(value);
}

// Retrieves up-values.
template <typename T>
T FromUpValue(const int index, lua_State *state) {
  return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
}

class LuaEnvironment {
 public:
  // Wrapper for handling an iterator.
  class Iterator {
   public:
    virtual ~Iterator() {}
    static int NextCallback(lua_State *state);
    static int LengthCallback(lua_State *state);
    static int ItemCallback(lua_State *state);
    static int IteritemsCallback(lua_State *state);

    // Called when the next element of an iterator is fetched.
    virtual int Next(lua_State *state) const = 0;

    // Called when the length of the iterator is queried.
    virtual int Length(lua_State *state) const = 0;

    // Called when an item is queried.
    virtual int Item(lua_State *state) const = 0;

    // Called when a new iterator is started.
    virtual int Iteritems(lua_State *state) const = 0;

   protected:
    static constexpr int kIteratorArgId = 1;
  };

  template <typename T>
  class ItemIterator : public Iterator {
   public:
    void NewIterator(StringPiece name, const T *items, lua_State *state) const {
      lua_newtable(state);
      luaL_newmetatable(state, name.data());
      lua_pushlightuserdata(state, AsUserData(this));
      lua_pushlightuserdata(state, AsUserData(items));
      lua_pushcclosure(state, &Iterator::ItemCallback, 2);
      lua_setfield(state, -2, kIndexKey);
      lua_pushlightuserdata(state, AsUserData(this));
      lua_pushlightuserdata(state, AsUserData(items));
      lua_pushcclosure(state, &Iterator::LengthCallback, 2);
      lua_setfield(state, -2, kLengthKey);
      lua_pushlightuserdata(state, AsUserData(this));
      lua_pushlightuserdata(state, AsUserData(items));
      lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
      lua_setfield(state, -2, kPairsKey);
      lua_setmetatable(state, -2);
    }

    int Iteritems(lua_State *state) const override {
      lua_pushlightuserdata(state, AsUserData(this));
      lua_pushlightuserdata(
          state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
      lua_pushnumber(state, 0);
      lua_pushcclosure(state, &Iterator::NextCallback, 3);
      return /*num results=*/1;
    }

    int Length(lua_State *state) const override {
      lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
      return /*num results=*/1;
    }

    int Next(lua_State *state) const override {
      return Next(FromUpValue<T *>(kItemsArgId, state),
                  lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
                  state);
    }

    int Next(const T *items, const int64 pos, lua_State *state) const {
      if (pos >= items->size()) {
        return 0;
      }

      // Update iterator value.
      lua_pushnumber(state, pos + 1);
      lua_replace(state, lua_upvalueindex(3));

      // Push key.
      lua_pushinteger(state, pos + 1);

      // Push item.
      return 1 + Item(items, pos, state);
    }

    int Item(lua_State *state) const override {
      const T *items = FromUpValue<T *>(kItemsArgId, state);
      switch (lua_type(state, -1)) {
        case LUA_TNUMBER: {
          // Lua is one based, so adjust the index here.
          const int64 index =
              static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
          if (index < 0 || index >= items->size()) {
            TC3_LOG(ERROR) << "Invalid index: " << index;
            lua_error(state);
            return 0;
          }
          return Item(items, index, state);
        }
        case LUA_TSTRING: {
          size_t key_length = 0;
          const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
          return Item(items, StringPiece(key, key_length), state);
        }
        default:
          TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
          lua_error(state);
          return 0;
      }
    }

    virtual int Item(const T *items, const int64 pos,
                     lua_State *state) const = 0;

    virtual int Item(const T *items, StringPiece key, lua_State *state) const {
      TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
      lua_error(state);
      return 0;
    }

   protected:
    static constexpr int kItemsArgId = 2;
    static constexpr int kIterValueArgId = 3;
  };

  virtual ~LuaEnvironment();
  LuaEnvironment();

  // Compile a lua snippet into binary bytecode.
  // NOTE: The compiled bytecode might not be compatible across Lua versions
  // and platforms.
  bool Compile(StringPiece snippet, std::string *bytecode);

  typedef int (*CallbackHandler)(lua_State *);

  // Loads default libraries.
  void LoadDefaultLibraries();

  // Provides a callback to Lua.
  template <typename T, int (T::*handler)()>
  void Bind() {
    lua_pushlightuserdata(state_, static_cast<void *>(this));
    lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
  }

  // Setup a named table that callsback whenever a member is accessed.
  // This allows to lazily provide required information to the script.
  template <typename T, int (T::*handler)()>
  void BindTable(const char *name) {
    lua_newtable(state_);
    luaL_newmetatable(state_, name);
    lua_pushlightuserdata(state_, static_cast<void *>(this));
    lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
    lua_setfield(state_, -2, kIndexKey);
    lua_setmetatable(state_, -2);
  }

  void PushValue(const Variant &value);

  // Reads a string from the stack.
  StringPiece ReadString(const int index) const;

  // Pushes a string to the stack.
  void PushString(const StringPiece str);

  // Pushes a flatbuffer to the stack.
  void PushFlatbuffer(const reflection::Schema *schema,
                      const flatbuffers::Table *table);

  // Reads a flatbuffer from the stack.
  int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);

  // Runs a closure in protected mode.
  // `func`: closure to run in protected mode.
  // `num_lua_args`: number of arguments from the lua stack to process.
  // `num_results`: number of result values pushed on the stack.
  int RunProtected(const std::function<int()> &func, const int num_args = 0,
                   const int num_results = 0);

  lua_State *state() const { return state_; }

 protected:
  lua_State *state_;

 private:
  // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
  static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
                             const reflection::Object *type,
                             const flatbuffers::Table *table, lua_State *state);
  static int GetFieldCallback(lua_State *state);
  static int GetField(const reflection::Schema *schema,
                      const reflection::Object *type,
                      const flatbuffers::Table *table, lua_State *state);

  template <typename T, int (T::*handler)()>
  static int Dispatch(lua_State *state) {
    T *env = FromUpValue<T *>(1, state);
    return ((*env).*handler)();
  }
};

bool Compile(StringPiece snippet, std::string *bytecode);

}  // namespace libtextclassifier3

#endif  // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_