// Copyright 2015 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // fixedpoint.h: fixed-point arithmetic, with basic operations and // a few math functions such as tanh. // This is only used in output.h // for some specific output pipeline stages (tanh); most of gemmlowp // uses only plain integer arithmetic, not fixed-point arithmetic. // At the most basic level, we distinguish between plain integer // arithmetic and fixed-point arithmetic by the type of multiplication // that is used: plain integer arithmetic uses plain (overflowing) // integer multiplication, whereas fixed-point arithmetic uses // "multiply-high" instructions, which means using only the most // significant bits of the product, or equivalently, multiplying // fixed-point numbers in the [-1 .. +1] interval. #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ #include "common.h" #include <limits> #include <cassert> namespace gemmlowp { template <typename tIntegerType> tIntegerType BitAnd(tIntegerType a, tIntegerType b) { return a & b; } template <typename tIntegerType> tIntegerType BitOr(tIntegerType a, tIntegerType b) { return a | b; } template <typename tIntegerType> tIntegerType BitXor(tIntegerType a, tIntegerType b) { return a ^ b; } template <typename tIntegerType> tIntegerType BitNot(tIntegerType a) { return ~a; } template <typename tIntegerType> tIntegerType Add(tIntegerType a, tIntegerType b) { return a + b; } template <typename tIntegerType> tIntegerType Sub(tIntegerType a, tIntegerType b) { return a - b; } template <typename tIntegerType> tIntegerType Neg(tIntegerType a) { return -a; } template <typename tIntegerType> tIntegerType ShiftLeft(tIntegerType a, int offset) { return a * (1 << offset); } template <typename tIntegerType> tIntegerType ShiftRight(tIntegerType a, int offset) { return a / (1 << offset); } template <typename tIntegerType> tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) { return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); } template <typename tIntegerType> tIntegerType MaskIfNonZero(tIntegerType a) { static const tIntegerType zero = 0; return a ? BitNot(zero) : zero; } template <typename tIntegerType> tIntegerType MaskIfZero(tIntegerType a) { return MaskIfNonZero<tIntegerType>(!a); } template <typename tIntegerType> tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero<tIntegerType>(a == b); } template <typename tIntegerType> tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero<tIntegerType>(a != b); } template <typename tIntegerType> tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { return MaskIfNonZero<tIntegerType>(a > b); } template <typename tIntegerType> tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero<tIntegerType>(a >= b); } template <typename tIntegerType> tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { return MaskIfNonZero<tIntegerType>(a < b); } template <typename tIntegerType> tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero<tIntegerType>(a <= b); } template <typename tIntegerType> bool All(tIntegerType a) { return a; } template <typename tIntegerType> bool Any(tIntegerType a) { return a; } template <typename IntegerType> IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); return a; } template <> inline int32_t RoundingHalfSum(int32_t a, int32_t b) { int64_t a64 = a; int64_t b64 = b; int64_t sum = a64 + b64; int64_t sign = sum >= 0 ? 1 : -1; return static_cast<int32_t>((sum + sign) / 2); } template <typename IntegerType> IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); return a; } // This function implements the same computation as the ARMv7 NEON VQRDMULH // instruction. template <> inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) { bool overflow = a == b && a == std::numeric_limits<int32_t>::min(); int64_t a_64(a); int64_t b_64(b); int64_t ab_64 = a_64 * b_64; int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31)); return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32; } template <int Exponent, typename IntegerType, int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> struct ImplSaturatingRoundingMultiplyByPOT {}; template <int Exponent, typename IntegerType> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { static IntegerType eval(IntegerType x) { return x; } }; template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, 1> { static int32_t eval(int32_t x) { const int64_t min = std::numeric_limits<int32_t>::min(); const int64_t max = std::numeric_limits<int32_t>::max(); return x >= (1 << (31 - Exponent)) ? max : x <= -(1 << (31 - Exponent)) ? min : x * (1 << Exponent); } }; template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, -1> { static int32_t eval(int32_t x) { int32_t b = (std::abs(x) & (1 << (-Exponent - 1))) >> (-Exponent - 1); int32_t nudge = x >= 0 ? b : -b; return x / (1 << -Exponent) + nudge; } }; template <int Exponent, typename IntegerType> IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); } template <typename tIntegerType> struct FixedPointRawTypeTraits {}; template <> struct FixedPointRawTypeTraits<int32_t> { typedef int32_t ScalarRawType; static const int kLanes = 1; }; template <typename tRawType> tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { return x; } template <typename tRawType, int tIntegerBits> class FixedPoint { public: typedef tRawType RawType; typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; typedef typename RawTypeTraits::ScalarRawType ScalarRawType; static const int kTotalBits = 8 * sizeof(ScalarRawType); static const int kIntegerBits = tIntegerBits; static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; static const ScalarRawType ScalarRawMin() { return std::numeric_limits<ScalarRawType>::min(); } static const ScalarRawType ScalarRawMax() { return std::numeric_limits<ScalarRawType>::max(); } static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); } static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); } static FixedPoint FromRaw(RawType x) { FixedPoint retval; retval.raw() = x; return retval; } static FixedPoint FromScalarRaw(ScalarRawType x) { FixedPoint retval; retval.raw() = Dup<RawType>(x); return retval; } static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); } template <int Exponent> static FixedPoint ConstantPOT() { static const int kOffset = kFractionalBits + Exponent; static_assert( kOffset < 31, "Constant not exactly representable in this fixed-point format"); return FromScalarRaw(ScalarRawType(1) << kOffset); } static FixedPoint Zero() { return FromScalarRaw(0); } static FixedPoint One() { return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() : (ScalarRawType(1) << kFractionalBits)); } RawType raw() const { return i_; } RawType& raw() { return i_; } private: RawType i_; }; template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( FixedPoint<tRawType, tIntegerBits_a> a, FixedPoint<tRawType, tIntegerBits_b> b) { FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); return c; } template <int tExponent, typename tRawType, int tIntegerBits> FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( FixedPoint<tRawType, tIntegerBits> a) { FixedPoint<tRawType, tExponent + tIntegerBits> c; c.raw() = a.raw(); return c; } template <int tExponent, typename tRawType, int tIntegerBits> FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( FixedPoint<tRawType, tIntegerBits> a) { return FixedPoint<tRawType, tIntegerBits>::FromRaw( SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); } #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ template <typename tRawType, int tIntegerBits> \ FixedPoint<tRawType, tIntegerBits> FuncName( \ FixedPoint<tRawType, tIntegerBits> a) { \ return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ } #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ template <typename tRawType, int tIntegerBits> \ FixedPoint<tRawType, tIntegerBits> FuncName( \ FixedPoint<tRawType, tIntegerBits> a, \ FixedPoint<tRawType, tIntegerBits> b) { \ return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ ImplFuncName(a.raw(), b.raw())); \ } MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) #undef MAKE_FIXEDPOINT_UNARY_FUNC #undef MAKE_FIXEDPOINT_BINARY_FUNC #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ template <typename tRawType, int tIntegerBits> \ tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ return FuncName(a.raw()); \ } #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ template <typename tRawType, int tIntegerBits> \ tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ FixedPoint<tRawType, tIntegerBits> b) { \ return FuncName(a.raw(), b.raw()); \ } MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW template <typename tRawType, int tIntegerBits> FixedPoint<tRawType, tIntegerBits> SelectUsingMask( tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, FixedPoint<tRawType, tIntegerBits> else_val) { return FixedPoint<tRawType, tIntegerBits>::FromRaw( SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); } template <typename tRawType, int tIntegerBits> bool operator==(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { return All(MaskIfEqual(a.raw(), b.raw())); } template <typename tRawType, int tIntegerBits> bool operator!=(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { return !(a == b); } template <typename tRawType, int tIntegerBits> double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, "not applicable to SIMD types"); typedef FixedPoint<tRawType, tIntegerBits> F; return x.raw() / double(1ll << F::kFractionalBits); } template <typename tRawType, int tIntegerBits> FixedPoint<tRawType, tIntegerBits> ToFixedPoint(double x) { typedef FixedPoint<tRawType, tIntegerBits> F; return F::FromScalarRaw(static_cast<int32_t>( std::min(std::max(round(x * double(1ll << F::kFractionalBits)), double(F::ScalarRawMin())), double(F::ScalarRawMax())))); } template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> FixedPoint<tRawType, tIntegerBitsDst> Rescale( FixedPoint<tRawType, tIntegerBitsSrc> x) { static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; FixedPoint<tRawType, tIntegerBitsDst> result; result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); return result; } #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS template <typename FixedPointType> FixedPointType CheckedFixedPointConstant( typename FixedPointType::ScalarRawType raw_value, double double_value) { typedef typename FixedPointType::RawType RawType; static const int kIntegerBits = FixedPointType::kIntegerBits; FixedPointType ref = FixedPointType::FromScalarRaw(raw_value); FixedPointType check = ToFixedPoint<RawType, kIntegerBits>(double_value); assert(ref == check); return ref; } #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ DoubleValue) \ (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) #else #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ DoubleValue) \ (FixedPointType::FromScalarRaw(ScalarRawValue)) #endif template <typename tRawType> FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( FixedPoint<tRawType, 0> a) { typedef FixedPoint<tRawType, 0> F; const F constant_term = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); const F constant_1_over_3 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); // We're evaluating a Taylor expansion around -1/8, so we do the change of // variable: x = a + 1/8. // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. F x = a + F::template ConstantPOT<-3>(); F x2 = x * x; F x3 = x2 * x; F x4 = x2 * x2; F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); F x4_over_24_plus_x3_over_6_plus_x2_over_2 = SaturatingRoundingMultiplyByPOT<-1>( ((x4_over_4 + x3) * constant_1_over_3) + x2); return constant_term + constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); } template <typename tRawType, int tIntegerBits> FixedPoint<tRawType, 0> exp_on_negative_values( FixedPoint<tRawType, tIntegerBits> a) { typedef FixedPoint<tRawType, tIntegerBits> InputF; typedef FixedPoint<tRawType, 0> ResultF; static const int kFractionalBits = InputF::kFractionalBits; static const int kIntegerBits = InputF::kIntegerBits; static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); InputF mask = kOneQuarter - InputF::FromScalarRaw(1); InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( Rescale<0>(a_mod_quarter_minus_one_quarter)); tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ if (kIntegerBits > Exponent) { \ const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ result = SelectUsingMask( \ MaskIfNonZero(BitAnd( \ remainder, Dup<tRawType>(1 << (kFractionalBits + Exponent)))), \ result * kMultiplier, result); \ } GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); #undef GEMMLOWP_EXP_BARREL_SHIFTER if (kIntegerBits > 5) { static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; const InputF clamp = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); } result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); return result; } template <typename tRawType> FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( FixedPoint<tRawType, 0> a) { typedef FixedPoint<tRawType, 0> F0; typedef FixedPoint<tRawType, 2> F2; F0 half_denominator = RoundingHalfSum(a, F0::One()); const F2 constant_48_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; for (int i = 0; i < 3; i++) { F2 half_denominator_times_x = half_denominator * x; F2 one_minus_half_denominator_times_x = F2::One() - half_denominator_times_x; x = x + Rescale<2>(x * one_minus_half_denominator_times_x); } return Rescale<0>(x - F2::One()); } template <typename tRawType, int tIntegerBits> FixedPoint<tRawType, 0> neg_tanh_on_negative_values( FixedPoint<tRawType, tIntegerBits> a) { return one_minus_x_over_one_plus_x_for_x_in_0_1( exp_on_negative_values(ExactMulByPot<1>(a))); } template <typename tRawType, int tIntegerBits> FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { typedef FixedPoint<tRawType, tIntegerBits> InputF; typedef FixedPoint<tRawType, 0> ResultF; tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); tRawType mask_if_zero = MaskIfZero(a); InputF n = SelectUsingMask(mask_if_negative, a, -a); ResultF t = neg_tanh_on_negative_values(n); return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t)); } } // end namespace gemmlowp #ifdef GEMMLOWP_NEON #include "fixedpoint_neon.h" #endif #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_