#include <algorithm>
#include <stdexcept>
#include "trie.h"
namespace marisa_alpha {
namespace {
template <typename T, typename U>
class PredictCallback {
public:
PredictCallback(T key_ids, U keys, std::size_t max_num_results)
: key_ids_(key_ids), keys_(keys),
max_num_results_(max_num_results), num_results_(0) {}
PredictCallback(const PredictCallback &callback)
: key_ids_(callback.key_ids_), keys_(callback.keys_),
max_num_results_(callback.max_num_results_),
num_results_(callback.num_results_) {}
bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) {
if (key_ids_.is_valid()) {
key_ids_.insert(num_results_, key_id);
}
if (keys_.is_valid()) {
keys_.insert(num_results_, key);
}
return ++num_results_ < max_num_results_;
}
private:
T key_ids_;
U keys_;
const std::size_t max_num_results_;
std::size_t num_results_;
// Disallows assignment.
PredictCallback &operator=(const PredictCallback &);
};
} // namespace
std::string Trie::restore(UInt32 key_id) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
std::string key;
restore_(key_id, &key);
return key;
}
void Trie::restore(UInt32 key_id, std::string *key) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
MARISA_ALPHA_THROW_IF(key == NULL, MARISA_ALPHA_PARAM_ERROR);
restore_(key_id, key);
}
std::size_t Trie::restore(UInt32 key_id, char *key_buf,
std::size_t key_buf_size) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
MARISA_ALPHA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
MARISA_ALPHA_PARAM_ERROR);
return restore_(key_id, key_buf, key_buf_size);
}
UInt32 Trie::lookup(const char *str) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return lookup_<CQuery>(CQuery(str));
}
UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return lookup_<const Query &>(Query(ptr, length));
}
std::size_t Trie::find(const char *str,
UInt32 *key_ids, std::size_t *key_lengths,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return find_<CQuery>(CQuery(str),
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}
std::size_t Trie::find(const char *ptr, std::size_t length,
UInt32 *key_ids, std::size_t *key_lengths,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return find_<const Query &>(Query(ptr, length),
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}
std::size_t Trie::find(const char *str,
std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return find_<CQuery>(CQuery(str),
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}
std::size_t Trie::find(const char *ptr, std::size_t length,
std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return find_<const Query &>(Query(ptr, length),
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}
UInt32 Trie::find_first(const char *str,
std::size_t *key_length) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return find_first_<CQuery>(CQuery(str), key_length);
}
UInt32 Trie::find_first(const char *ptr, std::size_t length,
std::size_t *key_length) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return find_first_<const Query &>(Query(ptr, length), key_length);
}
UInt32 Trie::find_last(const char *str,
std::size_t *key_length) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return find_last_<CQuery>(CQuery(str), key_length);
}
UInt32 Trie::find_last(const char *ptr, std::size_t length,
std::size_t *key_length) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return find_last_<const Query &>(Query(ptr, length), key_length);
}
std::size_t Trie::predict(const char *str,
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return (keys == NULL) ?
predict_breadth_first(str, key_ids, keys, max_num_results) :
predict_depth_first(str, key_ids, keys, max_num_results);
}
std::size_t Trie::predict(const char *ptr, std::size_t length,
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return (keys == NULL) ?
predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
predict_depth_first(ptr, length, key_ids, keys, max_num_results);
}
std::size_t Trie::predict(const char *str,
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return (keys == NULL) ?
predict_breadth_first(str, key_ids, keys, max_num_results) :
predict_depth_first(str, key_ids, keys, max_num_results);
}
std::size_t Trie::predict(const char *ptr, std::size_t length,
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return (keys == NULL) ?
predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
predict_depth_first(ptr, length, key_ids, keys, max_num_results);
}
std::size_t Trie::predict_breadth_first(const char *str,
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return predict_breadth_first_<CQuery>(CQuery(str),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return predict_breadth_first_<const Query &>(Query(ptr, length),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_breadth_first(const char *str,
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return predict_breadth_first_<CQuery>(CQuery(str),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return predict_breadth_first_<const Query &>(Query(ptr, length),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_depth_first(const char *str,
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return predict_depth_first_<CQuery>(CQuery(str),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return predict_depth_first_<const Query &>(Query(ptr, length),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_depth_first(
const char *str, std::vector<UInt32> *key_ids,
std::vector<std::string> *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
return predict_depth_first_<CQuery>(CQuery(str),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
std::size_t Trie::predict_depth_first(
const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
std::vector<std::string> *keys, std::size_t max_num_results) const {
MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
MARISA_ALPHA_PARAM_ERROR);
return predict_depth_first_<const Query &>(Query(ptr, length),
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}
void Trie::restore_(UInt32 key_id, std::string *key) const {
const std::size_t start_pos = key->length();
try {
UInt32 node = key_id_to_node(key_id);
while (node != 0) {
if (has_link(node)) {
const std::size_t prev_pos = key->length();
if (has_trie()) {
trie_->trie_restore(get_link(node), key);
} else {
tail_restore(node, key);
}
std::reverse(key->begin() + prev_pos, key->end());
} else {
*key += labels_[node];
}
node = get_parent(node);
}
std::reverse(key->begin() + start_pos, key->end());
} catch (const std::bad_alloc &) {
key->resize(start_pos);
MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
} catch (const std::length_error &) {
key->resize(start_pos);
MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
}
}
void Trie::trie_restore(UInt32 node, std::string *key) const {
do {
if (has_link(node)) {
if (has_trie()) {
trie_->trie_restore(get_link(node), key);
} else {
tail_restore(node, key);
}
} else {
*key += labels_[node];
}
node = get_parent(node);
} while (node != 0);
}
void Trie::tail_restore(UInt32 node, std::string *key) const {
const UInt32 link_id = link_flags_.rank1(node);
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
const UInt32 length = (links_[link_id + 1] * 256)
+ labels_[link_flags_.select1(link_id + 1)] - offset;
key->append(reinterpret_cast<const char *>(tail_[offset]), length);
} else {
key->append(reinterpret_cast<const char *>(tail_[offset]));
}
}
std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
std::size_t key_buf_size) const {
std::size_t pos = 0;
UInt32 node = key_id_to_node(key_id);
while (node != 0) {
if (has_link(node)) {
const std::size_t prev_pos = pos;
if (has_trie()) {
trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
} else {
tail_restore(node, key_buf, key_buf_size, pos);
}
if (pos < key_buf_size) {
std::reverse(key_buf + prev_pos, key_buf + pos);
}
} else {
if (pos < key_buf_size) {
key_buf[pos] = labels_[node];
}
++pos;
}
node = get_parent(node);
}
if (pos < key_buf_size) {
key_buf[pos] = '\0';
std::reverse(key_buf, key_buf + pos);
}
return pos;
}
void Trie::trie_restore(UInt32 node, char *key_buf,
std::size_t key_buf_size, std::size_t &pos) const {
do {
if (has_link(node)) {
if (has_trie()) {
trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
} else {
tail_restore(node, key_buf, key_buf_size, pos);
}
} else {
if (pos < key_buf_size) {
key_buf[pos] = labels_[node];
}
++pos;
}
node = get_parent(node);
} while (node != 0);
}
void Trie::tail_restore(UInt32 node, char *key_buf,
std::size_t key_buf_size, std::size_t &pos) const {
const UInt32 link_id = link_flags_.rank1(node);
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
const UInt8 *ptr = tail_[offset];
const UInt32 length = (links_[link_id + 1] * 256)
+ labels_[link_flags_.select1(link_id + 1)] - offset;
for (UInt32 i = 0; i < length; ++i) {
if (pos < key_buf_size) {
key_buf[pos] = ptr[i];
}
++pos;
}
} else {
for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
if (pos < key_buf_size) {
key_buf[pos] = *str;
}
++pos;
}
}
}
template <typename T>
UInt32 Trie::lookup_(T query) const {
UInt32 node = 0;
std::size_t pos = 0;
while (!query.ends_at(pos)) {
if (!find_child<T>(node, query, pos)) {
return notfound();
}
}
return terminal_flags_[node] ? node_to_key_id(node) : notfound();
}
template <typename T>
std::size_t Trie::trie_match(UInt32 node, T query,
std::size_t pos) const {
if (has_link(node)) {
std::size_t next_pos;
if (has_trie()) {
next_pos = trie_->trie_match<T>(get_link(node), query, pos);
} else {
next_pos = tail_match<T>(node, get_link_id(node), query, pos);
}
if ((next_pos == mismatch()) || (next_pos == pos)) {
return next_pos;
}
pos = next_pos;
} else if (labels_[node] != query[pos]) {
return pos;
} else {
++pos;
}
node = get_parent(node);
while (node != 0) {
if (query.ends_at(pos)) {
return mismatch();
}
if (has_link(node)) {
std::size_t next_pos;
if (has_trie()) {
next_pos = trie_->trie_match<T>(get_link(node), query, pos);
} else {
next_pos = tail_match<T>(node, get_link_id(node), query, pos);
}
if ((next_pos == mismatch()) || (next_pos == pos)) {
return mismatch();
}
pos = next_pos;
} else if (labels_[node] != query[pos]) {
return mismatch();
} else {
++pos;
}
node = get_parent(node);
}
return pos;
}
template std::size_t Trie::trie_match<CQuery>(UInt32 node,
CQuery query, std::size_t pos) const;
template std::size_t Trie::trie_match<const Query &>(UInt32 node,
const Query &query, std::size_t pos) const;
template <typename T>
std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
T query, std::size_t pos) const {
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
const UInt8 *ptr = tail_[offset];
if (*ptr != query[pos]) {
return pos;
} else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
const UInt32 length = (links_[link_id + 1] * 256)
+ labels_[link_flags_.select1(link_id + 1)] - offset;
for (UInt32 i = 1; i < length; ++i) {
if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
return mismatch();
}
}
return pos + length;
} else {
for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
if (query.ends_at(pos) || (*ptr != query[pos])) {
return mismatch();
}
}
return pos;
}
}
template std::size_t Trie::tail_match<CQuery>(UInt32 node,
UInt32 link_id, CQuery query, std::size_t pos) const;
template std::size_t Trie::tail_match<const Query &>(UInt32 node,
UInt32 link_id, const Query &query, std::size_t pos) const;
template <typename T, typename U, typename V>
std::size_t Trie::find_(T query, U key_ids, V key_lengths,
std::size_t max_num_results) const try {
if (max_num_results == 0) {
return 0;
}
std::size_t count = 0;
UInt32 node = 0;
std::size_t pos = 0;
do {
if (terminal_flags_[node]) {
if (key_ids.is_valid()) {
key_ids.insert(count, node_to_key_id(node));
}
if (key_lengths.is_valid()) {
key_lengths.insert(count, pos);
}
if (++count >= max_num_results) {
return count;
}
}
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
return count;
} catch (const std::bad_alloc &) {
MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
} catch (const std::length_error &) {
MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
}
template <typename T>
UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
UInt32 node = 0;
std::size_t pos = 0;
do {
if (terminal_flags_[node]) {
if (key_length != NULL) {
*key_length = pos;
}
return node_to_key_id(node);
}
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
return notfound();
}
template <typename T>
UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
UInt32 node = 0;
UInt32 node_found = notfound();
std::size_t pos = 0;
std::size_t pos_found = mismatch();
do {
if (terminal_flags_[node]) {
node_found = node;
pos_found = pos;
}
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
if (node_found != notfound()) {
if (key_length != NULL) {
*key_length = pos_found;
}
return node_to_key_id(node_found);
}
return notfound();
}
template <typename T, typename U, typename V>
std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
std::size_t max_num_results) const try {
if (max_num_results == 0) {
return 0;
}
UInt32 node = 0;
std::size_t pos = 0;
while (!query.ends_at(pos)) {
if (!predict_child<T>(node, query, pos, NULL)) {
return 0;
}
}
std::string key;
std::size_t count = 0;
if (terminal_flags_[node]) {
const UInt32 key_id = node_to_key_id(node);
if (key_ids.is_valid()) {
key_ids.insert(count, key_id);
}
if (keys.is_valid()) {
restore(key_id, &key);
keys.insert(count, key);
}
if (++count >= max_num_results) {
return count;
}
}
const UInt32 louds_pos = get_child(node);
if (!louds_[louds_pos]) {
return count;
}
UInt32 node_begin = louds_pos_to_node(louds_pos, node);
UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
while (node_begin < node_end) {
const UInt32 key_id_begin = node_to_key_id(node_begin);
const UInt32 key_id_end = node_to_key_id(node_end);
if (key_ids.is_valid()) {
UInt32 temp_count = count;
for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
key_ids.insert(temp_count, key_id);
if (++temp_count >= max_num_results) {
break;
}
}
}
if (keys.is_valid()) {
UInt32 temp_count = count;
for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
key.clear();
restore(key_id, &key);
keys.insert(temp_count, key);
if (++temp_count >= max_num_results) {
break;
}
}
}
count += key_id_end - key_id_begin;
if (count >= max_num_results) {
return max_num_results;
}
node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
node_end = louds_pos_to_node(get_child(node_end), node_end);
}
return count;
} catch (const std::bad_alloc &) {
MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
} catch (const std::length_error &) {
MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
}
template <typename T, typename U, typename V>
std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
std::size_t max_num_results) const try {
if (max_num_results == 0) {
return 0;
} else if (keys.is_valid()) {
PredictCallback<U, V> callback(key_ids, keys, max_num_results);
return predict_callback_(query, callback);
}
UInt32 node = 0;
std::size_t pos = 0;
while (!query.ends_at(pos)) {
if (!predict_child<T>(node, query, pos, NULL)) {
return 0;
}
}
std::size_t count = 0;
if (terminal_flags_[node]) {
if (key_ids.is_valid()) {
key_ids.insert(count, node_to_key_id(node));
}
if (++count >= max_num_results) {
return count;
}
}
Cell cell;
cell.set_louds_pos(get_child(node));
if (!louds_[cell.louds_pos()]) {
return count;
}
cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
cell.set_key_id(node_to_key_id(cell.node()));
Vector<Cell> stack;
stack.push_back(cell);
std::size_t stack_pos = 1;
while (stack_pos != 0) {
Cell &cur = stack[stack_pos - 1];
if (!louds_[cur.louds_pos()]) {
cur.set_louds_pos(cur.louds_pos() + 1);
--stack_pos;
continue;
}
cur.set_louds_pos(cur.louds_pos() + 1);
if (terminal_flags_[cur.node()]) {
if (key_ids.is_valid()) {
key_ids.insert(count, cur.key_id());
}
if (++count >= max_num_results) {
return count;
}
cur.set_key_id(cur.key_id() + 1);
}
if (stack_pos == stack.size()) {
cell.set_louds_pos(get_child(cur.node()));
cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
cell.set_key_id(node_to_key_id(cell.node()));
stack.push_back(cell);
}
stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
++stack_pos;
}
return count;
} catch (const std::bad_alloc &) {
MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
} catch (const std::length_error &) {
MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
}
template <typename T>
std::size_t Trie::trie_prefix_match(UInt32 node, T query,
std::size_t pos, std::string *key) const {
if (has_link(node)) {
std::size_t next_pos;
if (has_trie()) {
next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
} else {
next_pos = tail_prefix_match<T>(
node, get_link_id(node), query, pos, key);
}
if ((next_pos == mismatch()) || (next_pos == pos)) {
return next_pos;
}
pos = next_pos;
} else if (labels_[node] != query[pos]) {
return pos;
} else {
++pos;
}
node = get_parent(node);
while (node != 0) {
if (query.ends_at(pos)) {
if (key != NULL) {
trie_restore(node, key);
}
return pos;
}
if (has_link(node)) {
std::size_t next_pos;
if (has_trie()) {
next_pos = trie_->trie_prefix_match<T>(
get_link(node), query, pos, key);
} else {
next_pos = tail_prefix_match<T>(
node, get_link_id(node), query, pos, key);
}
if ((next_pos == mismatch()) || (next_pos == pos)) {
return next_pos;
}
pos = next_pos;
} else if (labels_[node] != query[pos]) {
return mismatch();
} else {
++pos;
}
node = get_parent(node);
}
return pos;
}
template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
CQuery query, std::size_t pos, std::string *key) const;
template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
const Query &query, std::size_t pos, std::string *key) const;
template <typename T>
std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
T query, std::size_t pos, std::string *key) const {
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
const UInt8 *ptr = tail_[offset];
if (*ptr != query[pos]) {
return pos;
} else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
const UInt32 length = (links_[link_id + 1] * 256)
+ labels_[link_flags_.select1(link_id + 1)] - offset;
for (UInt32 i = 1; i < length; ++i) {
if (query.ends_at(pos + i)) {
if (key != NULL) {
key->append(reinterpret_cast<const char *>(ptr + i), length - i);
}
return pos + i;
} else if (ptr[i] != query[pos + i]) {
return mismatch();
}
}
return pos + length;
} else {
for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
if (query.ends_at(pos)) {
if (key != NULL) {
key->append(reinterpret_cast<const char *>(ptr));
}
return pos;
} else if (*ptr != query[pos]) {
return mismatch();
}
}
return pos;
}
}
template std::size_t Trie::tail_prefix_match<CQuery>(
UInt32 node, UInt32 link_id,
CQuery query, std::size_t pos, std::string *key) const;
template std::size_t Trie::tail_prefix_match<const Query &>(
UInt32 node, UInt32 link_id,
const Query &query, std::size_t pos, std::string *key) const;
} // namespace marisa_alpha