// lookahead-matcher.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Classes to add lookahead to FST matchers, useful e.g. for improving
// composition efficiency with certain inputs.

#ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
#define FST_LIB_LOOKAHEAD_MATCHER_H__

#include <fst/add-on.h>
#include <fst/const-fst.h>
#include <fst/fst.h>
#include <fst/label-reachable.h>
#include <fst/matcher.h>


DECLARE_string(save_relabel_ipairs);
DECLARE_string(save_relabel_opairs);

namespace fst {

// LOOKAHEAD MATCHERS - these have the interface of Matchers (see
// matcher.h) and these additional methods:
//
// template <class F>
// class LookAheadMatcher {
//  public:
//   typedef F FST;
//   typedef F::Arc Arc;
//   typedef typename Arc::StateId StateId;
//   typedef typename Arc::Label Label;
//   typedef typename Arc::Weight Weight;
//
//  // Required constructors.
//  LookAheadMatcher(const F &fst, MatchType match_type);
//   // If safe=true, the copy is thread-safe (except the lookahead Fst is
//   // preserved). See Fst<>::Cop() for further doc.
//  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
//
//  Below are methods for looking ahead for a match to a label and
//  more generally, to a rational set. Each returns false if there is
//  definitely not a match and returns true if there possibly is a
//  match.

//  // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
//  // after possibly following epsilon transitions?
//  bool LookAheadLabel(Label label) const;
//
//  // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
//  // arbitrary rational set of strings, specified by an FST and a state
//  // from which to begin the matching. If the lookahead FST is a
//  // transducer, this looks on the side different from the matcher
//  // 'match_type' (cf. composition).
//
//  // Are there paths P from 's' in the lookahead FST that can be read from
//  // the cur. matcher state?
//  bool LookAheadFst(const Fst<Arc>& fst, StateId s);
//
//  // Gives an estimate of the combined weight of the paths P in the
//  // lookahead and matcher FSTs for the last call to LookAheadFst.
//  // A trivial implementation returns Weight::One(). Non-trivial
//  // implementations are useful for weight-pushing in composition.
//  Weight LookAheadWeight() const;
//
//  // Is there is a single non-epsilon arc found in the lookahead FST
//  // that begins P (after possibly following any epsilons) in the last
//  // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
//  // return false. A trivial implementation returns false. Non-trivial
//  // implementations are useful for label-pushing in composition.
//  bool LookAheadPrefix(Arc *arc);
//
//  // Optionally pre-specifies the lookahead FST that will be passed
//  // to LookAheadFst() for possible precomputation. If copy is true,
//  // then 'fst' is a copy of the FST used in the previous call to
//  // this method (useful to avoid unnecessary updates).
//  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
//
// };

//
// LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
//
// Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
const uint32 kInputLookAheadMatcher =     0x00000010;

// Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
const uint32 kOutputLookAheadMatcher =    0x00000020;

// A non-trivial implementation of LookAheadWeight() method defined and
// should be used?
const uint32 kLookAheadWeight =           0x00000040;

// A non-trivial implementation of LookAheadPrefix() method defined and
// should be used?
const uint32 kLookAheadPrefix =           0x00000080;

// Look-ahead of matcher FST non-epsilon arcs?
const uint32 kLookAheadNonEpsilons =      0x00000100;

// Look-ahead of matcher FST epsilon arcs?
const uint32 kLookAheadEpsilons =         0x00000200;

// Ignore epsilon paths for the lookahead prefix? Note this gives
// correct results in composition only with an appropriate composition
// filter since it depends on the filter blocking the ignored paths.
const uint32 kLookAheadNonEpsilonPrefix = 0x00000400;

// For LabelLookAheadMatcher, save relabeling data to file
const uint32 kLookAheadKeepRelabelData =  0x00000800;

// Flags used for lookahead matchers.
const uint32 kLookAheadFlags =            0x00000ff0;

// LookAhead Matcher interface, templated on the Arc definition; used
// for lookahead matcher specializations that are returned by the
// InitMatcher() Fst method.
template <class A>
class LookAheadMatcherBase : public MatcherBase<A> {
 public:
  typedef A Arc;
  typedef typename A::StateId StateId;
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;

  LookAheadMatcherBase()
  : weight_(Weight::One()),
    prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}

  virtual ~LookAheadMatcherBase() {}

  bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }

  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
    return LookAheadFst_(fst, s);
  }

  Weight LookAheadWeight() const { return weight_; }

  bool LookAheadPrefix(Arc *arc) const {
    if (prefix_arc_.nextstate != kNoStateId) {
      *arc = prefix_arc_;
      return true;
    } else {
      return false;
    }
  }

  virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;

 protected:
  void SetLookAheadWeight(const Weight &w) { weight_ = w; }

  void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }

  void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }

 private:
  virtual bool LookAheadLabel_(Label label) const = 0;
  virtual bool LookAheadFst_(const Fst<Arc> &fst,
                             StateId s) = 0;  // This must set l.a. weight and
                                              // prefix if non-trivial.
  Weight weight_;                             // Look-ahead weight
  Arc prefix_arc_;                            // Look-ahead prefix arc
};


// Don't really lookahead, just declare future looks good regardless.
template <class M>
class TrivialLookAheadMatcher
    : public LookAheadMatcherBase<typename M::FST::Arc> {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
      : matcher_(fst, match_type) {}

  TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
                          bool safe = false)
      : matcher_(lmatcher.matcher_, safe) {}

  // General matcher methods
  TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
    return new TrivialLookAheadMatcher<M>(*this, safe);
  }

  MatchType Type(bool test) const { return matcher_.Type(test); }
  void SetState(StateId s) { return matcher_.SetState(s); }
  bool Find(Label label) { return matcher_.Find(label); }
  bool Done() const { return matcher_.Done(); }
  const Arc& Value() const { return matcher_.Value(); }
  void Next() { matcher_.Next(); }
  virtual const FST &GetFst() const { return matcher_.GetFst(); }
  uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
  uint32 Flags() const {
    return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
  }

  // Look-ahead methods.
  bool LookAheadLabel(Label label) const { return true;  }
  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
  Weight LookAheadWeight() const { return Weight::One(); }
  bool LookAheadPrefix(Arc *arc) const { return false; }
  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}

 private:
  // This allows base class virtual access to non-virtual derived-
  // class members of the same name. It makes the derived class more
  // efficient to use but unsafe to further derive.
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }

  bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    return LookAheadFst(fst, s);
  }

  Weight LookAheadWeight_() const { return LookAheadWeight(); }
  bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }

  M matcher_;
};

// Look-ahead of one transition. Template argument F accepts flags to
// control behavior.
template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
          kLookAheadWeight | kLookAheadPrefix>
class ArcLookAheadMatcher
    : public LookAheadMatcherBase<typename M::FST::Arc> {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;
  typedef NullAddOn MatcherData;

  using LookAheadMatcherBase<Arc>::LookAheadWeight;
  using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
  using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
  using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;

  ArcLookAheadMatcher(const FST &fst, MatchType match_type,
                      MatcherData *data = 0)
      : matcher_(fst, match_type),
        fst_(matcher_.GetFst()),
        lfst_(0),
        s_(kNoStateId) {}

  ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
                      bool safe = false)
      : matcher_(lmatcher.matcher_, safe),
        fst_(matcher_.GetFst()),
        lfst_(lmatcher.lfst_),
        s_(kNoStateId) {}

  // General matcher methods
  ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
    return new ArcLookAheadMatcher<M, F>(*this, safe);
  }

  MatchType Type(bool test) const { return matcher_.Type(test); }

  void SetState(StateId s) {
    s_ = s;
    matcher_.SetState(s);
  }

  bool Find(Label label) { return matcher_.Find(label); }
  bool Done() const { return matcher_.Done(); }
  const Arc& Value() const { return matcher_.Value(); }
  void Next() { matcher_.Next(); }
  const FST &GetFst() const { return fst_; }
  uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
  uint32 Flags() const {
    return matcher_.Flags() | kInputLookAheadMatcher |
        kOutputLookAheadMatcher | F;
  }

  // Writable matcher methods
  MatcherData *GetData() const { return 0; }

  // Look-ahead methods.
  bool LookAheadLabel(Label label) const { return matcher_.Find(label); }

  // Checks if there is a matching (possibly super-final) transition
  // at (s_, s).
  bool LookAheadFst(const Fst<Arc> &fst, StateId s);

  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    lfst_ = &fst;
  }

 private:
  // This allows base class virtual access to non-virtual derived-
  // class members of the same name. It makes the derived class more
  // efficient to use but unsafe to further derive.
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
  bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    return LookAheadFst(fst, s);
  }

  mutable M matcher_;
  const FST &fst_;         // Matcher FST
  const Fst<Arc> *lfst_;   // Look-ahead FST
  StateId s_;              // Matcher state
};

