普通文本  |  1292行  |  49.52 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "text-classifier.h"

#include <fstream>
#include <iostream>
#include <memory>
#include <string>

#include "model_generated.h"
#include "types-test-util.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

namespace libtextclassifier2 {
namespace {

using testing::ElementsAreArray;
using testing::IsEmpty;
using testing::Pair;
using testing::Values;

std::string FirstResult(const std::vector<ClassificationResult>& results) {
  if (results.empty()) {
    return "<INVALID RESULTS>";
  }
  return results[0].collection;
}

MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
  return testing::Value(arg.span, Pair(start, end)) &&
         testing::Value(FirstResult(arg.classification), best_class);
}

std::string ReadFile(const std::string& file_name) {
  std::ifstream file_stream(file_name);
  return std::string(std::istreambuf_iterator<char>(file_stream), {});
}

std::string GetModelPath() {
  return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
}

TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib);
  EXPECT_FALSE(classifier);
}

class TextClassifierTest : public ::testing::TestWithParam<const char*> {};

INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest,
                        Values("test_model_cc.fb"));
INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest,
                        Values("test_model.fb"));

TEST_P(TextClassifierTest, ClassifyText) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ("other",
            FirstResult(classifier->ClassifyText(
                "this afternoon Barack Obama gave a speech at", {15, 27})));
  EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
                         "Call me at (800) 123-456 today", {11, 24})));

  // More lines.
  EXPECT_EQ("other",
            FirstResult(classifier->ClassifyText(
                "this afternoon Barack Obama gave a speech at|Visit "
                "www.google.com every today!|Call me at (800) 123-456 today.",
                {15, 27})));
  EXPECT_EQ("phone",
            FirstResult(classifier->ClassifyText(
                "this afternoon Barack Obama gave a speech at|Visit "
                "www.google.com every today!|Call me at (800) 123-456 today.",
                {90, 103})));

  // Single word.
  EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
  EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
  EXPECT_EQ("<INVALID RESULTS>",
            FirstResult(classifier->ClassifyText("asdf", {0, 0})));

  // Junk.
  EXPECT_EQ("<INVALID RESULTS>",
            FirstResult(classifier->ClassifyText("", {0, 0})));
  EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
                                     "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
  // Test invalid utf8 input.
  EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
                                     "\xf0\x9f\x98\x8b\x8b", {0, 0})));
}

TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  unpacked_model->classification_model.clear();
  unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
  unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);

  // The classification model is still needed for selection scores.
  ASSERT_FALSE(classifier);
}

TEST_P(TextClassifierTest, ClassifyTextDisabled) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
  unpacked_model->triggering_options->enabled_modes =
      ModeFlag_ANNOTATION_AND_SELECTION;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_THAT(
      classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
      IsEmpty());
}

TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
                                        &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
                         "Call me at (800) 123-456 today", {11, 24})));

  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
  unpacked_model->output_options.reset(new OutputOptionsT);

  // Disable phone classification
  unpacked_model->output_options->filtered_collections_classification.push_back(
      "phone");

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder.GetBufferPointer()),
      builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
                         "Call me at (800) 123-456 today", {11, 24})));

  // Check that the address classification still passes.
  EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
                           "350 Third Street, Cambridge", {0, 27})));
}

