#include <sstream>
#include <marisa.h>
#include "assert.h"
namespace {
class FindCallback {
public:
FindCallback(std::vector<marisa::UInt32> *key_ids,
std::vector<std::size_t> *key_lengths)
: key_ids_(key_ids), key_lengths_(key_lengths) {}
FindCallback(const FindCallback &callback)
: key_ids_(callback.key_ids_), key_lengths_(callback.key_lengths_) {}
bool operator()(marisa::UInt32 key_id, std::size_t key_length) const {
key_ids_->push_back(key_id);
key_lengths_->push_back(key_length);
return true;
}
private:
std::vector<marisa::UInt32> *key_ids_;
std::vector<std::size_t> *key_lengths_;
// Disallows assignment.
FindCallback &operator=(const FindCallback &);
};
class PredictCallback {
public:
PredictCallback(std::vector<marisa::UInt32> *key_ids,
std::vector<std::string> *keys)
: key_ids_(key_ids), keys_(keys) {}
PredictCallback(const PredictCallback &callback)
: key_ids_(callback.key_ids_), keys_(callback.keys_) {}
bool operator()(marisa::UInt32 key_id, const std::string &key) const {
key_ids_->push_back(key_id);
keys_->push_back(key);
return true;
}
private:
std::vector<marisa::UInt32> *key_ids_;
std::vector<std::string> *keys_;
// Disallows assignment.
PredictCallback &operator=(const PredictCallback &);
};
void TestTrie() {
TEST_START();
marisa::Trie trie;
ASSERT(trie.num_tries() == 0);
ASSERT(trie.num_keys() == 0);
ASSERT(trie.num_nodes() == 0);
ASSERT(trie.total_size() == (sizeof(marisa::UInt32) * 23));
std::vector<std::string> keys;
trie.build(keys);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 0);
ASSERT(trie.num_nodes() == 1);
keys.push_back("apple");
keys.push_back("and");
keys.push_back("Bad");
keys.push_back("apple");
keys.push_back("app");
std::vector<marisa::UInt32> key_ids;
trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL | MARISA_LABEL_ORDER);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 11);
ASSERT(key_ids.size() == 5);
ASSERT(key_ids[0] == 3);
ASSERT(key_ids[1] == 1);
ASSERT(key_ids[2] == 0);
ASSERT(key_ids[3] == 3);
ASSERT(key_ids[4] == 2);
char key_buf[256];
std::size_t key_length;
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
trie.clear();
ASSERT(trie.num_tries() == 0);
ASSERT(trie.num_keys() == 0);
ASSERT(trie.num_nodes() == 0);
ASSERT(trie.total_size() == (sizeof(marisa::UInt32) * 23));
trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 11);
ASSERT(key_ids.size() == 5);
ASSERT(key_ids[0] == 3);
ASSERT(key_ids[1] == 1);
ASSERT(key_ids[2] == 2);
ASSERT(key_ids[3] == 3);
ASSERT(key_ids[4] == 0);
for (std::size_t i = 0; i < keys.size(); ++i) {
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
}
ASSERT(trie["appl"] == trie.notfound());
ASSERT(trie["applex"] == trie.notfound());
ASSERT(trie.find_first("ap") == trie.notfound());
ASSERT(trie.find_first("applex") == trie["app"]);
ASSERT(trie.find_last("ap") == trie.notfound());
ASSERT(trie.find_last("applex") == trie["apple"]);
std::vector<marisa::UInt32> ids;
ASSERT(trie.find("ap", &ids) == 0);
ASSERT(trie.find("applex", &ids) == 2);
ASSERT(ids.size() == 2);
ASSERT(ids[0] == trie["app"]);
ASSERT(ids[1] == trie["apple"]);
std::vector<std::size_t> lengths;
ASSERT(trie.find("Baddie", &ids, &lengths) == 1);
ASSERT(ids.size() == 3);
ASSERT(ids[2] == trie["Bad"]);
ASSERT(lengths.size() == 1);
ASSERT(lengths[0] == 3);
ASSERT(trie.find_callback("anderson", FindCallback(&ids, &lengths)) == 1);
ASSERT(ids.size() == 4);
ASSERT(ids[3] == trie["and"]);
ASSERT(lengths.size() == 2);
ASSERT(lengths[1] == 3);
ASSERT(trie.predict("") == 4);
ASSERT(trie.predict("a") == 3);
ASSERT(trie.predict("ap") == 2);
ASSERT(trie.predict("app") == 2);
ASSERT(trie.predict("appl") == 1);
ASSERT(trie.predict("apple") == 1);
ASSERT(trie.predict("appleX") == 0);
ASSERT(trie.predict("X") == 0);
ids.clear();
ASSERT(trie.predict("a", &ids) == 3);
ASSERT(ids.size() == 3);
ASSERT(ids[0] == trie["app"]);
ASSERT(ids[1] == trie["and"]);
ASSERT(ids[2] == trie["apple"]);
std::vector<std::string> strs;
ASSERT(trie.predict("a", &ids, &strs) == 3);
ASSERT(ids.size() == 6);
ASSERT(ids[3] == trie["app"]);
ASSERT(ids[4] == trie["apple"]);
ASSERT(ids[5] == trie["and"]);
ASSERT(strs[0] == "app");
ASSERT(strs[1] == "apple");
ASSERT(strs[2] == "and");
TEST_END();
}
void TestPrefixTrie() {
TEST_START();
std::vector<std::string> keys;
keys.push_back("after");
keys.push_back("bar");
keys.push_back("car");
keys.push_back("caster");
marisa::Trie trie;
std::vector<marisa::UInt32> key_ids;
trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE
| MARISA_TEXT_TAIL | MARISA_LABEL_ORDER);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 7);
char key_buf[256];
std::size_t key_length;
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
key_length = trie.restore(key_ids[0], NULL, 0);
ASSERT(key_length == keys[0].length());
EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_PARAM_ERROR);
key_length = trie.restore(key_ids[0], key_buf, 5);
ASSERT(key_length == keys[0].length());
key_length = trie.restore(key_ids[0], key_buf, 6);
ASSERT(key_length == keys[0].length());
trie.build(keys, &key_ids, 2 | MARISA_PREFIX_TRIE
| MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER);
ASSERT(trie.num_tries() == 2);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 16);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
key_length = trie.restore(key_ids[0], NULL, 0);
ASSERT(key_length == keys[0].length());
EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_PARAM_ERROR);
key_length = trie.restore(key_ids[0], key_buf, 5);
ASSERT(key_length == keys[0].length());
key_length = trie.restore(key_ids[0], key_buf, 6);
ASSERT(key_length == keys[0].length());
trie.build(keys, &key_ids, 2 | MARISA_PREFIX_TRIE
| MARISA_TEXT_TAIL | MARISA_LABEL_ORDER);
ASSERT(trie.num_tries() == 2);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 14);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
trie.save("trie-test.dat");
trie.clear();
marisa::Mapper mapper;
trie.mmap(&mapper, "trie-test.dat");
ASSERT(mapper.is_open());
ASSERT(trie.num_tries() == 2);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 14);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
std::stringstream stream;
trie.write(stream);
trie.clear();
trie.read(stream);
ASSERT(trie.num_tries() == 2);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 14);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
trie.build(keys, &key_ids, 3 | MARISA_PREFIX_TRIE
| MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER);
ASSERT(trie.num_tries() == 3);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 19);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
ASSERT(trie["ca"] == trie.notfound());
ASSERT(trie["card"] == trie.notfound());
std::size_t length = 0;
ASSERT(trie.find_first("ca") == trie.notfound());
ASSERT(trie.find_first("car") == trie["car"]);
ASSERT(trie.find_first("card", &length) == trie["car"]);
ASSERT(length == 3);
ASSERT(trie.find_last("afte") == trie.notfound());
ASSERT(trie.find_last("after") == trie["after"]);
ASSERT(trie.find_last("afternoon", &length) == trie["after"]);
ASSERT(length == 5);
{
std::vector<marisa::UInt32> ids;
std::vector<std::size_t> lengths;
ASSERT(trie.find("card", &ids, &lengths) == 1);
ASSERT(ids.size() == 1);
ASSERT(ids[0] == trie["car"]);
ASSERT(lengths.size() == 1);
ASSERT(lengths[0] == 3);
ASSERT(trie.predict("ca", &ids) == 2);
ASSERT(ids.size() == 3);
ASSERT(ids[1] == trie["car"]);
ASSERT(ids[2] == trie["caster"]);
ASSERT(trie.predict("ca", &ids, NULL, 1) == 1);
ASSERT(ids.size() == 4);
ASSERT(ids[3] == trie["car"]);
std::vector<std::string> strs;
ASSERT(trie.predict("ca", &ids, &strs, 1) == 1);
ASSERT(ids.size() == 5);
ASSERT(ids[4] == trie["car"]);
ASSERT(strs.size() == 1);
ASSERT(strs[0] == "car");
ASSERT(trie.predict_callback("", PredictCallback(&ids, &strs)) == 4);
ASSERT(ids.size() == 9);
ASSERT(ids[5] == trie["car"]);
ASSERT(ids[6] == trie["caster"]);
ASSERT(ids[7] == trie["after"]);
ASSERT(ids[8] == trie["bar"]);
ASSERT(strs.size() == 5);
ASSERT(strs[1] == "car");
ASSERT(strs[2] == "caster");
ASSERT(strs[3] == "after");
ASSERT(strs[4] == "bar");
}
{
marisa::UInt32 ids[10];
std::size_t lengths[10];
ASSERT(trie.find("card", ids, lengths, 10) == 1);
ASSERT(ids[0] == trie["car"]);
ASSERT(lengths[0] == 3);
ASSERT(trie.predict("ca", ids, NULL, 10) == 2);
ASSERT(ids[0] == trie["car"]);
ASSERT(ids[1] == trie["caster"]);
ASSERT(trie.predict("ca", ids, NULL, 1) == 1);
ASSERT(ids[0] == trie["car"]);
std::string strs[10];
ASSERT(trie.predict("ca", ids, strs, 1) == 1);
ASSERT(ids[0] == trie["car"]);
ASSERT(strs[0] == "car");
ASSERT(trie.predict("", ids, strs, 10) == 4);
ASSERT(ids[0] == trie["car"]);
ASSERT(ids[1] == trie["caster"]);
ASSERT(ids[2] == trie["after"]);
ASSERT(ids[3] == trie["bar"]);
ASSERT(strs[0] == "car");
ASSERT(strs[1] == "caster");
ASSERT(strs[2] == "after");
ASSERT(strs[3] == "bar");
}
TEST_END();
}
void TestPatriciaTrie() {
TEST_START();
std::vector<std::string> keys;
keys.push_back("bach");
keys.push_back("bet");
keys.push_back("chat");
keys.push_back("check");
keys.push_back("check");
marisa::Trie trie;
std::vector<marisa::UInt32> key_ids;
trie.build(keys, &key_ids, 1);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 7);
ASSERT(key_ids.size() == 5);
ASSERT(key_ids[0] == 2);
ASSERT(key_ids[1] == 3);
ASSERT(key_ids[2] == 1);
ASSERT(key_ids[3] == 0);
ASSERT(key_ids[4] == 0);
char key_buf[256];
std::size_t key_length;
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
trie.build(keys, &key_ids, 2 | MARISA_WITHOUT_TAIL);
ASSERT(trie.num_tries() == 2);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 17);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
trie.build(keys, &key_ids, 2);
ASSERT(trie.num_tries() == 2);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 14);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
trie.build(keys, &key_ids, 3 | MARISA_WITHOUT_TAIL);
ASSERT(trie.num_tries() == 3);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 20);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
std::stringstream stream;
trie.write(stream);
trie.clear();
trie.read(stream);
ASSERT(trie.num_tries() == 3);
ASSERT(trie.num_keys() == 4);
ASSERT(trie.num_nodes() == 20);
for (std::size_t i = 0; i < keys.size(); ++i) {
key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
ASSERT(trie[keys[i]] == key_ids[i]);
ASSERT(trie[key_ids[i]] == keys[i]);
ASSERT(key_length == keys[i].length());
ASSERT(keys[i] == key_buf);
}
TEST_END();
}
void TestEmptyString() {
TEST_START();
std::vector<std::string> keys;
keys.push_back("");
marisa::Trie trie;
std::vector<marisa::UInt32> key_ids;
trie.build(keys, &key_ids);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 1);
ASSERT(trie.num_nodes() == 1);
ASSERT(key_ids.size() == 1);
ASSERT(key_ids[0] == 0);
ASSERT(trie[""] == 0);
ASSERT(trie[(marisa::UInt32)0] == "");
ASSERT(trie["x"] == trie.notfound());
ASSERT(trie.find_first("") == 0);
ASSERT(trie.find_first("x") == 0);
ASSERT(trie.find_last("") == 0);
ASSERT(trie.find_last("x") == 0);
std::vector<marisa::UInt32> ids;
ASSERT(trie.find("xyz", &ids) == 1);
ASSERT(ids.size() == 1);
ASSERT(ids[0] == trie[""]);
std::vector<std::size_t> lengths;
ASSERT(trie.find("xyz", &ids, &lengths) == 1);
ASSERT(ids.size() == 2);
ASSERT(ids[0] == trie[""]);
ASSERT(ids[1] == trie[""]);
ASSERT(lengths.size() == 1);
ASSERT(lengths[0] == 0);
ASSERT(trie.find_callback("xyz", FindCallback(&ids, &lengths)) == 1);
ASSERT(ids.size() == 3);
ASSERT(ids[2] == trie[""]);
ASSERT(lengths.size() == 2);
ASSERT(lengths[1] == 0);
ASSERT(trie.predict("xyz", &ids) == 0);
ASSERT(trie.predict("", &ids) == 1);
ASSERT(ids.size() == 4);
ASSERT(ids[3] == trie[""]);
std::vector<std::string> strs;
ASSERT(trie.predict("", &ids, &strs) == 1);
ASSERT(ids.size() == 5);
ASSERT(ids[4] == trie[""]);
ASSERT(strs[0] == "");
TEST_END();
}
void TestBinaryKey() {
TEST_START();
std::string binary_key = "NP";
binary_key += '\0';
binary_key += "Trie";
std::vector<std::string> keys;
keys.push_back(binary_key);
marisa::Trie trie;
std::vector<marisa::UInt32> key_ids;
trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 1);
ASSERT(trie.num_nodes() == 8);
ASSERT(key_ids.size() == 1);
char key_buf[256];
std::size_t key_length;
key_length = trie.restore(0, key_buf, sizeof(key_buf));
ASSERT(trie[keys[0]] == key_ids[0]);
ASSERT(trie[key_ids[0]] == keys[0]);
ASSERT(std::string(key_buf, key_length) == keys[0]);
trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_BINARY_TAIL);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 1);
ASSERT(trie.num_nodes() == 2);
ASSERT(key_ids.size() == 1);
key_length = trie.restore(0, key_buf, sizeof(key_buf));
ASSERT(trie[keys[0]] == key_ids[0]);
ASSERT(trie[key_ids[0]] == keys[0]);
ASSERT(std::string(key_buf, key_length) == keys[0]);
trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_TEXT_TAIL);
ASSERT(trie.num_tries() == 1);
ASSERT(trie.num_keys() == 1);
ASSERT(trie.num_nodes() == 2);
ASSERT(key_ids.size() == 1);
key_length = trie.restore(0, key_buf, sizeof(key_buf));
ASSERT(trie[keys[0]] == key_ids[0]);
ASSERT(trie[key_ids[0]] == keys[0]);
ASSERT(std::string(key_buf, key_length) == keys[0]);
std::vector<marisa::UInt32> ids;
ASSERT(trie.predict_breadth_first("", &ids) == 1);
ASSERT(ids.size() == 1);
ASSERT(ids[0] == key_ids[0]);
std::vector<std::string> strs;
ASSERT(trie.predict_depth_first("NP", &ids, &strs) == 1);
ASSERT(ids.size() == 2);
ASSERT(ids[1] == key_ids[0]);
ASSERT(strs[0] == keys[0]);
TEST_END();
}
} // namespace
int main() {
TestTrie();
TestPrefixTrie();
TestPatriciaTrie();
TestEmptyString();
TestBinaryKey();
return 0;
}