template <class M, uint32 F>
bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
  if (&fst != lfst_)
    InitLookAheadFst(fst);

  bool ret = false;
  ssize_t nprefix = 0;
  if (F & kLookAheadWeight)
    SetLookAheadWeight(Weight::Zero());
  if (F & kLookAheadPrefix)
    ClearLookAheadPrefix();
  if (fst_.Final(s_) != Weight::Zero() &&
      lfst_->Final(s) != Weight::Zero()) {
    if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
      return true;
    ++nprefix;
    if (F & kLookAheadWeight)
      SetLookAheadWeight(Plus(LookAheadWeight(),
                              Times(fst_.Final(s_), lfst_->Final(s))));
    ret = true;
  }
  if (matcher_.Find(kNoLabel)) {
    if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
      return true;
    ++nprefix;
    if (F & kLookAheadWeight)
      for (; !matcher_.Done(); matcher_.Next())
        SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
    ret = true;
  }
  for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
       !aiter.Done();
       aiter.Next()) {
    const Arc &arc = aiter.Value();
    Label label = kNoLabel;
    switch (matcher_.Type(false)) {
      case MATCH_INPUT:
        label = arc.olabel;
        break;
      case MATCH_OUTPUT:
        label = arc.ilabel;
        break;
      default:
        FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
        return true;
    }
    if (label == 0) {
      if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
        return true;
      if (!(F & kLookAheadNonEpsilonPrefix))
        ++nprefix;
      if (F & kLookAheadWeight)
        SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
      ret = true;
    } else if (matcher_.Find(label)) {
      if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
        return true;
      for (; !matcher_.Done(); matcher_.Next()) {
        ++nprefix;
        if (F & kLookAheadWeight)
          SetLookAheadWeight(Plus(LookAheadWeight(),
                                  Times(arc.weight,
                                        matcher_.Value().weight)));
        if ((F & kLookAheadPrefix) && nprefix == 1)
          SetLookAheadPrefix(arc);
      }
      ret = true;
    }
  }
  if (F & kLookAheadPrefix) {
    if (nprefix == 1)
      SetLookAheadWeight(Weight::One());  // Avoids double counting.
    else
      ClearLookAheadPrefix();
  }
  return ret;
}


// Template argument F accepts flags to control behavior.
// It must include precisely one of KInputLookAheadMatcher or
// KOutputLookAheadMatcher.
template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
          kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
          kLookAheadKeepRelabelData,
          class S = DefaultAccumulator<typename M::Arc> >
class LabelLookAheadMatcher
    : public LookAheadMatcherBase<typename M::FST::Arc> {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;
  typedef LabelReachableData<Label> MatcherData;

  using LookAheadMatcherBase<Arc>::LookAheadWeight;
  using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
  using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
  using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;

  LabelLookAheadMatcher(const FST &fst, MatchType match_type,
                        MatcherData *data = 0, S *s = 0)
      : matcher_(fst, match_type),
        lfst_(0),
        label_reachable_(0),
        s_(kNoStateId),
        error_(false) {
    if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
      FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
      error_ = true;
    }
    bool reach_input = match_type == MATCH_INPUT;
    if (data) {
      if (reach_input == data->ReachInput())
        label_reachable_ = new LabelReachable<Arc, S>(data, s);
    } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
               (!reach_input && (F & kOutputLookAheadMatcher))) {
      label_reachable_ = new LabelReachable<Arc, S>(
          fst, reach_input, s, F & kLookAheadKeepRelabelData);
    }
  }

  LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
                        bool safe = false)
      : matcher_(lmatcher.matcher_, safe),
        lfst_(lmatcher.lfst_),
        label_reachable_(
            lmatcher.label_reachable_ ?
            new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
        s_(kNoStateId),
        error_(lmatcher.error_) {}

  ~LabelLookAheadMatcher() {
    delete label_reachable_;
  }

  // General matcher methods
  LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
    return new LabelLookAheadMatcher<M, F, S>(*this, safe);
  }

  MatchType Type(bool test) const { return matcher_.Type(test); }

  void SetState(StateId s) {
    if (s_ == s)
      return;
    s_ = s;
    match_set_state_ = false;
    reach_set_state_ = false;
  }

  bool Find(Label label) {
    if (!match_set_state_) {
      matcher_.SetState(s_);
      match_set_state_ = true;
    }
    return matcher_.Find(label);
  }

  bool Done() const { return matcher_.Done(); }
  const Arc& Value() const { return matcher_.Value(); }
  void Next() { matcher_.Next(); }
  const FST &GetFst() const { return matcher_.GetFst(); }

  uint64 Properties(uint64 inprops) const {
    uint64 outprops = matcher_.Properties(inprops);
    if (error_ || (label_reachable_ && label_reachable_->Error()))
      outprops |= kError;
    return outprops;
  }

  uint32 Flags() const {
    if (label_reachable_ && label_reachable_->GetData()->ReachInput())
      return matcher_.Flags() | F | kInputLookAheadMatcher;
    else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
      return matcher_.Flags() | F | kOutputLookAheadMatcher;
    else
      return matcher_.Flags();
  }

  // Writable matcher methods
  MatcherData *GetData() const {
    return label_reachable_ ? label_reachable_->GetData() : 0;
  };

  // Look-ahead methods.
  bool LookAheadLabel(Label label) const {
    if (label == 0)
      return true;

    if (label_reachable_) {
      if (!reach_set_state_) {
        label_reachable_->SetState(s_);
        reach_set_state_ = true;
      }
      return label_reachable_->Reach(label);
    } else {
      return true;
    }
  }

  // Checks if there is a matching (possibly super-final) transition
  // at (s_, s).
  template <class L>
  bool LookAheadFst(const L &fst, StateId s);

  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    lfst_ = &fst;
    if (label_reachable_)
      label_reachable_->ReachInit(fst, copy);
  }

  template <class L>
  void InitLookAheadFst(const L& fst, bool copy = false) {
    lfst_ = static_cast<const Fst<Arc> *>(&fst);
    if (label_reachable_)
      label_reachable_->ReachInit(fst, copy);
  }

 private:
  // This allows base class virtual access to non-virtual derived-
  // class members of the same name. It makes the derived class more
  // efficient to use but unsafe to further derive.
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
  bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    return LookAheadFst(fst, s);
  }

  mutable M matcher_;
  const Fst<Arc> *lfst_;                     // Look-ahead FST
  LabelReachable<Arc, S> *label_reachable_;  // Label reachability info
  StateId s_;                                // Matcher state
  bool match_set_state_;                     // matcher_.SetState called?
  mutable bool reach_set_state_;             // reachable_.SetState called?
  bool error_;
};

