C++程序  |  416行  |  11.34 KB

#ifndef __FST_IO_H__
#define __FST_IO_H__

// fst-io.h 
// This is a copy of the OPENFST SDK application sample files ...
// except for the main functions ifdef'ed out
// 2007, 2008 Nuance Communications
//
// print-main.h compile-main.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// \file
// Classes and functions to compile a binary Fst from textual input.
// Includes helper function for fstcompile.cc that templates the main
// on the arc type to support multiple and extensible arc types.

#include <fstream>
#include <sstream>

#include "fst/lib/fst.h"
#include "fst/lib/fstlib.h"
#include "fst/lib/fst-decl.h"
#include "fst/lib/vector-fst.h"
#include "fst/lib/arcsort.h"
#include "fst/lib/invert.h"

namespace fst {
  
  template <class A> class FstPrinter {
  public:
    typedef A Arc;
    typedef typename A::StateId StateId;
    typedef typename A::Label Label;
    typedef typename A::Weight Weight;
    
    FstPrinter(const Fst<A> &fst,
	       const SymbolTable *isyms,
	       const SymbolTable *osyms,
	       const SymbolTable *ssyms,
	       bool accep)
      : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
      accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {}
    
    // Print Fst to an output strm
    void Print(ostream *ostrm, const string &dest) {
      ostrm_ = ostrm;
      dest_ = dest;
      StateId start = fst_.Start();
      if (start == kNoStateId)
	return;
      // initial state first
      PrintState(start);
      for (StateIterator< Fst<A> > siter(fst_);
	   !siter.Done();
	   siter.Next()) {
	StateId s = siter.Value();
	if (s != start)
	  PrintState(s);
      }
    }
    
  private:
    // Maximum line length in text file.
    static const int kLineLen = 8096;
    
    void PrintId(int64 id, const SymbolTable *syms,
		 const char *name) const {
      if (syms) {
	string symbol = syms->Find(id);
	if (symbol == "") {
	  LOG(ERROR) << "FstPrinter: Integer " << id
		     << " is not mapped to any textual symbol"
		     << ", symbol table = " << syms->Name()
		     << ", destination = " << dest_;
	  exit(1);
	}
	*ostrm_ << symbol;
      } else {
	*ostrm_ << id;
      }
    }
    
    void PrintStateId(StateId s) const {
      PrintId(s, ssyms_, "state ID");
    }
    
    void PrintILabel(Label l) const {
      PrintId(l, isyms_, "arc input label");
    }
    
    void PrintOLabel(Label l) const {
      PrintId(l, osyms_, "arc output label");
    }
    
    void PrintState(StateId s) const {
      bool output = false;
      for (ArcIterator< Fst<A> > aiter(fst_, s);
	   !aiter.Done();
	   aiter.Next()) {
	Arc arc = aiter.Value();
	PrintStateId(s);
	*ostrm_ << "\t";
	PrintStateId(arc.nextstate);
	*ostrm_ << "\t";
	PrintILabel(arc.ilabel);
	if (!accep_) {
	  *ostrm_ << "\t";
	  PrintOLabel(arc.olabel);
	}
	if (arc.weight != Weight::One())
	  *ostrm_ << "\t" << arc.weight;
	*ostrm_ << "\n";
	output = true;
      }
      Weight final = fst_.Final(s);
      if (final != Weight::Zero() || !output) {
	PrintStateId(s);
	if (final != Weight::One()) {
	  *ostrm_ << "\t" << final;
	}
	*ostrm_ << "\n";
      }
    }
    
    const Fst<A> &fst_;
    const SymbolTable *isyms_;     // ilabel symbol table
    const SymbolTable *osyms_;     // olabel symbol table
    const SymbolTable *ssyms_;     // slabel symbol table
    bool accep_;                   // print as acceptor when possible
    ostream *ostrm_;                // binary FST destination
    string dest_;                  // binary FST destination name
    DISALLOW_EVIL_CONSTRUCTORS(FstPrinter);
  };
  
#if 0
  // Main function for fstprint templated on the arc type.
  template <class Arc>
    int PrintMain(int argc, char **argv, istream &istrm,
		  const FstReadOptions &opts) {
    Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts);
    if (!fst) return 1;
    
