// rmepsilon.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Functions and classes that implemement epsilon-removal.

#ifndef FST_LIB_RMEPSILON_H__
#define FST_LIB_RMEPSILON_H__

#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <fst/slist.h>
#include <stack>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;

#include <fst/arcfilter.h>
#include <fst/cache.h>
#include <fst/connect.h>
#include <fst/factor-weight.h>
#include <fst/invert.h>
#include <fst/prune.h>
#include <fst/queue.h>
#include <fst/shortest-distance.h>
#include <fst/topsort.h>


namespace fst {

template <class Arc, class Queue>
class RmEpsilonOptions
    : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
 public:
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;

  bool connect;              // Connect output
  Weight weight_threshold;   // Pruning weight threshold.
  StateId state_threshold;   // Pruning state threshold.

  explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
                            Weight w = Weight::Zero(),
                            StateId n = kNoStateId)
      : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
          q, EpsilonArcFilter<Arc>(), kNoStateId, d),
        connect(c), weight_threshold(w), state_threshold(n) {}
 private:
  RmEpsilonOptions();  // disallow
};

// Computation state of the epsilon-removal algorithm.
template <class Arc, class Queue>
class RmEpsilonState {
 public:
  typedef typename Arc::Label Label;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;

  RmEpsilonState(const Fst<Arc> &fst,
                 vector<Weight> *distance,
                 const RmEpsilonOptions<Arc, Queue> &opts)
      : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
        expand_id_(0) {}

  // Compute arcs and final weight for state 's'
  void Expand(StateId s);

  // Returns arcs of expanded state.
  vector<Arc> &Arcs() { return arcs_; }

  // Returns final weight of expanded state.
  const Weight &Final() const { return final_; }

  // Return true if an error has occured.
  bool Error() const { return sd_state_.Error(); }

 private:
  static const size_t kPrime0 = 7853;
  static const size_t kPrime1 = 7867;

  struct Element {
    Label ilabel;
    Label olabel;
    StateId nextstate;

    Element() {}

    Element(Label i, Label o, StateId s)
        : ilabel(i), olabel(o), nextstate(s) {}
  };

  class ElementKey {
   public:
    size_t operator()(const Element& e) const {
      return static_cast<size_t>(e.nextstate +
                                 e.ilabel * kPrime0 +
                                 e.olabel * kPrime1);
    }

   private:
  };

  class ElementEqual {
   public:
    bool operator()(const Element &e1, const Element &e2) const {
      return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
                         && (e1.nextstate == e2.nextstate);
    }
  };

  typedef unordered_map<Element, pair<StateId, size_t>,
                   ElementKey, ElementEqual> ElementMap;

  const Fst<Arc> &fst_;
  // Distance from state being expanded in epsilon-closure.
  vector<Weight> *distance_;
  // Shortest distance algorithm computation state.
  ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
  // Maps an element 'e' to a pair 'p' corresponding to a position
  // in the arcs vector of the state being expanded. 'e' corresponds
  // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
  // equal to the state being expanded.
  ElementMap element_map_;
  EpsilonArcFilter<Arc> eps_filter_;
  stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
  vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
  slist<StateId> visited_states_; // List of visited states
  vector<Arc> arcs_;              // Arcs of state being expanded
  Weight final_;                  // Final weight of state being expanded
  StateId expand_id_;             // Unique ID for each call to Expand

  DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
};

template <class Arc, class Queue>
const size_t RmEpsilonState<Arc, Queue>::kPrime0;
template <class Arc, class Queue>
const size_t RmEpsilonState<Arc, Queue>::kPrime1;