template <class M, uint32 F, class S>
template <class L> inline
bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
  if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
    InitLookAheadFst(fst);

  SetLookAheadWeight(Weight::One());
  ClearLookAheadPrefix();

  if (!label_reachable_)
    return true;

  label_reachable_->SetState(s_, s);
  reach_set_state_ = true;

  bool compute_weight = F & kLookAheadWeight;
  bool compute_prefix = F & kLookAheadPrefix;

  bool reach_input = Type(false) == MATCH_OUTPUT;
  ArcIterator<L> aiter(fst, s);
  bool reach_arc = label_reachable_->Reach(&aiter, 0,
                                           internal::NumArcs(*lfst_, s),
                                           reach_input, compute_weight);
  Weight lfinal = internal::Final(*lfst_, s);
  bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal();
  if (reach_arc) {
    ssize_t begin = label_reachable_->ReachBegin();
    ssize_t end = label_reachable_->ReachEnd();
    if (compute_prefix && end - begin == 1 && !reach_final) {
      aiter.Seek(begin);
      SetLookAheadPrefix(aiter.Value());
      compute_weight = false;
    } else if (compute_weight) {
      SetLookAheadWeight(label_reachable_->ReachWeight());
    }
  }
  if (reach_final && compute_weight)
    SetLookAheadWeight(reach_arc ?
                       Plus(LookAheadWeight(), lfinal) : lfinal);

  return reach_arc || reach_final;
}


// Label-lookahead relabeling class.
template <class A>
class LabelLookAheadRelabeler {
 public:
  typedef typename A::Label Label;
  typedef LabelReachableData<Label> MatcherData;
  typedef AddOnPair<MatcherData, MatcherData> D;

  // Relabels matcher Fst - initialization function object.
  template <typename I>
  LabelLookAheadRelabeler(I **impl);

  // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
  template <class L>
  static void Relabel(MutableFst<A> *fst, const L &mfst,
                      bool relabel_input) {
    typename L::Impl *impl = mfst.GetImpl();
    D *data = impl->GetAddOn();
    LabelReachable<A> reachable(data->First() ?
                                  data->First() : data->Second());
    reachable.Relabel(fst, relabel_input);
  }

