// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Calibration used to determine thresholds for using
// different algorithms.  Ideally, this would be converted
// to go generate to create thresholds.go

// This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get
// a clear picture.

// Calculates lower and upper thresholds for when basicSqr
// is faster than standard multiplication.

// Usage: go test -run=TestCalibrate -v -calibrate

package big

import (
	"flag"
	"fmt"
	"testing"
	"time"
)

var calibrate = flag.Bool("calibrate", false, "run calibration test")

func TestCalibrate(t *testing.T) {
	if *calibrate {
		computeKaratsubaThresholds()

		// compute basicSqrThreshold where overhead becomes negligible
		minSqr := computeSqrThreshold(10, 30, 1, 3)
		// compute karatsubaSqrThreshold where karatsuba is faster
		maxSqr := computeSqrThreshold(300, 500, 10, 3)
		if minSqr != 0 {
			fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
		} else {
			fmt.Println("no basicSqrThreshold found")
		}
		if maxSqr != 0 {
			fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
		} else {
			fmt.Println("no karatsubaSqrThreshold found")
		}
	}
}

func karatsubaLoad(b *testing.B) {
	BenchmarkMul(b)
}

// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark
// given Karatsuba threshold th.
func measureKaratsuba(th int) time.Duration {
	th, karatsubaThreshold = karatsubaThreshold, th
	res := testing.Benchmark(karatsubaLoad)
	karatsubaThreshold = th
	return time.Duration(res.NsPerOp())
}

func computeKaratsubaThresholds() {
	fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
	fmt.Printf("(run repeatedly for good results)\n")

	// determine Tk, the work load execution time using basic multiplication
	Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled
	fmt.Printf("Tb = %10s\n", Tb)

	// thresholds
	th := 4
	th1 := -1
	th2 := -1

	var deltaOld time.Duration
	for count := -1; count != 0 && th < 128; count-- {
		// determine Tk, the work load execution time using Karatsuba multiplication
		Tk := measureKaratsuba(th)

		// improvement over Tb
		delta := (Tb - Tk) * 100 / Tb

		fmt.Printf("th = %3d  Tk = %10s  %4d%%", th, Tk, delta)

		// determine break-even point
		if Tk < Tb && th1 < 0 {
			th1 = th
			fmt.Print("  break-even point")
		}

		// determine diminishing return
		if 0 < delta && delta < deltaOld && th2 < 0 {
			th2 = th
			fmt.Print("  diminishing return")
		}
		deltaOld = delta

		fmt.Println()

		// trigger counter
		if th1 >= 0 && th2 >= 0 && count < 0 {
			count = 10 // this many extra measurements after we got both thresholds
		}

		th++
	}
}

func measureBasicSqr(words, nruns int, enable bool) time.Duration {
	// more runs for better statistics
	initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold

	if enable {
		// set thresholds to use basicSqr at this number of words
		basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
	} else {
		// set thresholds to disable basicSqr for any number of words
		basicSqrThreshold, karatsubaSqrThreshold = -1, -1
	}

	var testval int64
	for i := 0; i < nruns; i++ {
		res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
		testval += res.NsPerOp()
	}
	testval /= int64(nruns)

	basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr

	return time.Duration(testval)
}

func computeSqrThreshold(from, to, step, nruns int) int {
	fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
	fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
	var initPos bool
	var threshold int
	for i := from; i <= to; i += step {
		baseline := measureBasicSqr(i, nruns, false)
		testval := measureBasicSqr(i, nruns, true)
		pos := baseline > testval
		delta := baseline - testval
		percent := delta * 100 / baseline
		fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
		if i == from {
			initPos = pos
		}
		if threshold == 0 && pos != initPos {
			threshold = i
			fmt.Printf("  threshold  found")
		}
		fmt.Println()

	}
	if threshold != 0 {
		fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
	} else {
		fmt.Printf("Found NO threshold between %d - %d\n", from, to)
	}
	return threshold
}