// tuple-weight.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google (Cyril Allauzen)
//
// \file
// Tuple weight set operation definitions.

#ifndef FST_LIB_TUPLE_WEIGHT_H__
#define FST_LIB_TUPLE_WEIGHT_H__

#include <string>
#include <vector>
using std::vector;

#include <fst/weight.h>


DECLARE_string(fst_weight_parentheses);
DECLARE_string(fst_weight_separator);

namespace fst {

template<class W, unsigned int n> class TupleWeight;
template <class W, unsigned int n>
istream &operator>>(istream &strm, TupleWeight<W, n> &w);

// n-tuple weight, element of the n-th catersian power of W
template <class W, unsigned int n>
class TupleWeight {
 public:
  typedef TupleWeight<typename W::ReverseWeight, n> ReverseWeight;

  TupleWeight() {}

  TupleWeight(const TupleWeight &w) {
    for (size_t i = 0; i < n; ++i)
      values_[i] = w.values_[i];
  }

  template <class Iterator>
  TupleWeight(Iterator begin, Iterator end) {
    for (Iterator iter = begin; iter != end; ++iter)
      values_[iter - begin] = *iter;
  }

  TupleWeight(const W &w) {
    for (size_t i = 0; i < n; ++i)
      values_[i] = w;
  }

  static const TupleWeight<W, n> &Zero() {
    static const TupleWeight<W, n> zero(W::Zero());
    return zero;
  }

  static const TupleWeight<W, n> &One() {
    static const TupleWeight<W, n> one(W::One());
    return one;
  }

  static const TupleWeight<W, n> &NoWeight() {
    static const TupleWeight<W, n> no_weight(W::NoWeight());
    return no_weight;
  }

  static unsigned int Length() {
    return n;
  }

  istream &Read(istream &strm) {
    for (size_t i = 0; i < n; ++i)
      values_[i].Read(strm);
    return strm;
  }

  ostream &Write(ostream &strm) const {
    for (size_t i = 0; i < n; ++i)
      values_[i].Write(strm);
    return strm;
  }

  TupleWeight<W, n> &operator=(const TupleWeight<W, n> &w) {
    for (size_t i = 0; i < n; ++i)
      values_[i] = w.values_[i];
    return *this;
  }

  bool Member() const {
    bool member = true;
    for (size_t i = 0; i < n; ++i)
      member = member && values_[i].Member();
    return member;
  }

  size_t Hash() const {
    uint64 hash = 0;
    for (size_t i = 0; i < n; ++i)
      hash = 5 * hash + values_[i].Hash();
    return size_t(hash);
  }

  TupleWeight<W, n> Quantize(float delta = kDelta) const {
    TupleWeight<W, n> w;
    for (size_t i = 0; i < n; ++i)
      w.values_[i] = values_[i].Quantize(delta);
    return w;
  }

  ReverseWeight Reverse() const {
    TupleWeight<W, n> w;
    for (size_t i = 0; i < n; ++i)
      w.values_[i] = values_[i].Reverse();
    return w;
  }

  const W& Value(size_t i) const { return values_[i]; }

  void SetValue(size_t i, const W &w) { values_[i] = w; }

 protected:
  // Reads TupleWeight when there are no parentheses around tuple terms
  inline static istream &ReadNoParen(istream &strm,
                                     TupleWeight<W, n> &w,
                                     char separator) {
    int c;
    do {
      c = strm.get();
    } while (isspace(c));

    for (size_t i = 0; i < n - 1; ++i) {
      string s;
      if (i)
        c = strm.get();
      while (c != separator) {
        if (c == EOF) {
          strm.clear(std::ios::badbit);
          return strm;
        }
        s += c;
        c = strm.get();
      }
      // read (i+1)-th element
      istringstream sstrm(s);
      W r = W::Zero();
      sstrm >> r;
      w.SetValue(i, r);
    }

    // read n-th element
    W r = W::Zero();
    strm >> r;
    w.SetValue(n - 1, r);

    return strm;
  }

