// float-weight.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Float weight set and associated semiring operation definitions.
//

#ifndef FST_LIB_FLOAT_WEIGHT_H__
#define FST_LIB_FLOAT_WEIGHT_H__

#include <limits>
#include <climits>
#include <sstream>
#include <string>

#include <fst/util.h>
#include <fst/weight.h>


namespace fst {

// numeric limits class
template <class T>
class FloatLimits {
 public:
  static const T PosInfinity() {
    static const T pos_infinity = numeric_limits<T>::infinity();
    return pos_infinity;
  }

  static const T NegInfinity() {
    static const T neg_infinity = -PosInfinity();
    return neg_infinity;
  }

  static const T NumberBad() {
    static const T number_bad = numeric_limits<T>::quiet_NaN();
    return number_bad;
  }

};

// weight class to be templated on floating-points types
template <class T = float>
class FloatWeightTpl {
 public:
  FloatWeightTpl() {}

  FloatWeightTpl(T f) : value_(f) {}

  FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}

  FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
    value_ = w.value_;
    return *this;
  }

  istream &Read(istream &strm) {
    return ReadType(strm, &value_);
  }

  ostream &Write(ostream &strm) const {
    return WriteType(strm, value_);
  }

  size_t Hash() const {
    union {
      T f;
      size_t s;
    } u;
    u.s = 0;
    u.f = value_;
    return u.s;
  }

  const T &Value() const { return value_; }

 protected:
  void SetValue(const T &f) { value_ = f; }

  inline static string GetPrecisionString() {
    int64 size = sizeof(T);
    if (size == sizeof(float)) return "";
    size *= CHAR_BIT;

    string result;
    Int64ToStr(size, &result);
    return result;
  }

 private:
  T value_;
};

// Single-precision float weight
typedef FloatWeightTpl<float> FloatWeight;

template <class T>
inline bool operator==(const FloatWeightTpl<T> &w1,
                       const FloatWeightTpl<T> &w2) {
  // Volatile qualifier thwarts over-aggressive compiler optimizations
  // that lead to problems esp. with NaturalLess().
  volatile T v1 = w1.Value();
  volatile T v2 = w2.Value();
  return v1 == v2;
}

inline bool operator==(const FloatWeightTpl<double> &w1,
                       const FloatWeightTpl<double> &w2) {
  return operator==<double>(w1, w2);
}

inline bool operator==(const FloatWeightTpl<float> &w1,
                       const FloatWeightTpl<float> &w2) {
  return operator==<float>(w1, w2);
}

template <class T>
inline bool operator!=(const FloatWeightTpl<T> &w1,
                       const FloatWeightTpl<T> &w2) {
  return !(w1 == w2);
}

inline bool operator!=(const FloatWeightTpl<double> &w1,
                       const FloatWeightTpl<double> &w2) {
  return operator!=<double>(w1, w2);
}

inline bool operator!=(const FloatWeightTpl<float> &w1,
                       const FloatWeightTpl<float> &w2) {
  return operator!=<float>(w1, w2);
}

template <class T>
inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
                        const FloatWeightTpl<T> &w2,
                        float delta = kDelta) {
  return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
}

template <class T>
inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
  if (w.Value() == FloatLimits<T>::PosInfinity())
    return strm << "Infinity";
  else if (w.Value() == FloatLimits<T>::NegInfinity())
    return strm << "-Infinity";
  else if (w.Value() != w.Value())   // Fails for NaN
    return strm << "BadNumber";
  else
    return strm << w.Value();
}

template <class T>
inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
  string s;
  strm >> s;
  if (s == "Infinity") {
    w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
  } else if (s == "-Infinity") {
    w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
  } else {
    char *p;
    T f = strtod(s.c_str(), &p);
    if (p < s.c_str() + s.size())
      strm.clear(std::ios::badbit);
    else
      w = FloatWeightTpl<T>(f);
  }
  return strm;
}


// Tropical semiring: (min, +, inf, 0)
template <class T>
class TropicalWeightTpl : public FloatWeightTpl<T> {
 public:
  using FloatWeightTpl<T>::Value;

  typedef TropicalWeightTpl<T> ReverseWeight;

  TropicalWeightTpl() : FloatWeightTpl<T>() {}

  TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}

  TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}

  static const TropicalWeightTpl<T> Zero() {
    return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); }

  static const TropicalWeightTpl<T> One() {
    return TropicalWeightTpl<T>(0.0F); }

  static const TropicalWeightTpl<T> NoWeight() {
    return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); }

  static const string &Type() {
    static const string type = "tropical" +
        FloatWeightTpl<T>::GetPrecisionString();
    return type;
  }

  bool Member() const {
    // First part fails for IEEE NaN
    return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
  }

  TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
    if (Value() == FloatLimits<T>::NegInfinity() ||
        Value() == FloatLimits<T>::PosInfinity() ||
        Value() != Value())
      return *this;
    else
      return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
  }

  TropicalWeightTpl<T> Reverse() const { return *this; }

  static uint64 Properties() {
    return kLeftSemiring | kRightSemiring | kCommutative |
        kPath | kIdempotent;
  }
};