std::unique_ptr<RegexModel_::PatternT> MakePattern(
    const std::string& collection_name, const std::string& pattern,
    const bool enabled_for_classification, const bool enabled_for_selection,
    const bool enabled_for_annotation, const float score) {
  std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
  result->collection_name = collection_name;
  result->pattern = pattern;
  // We cannot directly operate with |= on the flag, so use an int here.
  int enabled_modes = ModeFlag_NONE;
  if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
  if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
  if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
  result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
  result->target_classification_score = score;
  result->priority_score = score;
  return result;
}

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, ClassifyTextRegularExpression) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Add test regex models.
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "person", "Barack Obama", /*enabled_for_classification=*/true,
      /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
      /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ("flight",
            FirstResult(classifier->ClassifyText(
                "Your flight LX373 is delayed by 3 hours.", {12, 17})));
  EXPECT_EQ("person",
            FirstResult(classifier->ClassifyText(
                "this afternoon Barack Obama gave a speech at", {15, 27})));
  EXPECT_EQ("email",
            FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
  EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
                         "Contact me at you@android.com", {14, 29})));

  EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
                       "Visit www.google.com every today!", {6, 20})));

  EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
  EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
                                                           {7, 12})));

  // More lines.
  EXPECT_EQ("url",
            FirstResult(classifier->ClassifyText(
                "this afternoon Barack Obama gave a speech at|Visit "
                "www.google.com every today!|Call me at (800) 123-456 today.",
                {51, 65})));
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Add test regex models.
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.back()->priority_score = 1.1;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  // Check regular expression selection.
  EXPECT_EQ(classifier->SuggestSelection(
                "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
            std::make_pair(12, 19));
  EXPECT_EQ(classifier->SuggestSelection(
                "this afternoon Barack Obama gave a speech at", {15, 21}),
            std::make_pair(15, 27));
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest,
       SuggestSelectionRegularExpressionConflictsModelWins) {
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Add test regex models.
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.back()->priority_score = 0.5;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize());
  ASSERT_TRUE(classifier);

  // Check conflict resolution.
  EXPECT_EQ(
      classifier->SuggestSelection(
          "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
          {55, 57}),
      std::make_pair(26, 62));
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest,
       SuggestSelectionRegularExpressionConflictsRegexWins) {
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Add test regex models.
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
  unpacked_model->regex_model->patterns.back()->priority_score = 1.1;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize());
  ASSERT_TRUE(classifier);

  // Check conflict resolution.
  EXPECT_EQ(
      classifier->SuggestSelection(
          "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
          {55, 57}),
      std::make_pair(55, 62));
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateRegex) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Add test regex models.
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";
  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
                  IsAnnotatedSpan(6, 18, "person"),
                  IsAnnotatedSpan(19, 24, "date"),
                  IsAnnotatedSpan(28, 55, "address"),
                  IsAnnotatedSpan(79, 91, "phone"),
              }));
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

TEST_P(TextClassifierTest, PhoneFiltering) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
                         "phone: (123) 456 789", {7, 20})));
  EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
                         "phone: (123) 456 789,0001112", {7, 25})));
  EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
                         "phone: (123) 456 789,0001112", {7, 28})));
}

TEST_P(TextClassifierTest, SuggestSelection) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(classifier->SuggestSelection(
                "this afternoon Barack Obama gave a speech at", {15, 21}),
            std::make_pair(15, 21));

  // Try passing whole string.
  // If more than 1 token is specified, we should return back what entered.
  EXPECT_EQ(
      classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
      std::make_pair(0, 27));

  // Single letter.
  EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));

  // Single word.
  EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));

  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
      std::make_pair(11, 23));

  // Unpaired bracket stripping.
  EXPECT_EQ(
      classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
      std::make_pair(11, 25));
  EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
            std::make_pair(12, 15));
  EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
            std::make_pair(11, 15));
  EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
            std::make_pair(12, 15));

  // If the resulting selection would be empty, the original span is returned.
  EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
            std::make_pair(11, 13));
  EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
            std::make_pair(11, 12));
  EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
            std::make_pair(11, 12));
}

TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Disable the selection model.
  unpacked_model->selection_model.clear();
  unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
  unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  // Selection model needs to be present for annotation.
  ASSERT_FALSE(classifier);
}

TEST_P(TextClassifierTest, SuggestSelectionDisabled) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Disable the selection model.
  unpacked_model->selection_model.clear();
  unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
  unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
  unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
      std::make_pair(11, 14));

  EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
                         "call me at (800) 123-456 today", {11, 24})));

  EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
              IsEmpty());
}

TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
                                        &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
      std::make_pair(11, 23));

  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
  unpacked_model->output_options.reset(new OutputOptionsT);

  // Disable phone selection
  unpacked_model->output_options->filtered_collections_selection.push_back(
      "phone");
  // We need to force this for filtering.
  unpacked_model->selection_options->always_classify_suggested_selection = true;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder.GetBufferPointer()),
      builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
      std::make_pair(11, 14));

  // Address selection should still work.
  EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
            std::make_pair(0, 27));
}

TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
            std::make_pair(0, 27));
  EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
            std::make_pair(0, 27));
  EXPECT_EQ(
      classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
      std::make_pair(0, 27));
  EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
                                         {16, 22}),
            std::make_pair(6, 33));
}

TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
            std::make_pair(4, 16));
  EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
            std::make_pair(0, 12));

  SelectionOptions options;
  EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
            std::make_pair(0, 7));
}

TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  // From the right.
  EXPECT_EQ(classifier->SuggestSelection(
                "this afternoon BarackObama, gave a speech at", {15, 26}),
            std::make_pair(15, 26));

  // From the right multiple.
  EXPECT_EQ(classifier->SuggestSelection(
                "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
            std::make_pair(15, 26));

  // From the left multiple.
  EXPECT_EQ(classifier->SuggestSelection(
                "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
            std::make_pair(21, 32));

  // From both sides.
  EXPECT_EQ(classifier->SuggestSelection(
                "this afternoon !BarackObama,- gave a speech at", {16, 27}),
            std::make_pair(16, 27));
}

TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  // Try passing in bunch of invalid selections.
  EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
  EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
            std::make_pair(-10, 27));
  EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
            std::make_pair(0, 27));
  EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
            std::make_pair(-30, 300));
  EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
            std::make_pair(-10, -1));
  EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
            std::make_pair(100, 17));

  // Try passing invalid utf8.
  EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
            std::make_pair(-1, -1));
}

TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
      std::make_pair(11, 23));
  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
      std::make_pair(10, 11));
  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
      std::make_pair(23, 24));
  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
      std::make_pair(23, 24));
  EXPECT_EQ(classifier->SuggestSelection("call me at 857   225 3556, today",
                                         {14, 17}),
            std::make_pair(11, 25));
  EXPECT_EQ(
      classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
      std::make_pair(11, 23));
  EXPECT_EQ(
      classifier->SuggestSelection(
          "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
      std::make_pair(14, 40));
  EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
            std::make_pair(4, 5));
  EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
            std::make_pair(7, 8));

  // With a punctuation around the selected whitespace.
  EXPECT_EQ(
      classifier->SuggestSelection(
          "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
      std::make_pair(14, 41));

  // When all's whitespace, should return the original indices.
  EXPECT_EQ(classifier->SuggestSelection("      ", {0, 1}),
            std::make_pair(0, 1));
  EXPECT_EQ(classifier->SuggestSelection("      ", {0, 3}),
            std::make_pair(0, 3));
  EXPECT_EQ(classifier->SuggestSelection("      ", {2, 3}),
            std::make_pair(2, 3));
  EXPECT_EQ(classifier->SuggestSelection("      ", {5, 6}),
            std::make_pair(5, 6));
}

TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) {
  CREATE_UNILIB_FOR_TESTING;
  UnicodeText text;

  text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
            std::make_pair(3, 4));
  text = UTF8ToUnicodeText("abcd     ", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
            std::make_pair(3, 4));

  // Nothing on the left.
  text = UTF8ToUnicodeText("     efgh", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
            std::make_pair(4, 5));
  text = UTF8ToUnicodeText("     efgh", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
            std::make_pair(0, 1));

  // Whitespace only.
  text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib),
            std::make_pair(2, 3));
  text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
            std::make_pair(4, 5));
  text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
  EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
            std::make_pair(0, 1));
}

TEST_P(TextClassifierTest, Annotate) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";
  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
                  IsAnnotatedSpan(19, 24, "date"),
#endif
                  IsAnnotatedSpan(28, 55, "address"),
                  IsAnnotatedSpan(79, 91, "phone"),
              }));

  AnnotationOptions options;
  EXPECT_THAT(classifier->Annotate("853 225 3556", options),
              ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
  EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());

  // Try passing invalid utf8.
  EXPECT_TRUE(
      classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
          .empty());
}

TEST_P(TextClassifierTest, AnnotateSmallBatches) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Set the batch size.
  unpacked_model->selection_options->batch_size = 4;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";
  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
                  IsAnnotatedSpan(19, 24, "date"),
#endif
                  IsAnnotatedSpan(28, 55, "address"),
                  IsAnnotatedSpan(79, 91, "phone"),
              }));

  AnnotationOptions options;
  EXPECT_THAT(classifier->Annotate("853 225 3556", options),
              ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
  EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
}

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
  // Add test threshold.
  unpacked_model->triggering_options->min_annotate_confidence =
      2.f;  // Discards all results.
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";

  EXPECT_EQ(classifier->Annotate(test_string).size(), 1);
}
#endif

TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Add test thresholds.
  unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
  unpacked_model->triggering_options->min_annotate_confidence =
      0.f;  // Keeps all results.
  unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
  EXPECT_EQ(classifier->Annotate(test_string).size(), 3);
#else
  // In non-ICU mode there is no "date" result.
  EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
#endif
}

TEST_P(TextClassifierTest, AnnotateDisabled) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Disable the model for annotation.
  unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);
  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";
  EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
}

TEST_P(TextClassifierTest, AnnotateFilteredCollections) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
                                        &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";

  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
                  IsAnnotatedSpan(19, 24, "date"),
#endif
                  IsAnnotatedSpan(28, 55, "address"),
                  IsAnnotatedSpan(79, 91, "phone"),
              }));

  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
  unpacked_model->output_options.reset(new OutputOptionsT);

  // Disable phone annotation
  unpacked_model->output_options->filtered_collections_annotation.push_back(
      "phone");

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder.GetBufferPointer()),
      builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
                  IsAnnotatedSpan(19, 24, "date"),
#endif
                  IsAnnotatedSpan(28, 55, "address"),
              }));
}

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
                                        &unilib);
  ASSERT_TRUE(classifier);

  const std::string test_string =
      "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
      "number is 853 225 3556";

  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
                  IsAnnotatedSpan(19, 24, "date"),
#endif
                  IsAnnotatedSpan(28, 55, "address"),
                  IsAnnotatedSpan(79, 91, "phone"),
              }));

  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
  unpacked_model->output_options.reset(new OutputOptionsT);

  // We add a custom annotator that wins against the phone classification
  // below and that we subsequently suppress.
  unpacked_model->output_options->filtered_collections_annotation.push_back(
      "suppress");

  unpacked_model->regex_model->patterns.push_back(MakePattern(
      "suppress", "(\\d{3} ?\\d{4})",
      /*enabled_for_classification=*/false,
      /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder.GetBufferPointer()),
      builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_THAT(classifier->Annotate(test_string),
              ElementsAreArray({
                  IsAnnotatedSpan(19, 24, "date"),
                  IsAnnotatedSpan(28, 55, "address"),
              }));
}
#endif

#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDate) {
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam());
  EXPECT_TRUE(classifier);

  std::vector<ClassificationResult> result;
  ClassificationOptions options;

  options.reference_timezone = "Europe/Zurich";
  result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);

  ASSERT_EQ(result.size(), 1);
  EXPECT_THAT(result[0].collection, "date");
  EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
  EXPECT_EQ(result[0].datetime_parse_result.granularity,
            DatetimeGranularity::GRANULARITY_DAY);
  result.clear();

  options.reference_timezone = "America/Los_Angeles";
  result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
  ASSERT_EQ(result.size(), 1);
  EXPECT_THAT(result[0].collection, "date");
  EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
  EXPECT_EQ(result[0].datetime_parse_result.granularity,
            DatetimeGranularity::GRANULARITY_DAY);
  result.clear();

  options.reference_timezone = "America/Los_Angeles";
  result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
  ASSERT_EQ(result.size(), 1);
  EXPECT_THAT(result[0].collection, "date");
  EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
  EXPECT_EQ(result[0].datetime_parse_result.granularity,
            DatetimeGranularity::GRANULARITY_SECOND);
  result.clear();

  // Date on another line.
  options.reference_timezone = "Europe/Zurich";
  result = classifier->ClassifyText(
      "hello world this is the first line\n"
      "january 1, 2017",
      {35, 50}, options);
  ASSERT_EQ(result.size(), 1);
  EXPECT_THAT(result[0].collection, "date");
  EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
  EXPECT_EQ(result[0].datetime_parse_result.granularity,
            DatetimeGranularity::GRANULARITY_DAY);
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDatePriorities) {
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam());
  EXPECT_TRUE(classifier);

  std::vector<ClassificationResult> result;
  ClassificationOptions options;

  result.clear();
  options.reference_timezone = "Europe/Zurich";
  options.locales = "en-US";
  result = classifier->ClassifyText("03.05.1970", {0, 10}, options);

  ASSERT_EQ(result.size(), 1);
  EXPECT_THAT(result[0].collection, "date");
  EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
  EXPECT_EQ(result[0].datetime_parse_result.granularity,
            DatetimeGranularity::GRANULARITY_DAY);

  result.clear();
  options.reference_timezone = "Europe/Zurich";
  options.locales = "de";
  result = classifier->ClassifyText("03.05.1970", {0, 10}, options);

  ASSERT_EQ(result.size(), 1);
  EXPECT_THAT(result[0].collection, "date");
  EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
  EXPECT_EQ(result[0].datetime_parse_result.granularity,
            DatetimeGranularity::GRANULARITY_DAY);
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  // Disable the patterns for selection.
  for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
    unpacked_model->datetime_model->patterns[i]->enabled_modes =
        ModeFlag_ANNOTATION_AND_CLASSIFICATION;
  }
  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));

  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromUnownedBuffer(
          reinterpret_cast<const char*>(builder.GetBufferPointer()),
          builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);
  EXPECT_EQ("date",
            FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
  EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
            std::make_pair(0, 7));
  EXPECT_THAT(classifier->Annotate("january 1, 2017"),
              ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

class TestingTextClassifier : public TextClassifier {
 public:
  TestingTextClassifier(const std::string& model, const UniLib* unilib)
      : TextClassifier(ViewModel(model.data(), model.size()), unilib) {}

  using TextClassifier::ResolveConflicts;
};

AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
                                const std::string& collection,
                                const float score) {
  AnnotatedSpan result;
  result.span = span;
  result.classification.push_back({collection, score});
  return result;
}

