C++程序  |  183行  |  6.5 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
#define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_

#include "actions/types.h"
#include "annotator/types.h"
#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"

#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif

// Action specific shared lua utilities.
namespace libtextclassifier3 {

// Provides an annotation to lua.
void PushAnnotation(const ClassificationResult& classification,
                    const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env);
void PushAnnotation(const ClassificationResult& classification,
                    StringPiece text,
                    const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env);
void PushAnnotation(const ActionSuggestionAnnotation& annotation,
                    const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env);

// A lua iterator to enumerate annotation.
template <typename Annotation>
class AnnotationIterator
    : public LuaEnvironment::ItemIterator<std::vector<Annotation>> {
 public:
  AnnotationIterator(const reflection::Schema* entity_data_schema,
                     LuaEnvironment* env)
      : env_(env), entity_data_schema_(entity_data_schema) {}
  int Item(const std::vector<Annotation>* annotations, const int64 pos,
           lua_State* state) const override {
    PushAnnotation((*annotations)[pos], entity_data_schema_, env_);
    return 1;
  }
  int Item(const std::vector<Annotation>* annotations, StringPiece key,
           lua_State* state) const override;

 private:
  LuaEnvironment* env_;
  const reflection::Schema* entity_data_schema_;
};

template <>
int AnnotationIterator<ClassificationResult>::Item(
    const std::vector<ClassificationResult>* annotations, StringPiece key,
    lua_State* state) const;

template <>
int AnnotationIterator<ActionSuggestionAnnotation>::Item(
    const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
    lua_State* state) const;

void PushAnnotatedSpan(
    const AnnotatedSpan& annotated_span,
    const AnnotationIterator<ClassificationResult>& annotation_iterator,
    LuaEnvironment* env);

MessageTextSpan ReadSpan(LuaEnvironment* env);
ActionSuggestionAnnotation ReadAnnotation(
    const reflection::Schema* entity_data_schema, LuaEnvironment* env);
int ReadAnnotations(const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env,
                    std::vector<ActionSuggestionAnnotation>* annotations);
ClassificationResult ReadClassificationResult(
    const reflection::Schema* entity_data_schema, LuaEnvironment* env);

// A lua iterator to enumerate annotated spans.
class AnnotatedSpanIterator
    : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> {
 public:
  AnnotatedSpanIterator(
      const AnnotationIterator<ClassificationResult>& annotation_iterator,
      LuaEnvironment* env)
      : env_(env), annotation_iterator_(annotation_iterator) {}
  AnnotatedSpanIterator(const reflection::Schema* entity_data_schema,
                        LuaEnvironment* env)
      : env_(env), annotation_iterator_(entity_data_schema, env) {}

  int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos,
           lua_State* state) const override {
    PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_);
    return /*num results=*/1;
  }

 private:
  LuaEnvironment* env_;
  AnnotationIterator<ClassificationResult> annotation_iterator_;
};

// Provides an action to lua.
void PushAction(
    const ActionSuggestion& action,
    const reflection::Schema* entity_data_schema,
    const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
    LuaEnvironment* env);

ActionSuggestion ReadAction(
    const reflection::Schema* actions_entity_data_schema,
    const reflection::Schema* annotations_entity_data_schema,
    LuaEnvironment* env);
int ReadActions(const reflection::Schema* actions_entity_data_schema,
                const reflection::Schema* annotations_entity_data_schema,
                LuaEnvironment* env, std::vector<ActionSuggestion>* actions);

// A lua iterator to enumerate actions suggestions.
class ActionsIterator
    : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> {
 public:
  ActionsIterator(const reflection::Schema* entity_data_schema,
                  const reflection::Schema* annotations_entity_data_schema,
                  LuaEnvironment* env)
      : env_(env),
        entity_data_schema_(entity_data_schema),
        annotation_iterator_(annotations_entity_data_schema, env) {}
  int Item(const std::vector<ActionSuggestion>* actions, const int64 pos,
           lua_State* state) const override {
    PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_,
               env_);
    return /*num results=*/1;
  }

 private:
  LuaEnvironment* env_;
  const reflection::Schema* entity_data_schema_;
  AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
};

// Conversation message lua iterator.
class ConversationIterator
    : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> {
 public:
  ConversationIterator(
      const AnnotationIterator<ClassificationResult>& annotation_iterator,
      LuaEnvironment* env)
      : env_(env),
        annotated_span_iterator_(
            AnnotatedSpanIterator(annotation_iterator, env)) {}
  ConversationIterator(const reflection::Schema* entity_data_schema,
                       LuaEnvironment* env)
      : env_(env),
        annotated_span_iterator_(
            AnnotatedSpanIterator(entity_data_schema, env)) {}

  int Item(const std::vector<ConversationMessage>* messages, const int64 pos,
           lua_State* state) const override;

 private:
  LuaEnvironment* env_;
  AnnotatedSpanIterator annotated_span_iterator_;
};

}  // namespace libtextclassifier3

#endif  // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_