  // Returns relabeling pairs (cf. relabel.h::Relabel()).
  // Class L should be a label-lookahead Fst.
  // If 'avoid_collisions' is true, extra pairs are added to
  // ensure no collisions when relabeling automata that have
  // labels unseen here.
  template <class L>
  static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
                           bool avoid_collisions = false) {
    typename L::Impl *impl = mfst.GetImpl();
    D *data = impl->GetAddOn();
    LabelReachable<A> reachable(data->First() ?
                                  data->First() : data->Second());
    reachable.RelabelPairs(pairs, avoid_collisions);
  }
};

template <class A>
template <typename I> inline
LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
  Fst<A> &fst = (*impl)->GetFst();
  D *data = (*impl)->GetAddOn();
  const string name = (*impl)->Type();
  bool is_mutable = fst.Properties(kMutable, false);
  MutableFst<A> *mfst = 0;
  if (is_mutable) {
    mfst = static_cast<MutableFst<A> *>(&fst);
  } else {
    mfst = new VectorFst<A>(fst);
    data->IncrRefCount();
    delete *impl;
  }
  if (data->First()) {  // reach_input
    LabelReachable<A> reachable(data->First());
    reachable.Relabel(mfst, true);
    if (!FLAGS_save_relabel_ipairs.empty()) {
      vector<pair<Label, Label> > pairs;
      reachable.RelabelPairs(&pairs, true);
      WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
    }
  } else {
    LabelReachable<A> reachable(data->Second());
    reachable.Relabel(mfst, false);
    if (!FLAGS_save_relabel_opairs.empty()) {
      vector<pair<Label, Label> > pairs;
      reachable.RelabelPairs(&pairs, true);
      WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
    }
  }
  if (!is_mutable) {
    *impl = new I(*mfst, name);
    (*impl)->SetAddOn(data);
    delete mfst;
    data->DecrRefCount();
  }
}


// Generic lookahead matcher, templated on the FST definition
// - a wrapper around pointer to specific one.
template <class F>
class LookAheadMatcher {
 public:
  typedef F FST;
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;
  typedef LookAheadMatcherBase<Arc> LBase;

  LookAheadMatcher(const F &fst, MatchType match_type) {
    base_ = fst.InitMatcher(match_type);
    if (!base_)
      base_ = new SortedMatcher<F>(fst, match_type);
    lookahead_ = false;
  }

  LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
    base_ = matcher.base_->Copy(safe);
    lookahead_ = matcher.lookahead_;
  }

  ~LookAheadMatcher() { delete base_; }

  // General matcher methods
  LookAheadMatcher<F> *Copy(bool safe = false) const {
      return new LookAheadMatcher<F>(*this, safe);
  }

  MatchType Type(bool test) const { return base_->Type(test); }
  void SetState(StateId s) { base_->SetState(s); }
  bool Find(Label label) { return base_->Find(label); }
  bool Done() const { return base_->Done(); }
  const Arc& Value() const { return base_->Value(); }
  void Next() { base_->Next(); }
  const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }

  uint64 Properties(uint64 props) const { return base_->Properties(props); }

  uint32 Flags() const { return base_->Flags(); }

  // Look-ahead methods
  bool LookAheadLabel(Label label) const {
    if (LookAheadCheck()) {
      LBase *lbase = static_cast<LBase *>(base_);
      return lbase->LookAheadLabel(label);
    } else {
      return true;
    }
  }

  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
    if (LookAheadCheck()) {
      LBase *lbase = static_cast<LBase *>(base_);
      return lbase->LookAheadFst(fst, s);
    } else {
      return true;
    }
  }

  Weight LookAheadWeight() const {
    if (LookAheadCheck()) {
      LBase *lbase = static_cast<LBase *>(base_);
      return lbase->LookAheadWeight();
    } else {
      return Weight::One();
    }
  }

  bool LookAheadPrefix(Arc *arc) const {
    if (LookAheadCheck()) {
      LBase *lbase = static_cast<LBase *>(base_);
      return lbase->LookAheadPrefix(arc);
    } else {
      return false;
    }
  }

  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    if (LookAheadCheck()) {
      LBase *lbase = static_cast<LBase *>(base_);
      lbase->InitLookAheadFst(fst, copy);
    }
  }

 private:
  bool LookAheadCheck() const {
    if (!lookahead_) {
      lookahead_ = base_->Flags() &
          (kInputLookAheadMatcher | kOutputLookAheadMatcher);
      if (!lookahead_) {
        FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
      }
    }
    return lookahead_;
  }

  MatcherBase<Arc> *base_;
  mutable bool lookahead_;

  void operator=(const LookAheadMatcher<Arc> &);  // disallow
};

}  // namespace fst

#endif  // FST_LIB_LOOKAHEAD_MATCHER_H__