#ifndef MARISA_ALPHA_TRIE_H_
#define MARISA_ALPHA_TRIE_H_

#include "base.h"

#ifdef __cplusplus

#include <memory>
#include <vector>

#include "progress.h"
#include "key.h"
#include "query.h"
#include "container.h"
#include "intvector.h"
#include "bitvector.h"
#include "tail.h"

namespace marisa_alpha {

class Trie {
 public:
  Trie();

  void build(const char * const *keys, std::size_t num_keys,
      const std::size_t *key_lengths = NULL,
      const double *key_weights = NULL,
      UInt32 *key_ids = NULL, int flags = 0);

  void build(const std::vector<std::string> &keys,
      std::vector<UInt32> *key_ids = NULL, int flags = 0);
  void build(const std::vector<std::pair<std::string, double> > &keys,
      std::vector<UInt32> *key_ids = NULL, int flags = 0);

  void mmap(Mapper *mapper, const char *filename,
      long offset = 0, int whence = SEEK_SET);
  void map(const void *ptr, std::size_t size);
  void map(Mapper &mapper);

  void load(const char *filename,
      long offset = 0, int whence = SEEK_SET);
  void fread(std::FILE *file);
  void read(int fd);
  void read(std::istream &stream);
  void read(Reader &reader);

  void save(const char *filename, bool trunc_flag = true,
      long offset = 0, int whence = SEEK_SET) const;
  void fwrite(std::FILE *file) const;
  void write(int fd) const;
  void write(std::ostream &stream) const;
  void write(Writer &writer) const;

  std::string operator[](UInt32 key_id) const;

  UInt32 operator[](const char *str) const;
  UInt32 operator[](const std::string &str) const;

  std::string restore(UInt32 key_id) const;
  void restore(UInt32 key_id, std::string *key) const;
  std::size_t restore(UInt32 key_id, char *key_buf,
      std::size_t key_buf_size) const;

  UInt32 lookup(const char *str) const;
  UInt32 lookup(const char *ptr, std::size_t length) const;
  UInt32 lookup(const std::string &str) const;

  std::size_t find(const char *str,
      UInt32 *key_ids, std::size_t *key_lengths,
      std::size_t max_num_results) const;
  std::size_t find(const char *ptr, std::size_t length,
      UInt32 *key_ids, std::size_t *key_lengths,
      std::size_t max_num_results) const;
  std::size_t find(const std::string &str,
      UInt32 *key_ids, std::size_t *key_lengths,
      std::size_t max_num_results) const;

