普通文本  |  355行  |  13.7 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "actions/lua-utils.h"

namespace libtextclassifier3 {
namespace {
static constexpr const char* kTextKey = "text";
static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
static constexpr const char* kGranularityKey = "granularity";
static constexpr const char* kCollectionKey = "collection";
static constexpr const char* kNameKey = "name";
static constexpr const char* kScoreKey = "score";
static constexpr const char* kPriorityScoreKey = "priority_score";
static constexpr const char* kTypeKey = "type";
static constexpr const char* kResponseTextKey = "response_text";
static constexpr const char* kAnnotationKey = "annotation";
static constexpr const char* kSpanKey = "span";
static constexpr const char* kMessageKey = "message";
static constexpr const char* kBeginKey = "begin";
static constexpr const char* kEndKey = "end";
static constexpr const char* kClassificationKey = "classification";
static constexpr const char* kSerializedEntity = "serialized_entity";
static constexpr const char* kEntityKey = "entity";
}  // namespace

template <>
int AnnotationIterator<ClassificationResult>::Item(
    const std::vector<ClassificationResult>* annotations, StringPiece key,
    lua_State* state) const {
  // Lookup annotation by collection.
  for (const ClassificationResult& annotation : *annotations) {
    if (key.Equals(annotation.collection)) {
      PushAnnotation(annotation, entity_data_schema_, env_);
      return 1;
    }
  }
  TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
                 << " found.";
  lua_error(state);
  return 0;
}

template <>
int AnnotationIterator<ActionSuggestionAnnotation>::Item(
    const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
    lua_State* state) const {
  // Lookup annotation by name.
  for (const ActionSuggestionAnnotation& annotation : *annotations) {
    if (key.Equals(annotation.name)) {
      PushAnnotation(annotation, entity_data_schema_, env_);
      return 1;
    }
  }
  TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
  lua_error(state);
  return 0;
}

void PushAnnotation(const ClassificationResult& classification,
                    const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env) {
  if (entity_data_schema == nullptr ||
      classification.serialized_entity_data.empty()) {
    // Empty table.
    lua_newtable(env->state());
  } else {
    env->PushFlatbuffer(entity_data_schema,
                        flatbuffers::GetRoot<flatbuffers::Table>(
                            classification.serialized_entity_data.data()));
  }
  lua_pushinteger(env->state(),
                  classification.datetime_parse_result.time_ms_utc);
  lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
  lua_pushinteger(env->state(),
                  classification.datetime_parse_result.granularity);
  lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
  env->PushString(classification.collection);
  lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
  lua_pushnumber(env->state(), classification.score);
  lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
  env->PushString(classification.serialized_entity_data);
  lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
}

void PushAnnotation(const ClassificationResult& classification,
                    StringPiece text,
                    const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env) {
  PushAnnotation(classification, entity_data_schema, env);
  env->PushString(text);
  lua_setfield(env->state(), /*idx=*/-2, kTextKey);
}

void PushAnnotatedSpan(
    const AnnotatedSpan& annotated_span,
    const AnnotationIterator<ClassificationResult>& annotation_iterator,
    LuaEnvironment* env) {
  lua_newtable(env->state());
  {
    lua_newtable(env->state());
    lua_pushinteger(env->state(), annotated_span.span.first);
    lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
    lua_pushinteger(env->state(), annotated_span.span.second);
    lua_setfield(env->state(), /*idx=*/-2, kEndKey);
  }
  lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
  annotation_iterator.NewIterator(kClassificationKey,
                                  &annotated_span.classification, env->state());
  lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
}

MessageTextSpan ReadSpan(LuaEnvironment* env) {
  MessageTextSpan span;
  lua_pushnil(env->state());
  while (lua_next(env->state(), /*idx=*/-2)) {
    const StringPiece key = env->ReadString(/*index=*/-2);
    if (key.Equals(kMessageKey)) {
      span.message_index =
          static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kBeginKey)) {
      span.span.first =
          static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kEndKey)) {
      span.span.second =
          static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kTextKey)) {
      span.text = env->ReadString(/*index=*/-1).ToString();
    } else {
      TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
    }
    lua_pop(env->state(), 1);
  }
  return span;
}

int ReadAnnotations(const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env,
                    std::vector<ActionSuggestionAnnotation>* annotations) {
  if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
    TC3_LOG(ERROR) << "Expected annotations table, got: "
                   << lua_type(env->state(), /*idx=*/-1);
    lua_pop(env->state(), 1);
    lua_error(env->state());
    return LUA_ERRRUN;
  }

  // Read actions.
  lua_pushnil(env->state());
  while (lua_next(env->state(), /*idx=*/-2)) {
    if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
      TC3_LOG(ERROR) << "Expected annotation table, got: "
                     << lua_type(env->state(), /*idx=*/-1);
      lua_pop(env->state(), 1);
      continue;
    }
    annotations->push_back(ReadAnnotation(entity_data_schema, env));
    lua_pop(env->state(), 1);
  }
  return LUA_OK;
}

ActionSuggestionAnnotation ReadAnnotation(
    const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
  ActionSuggestionAnnotation annotation;
  lua_pushnil(env->state());
  while (lua_next(env->state(), /*idx=*/-2)) {
    const StringPiece key = env->ReadString(/*index=*/-2);
    if (key.Equals(kNameKey)) {
      annotation.name = env->ReadString(/*index=*/-1).ToString();
    } else if (key.Equals(kSpanKey)) {
      annotation.span = ReadSpan(env);
    } else if (key.Equals(kEntityKey)) {
      annotation.entity = ReadClassificationResult(entity_data_schema, env);
    } else {
      TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
    }
    lua_pop(env->state(), 1);
  }
  return annotation;
}