TEST(TextClassifierTest, ResolveConflictsTrivial) {
  CREATE_UNILIB_FOR_TESTING;
  TestingTextClassifier classifier("", &unilib);

  std::vector<AnnotatedSpan> candidates{
      {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};

  std::vector<int> chosen;
  classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
                              /*interpreter_manager=*/nullptr, &chosen);
  EXPECT_THAT(chosen, ElementsAreArray({0}));
}

TEST(TextClassifierTest, ResolveConflictsSequence) {
  CREATE_UNILIB_FOR_TESTING;
  TestingTextClassifier classifier("", &unilib);

  std::vector<AnnotatedSpan> candidates{{
      MakeAnnotatedSpan({0, 1}, "phone", 1.0),
      MakeAnnotatedSpan({1, 2}, "phone", 1.0),
      MakeAnnotatedSpan({2, 3}, "phone", 1.0),
      MakeAnnotatedSpan({3, 4}, "phone", 1.0),
      MakeAnnotatedSpan({4, 5}, "phone", 1.0),
  }};

  std::vector<int> chosen;
  classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
                              /*interpreter_manager=*/nullptr, &chosen);
  EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}

TEST(TextClassifierTest, ResolveConflictsThreeSpans) {
  CREATE_UNILIB_FOR_TESTING;
  TestingTextClassifier classifier("", &unilib);

  std::vector<AnnotatedSpan> candidates{{
      MakeAnnotatedSpan({0, 3}, "phone", 1.0),
      MakeAnnotatedSpan({1, 5}, "phone", 0.5),  // Looser!
      MakeAnnotatedSpan({3, 7}, "phone", 1.0),
  }};

  std::vector<int> chosen;
  classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
                              /*interpreter_manager=*/nullptr, &chosen);
  EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}

TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) {
  CREATE_UNILIB_FOR_TESTING;
  TestingTextClassifier classifier("", &unilib);

  std::vector<AnnotatedSpan> candidates{{
      MakeAnnotatedSpan({0, 3}, "phone", 0.5),  // Looser!
      MakeAnnotatedSpan({1, 5}, "phone", 1.0),
      MakeAnnotatedSpan({3, 7}, "phone", 0.6),  // Looser!
  }};

  std::vector<int> chosen;
  classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
                              /*interpreter_manager=*/nullptr, &chosen);
  EXPECT_THAT(chosen, ElementsAreArray({1}));
}