  std::size_t find(const char *str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::size_t> *key_lengths = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t find(const char *ptr, std::size_t length,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::size_t> *key_lengths = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t find(const std::string &str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::size_t> *key_lengths = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;

  UInt32 find_first(const char *str,
      std::size_t *key_length = NULL) const;
  UInt32 find_first(const char *ptr, std::size_t length,
      std::size_t *key_length = NULL) const;
  UInt32 find_first(const std::string &str,
      std::size_t *key_length = NULL) const;

  UInt32 find_last(const char *str,
      std::size_t *key_length = NULL) const;
  UInt32 find_last(const char *ptr, std::size_t length,
      std::size_t *key_length = NULL) const;
  UInt32 find_last(const std::string &str,
      std::size_t *key_length = NULL) const;

  // bool callback(UInt32 key_id, std::size_t key_length);
  template <typename T>
  std::size_t find_callback(const char *str, T callback) const;
  template <typename T>
  std::size_t find_callback(const char *ptr, std::size_t length,
      T callback) const;
  template <typename T>
  std::size_t find_callback(const std::string &str, T callback) const;

  std::size_t predict(const char *str,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
  std::size_t predict(const char *ptr, std::size_t length,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
  std::size_t predict(const std::string &str,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;

  std::size_t predict(const char *str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t predict(const char *ptr, std::size_t length,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t predict(const std::string &str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;

  std::size_t predict_breadth_first(const char *str,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
  std::size_t predict_breadth_first(const char *ptr, std::size_t length,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
  std::size_t predict_breadth_first(const std::string &str,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;

  std::size_t predict_breadth_first(const char *str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t predict_breadth_first(const char *ptr, std::size_t length,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t predict_breadth_first(const std::string &str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;

  std::size_t predict_depth_first(const char *str,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
  std::size_t predict_depth_first(const char *ptr, std::size_t length,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
  std::size_t predict_depth_first(const std::string &str,
      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;

  std::size_t predict_depth_first(const char *str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t predict_depth_first(const char *ptr, std::size_t length,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
  std::size_t predict_depth_first(const std::string &str,
      std::vector<UInt32> *key_ids = NULL,
      std::vector<std::string> *keys = NULL,
      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;

  // bool callback(UInt32 key_id, const std::string &key);
  template <typename T>
  std::size_t predict_callback(const char *str, T callback) const;
  template <typename T>
  std::size_t predict_callback(const char *ptr, std::size_t length,
      T callback) const;
  template <typename T>
  std::size_t predict_callback(const std::string &str, T callback) const;

  bool empty() const;
  std::size_t num_tries() const;
  std::size_t num_keys() const;
  std::size_t num_nodes() const;
  std::size_t total_size() const;

  void clear();
  void swap(Trie *rhs);

  static UInt32 notfound();
  static std::size_t mismatch();

 private:
  BitVector louds_;
  Vector<UInt8> labels_;
  BitVector terminal_flags_;
  BitVector link_flags_;
  IntVector links_;
  std::auto_ptr<Trie> trie_;
  Tail tail_;
  UInt32 num_first_branches_;
  UInt32 num_keys_;

  void build_trie(Vector<Key<String> > &keys,
      std::vector<UInt32> *key_ids, int flags);
  void build_trie(Vector<Key<String> > &keys,
      UInt32 *key_ids, int flags);

  template <typename T>
  void build_trie(Vector<Key<T> > &keys,
      Vector<UInt32> *terminals, Progress &progress);

  template <typename T>
  void build_cur(Vector<Key<T> > &keys,
      Vector<UInt32> *terminals, Progress &progress);

  void build_next(Vector<Key<String> > &keys,
      Vector<UInt32> *terminals, Progress &progress);
  void build_next(Vector<Key<RString> > &rkeys,
      Vector<UInt32> *terminals, Progress &progress);

  template <typename T>
  UInt32 sort_keys(Vector<Key<T> > &keys) const;

  template <typename T>
  void build_terminals(const Vector<Key<T> > &keys,
      Vector<UInt32> *terminals) const;

  void restore_(UInt32 key_id, std::string *key) const;
  void trie_restore(UInt32 node, std::string *key) const;
  void tail_restore(UInt32 node, std::string *key) const;

  std::size_t restore_(UInt32 key_id, char *key_buf,
      std::size_t key_buf_size) const;
  void trie_restore(UInt32 node, char *key_buf,
      std::size_t key_buf_size, std::size_t &key_pos) const;
  void tail_restore(UInt32 node, char *key_buf,
      std::size_t key_buf_size, std::size_t &key_pos) const;

  template <typename T>
  UInt32 lookup_(T query) const;
  template <typename T>
  bool find_child(UInt32 &node, T query, std::size_t &pos) const;
  template <typename T>
  std::size_t trie_match(UInt32 node, T query, std::size_t pos) const;
  template <typename T>
  std::size_t tail_match(UInt32 node, UInt32 link_id,
      T query, std::size_t pos) const;

  template <typename T, typename U, typename V>
  std::size_t find_(T query, U key_ids, V key_lengths,
      std::size_t max_num_results) const;
  template <typename T>
  UInt32 find_first_(T query, std::size_t *key_length) const;
  template <typename T>
  UInt32 find_last_(T query, std::size_t *key_length) const;
  template <typename T, typename U>
  std::size_t find_callback_(T query, U callback) const;

  template <typename T, typename U, typename V>
  std::size_t predict_breadth_first_(T query, U key_ids, V keys,
      std::size_t max_num_results) const;
  template <typename T, typename U, typename V>
  std::size_t predict_depth_first_(T query, U key_ids, V keys,
      std::size_t max_num_results) const;
  template <typename T, typename U>
  std::size_t predict_callback_(T query, U callback) const;

  template <typename T>
  bool predict_child(UInt32 &node, T query, std::size_t &pos,
      std::string *key) const;
  template <typename T>
  std::size_t trie_prefix_match(UInt32 node, T query,
      std::size_t pos, std::string *key) const;
  template <typename T>
  std::size_t tail_prefix_match(UInt32 node, UInt32 link_id,
      T query, std::size_t pos, std::string *key) const;

  UInt32 key_id_to_node(UInt32 key_id) const;
  UInt32 node_to_key_id(UInt32 node) const;
  UInt32 louds_pos_to_node(UInt32 louds_pos, UInt32 parent_node) const;

  UInt32 get_child(UInt32 node) const;
  UInt32 get_parent(UInt32 node) const;

  bool has_link(UInt32 node) const;
  UInt32 get_link_id(UInt32 node) const;
  UInt32 get_link(UInt32 node) const;
  UInt32 get_link(UInt32 node, UInt32 link_id) const;

  bool has_link() const;
  bool has_trie() const;
  bool has_tail() const;

  // Disallows copy and assignment.
  Trie(const Trie &);
  Trie &operator=(const Trie &);
};

}  // namespace marisa_alpha

#include "trie-inline.h"

#else  // __cplusplus

#include <stdio.h>

#endif  // __cplusplus

#ifdef __cplusplus
extern "C" {
#endif  // __cplusplus

typedef struct marisa_alpha_trie_ marisa_alpha_trie;

marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h);
marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h);

marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
    const char * const *keys, size_t num_keys, const size_t *key_lengths,
    const double *key_weights, marisa_alpha_uint32 *key_ids, int flags);

marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
    const char *filename, long offset, int whence);
marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
    size_t size);

marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
    const char *filename, long offset, int whence);
marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file);
marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd);

marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
    const char *filename, int trunc_flag, long offset, int whence);
marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
    FILE *file);
marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd);

marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
    marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
    size_t *key_length);

marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
    const char *ptr, size_t length, marisa_alpha_uint32 *key_id);

marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
    const char *ptr, size_t length,
    marisa_alpha_uint32 *key_ids, size_t *key_lengths,
    size_t max_num_results, size_t *num_results);
marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
    const char *ptr, size_t length,
    marisa_alpha_uint32 *key_id, size_t *key_length);
marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
    const char *ptr, size_t length,
    marisa_alpha_uint32 *key_id, size_t *key_length);
marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
    const char *ptr, size_t length,
    int (*callback)(void *, marisa_alpha_uint32, size_t),
    void *first_arg_to_callback);

marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
    const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
    size_t max_num_results, size_t *num_results);
marisa_alpha_status marisa_alpha_predict_breadth_first(
    const marisa_alpha_trie *h, const char *ptr, size_t length,
    marisa_alpha_uint32 *key_ids, size_t max_num_results, size_t *num_results);
marisa_alpha_status marisa_alpha_predict_depth_first(
    const marisa_alpha_trie *h, const char *ptr, size_t length,
    marisa_alpha_uint32 *key_ids, size_t max_num_results, size_t *num_results);
marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
    const char *ptr, size_t length,
    int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
    void *first_arg_to_callback);

size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h);
size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h);
size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h);
size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h);

marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h);

#ifdef __cplusplus
}  // extern "C"
#endif  // __cplusplus

#endif  // MARISA_ALPHA_TRIE_H_