/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ #define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ #include "actions/types.h" #include "annotator/types.h" #include "utils/flatbuffers.h" #include "utils/lua-utils.h" #ifdef __cplusplus extern "C" { #endif #include "lauxlib.h" #include "lua.h" #include "lualib.h" #ifdef __cplusplus } #endif // Action specific shared lua utilities. namespace libtextclassifier3 { // Provides an annotation to lua. void PushAnnotation(const ClassificationResult& classification, const reflection::Schema* entity_data_schema, LuaEnvironment* env); void PushAnnotation(const ClassificationResult& classification, StringPiece text, const reflection::Schema* entity_data_schema, LuaEnvironment* env); void PushAnnotation(const ActionSuggestionAnnotation& annotation, const reflection::Schema* entity_data_schema, LuaEnvironment* env); // A lua iterator to enumerate annotation. template <typename Annotation> class AnnotationIterator : public LuaEnvironment::ItemIterator<std::vector<Annotation>> { public: AnnotationIterator(const reflection::Schema* entity_data_schema, LuaEnvironment* env) : env_(env), entity_data_schema_(entity_data_schema) {} int Item(const std::vector<Annotation>* annotations, const int64 pos, lua_State* state) const override { PushAnnotation((*annotations)[pos], entity_data_schema_, env_); return 1; } int Item(const std::vector<Annotation>* annotations, StringPiece key, lua_State* state) const override; private: LuaEnvironment* env_; const reflection::Schema* entity_data_schema_; }; template <> int AnnotationIterator<ClassificationResult>::Item( const std::vector<ClassificationResult>* annotations, StringPiece key, lua_State* state) const; template <> int AnnotationIterator<ActionSuggestionAnnotation>::Item( const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key, lua_State* state) const; void PushAnnotatedSpan( const AnnotatedSpan& annotated_span, const AnnotationIterator<ClassificationResult>& annotation_iterator, LuaEnvironment* env); MessageTextSpan ReadSpan(LuaEnvironment* env); ActionSuggestionAnnotation ReadAnnotation( const reflection::Schema* entity_data_schema, LuaEnvironment* env); int ReadAnnotations(const reflection::Schema* entity_data_schema, LuaEnvironment* env, std::vector<ActionSuggestionAnnotation>* annotations); ClassificationResult ReadClassificationResult( const reflection::Schema* entity_data_schema, LuaEnvironment* env); // A lua iterator to enumerate annotated spans. class AnnotatedSpanIterator : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> { public: AnnotatedSpanIterator( const AnnotationIterator<ClassificationResult>& annotation_iterator, LuaEnvironment* env) : env_(env), annotation_iterator_(annotation_iterator) {} AnnotatedSpanIterator(const reflection::Schema* entity_data_schema, LuaEnvironment* env) : env_(env), annotation_iterator_(entity_data_schema, env) {} int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos, lua_State* state) const override { PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_); return /*num results=*/1; } private: LuaEnvironment* env_; AnnotationIterator<ClassificationResult> annotation_iterator_; }; // Provides an action to lua. void PushAction( const ActionSuggestion& action, const reflection::Schema* entity_data_schema, const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator, LuaEnvironment* env); ActionSuggestion ReadAction( const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema, LuaEnvironment* env); int ReadActions(const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema, LuaEnvironment* env, std::vector<ActionSuggestion>* actions); // A lua iterator to enumerate actions suggestions. class ActionsIterator : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> { public: ActionsIterator(const reflection::Schema* entity_data_schema, const reflection::Schema* annotations_entity_data_schema, LuaEnvironment* env) : env_(env), entity_data_schema_(entity_data_schema), annotation_iterator_(annotations_entity_data_schema, env) {} int Item(const std::vector<ActionSuggestion>* actions, const int64 pos, lua_State* state) const override { PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_, env_); return /*num results=*/1; } private: LuaEnvironment* env_; const reflection::Schema* entity_data_schema_; AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_; }; // Conversation message lua iterator. class ConversationIterator : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> { public: ConversationIterator( const AnnotationIterator<ClassificationResult>& annotation_iterator, LuaEnvironment* env) : env_(env), annotated_span_iterator_( AnnotatedSpanIterator(annotation_iterator, env)) {} ConversationIterator(const reflection::Schema* entity_data_schema, LuaEnvironment* env) : env_(env), annotated_span_iterator_( AnnotatedSpanIterator(entity_data_schema, env)) {} int Item(const std::vector<ConversationMessage>* messages, const int64 pos, lua_State* state) const override; private: LuaEnvironment* env_; AnnotatedSpanIterator annotated_span_iterator_; }; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_