普通文本  |  202行  |  7.3 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "actions/lua-actions.h"

#include <map>
#include <string>

#include "actions/test_utils.h"
#include "actions/types.h"
#include "utils/tflite-model-executor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

namespace libtextclassifier3 {
namespace {

MATCHER_P2(IsAction, type, response_text, "") {
  return testing::Value(arg.type, type) &&
         testing::Value(arg.response_text, response_text);
}

MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }

TEST(LuaActions, SimpleAction) {
  Conversation conversation;
  const std::string test_snippet = R"(
    return {{ type = "test_action" }}
  )";
  std::vector<ActionSuggestion> actions;
  EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
                  test_snippet, conversation,
                  /*model_executor=*/nullptr,
                  /*model_spec=*/nullptr,
                  /*interpreter=*/nullptr,
                  /*actions_entity_data_schema=*/nullptr,
                  /*annotations_entity_data_schema=*/nullptr)
                  ->SuggestActions(&actions));
  EXPECT_THAT(actions,
              testing::ElementsAreArray({IsActionType("test_action")}));
}

TEST(LuaActions, ConversationActions) {
  Conversation conversation;
  conversation.messages.push_back({/*user_id=*/0, "hello there!"});
  conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
  const std::string test_snippet = R"(
    local actions = {}
    for i, message in pairs(messages) do
      if i < #messages then
        if message.text == "hello there!" and
           messages[i+1].text == "general kenobi!" then
           table.insert(actions, {
             type = "text_reply",
             response_text = "you are a bold one!"
           })
        end
        if message.text == "i am the senate!" and
           messages[i+1].text == "not yet!" then
           table.insert(actions, {
             type = "text_reply",
             response_text = "it's treason then"
           })
        end
      end
    end
    return actions;
  )";
  std::vector<ActionSuggestion> actions;
  EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
                  test_snippet, conversation,
                  /*model_executor=*/nullptr,
                  /*model_spec=*/nullptr,
                  /*interpreter=*/nullptr,
                  /*actions_entity_data_schema=*/nullptr,
                  /*annotations_entity_data_schema=*/nullptr)
                  ->SuggestActions(&actions));
  EXPECT_THAT(actions, testing::ElementsAreArray(
                           {IsAction("text_reply", "you are a bold one!")}));
}

TEST(LuaActions, SimpleModelAction) {
  Conversation conversation;
  const std::string test_snippet = R"(
    if #model.actions_scores == 0 then
      return {{ type = "test_action" }}
    end
    return {}
  )";
  std::vector<ActionSuggestion> actions;
  EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
                  test_snippet, conversation,
                  /*model_executor=*/nullptr,
                  /*model_spec=*/nullptr,
                  /*interpreter=*/nullptr,
                  /*actions_entity_data_schema=*/nullptr,
                  /*annotations_entity_data_schema=*/nullptr)
                  ->SuggestActions(&actions));
  EXPECT_THAT(actions,
              testing::ElementsAreArray({IsActionType("test_action")}));
}

TEST(LuaActions, AnnotationActions) {
  AnnotatedSpan annotation;
  annotation.span = {11, 15};
  annotation.classification = {ClassificationResult("address", 1.0)};
  Conversation conversation = {{{/*user_id=*/1, "are you at home?",
                                 /*reference_time_ms_utc=*/0,
                                 /*reference_timezone=*/"Europe/Zurich",
                                 /*annotations=*/{annotation},
                                 /*locales=*/"en"}}};
  const std::string test_snippet = R"(
    local actions = {}
    local last_message = messages[#messages]
    for i, annotation in pairs(last_message.annotation) do
      if #annotation.classification > 0 then
        if annotation.classification[1].collection == "address" then
           local text = string.sub(last_message.text,
                            annotation.span["begin"] + 1,
                            annotation.span["end"])
           table.insert(actions, {
             type = "text_reply",
             response_text = "i am at " .. text,
             annotation = {{
               name = "location",
               span = {
                 text = text
               },
               entity = annotation.classification[1]
             }},
           })
        end
      end
    end
    return actions;
  )";
  std::vector<ActionSuggestion> actions;
  EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
                  test_snippet, conversation,
                  /*model_executor=*/nullptr,
                  /*model_spec=*/nullptr,
                  /*interpreter=*/nullptr,
                  /*actions_entity_data_schema=*/nullptr,
                  /*annotations_entity_data_schema=*/nullptr)
                  ->SuggestActions(&actions));
  EXPECT_THAT(actions, testing::ElementsAreArray(
                           {IsAction("text_reply", "i am at home")}));
  EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
}

TEST(LuaActions, EntityData) {
  std::string test_schema = TestEntityDataSchema();
  Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
  const std::string test_snippet = R"(
    return {{
      type = "test",
      entity = {
        greeting = "hello",
        location = "there",
        person = "Kenobi",
      },
    }};
  )";
  std::vector<ActionSuggestion> actions;
  EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
                  test_snippet, conversation,
                  /*model_executor=*/nullptr,
                  /*model_spec=*/nullptr,
                  /*interpreter=*/nullptr,
                  /*actions_entity_data_schema=*/
                  flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
                  /*annotations_entity_data_schema=*/nullptr)
                  ->SuggestActions(&actions));
  EXPECT_THAT(actions, testing::SizeIs(1));
  EXPECT_EQ("test", actions.front().type);
  const flatbuffers::Table* entity =
      flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
          actions.front().serialized_entity_data.data()));
  EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
            "hello");
  EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
            "there");
  EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
            "Kenobi");
}

}  // namespace
}  // namespace libtextclassifier3