// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package gc

import (
	"fmt"
	"math/big"
)

// implements integer arithmetic

// Mpint represents an integer constant.
type Mpint struct {
	Val  big.Int
	Ovf  bool // set if Val overflowed compiler limit (sticky)
	Rune bool // set if syntax indicates default type rune
}

func (a *Mpint) SetOverflow() {
	a.Val.SetUint64(1) // avoid spurious div-zero errors
	a.Ovf = true
}

func (a *Mpint) checkOverflow(extra int) bool {
	// We don't need to be precise here, any reasonable upper limit would do.
	// For now, use existing limit so we pass all the tests unchanged.
	if a.Val.BitLen()+extra > Mpprec {
		a.SetOverflow()
	}
	return a.Ovf
}

func (a *Mpint) Set(b *Mpint) {
	a.Val.Set(&b.Val)
}

func (a *Mpint) SetFloat(b *Mpflt) bool {
	// avoid converting huge floating-point numbers to integers
	// (2*Mpprec is large enough to permit all tests to pass)
	if b.Val.MantExp(nil) > 2*Mpprec {
		a.SetOverflow()
		return false
	}

	if _, acc := b.Val.Int(&a.Val); acc == big.Exact {
		return true
	}

	const delta = 16 // a reasonably small number of bits > 0
	var t big.Float
	t.SetPrec(Mpprec - delta)

	// try rounding down a little
	t.SetMode(big.ToZero)
	t.Set(&b.Val)
	if _, acc := t.Int(&a.Val); acc == big.Exact {
		return true
	}

	// try rounding up a little
	t.SetMode(big.AwayFromZero)
	t.Set(&b.Val)
	if _, acc := t.Int(&a.Val); acc == big.Exact {
		return true
	}

	a.Ovf = false
	return false
}

func (a *Mpint) Add(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Add")
		}
		a.SetOverflow()
		return
	}

	a.Val.Add(&a.Val, &b.Val)

	if a.checkOverflow(0) {
		yyerror("constant addition overflow")
	}
}

func (a *Mpint) Sub(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Sub")
		}
		a.SetOverflow()
		return
	}

	a.Val.Sub(&a.Val, &b.Val)

	if a.checkOverflow(0) {
		yyerror("constant subtraction overflow")
	}
}

func (a *Mpint) Mul(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Mul")
		}
		a.SetOverflow()
		return
	}

	a.Val.Mul(&a.Val, &b.Val)

	if a.checkOverflow(0) {
		yyerror("constant multiplication overflow")
	}
}

func (a *Mpint) Quo(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Quo")
		}
		a.SetOverflow()
		return
	}

	a.Val.Quo(&a.Val, &b.Val)

	if a.checkOverflow(0) {
		// can only happen for div-0 which should be checked elsewhere
		yyerror("constant division overflow")
	}
}

func (a *Mpint) Rem(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Rem")
		}
		a.SetOverflow()
		return
	}

	a.Val.Rem(&a.Val, &b.Val)

	if a.checkOverflow(0) {
		// should never happen
		yyerror("constant modulo overflow")
	}
}

func (a *Mpint) Or(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Or")
		}
		a.SetOverflow()
		return
	}

	a.Val.Or(&a.Val, &b.Val)
}

func (a *Mpint) And(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint And")
		}
		a.SetOverflow()
		return
	}

	a.Val.And(&a.Val, &b.Val)
}

func (a *Mpint) AndNot(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint AndNot")
		}
		a.SetOverflow()
		return
	}

	a.Val.AndNot(&a.Val, &b.Val)
}

func (a *Mpint) Xor(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Xor")
		}
		a.SetOverflow()
		return
	}

	a.Val.Xor(&a.Val, &b.Val)
}

func (a *Mpint) Lsh(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Lsh")
		}
		a.SetOverflow()
		return
	}

	s := b.Int64()
	if s < 0 || s >= Mpprec {
		msg := "shift count too large"
		if s < 0 {
			msg = "invalid negative shift count"
		}
		yyerror("%s: %d", msg, s)
		a.SetInt64(0)
		return
	}

	if a.checkOverflow(int(s)) {
		yyerror("constant shift overflow")
		return
	}
	a.Val.Lsh(&a.Val, uint(s))
}

func (a *Mpint) Rsh(b *Mpint) {
	if a.Ovf || b.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("ovf in Mpint Rsh")
		}
		a.SetOverflow()
		return
	}

	s := b.Int64()
	if s < 0 {
		yyerror("invalid negative shift count: %d", s)
		if a.Val.Sign() < 0 {
			a.SetInt64(-1)
		} else {
			a.SetInt64(0)
		}
		return
	}

	a.Val.Rsh(&a.Val, uint(s))
}

func (a *Mpint) Cmp(b *Mpint) int {
	return a.Val.Cmp(&b.Val)
}

func (a *Mpint) CmpInt64(c int64) int {
	if c == 0 {
		return a.Val.Sign() // common case shortcut
	}
	return a.Val.Cmp(big.NewInt(c))
}

func (a *Mpint) Neg() {
	a.Val.Neg(&a.Val)
}

func (a *Mpint) Int64() int64 {
	if a.Ovf {
		if nsavederrors+nerrors == 0 {
			Fatalf("constant overflow")
		}
		return 0
	}

	return a.Val.Int64()
}

func (a *Mpint) SetInt64(c int64) {
	a.Val.SetInt64(c)
}

func (a *Mpint) SetString(as string) {
	_, ok := a.Val.SetString(as, 0)
	if !ok {
		// required syntax is [+-][0[x]]d*
		// At the moment we lose precise error cause;
		// the old code distinguished between:
		// - malformed hex constant
		// - malformed octal constant
		// - malformed decimal constant
		// TODO(gri) use different conversion function
		yyerror("malformed integer constant: %s", as)
		a.Val.SetUint64(0)
		return
	}
	if a.checkOverflow(0) {
		yyerror("constant too large: %s", as)
	}
}

func (x *Mpint) String() string {
	return bconv(x, 0)
}

func bconv(xval *Mpint, flag FmtFlag) string {
	if flag&FmtSharp != 0 {
		return fmt.Sprintf("%#x", &xval.Val)
	}
	return xval.Val.String()
}