template <class Arc, class Queue>
void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
   final_ = Weight::Zero();
   arcs_.clear();
   sd_state_.ShortestDistance(source);
   if (sd_state_.Error())
     return;
   eps_queue_.push(source);

   while (!eps_queue_.empty()) {
     StateId state = eps_queue_.top();
     eps_queue_.pop();

     while (visited_.size() <= state) visited_.push_back(false);
     if (visited_[state]) continue;
     visited_[state] = true;
     visited_states_.push_front(state);

     for (ArcIterator< Fst<Arc> > ait(fst_, state);
          !ait.Done();
          ait.Next()) {
       Arc arc = ait.Value();
       arc.weight = Times((*distance_)[state], arc.weight);

       if (eps_filter_(arc)) {
         while (visited_.size() <= arc.nextstate)
           visited_.push_back(false);
         if (!visited_[arc.nextstate])
           eps_queue_.push(arc.nextstate);
       } else {
          Element element(arc.ilabel, arc.olabel, arc.nextstate);
          typename ElementMap::iterator it = element_map_.find(element);
          if (it == element_map_.end()) {
            element_map_.insert(
                pair<Element, pair<StateId, size_t> >
                (element, pair<StateId, size_t>(expand_id_, arcs_.size())));
            arcs_.push_back(arc);
          } else {
            if (((*it).second).first == expand_id_) {
              Weight &w = arcs_[((*it).second).second].weight;
              w = Plus(w, arc.weight);
            } else {
              ((*it).second).first = expand_id_;
              ((*it).second).second = arcs_.size();
              arcs_.push_back(arc);
            }
          }
        }
     }
     final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
   }

   while (!visited_states_.empty()) {
     visited_[visited_states_.front()] = false;
     visited_states_.pop_front();
   }
   ++expand_id_;
}

// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions.  This version modifies
// its input. It allows fine control via the options argument; see
// below for a simpler interface.
//
// The vector 'distance' will be used to hold the shortest distances
// during the epsilon-closure computation. The state queue discipline
// and convergence delta are taken in the options argument.
template <class Arc, class Queue>
void RmEpsilon(MutableFst<Arc> *fst,
               vector<typename Arc::Weight> *distance,
               const RmEpsilonOptions<Arc, Queue> &opts) {
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  typedef typename Arc::Label Label;

  if (fst->Start() == kNoStateId) {
    return;
  }

  // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon
  // incoming transition or is the start state.
  vector<bool> noneps_in(fst->NumStates(), false);
  noneps_in[fst->Start()] = true;
  for (StateId i = 0; i < fst->NumStates(); ++i) {
    for (ArcIterator<Fst<Arc> > aiter(*fst, i);
         !aiter.Done();
         aiter.Next()) {
      if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0)
        noneps_in[aiter.Value().nextstate] = true;
    }
  }

  // States sorted in topological order when (acyclic) or generic
  // topological order (cyclic).
  vector<StateId> states;
  states.reserve(fst->NumStates());

  if (fst->Properties(kTopSorted, false) & kTopSorted) {
    for (StateId i = 0; i < fst->NumStates(); i++)
      states.push_back(i);
  } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
    vector<StateId> order;
    bool acyclic;
    TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
    DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
    // Sanity check: should be acyclic if property bit is set.
    if(!acyclic) {
      FSTERROR() << "RmEpsilon: inconsistent acyclic property bit";
      fst->SetProperties(kError, kError);
      return;
    }
    states.resize(order.size());
    for (StateId i = 0; i < order.size(); i++)
      states[order[i]] = i;
  } else {
     uint64 props;
     vector<StateId> scc;
     SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
     DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
     vector<StateId> first(scc.size(), kNoStateId);
     vector<StateId> next(scc.size(), kNoStateId);
     for (StateId i = 0; i < scc.size(); i++) {
       if (first[scc[i]] != kNoStateId)
         next[i] = first[scc[i]];
       first[scc[i]] = i;
     }
     for (StateId i = 0; i < first.size(); i++)
       for (StateId j = first[i]; j != kNoStateId; j = next[j])
         states.push_back(j);
  }

  RmEpsilonState<Arc, Queue>
    rmeps_state(*fst, distance, opts);

  while (!states.empty()) {
    StateId state = states.back();
    states.pop_back();
    if (!noneps_in[state])
      continue;
    rmeps_state.Expand(state);
    fst->SetFinal(state, rmeps_state.Final());
    fst->DeleteArcs(state);
    vector<Arc> &arcs = rmeps_state.Arcs();
    fst->ReserveArcs(state, arcs.size());
    while (!arcs.empty()) {
      fst->AddArc(state, arcs.back());
      arcs.pop_back();
    }
  }

  for (StateId s = 0; s < fst->NumStates(); ++s) {
    if (!noneps_in[s])
      fst->DeleteArcs(s);
  }

  if(rmeps_state.Error())
    fst->SetProperties(kError, kError);
  fst->SetProperties(
      RmEpsilonProperties(fst->Properties(kFstProperties, false)),
      kFstProperties);

  if (opts.weight_threshold != Weight::Zero() ||
      opts.state_threshold != kNoStateId)
    Prune(fst, opts.weight_threshold, opts.state_threshold);
  if (opts.connect && (opts.weight_threshold == Weight::Zero() ||
                       opts.state_threshold != kNoStateId))
    Connect(fst);
}

// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions. This version modifies its
// input. It has a simplified interface; see above for a version that
// allows finer control.
//
// Complexity:
// - Time:
//   - Unweighted: O(V2 + V E)
//   - Acyclic: O(V2 + V E)
//   - Tropical semiring: O(V2 log V + V E)
//   - General: exponential
// - Space: O(V E)
// where V = # of states visited, E = # of arcs.
//
// References:
// - Mehryar Mohri. Generic Epsilon-Removal and Input
//   Epsilon-Normalization Algorithms for Weighted Transducers,
//   "International Journal of Computer Science", 13(1):129-143 (2002).
template <class Arc>
void RmEpsilon(MutableFst<Arc> *fst,
               bool connect = true,
               typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
               typename Arc::StateId state_threshold = kNoStateId,
               float delta = kDelta) {
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  typedef typename Arc::Label Label;

  vector<Weight> distance;
  AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
  RmEpsilonOptions<Arc, AutoQueue<StateId> >
      opts(&state_queue, delta, connect, weight_threshold, state_threshold);

  RmEpsilon(fst, &distance, opts);
}


struct RmEpsilonFstOptions : CacheOptions {
  float delta;

  RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
      : CacheOptions(opts), delta(delta) {}

  explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
};


// Implementation of delayed RmEpsilonFst.
template <class A>
class RmEpsilonFstImpl : public CacheImpl<A> {
 public:
  using FstImpl<A>::SetType;
  using FstImpl<A>::SetProperties;
  using FstImpl<A>::SetInputSymbols;
  using FstImpl<A>::SetOutputSymbols;

  using CacheBaseImpl< CacheState<A> >::PushArc;
  using CacheBaseImpl< CacheState<A> >::HasArcs;
  using CacheBaseImpl< CacheState<A> >::HasFinal;
  using CacheBaseImpl< CacheState<A> >::HasStart;
  using CacheBaseImpl< CacheState<A> >::SetArcs;
  using CacheBaseImpl< CacheState<A> >::SetFinal;
  using CacheBaseImpl< CacheState<A> >::SetStart;

  typedef typename A::Label Label;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;

  RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
      : CacheImpl<A>(opts),
        fst_(fst.Copy()),
        delta_(opts.delta),
        rmeps_state_(
            *fst_,
            &distance_,
            RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
    SetType("rmepsilon");
    uint64 props = fst.Properties(kFstProperties, false);
    SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
    SetInputSymbols(fst.InputSymbols());
    SetOutputSymbols(fst.OutputSymbols());
  }

  RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
      : CacheImpl<A>(impl),
        fst_(impl.fst_->Copy(true)),
        delta_(impl.delta_),
        rmeps_state_(
            *fst_,
            &distance_,
            RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
    SetType("rmepsilon");
    SetProperties(impl.Properties(), kCopyProperties);
    SetInputSymbols(impl.InputSymbols());
    SetOutputSymbols(impl.OutputSymbols());
  }

  ~RmEpsilonFstImpl() {
    delete fst_;
  }

  StateId Start() {
    if (!HasStart()) {
      SetStart(fst_->Start());
    }
    return CacheImpl<A>::Start();
  }