    string dest = "standard output";
    ostream *ostrm = &std::cout;
    if (argc == 3) {
      dest = argv[2];
      ostrm = new ofstream(argv[2]);
      if (!*ostrm) {
	LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2];
	return 0;
      }
    }
    ostrm->precision(9);
    
    const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
    
    if (!FLAGS_isymbols.empty() && !FLAGS_numeric) {
      isyms = SymbolTable::ReadText(FLAGS_isymbols);
      if (!isyms) exit(1);
    }
    
    if (!FLAGS_osymbols.empty() && !FLAGS_numeric) {
      osyms = SymbolTable::ReadText(FLAGS_osymbols);
      if (!osyms) exit(1);
    }
    
    if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) {
      ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
      if (!ssyms) exit(1);
    }
    
    if (!isyms && !FLAGS_numeric)
      isyms = fst->InputSymbols();
    if (!osyms && !FLAGS_numeric)
      osyms = fst->OutputSymbols();
    
    FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor);
    fstprinter.Print(ostrm, dest);
    
    if (isyms && !FLAGS_save_isymbols.empty())
      isyms->WriteText(FLAGS_save_isymbols);
    
    if (osyms && !FLAGS_save_osymbols.empty())
      osyms->WriteText(FLAGS_save_osymbols);
    
    if (ostrm != &std::cout)
      delete ostrm;
    return 0;
  }
#endif
  
  
  template <class A> class FstReader {
  public:
    typedef A Arc;
    typedef typename A::StateId StateId;
    typedef typename A::Label Label;
    typedef typename A::Weight Weight;
    
    FstReader(istream &istrm, const string &source,
	      const SymbolTable *isyms, const SymbolTable *osyms,
	      const SymbolTable *ssyms, bool accep, bool ikeep,
	      bool okeep, bool nkeep)
      : nline_(0), source_(source),
      isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
      nstates_(0), keep_state_numbering_(nkeep) {
      char line[kLineLen];
      while (istrm.getline(line, kLineLen)) {
	++nline_;
	vector<char *> col;
	SplitToVector(line, "\n\t ", &col, true);
	if (col.size() == 0 || col[0][0] == '\0')  // empty line
	  continue;
	if (col.size() > 5 ||
	    col.size() > 4 && accep ||
	    col.size() == 3 && !accep) {
	  LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_
		     << ", line = " << nline_;
	  exit(1);
	}
	StateId s = StrToStateId(col[0]);
	while (s >= fst_.NumStates())
	  fst_.AddState();
	if (nline_ == 1)
	  fst_.SetStart(s);
	
	Arc arc;
	StateId d = s;
	switch (col.size()) {
	case 1:
	  fst_.SetFinal(s, Weight::One());
	  break;
	case 2:
	  fst_.SetFinal(s, StrToWeight(col[1], true));
	  break;
	case 3:
	  arc.nextstate = d = StrToStateId(col[1]);
	  arc.ilabel = StrToILabel(col[2]);
	  arc.olabel = arc.ilabel;
	  arc.weight = Weight::One();
	  fst_.AddArc(s, arc);
	  break;
	case 4:
	  arc.nextstate = d = StrToStateId(col[1]);
	  arc.ilabel = StrToILabel(col[2]);
	  if (accep) {
	    arc.olabel = arc.ilabel;
	    arc.weight = StrToWeight(col[3], false);
	  } else {
	    arc.olabel = StrToOLabel(col[3]);
	    arc.weight = Weight::One();
	  }
	  fst_.AddArc(s, arc);
	  break;
	case 5:
	  arc.nextstate = d = StrToStateId(col[1]);
	  arc.ilabel = StrToILabel(col[2]);
	  arc.olabel = StrToOLabel(col[3]);
	  arc.weight = StrToWeight(col[4], false);
	  fst_.AddArc(s, arc);
	}
	while (d >= fst_.NumStates())
	  fst_.AddState();
      }
      if (ikeep)
	fst_.SetInputSymbols(isyms);
      if (okeep)
	fst_.SetOutputSymbols(osyms);
    }
    
    const VectorFst<A> &Fst() const { return fst_; }
    
  private:
    // Maximum line length in text file.
    static const int kLineLen = 8096;
    
    int64 StrToId(const char *s, const SymbolTable *syms,
		  const char *name) const {
      int64 n;
      
      if (syms) {
	n = syms->Find(s);
	if (n < 0) {
	  LOG(ERROR) << "FstReader: Symbol \"" << s
		     << "\" is not mapped to any integer " << name
		     << ", symbol table = " << syms->Name()
		     << ", source = " << source_ << ", line = " << nline_;
	  exit(1);
	}
      } else {
	char *p;
	n = strtoll(s, &p, 10);
	if (p < s + strlen(s) || n < 0) {
	  LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s
		     << "\", source = " << source_ << ", line = " << nline_;
	  exit(1);
	}
      }
      return n;
    }
    
    StateId StrToStateId(const char *s) {
      StateId n = StrToId(s, ssyms_, "state ID");
      
      if (keep_state_numbering_)
	return n;
      
      // remap state IDs to make dense set
      typename hash_map<StateId, StateId>::const_iterator it = states_.find(n);
      if (it == states_.end()) {
	states_[n] = nstates_;
	return nstates_++;
      } else {
	return it->second;
      }
    }
    
    StateId StrToILabel(const char *s) const {
      return StrToId(s, isyms_, "arc ilabel");
    }
    
    StateId StrToOLabel(const char *s) const {
      return StrToId(s, osyms_, "arc olabel");
    }
    
    Weight StrToWeight(const char *s, bool allow_zero) const {
      Weight w;
      istringstream strm(s);
      strm >> w;
      if (strm.fail() || !allow_zero && w == Weight::Zero()) {
	LOG(ERROR) << "FstReader: Bad weight = \"" << s
		   << "\", source = " << source_ << ", line = " << nline_;
	exit(1);
      }
      return w;
    }
    
    VectorFst<A> fst_;
    size_t nline_;
    string source_;                      // text FST source name
    const SymbolTable *isyms_;           // ilabel symbol table
    const SymbolTable *osyms_;           // olabel symbol table
    const SymbolTable *ssyms_;           // slabel symbol table
    hash_map<StateId, StateId> states_;  // state ID map
    StateId nstates_;                    // number of seen states
    bool keep_state_numbering_;
    DISALLOW_EVIL_CONSTRUCTORS(FstReader);
  };
  