TEST(TextClassifierTest, ResolveConflictsFiveSpans) {
  CREATE_UNILIB_FOR_TESTING;
  TestingTextClassifier classifier("", &unilib);

  std::vector<AnnotatedSpan> candidates{{
      MakeAnnotatedSpan({0, 3}, "phone", 0.5),
      MakeAnnotatedSpan({1, 5}, "other", 1.0),  // Looser!
      MakeAnnotatedSpan({3, 7}, "phone", 0.6),
      MakeAnnotatedSpan({8, 12}, "phone", 0.6),  // Looser!
      MakeAnnotatedSpan({11, 15}, "phone", 0.9),
  }};

  std::vector<int> chosen;
  classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
                              /*interpreter_manager=*/nullptr, &chosen);
  EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, LongInput) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  for (const auto& type_value_pair :
       std::vector<std::pair<std::string, std::string>>{
           {"address", "350 Third Street, Cambridge"},
           {"phone", "123 456-7890"},
           {"url", "www.google.com"},
           {"email", "someone@gmail.com"},
           {"flight", "LX 38"},
           {"date", "September 1, 2018"}}) {
    const std::string input_100k = std::string(50000, ' ') +
                                   type_value_pair.second +
                                   std::string(50000, ' ');
    const int value_length = type_value_pair.second.size();

    EXPECT_THAT(classifier->Annotate(input_100k),
                ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
                                                  type_value_pair.first)}));
    EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
              std::make_pair(50000, 50000 + value_length));
    EXPECT_EQ(type_value_pair.first,
              FirstResult(classifier->ClassifyText(
                  input_100k, {50000, 50000 + value_length})));
  }
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
// These coarse tests are there only to make sure the execution happens in
// reasonable amount of time.
TEST_P(TextClassifierTest, LongInputNoResultCheck) {
  CREATE_UNILIB_FOR_TESTING;
  std::unique_ptr<TextClassifier> classifier =
      TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
  ASSERT_TRUE(classifier);

  for (const std::string& value :
       std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
    const std::string input_100k =
        std::string(50000, ' ') + value + std::string(50000, ' ');
    const int value_length = value.size();

    classifier->Annotate(input_100k);
    classifier->SuggestSelection(input_100k, {50000, 50001});
    classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
  }
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, MaxTokenLength) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  std::unique_ptr<TextClassifier> classifier;

  // With unrestricted number of tokens should behave normally.
  unpacked_model->classification_options->max_num_tokens = -1;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));
  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder.GetBufferPointer()),
      builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(FirstResult(classifier->ClassifyText(
                "I live at 350 Third Street, Cambridge.", {10, 37})),
            "address");

  // Raise the maximum number of tokens to suppress the classification.
  unpacked_model->classification_options->max_num_tokens = 3;

  flatbuffers::FlatBufferBuilder builder2;
  builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder2.GetBufferPointer()),
      builder2.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(FirstResult(classifier->ClassifyText(
                "I live at 350 Third Street, Cambridge.", {10, 37})),
            "other");
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, MinAddressTokenLength) {
  CREATE_UNILIB_FOR_TESTING;
  const std::string test_model = ReadFile(GetModelPath() + GetParam());
  std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());

  std::unique_ptr<TextClassifier> classifier;

  // With unrestricted number of address tokens should behave normally.
  unpacked_model->classification_options->address_min_num_tokens = 0;

  flatbuffers::FlatBufferBuilder builder;
  builder.Finish(Model::Pack(builder, unpacked_model.get()));
  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder.GetBufferPointer()),
      builder.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(FirstResult(classifier->ClassifyText(
                "I live at 350 Third Street, Cambridge.", {10, 37})),
            "address");

  // Raise number of address tokens to suppress the address classification.
  unpacked_model->classification_options->address_min_num_tokens = 5;

  flatbuffers::FlatBufferBuilder builder2;
  builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
  classifier = TextClassifier::FromUnownedBuffer(
      reinterpret_cast<const char*>(builder2.GetBufferPointer()),
      builder2.GetSize(), &unilib);
  ASSERT_TRUE(classifier);

  EXPECT_EQ(FirstResult(classifier->ClassifyText(
                "I live at 350 Third Street, Cambridge.", {10, 37})),
            "other");
}
#endif  // LIBTEXTCLASSIFIER_UNILIB_ICU

}  // namespace
}  // namespace libtextclassifier2