  Weight Final(StateId s) {
    if (!HasFinal(s)) {
      Expand(s);
    }
    return CacheImpl<A>::Final(s);
  }

  size_t NumArcs(StateId s) {
    if (!HasArcs(s))
      Expand(s);
    return CacheImpl<A>::NumArcs(s);
  }

  size_t NumInputEpsilons(StateId s) {
    if (!HasArcs(s))
      Expand(s);
    return CacheImpl<A>::NumInputEpsilons(s);
  }

  size_t NumOutputEpsilons(StateId s) {
    if (!HasArcs(s))
      Expand(s);
    return CacheImpl<A>::NumOutputEpsilons(s);
  }

  uint64 Properties() const { return Properties(kFstProperties); }

  // Set error if found; return FST impl properties.
  uint64 Properties(uint64 mask) const {
    if ((mask & kError) &&
        (fst_->Properties(kError, false) || rmeps_state_.Error()))
      SetProperties(kError, kError);
    return FstImpl<A>::Properties(mask);
  }

  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    if (!HasArcs(s))
      Expand(s);
    CacheImpl<A>::InitArcIterator(s, data);
  }

  void Expand(StateId s) {
    rmeps_state_.Expand(s);
    SetFinal(s, rmeps_state_.Final());
    vector<A> &arcs = rmeps_state_.Arcs();
    while (!arcs.empty()) {
      PushArc(s, arcs.back());
      arcs.pop_back();
    }
    SetArcs(s);
  }

 private:
  const Fst<A> *fst_;
  float delta_;
  vector<Weight> distance_;
  FifoQueue<StateId> queue_;
  RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;

  void operator=(const RmEpsilonFstImpl<A> &);  // disallow
};


// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions.  This version is a
// delayed Fst.
//
// Complexity:
// - Time:
//   - Unweighted: O(v^2 + v e)
//   - General: exponential
// - Space: O(v e)
// where v = # of states visited, e = # of arcs visited. Constant time
// to visit an input state or arc is assumed and exclusive of caching.
//
// References:
// - Mehryar Mohri. Generic Epsilon-Removal and Input
//   Epsilon-Normalization Algorithms for Weighted Transducers,
//   "International Journal of Computer Science", 13(1):129-143 (2002).
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > {
 public:
  friend class ArcIterator< RmEpsilonFst<A> >;
  friend class StateIterator< RmEpsilonFst<A> >;

  typedef A Arc;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;
  typedef RmEpsilonFstImpl<A> Impl;

  RmEpsilonFst(const Fst<A> &fst)
      : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {}

  RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
      : ImplToFst<Impl>(new Impl(fst, opts)) {}

  // See Fst<>::Copy() for doc.
  RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false)
      : ImplToFst<Impl>(fst, safe) {}

  // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
  virtual RmEpsilonFst<A> *Copy(bool safe = false) const {
    return new RmEpsilonFst<A>(*this, safe);
  }

  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;

  virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    GetImpl()->InitArcIterator(s, data);
  }

 private:
  // Makes visible to friends.
  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }

  void operator=(const RmEpsilonFst<A> &fst);  // disallow
};

// Specialization for RmEpsilonFst.
template<class A>
class StateIterator< RmEpsilonFst<A> >
    : public CacheStateIterator< RmEpsilonFst<A> > {
 public:
  explicit StateIterator(const RmEpsilonFst<A> &fst)
      : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {}
};


// Specialization for RmEpsilonFst.
template <class A>
class ArcIterator< RmEpsilonFst<A> >
    : public CacheArcIterator< RmEpsilonFst<A> > {
 public:
  typedef typename A::StateId StateId;

  ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
      : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) {
    if (!fst.GetImpl()->HasArcs(s))
      fst.GetImpl()->Expand(s);
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};


template <class A> inline
void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
  data->base = new StateIterator< RmEpsilonFst<A> >(*this);
}


// Useful alias when using StdArc.
typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;

}  // namespace fst

#endif  // FST_LIB_RMEPSILON_H__