#if 0
  // Main function for fstcompile templated on the arc type.  Last two
  // arguments unneeded since fstcompile passes the arc type as a flag
  // unlike the other mains, which infer the arc type from an input Fst.
  template <class Arc>
    int CompileMain(int argc, char **argv, istream& /* strm */,
		    const FstReadOptions & /* opts */) {
    char *ifilename = "standard input";
    istream *istrm = &std::cin;
    if (argc > 1 && strcmp(argv[1], "-") != 0) {
      ifilename = argv[1];
      istrm = new ifstream(ifilename);
      if (!*istrm) {
	LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename;
	return 1;
      }
    }
    const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
    
    if (!FLAGS_isymbols.empty()) {
      isyms = SymbolTable::ReadText(FLAGS_isymbols);
      if (!isyms) exit(1);
    }
    
    if (!FLAGS_osymbols.empty()) {
      osyms = SymbolTable::ReadText(FLAGS_osymbols);
      if (!osyms) exit(1);
    }
    
    if (!FLAGS_ssymbols.empty()) {
      ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
      if (!ssyms) exit(1);
    }
    
    FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms,
			     FLAGS_acceptor, FLAGS_keep_isymbols,
			     FLAGS_keep_osymbols, FLAGS_keep_state_numbering);
    
    const Fst<Arc> *fst = &fstreader.Fst();
    if (FLAGS_fst_type != "vector") {
      fst = Convert<Arc>(*fst, FLAGS_fst_type);
      if (!fst) return 1;
    }
    fst->Write(argc > 2 ? argv[2] : "");
    if (istrm != &std::cin)
      delete istrm;
    return 0;
  }
#endif
  
}  // namespace fst

#endif /* __FST_IO_H__ */