ClassificationResult ReadClassificationResult(
    const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
  ClassificationResult classification;
  lua_pushnil(env->state());
  while (lua_next(env->state(), /*idx=*/-2)) {
    const StringPiece key = env->ReadString(/*index=*/-2);
    if (key.Equals(kCollectionKey)) {
      classification.collection = env->ReadString(/*index=*/-1).ToString();
    } else if (key.Equals(kScoreKey)) {
      classification.score =
          static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kTimeUsecKey)) {
      classification.datetime_parse_result.time_ms_utc =
          static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kGranularityKey)) {
      classification.datetime_parse_result.granularity =
          static_cast<DatetimeGranularity>(
              lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kSerializedEntity)) {
      classification.serialized_entity_data =
          env->ReadString(/*index=*/-1).ToString();
    } else if (key.Equals(kEntityKey)) {
      auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
      env->ReadFlatbuffer(buffer.get());
      classification.serialized_entity_data = buffer->Serialize();
    } else {
      TC3_LOG(INFO) << "Unknown classification result field: "
                    << key.ToString();
    }
    lua_pop(env->state(), 1);
  }
  return classification;
}

void PushAnnotation(const ActionSuggestionAnnotation& annotation,
                    const reflection::Schema* entity_data_schema,
                    LuaEnvironment* env) {
  PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
                 env);
  env->PushString(annotation.name);
  lua_setfield(env->state(), /*idx=*/-2, kNameKey);
  {
    lua_newtable(env->state());
    lua_pushinteger(env->state(), annotation.span.message_index);
    lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
    lua_pushinteger(env->state(), annotation.span.span.first);
    lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
    lua_pushinteger(env->state(), annotation.span.span.second);
    lua_setfield(env->state(), /*idx=*/-2, kEndKey);
  }
  lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
}

void PushAction(
    const ActionSuggestion& action,
    const reflection::Schema* entity_data_schema,
    const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
    LuaEnvironment* env) {
  if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
    // Empty table.
    lua_newtable(env->state());
  } else {
    env->PushFlatbuffer(entity_data_schema,
                        flatbuffers::GetRoot<flatbuffers::Table>(
                            action.serialized_entity_data.data()));
  }
  env->PushString(action.type);
  lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
  env->PushString(action.response_text);
  lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
  lua_pushnumber(env->state(), action.score);
  lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
  lua_pushnumber(env->state(), action.priority_score);
  lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
  annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
                                  env->state());
  lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
}

ActionSuggestion ReadAction(
    const reflection::Schema* actions_entity_data_schema,
    const reflection::Schema* annotations_entity_data_schema,
    LuaEnvironment* env) {
  ActionSuggestion action;
  lua_pushnil(env->state());
  while (lua_next(env->state(), /*idx=*/-2)) {
    const StringPiece key = env->ReadString(/*index=*/-2);
    if (key.Equals(kResponseTextKey)) {
      action.response_text = env->ReadString(/*index=*/-1).ToString();
    } else if (key.Equals(kTypeKey)) {
      action.type = env->ReadString(/*index=*/-1).ToString();
    } else if (key.Equals(kScoreKey)) {
      action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kPriorityScoreKey)) {
      action.priority_score =
          static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
    } else if (key.Equals(kAnnotationKey)) {
      ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
    } else if (key.Equals(kEntityKey)) {
      auto buffer =
          ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
      env->ReadFlatbuffer(buffer.get());
      action.serialized_entity_data = buffer->Serialize();
    } else {
      TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
    }
    lua_pop(env->state(), 1);
  }
  return action;
}

int ReadActions(const reflection::Schema* actions_entity_data_schema,
                const reflection::Schema* annotations_entity_data_schema,
                LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
  if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
    TC3_LOG(ERROR) << "Expected actions table, got: "
                   << lua_type(env->state(), /*idx=*/-1);
    lua_pop(env->state(), 1);
    lua_error(env->state());
    return LUA_ERRRUN;
  }

  // Read actions.
  lua_pushnil(env->state());
  while (lua_next(env->state(), /*idx=*/-2)) {
    if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
      TC3_LOG(ERROR) << "Expected action table, got: "
                     << lua_type(env->state(), /*idx=*/-1);
      lua_pop(env->state(), 1);
      continue;
    }
    actions->push_back(ReadAction(actions_entity_data_schema,
                                  annotations_entity_data_schema, env));
    lua_pop(env->state(), /*n=1*/ 1);
  }
  lua_pop(env->state(), /*n=*/1);

  return LUA_OK;
}

int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
                               const int64 pos, lua_State* state) const {
  const ConversationMessage& message = (*messages)[pos];
  lua_newtable(state);
  lua_pushinteger(state, message.user_id);
  lua_setfield(state, /*idx=*/-2, "user_id");
  env_->PushString(message.text);
  lua_setfield(state, /*idx=*/-2, "text");
  lua_pushinteger(state, message.reference_time_ms_utc);
  lua_setfield(state, /*idx=*/-2, "time_ms_utc");
  env_->PushString(message.reference_timezone);
  lua_setfield(state, /*idx=*/-2, "timezone");
  annotated_span_iterator_.NewIterator("annotation", &message.annotations,
                                       state);
  lua_setfield(state, /*idx=*/-2, "annotation");
  return 1;
}

}  // namespace libtextclassifier3