// Single precision tropical weight
typedef TropicalWeightTpl<float> TropicalWeight;

template <class T>
inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
                                 const TropicalWeightTpl<T> &w2) {
  if (!w1.Member() || !w2.Member())
    return TropicalWeightTpl<T>::NoWeight();
  return w1.Value() < w2.Value() ? w1 : w2;
}

inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
                                     const TropicalWeightTpl<float> &w2) {
  return Plus<float>(w1, w2);
}

inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
                                      const TropicalWeightTpl<double> &w2) {
  return Plus<double>(w1, w2);
}

template <class T>
inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
                                  const TropicalWeightTpl<T> &w2) {
  if (!w1.Member() || !w2.Member())
    return TropicalWeightTpl<T>::NoWeight();
  T f1 = w1.Value(), f2 = w2.Value();
  if (f1 == FloatLimits<T>::PosInfinity())
    return w1;
  else if (f2 == FloatLimits<T>::PosInfinity())
    return w2;
  else
    return TropicalWeightTpl<T>(f1 + f2);
}

inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
                                      const TropicalWeightTpl<float> &w2) {
  return Times<float>(w1, w2);
}

inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
                                       const TropicalWeightTpl<double> &w2) {
  return Times<double>(w1, w2);
}

template <class T>
inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
                                   const TropicalWeightTpl<T> &w2,
                                   DivideType typ = DIVIDE_ANY) {
  if (!w1.Member() || !w2.Member())
    return TropicalWeightTpl<T>::NoWeight();
  T f1 = w1.Value(), f2 = w2.Value();
  if (f2 == FloatLimits<T>::PosInfinity())
    return FloatLimits<T>::NumberBad();
  else if (f1 == FloatLimits<T>::PosInfinity())
    return FloatLimits<T>::PosInfinity();
  else
    return TropicalWeightTpl<T>(f1 - f2);
}

inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
                                       const TropicalWeightTpl<float> &w2,
                                       DivideType typ = DIVIDE_ANY) {
  return Divide<float>(w1, w2, typ);
}

inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
                                        const TropicalWeightTpl<double> &w2,
                                        DivideType typ = DIVIDE_ANY) {
  return Divide<double>(w1, w2, typ);
}


// Log semiring: (log(e^-x + e^y), +, inf, 0)
template <class T>
class LogWeightTpl : public FloatWeightTpl<T> {
 public:
  using FloatWeightTpl<T>::Value;

  typedef LogWeightTpl ReverseWeight;

  LogWeightTpl() : FloatWeightTpl<T>() {}

  LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}

  LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}

  static const LogWeightTpl<T> Zero() {
    return LogWeightTpl<T>(FloatLimits<T>::PosInfinity());
  }

  static const LogWeightTpl<T> One() {
    return LogWeightTpl<T>(0.0F);
  }

  static const LogWeightTpl<T> NoWeight() {
    return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); }

  static const string &Type() {
    static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
    return type;
  }

  bool Member() const {
    // First part fails for IEEE NaN
    return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
  }

  LogWeightTpl<T> Quantize(float delta = kDelta) const {
    if (Value() == FloatLimits<T>::NegInfinity() ||
        Value() == FloatLimits<T>::PosInfinity() ||
        Value() != Value())
      return *this;
    else
      return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
  }

  LogWeightTpl<T> Reverse() const { return *this; }

  static uint64 Properties() {
    return kLeftSemiring | kRightSemiring | kCommutative;
  }
};

// Single-precision log weight
typedef LogWeightTpl<float> LogWeight;
// Double-precision log weight
typedef LogWeightTpl<double> Log64Weight;

template <class T>
inline T LogExp(T x) { return log(1.0F + exp(-x)); }

template <class T>
inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
                            const LogWeightTpl<T> &w2) {
  T f1 = w1.Value(), f2 = w2.Value();
  if (f1 == FloatLimits<T>::PosInfinity())
    return w2;
  else if (f2 == FloatLimits<T>::PosInfinity())
    return w1;
  else if (f1 > f2)
    return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
  else
    return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
}

inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
                                const LogWeightTpl<float> &w2) {
  return Plus<float>(w1, w2);
}

inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
                                 const LogWeightTpl<double> &w2) {
  return Plus<double>(w1, w2);
}

template <class T>
inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
                             const LogWeightTpl<T> &w2) {
  if (!w1.Member() || !w2.Member())
    return LogWeightTpl<T>::NoWeight();
  T f1 = w1.Value(), f2 = w2.Value();
  if (f1 == FloatLimits<T>::PosInfinity())
    return w1;
  else if (f2 == FloatLimits<T>::PosInfinity())
    return w2;
  else
    return LogWeightTpl<T>(f1 + f2);
}

inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
                                 const LogWeightTpl<float> &w2) {
  return Times<float>(w1, w2);
}

inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
                                  const LogWeightTpl<double> &w2) {
  return Times<double>(w1, w2);
}

template <class T>
inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
                              const LogWeightTpl<T> &w2,
                              DivideType typ = DIVIDE_ANY) {
  if (!w1.Member() || !w2.Member())
    return LogWeightTpl<T>::NoWeight();
  T f1 = w1.Value(), f2 = w2.Value();
  if (f2 == FloatLimits<T>::PosInfinity())
    return FloatLimits<T>::NumberBad();
  else if (f1 == FloatLimits<T>::PosInfinity())
    return FloatLimits<T>::PosInfinity();
  else
    return LogWeightTpl<T>(f1 - f2);
}

inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
                                  const LogWeightTpl<float> &w2,
                                  DivideType typ = DIVIDE_ANY) {
  return Divide<float>(w1, w2, typ);
}

inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
                                   const LogWeightTpl<double> &w2,
                                   DivideType typ = DIVIDE_ANY) {
  return Divide<double>(w1, w2, typ);
}

// MinMax semiring: (min, max, inf, -inf)
template <class T>
class MinMaxWeightTpl : public FloatWeightTpl<T> {
 public:
  using FloatWeightTpl<T>::Value;

  typedef MinMaxWeightTpl<T> ReverseWeight;

  MinMaxWeightTpl() : FloatWeightTpl<T>() {}

  MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}

  MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}

  static const MinMaxWeightTpl<T> Zero() {
    return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity());
  }

  static const MinMaxWeightTpl<T> One() {
    return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity());
  }

  static const MinMaxWeightTpl<T> NoWeight() {
    return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); }

  static const string &Type() {
    static const string type = "minmax" +
        FloatWeightTpl<T>::GetPrecisionString();
    return type;
  }

  bool Member() const {
    // Fails for IEEE NaN
    return Value() == Value();
  }

  MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
    // If one of infinities, or a NaN
    if (Value() == FloatLimits<T>::NegInfinity() ||
        Value() == FloatLimits<T>::PosInfinity() ||
        Value() != Value())
      return *this;
    else
      return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
  }

  MinMaxWeightTpl<T> Reverse() const { return *this; }

  static uint64 Properties() {
    return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
  }
};

// Single-precision min-max weight
typedef MinMaxWeightTpl<float> MinMaxWeight;

// Min
template <class T>
inline MinMaxWeightTpl<T> Plus(
    const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
  if (!w1.Member() || !w2.Member())
    return MinMaxWeightTpl<T>::NoWeight();
  return w1.Value() < w2.Value() ? w1 : w2;
}

inline MinMaxWeightTpl<float> Plus(
    const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
  return Plus<float>(w1, w2);
}

inline MinMaxWeightTpl<double> Plus(
    const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
  return Plus<double>(w1, w2);
}

// Max
template <class T>
inline MinMaxWeightTpl<T> Times(
    const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
  if (!w1.Member() || !w2.Member())
    return MinMaxWeightTpl<T>::NoWeight();
  return w1.Value() >= w2.Value() ? w1 : w2;
}

inline MinMaxWeightTpl<float> Times(
    const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
  return Times<float>(w1, w2);
}

inline MinMaxWeightTpl<double> Times(
    const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
  return Times<double>(w1, w2);
}

// Defined only for special cases
template <class T>
inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
                                 const MinMaxWeightTpl<T> &w2,
                                 DivideType typ = DIVIDE_ANY) {
  if (!w1.Member() || !w2.Member())
    return MinMaxWeightTpl<T>::NoWeight();
  // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
  return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
}

inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
                                     const MinMaxWeightTpl<float> &w2,
                                     DivideType typ = DIVIDE_ANY) {
  return Divide<float>(w1, w2, typ);
}

inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
                                      const MinMaxWeightTpl<double> &w2,
                                      DivideType typ = DIVIDE_ANY) {
  return Divide<double>(w1, w2, typ);
}

//
// WEIGHT CONVERTER SPECIALIZATIONS.
//

// Convert to tropical
template <>
struct WeightConvert<LogWeight, TropicalWeight> {
  TropicalWeight operator()(LogWeight w) const { return w.Value(); }
};

template <>
struct WeightConvert<Log64Weight, TropicalWeight> {
  TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
};

// Convert to log
template <>
struct WeightConvert<TropicalWeight, LogWeight> {
  LogWeight operator()(TropicalWeight w) const { return w.Value(); }
};

template <>
struct WeightConvert<Log64Weight, LogWeight> {
  LogWeight operator()(Log64Weight w) const { return w.Value(); }
};

// Convert to log64
template <>
struct WeightConvert<TropicalWeight, Log64Weight> {
  Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
};

template <>
struct WeightConvert<LogWeight, Log64Weight> {
  Log64Weight operator()(LogWeight w) const { return w.Value(); }
};

}  // namespace fst

#endif  // FST_LIB_FLOAT_WEIGHT_H__