/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/lua-ranker.h" #include "utils/base/logging.h" #include "utils/lua-utils.h" #ifdef __cplusplus extern "C" { #endif #include "lauxlib.h" #include "lualib.h" #ifdef __cplusplus } #endif namespace libtextclassifier3 { std::unique_ptr<ActionsSuggestionsLuaRanker> ActionsSuggestionsLuaRanker::Create( const Conversation& conversation, const std::string& ranker_code, const reflection::Schema* entity_data_schema, const reflection::Schema* annotations_entity_data_schema, ActionsSuggestionsResponse* response) { auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>( new ActionsSuggestionsLuaRanker( conversation, ranker_code, entity_data_schema, annotations_entity_data_schema, response)); if (!ranker->Initialize()) { TC3_LOG(ERROR) << "Could not initialize lua environment for ranker."; return nullptr; } return ranker; } bool ActionsSuggestionsLuaRanker::Initialize() { return RunProtected([this] { LoadDefaultLibraries(); // Expose generated actions. actions_iterator_.NewIterator("actions", &response_->actions, state_); lua_setglobal(state_, "actions"); // Expose conversation message stream. conversation_iterator_.NewIterator("messages", &conversation_.messages, state_); lua_setglobal(state_, "messages"); return LUA_OK; }) == LUA_OK; } int ActionsSuggestionsLuaRanker::ReadActionsRanking() { if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected actions table, got: " << lua_type(state_, /*idx=*/-1); lua_pop(state_, 1); lua_error(state_); return LUA_ERRRUN; } std::vector<ActionSuggestion> ranked_actions; lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { const int action_id = static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1; lua_pop(state_, 1); if (action_id < 0 || action_id >= response_->actions.size()) { TC3_LOG(ERROR) << "Invalid action index: " << action_id; lua_error(state_); return LUA_ERRRUN; } ranked_actions.push_back(response_->actions[action_id]); } lua_pop(state_, 1); response_->actions = ranked_actions; return LUA_OK; } bool ActionsSuggestionsLuaRanker::RankActions() { if (response_->actions.empty()) { // Nothing to do. return true; } if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(), /*name=*/nullptr) != LUA_OK) { TC3_LOG(ERROR) << "Could not load compiled ranking snippet."; return false; } if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) { TC3_LOG(ERROR) << "Could not run ranking snippet."; return false; } if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) != LUA_OK) { TC3_LOG(ERROR) << "Could not read lua result."; return false; } return true; } } // namespace libtextclassifier3