普通文本  |  383行  |  16.05 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "actions/ranker.h"

#include <string>

#include "actions/types.h"
#include "utils/zlib/zlib.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

namespace libtextclassifier3 {
namespace {

MATCHER_P3(IsAction, type, response_text, score, "") {
  return testing::Value(arg.type, type) &&
         testing::Value(arg.response_text, response_text) &&
         testing::Value(arg.score, score);
}

MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }

TEST(RankingTest, DeduplicationSmartReply) {
  const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
  ActionsSuggestionsResponse response;
  response.actions = {
      {/*response_text=*/"hello there", /*type=*/"text_reply",
       /*score=*/1.0},
      {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};

  RankingOptionsT options;
  options.deduplicate_suggestions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);
  EXPECT_THAT(
      response.actions,
      testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
}

TEST(RankingTest, DeduplicationExtraData) {
  const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
  ActionsSuggestionsResponse response;
  response.actions = {
      {/*response_text=*/"hello there", /*type=*/"text_reply",
       /*score=*/1.0, /*priority_score=*/0.0},
      {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
       /*priority_score=*/0.0},
      {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
       /*priority_score=*/0.0,
       /*annotations=*/{}, /*serialized_entity_data=*/"test"},
  };

  RankingOptionsT options;
  options.deduplicate_suggestions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);
  EXPECT_THAT(
      response.actions,
      testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
                                 // Is kept as it has different entity data.
                                 IsAction("text_reply", "hello there", 0.6)}));
}

TEST(RankingTest, DeduplicationAnnotations) {
  const Conversation conversation = {
      {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
  ActionsSuggestionsResponse response;
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
                       /*text=*/"742 Evergreen Terrace"};
    annotation.entity = ClassificationResult("address", 0.5);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"view_map",
                                /*score=*/0.5,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
  }
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
                       /*text=*/"742 Evergreen Terrace"};
    annotation.entity = ClassificationResult("address", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"view_map",
                                /*score=*/1.0,
                                /*priority_score=*/2.0,
                                /*annotations=*/{annotation}});
  }
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
                       /*text=*/"1-800-TESTING"};
    annotation.entity = ClassificationResult("phone", 0.5);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"call_phone",
                                /*score=*/0.5,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
  }

  RankingOptionsT options;
  options.deduplicate_suggestions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);
  EXPECT_THAT(response.actions,
              testing::ElementsAreArray({IsAction("view_map", "", 1.0),
                                         IsAction("call_phone", "", 0.5)}));
}

TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
  const Conversation conversation = {
      {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
  ActionsSuggestionsResponse response;
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
                       /*text=*/"742 Evergreen Terrace"};
    annotation.entity = ClassificationResult("address", 0.5);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"view_map",
                                /*score=*/0.6,
                                /*priority_score=*/2.0,
                                /*annotations=*/{annotation}});
  }
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
                       /*text=*/"742 Evergreen Terrace"};
    annotation.entity = ClassificationResult("address", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"view_map",
                                /*score=*/1.0,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
  }
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
                       /*text=*/"1-800-TESTING"};
    annotation.entity = ClassificationResult("phone", 0.5);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"call_phone",
                                /*score=*/0.5,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
  }

  RankingOptionsT options;
  options.deduplicate_suggestions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);
  EXPECT_THAT(
      response.actions,
      testing::ElementsAreArray(
          {IsAction("view_map", "",
                    0.6),  // lower score wins, as priority score is higher
           IsAction("call_phone", "", 0.5)}));
}

TEST(RankingTest, DeduplicatesConflictingActions) {
  const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
  ActionsSuggestionsResponse response;
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
                       /*text=*/"911"};
    annotation.entity = ClassificationResult("phone", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"call_phone",
                                /*score=*/1.0,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
  }
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
                       /*text=*/"A-911"};
    annotation.entity = ClassificationResult("code", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"copy_code",
                                /*score=*/1.0,
                                /*priority_score=*/2.0,
                                /*annotations=*/{annotation}});
  }
  RankingOptionsT options;
  options.deduplicate_suggestions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);
  EXPECT_THAT(response.actions,
              testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
}

TEST(RankingTest, HandlesCompressedLuaScript) {
  const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
  ActionsSuggestionsResponse response;
  response.actions = {
      {/*response_text=*/"hello there", /*type=*/"text_reply",
       /*score=*/1.0},
      {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
      {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
  const std::string test_snippet = R"(
    local result = {}
    for id, action in pairs(actions) do
      if action.type ~= "text_reply" then
        table.insert(result, id)
      end
    end
    return result
  )";
  RankingOptionsT options;
  options.compressed_lua_ranking_script.reset(new CompressedBufferT);
  std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
  compressor->Compress(test_snippet,
                       options.compressed_lua_ranking_script.get());
  options.deduplicate_suggestions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));

  std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      decompressor.get(), /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);
  EXPECT_THAT(response.actions,
              testing::ElementsAreArray({IsActionType("share_location"),
                                         IsActionType("add_to_collection")}));
}

TEST(RankingTest, SuppressSmartRepliesWithAction) {
  const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
  ActionsSuggestionsResponse response;
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
                       /*text=*/"911"};
    annotation.entity = ClassificationResult("phone", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"call_phone",
                                /*score=*/1.0,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
  }
  response.actions.push_back({/*response_text=*/"How are you?",
                              /*type=*/"text_reply"});
  RankingOptionsT options;
  options.suppress_smart_replies_with_actions = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);

  EXPECT_THAT(response.actions,
              testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
}

TEST(RankingTest, GroupsActionsByAnnotations) {
  const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
  ActionsSuggestionsResponse response;
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
                       /*text=*/"911"};
    annotation.entity = ClassificationResult("phone", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"call_phone",
                                /*score=*/1.0,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"add_contact",
                                /*score=*/0.0,
                                /*priority_score=*/0.0,
                                /*annotations=*/{annotation}});
  }
  response.actions.push_back({/*response_text=*/"How are you?",
                              /*type=*/"text_reply",
                              /*score=*/0.5});
  RankingOptionsT options;
  options.group_by_annotations = true;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);

  // The text reply should be last, even though it has a higher score than the
  // `add_contact` action.
  EXPECT_THAT(
      response.actions,
      testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
                                 IsAction("add_contact", "", 0.0),
                                 IsAction("text_reply", "How are you?", 0.5)}));
}

TEST(RankingTest, SortsActionsByScore) {
  const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
  ActionsSuggestionsResponse response;
  {
    ActionSuggestionAnnotation annotation;
    annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
                       /*text=*/"911"};
    annotation.entity = ClassificationResult("phone", 1.0);
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"call_phone",
                                /*score=*/1.0,
                                /*priority_score=*/1.0,
                                /*annotations=*/{annotation}});
    response.actions.push_back({/*response_text=*/"",
                                /*type=*/"add_contact",
                                /*score=*/0.0,
                                /*priority_score=*/0.0,
                                /*annotations=*/{annotation}});
  }
  response.actions.push_back({/*response_text=*/"How are you?",
                              /*type=*/"text_reply",
                              /*score=*/0.5});
  RankingOptionsT options;
  // Don't group by annotation.
  options.group_by_annotations = false;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(RankingOptions::Pack(builder, &options));
  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");

  ranker->RankActions(conversation, &response);

  EXPECT_THAT(
      response.actions,
      testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
                                 IsAction("text_reply", "How are you?", 0.5),
                                 IsAction("add_contact", "", 0.0)}));
}

}  // namespace
}  // namespace libtextclassifier3