  // Reads TupleWeight when there are parentheses around tuple terms
  inline static istream &ReadWithParen(istream &strm,
                                       TupleWeight<W, n> &w,
                                       char separator,
                                       char open_paren,
                                       char close_paren) {
    int c;
    do {
      c = strm.get();
    } while (isspace(c));

    if (c != open_paren) {
      FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
      strm.clear(std::ios::badbit);
      return strm;
    }

    for (size_t i = 0; i < n - 1; ++i) {
      // read (i+1)-th element
      stack<int> parens;
      string s;
      c = strm.get();
      while (c != separator || !parens.empty()) {
        if (c == EOF) {
          strm.clear(std::ios::badbit);
          return strm;
        }
        s += c;
        // if parens encountered before separator, they must be matched
        if (c == open_paren) {
          parens.push(1);
        } else if (c == close_paren) {
          // Fail for mismatched parens
          if (parens.empty()) {
            strm.clear(std::ios::failbit);
            return strm;
          }
          parens.pop();
        }
        c = strm.get();
      }
      istringstream sstrm(s);
      W r = W::Zero();
      sstrm >> r;
      w.SetValue(i, r);
    }

    // read n-th element
    string s;
    c = strm.get();
    while (c != EOF) {
      s += c;
      c = strm.get();
    }
    if (s.empty() || *s.rbegin() != close_paren) {
      FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
      strm.clear(std::ios::failbit);
      return strm;
    }
    s.erase(s.size() - 1, 1);
    istringstream sstrm(s);
    W r = W::Zero();
    sstrm >> r;
    w.SetValue(n - 1, r);

    return strm;
  }


 private:
  W values_[n];

  friend istream &operator>><W, n>(istream&, TupleWeight<W, n>&);
};

template <class W, unsigned int n>
inline bool operator==(const TupleWeight<W, n> &w1,
                       const TupleWeight<W, n> &w2) {
  bool equal = true;
  for (size_t i = 0; i < n; ++i)
    equal = equal && (w1.Value(i) == w2.Value(i));
  return equal;
}

template <class W, unsigned int n>
inline bool operator!=(const TupleWeight<W, n> &w1,
                       const TupleWeight<W, n> &w2) {
  bool not_equal = false;
  for (size_t i = 0; (i < n) && !not_equal; ++i)
    not_equal = not_equal || (w1.Value(i) != w2.Value(i));
  return not_equal;
}

template <class W, unsigned int n>
inline bool ApproxEqual(const TupleWeight<W, n> &w1,
                        const TupleWeight<W, n> &w2,
                        float delta = kDelta) {
  bool approx_equal = true;
  for (size_t i = 0; i < n; ++i)
    approx_equal = approx_equal &&
        ApproxEqual(w1.Value(i), w2.Value(i), delta);
  return approx_equal;
}

template <class W, unsigned int n>
inline ostream &operator<<(ostream &strm, const TupleWeight<W, n> &w) {
  if(FLAGS_fst_weight_separator.size() != 1) {
    FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
    strm.clear(std::ios::badbit);
    return strm;
  }
  char separator = FLAGS_fst_weight_separator[0];
  bool write_parens = false;
  if (!FLAGS_fst_weight_parentheses.empty()) {
    if (FLAGS_fst_weight_parentheses.size() != 2) {
      FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
      strm.clear(std::ios::badbit);
      return strm;
    }
    write_parens = true;
  }

  if (write_parens)
    strm << FLAGS_fst_weight_parentheses[0];
  for (size_t i  = 0; i < n; ++i) {
    if(i)
      strm << separator;
    strm << w.Value(i);
  }
  if (write_parens)
    strm << FLAGS_fst_weight_parentheses[1];

  return strm;
}

template <class W, unsigned int n>
inline istream &operator>>(istream &strm, TupleWeight<W, n> &w) {
  if(FLAGS_fst_weight_separator.size() != 1) {
    FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
    strm.clear(std::ios::badbit);
    return strm;
  }
  char separator = FLAGS_fst_weight_separator[0];

  if (!FLAGS_fst_weight_parentheses.empty()) {
    if (FLAGS_fst_weight_parentheses.size() != 2) {
      FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
      strm.clear(std::ios::badbit);
      return strm;
    }
    return TupleWeight<W, n>::ReadWithParen(
        strm, w, separator, FLAGS_fst_weight_parentheses[0],
        FLAGS_fst_weight_parentheses[1]);
  } else {
    return TupleWeight<W, n>::ReadNoParen(strm, w, separator);
  }
}



}  // namespace fst

#endif  // FST_LIB_TUPLE_WEIGHT_H__