// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

// Conservative resource-related analysis of programs.
// The analysis figures out what files descriptors are [potentially] opened
// at a particular point in program, what pages are [potentially] mapped,
// what files were already referenced in calls, etc.

package prog

import (
	"fmt"
)

type state struct {
	target    *Target
	ct        *ChoiceTable
	files     map[string]bool
	resources map[string][]*ResultArg
	strings   map[string]bool
	ma        *memAlloc
	va        *vmaAlloc
}

// analyze analyzes the program p up to but not including call c.
func analyze(ct *ChoiceTable, p *Prog, c *Call) *state {
	s := newState(p.Target, ct)
	resources := true
	for _, c1 := range p.Calls {
		if c1 == c {
			resources = false
		}
		s.analyzeImpl(c1, resources)
	}
	return s
}

func newState(target *Target, ct *ChoiceTable) *state {
	s := &state{
		target:    target,
		ct:        ct,
		files:     make(map[string]bool),
		resources: make(map[string][]*ResultArg),
		strings:   make(map[string]bool),
		ma:        newMemAlloc(target.NumPages * target.PageSize),
		va:        newVmaAlloc(target.NumPages),
	}
	return s
}

func (s *state) analyze(c *Call) {
	s.analyzeImpl(c, true)
}

func (s *state) analyzeImpl(c *Call, resources bool) {
	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
		switch a := arg.(type) {
		case *PointerArg:
			switch {
			case a.IsNull():
			case a.VmaSize != 0:
				s.va.noteAlloc(a.Address/s.target.PageSize, a.VmaSize/s.target.PageSize)
			default:
				s.ma.noteAlloc(a.Address, a.Res.Size())
			}
		}
		switch typ := arg.Type().(type) {
		case *ResourceType:
			a := arg.(*ResultArg)
			if resources && typ.Dir() != DirIn {
				s.resources[typ.Desc.Name] = append(s.resources[typ.Desc.Name], a)
				// TODO: negative PIDs and add them as well (that's process groups).
			}
		case *BufferType:
			a := arg.(*DataArg)
			if typ.Dir() != DirOut && len(a.Data()) != 0 {
				val := string(a.Data())
				// Remove trailing zero padding.
				for len(val) >= 2 && val[len(val)-1] == 0 && val[len(val)-2] == 0 {
					val = val[:len(val)-1]
				}
				switch typ.Kind {
				case BufferString:
					s.strings[val] = true
				case BufferFilename:
					if len(val) < 3 {
						// This is not our file, probalby one of specialFiles.
						return
					}
					if val[len(val)-1] == 0 {
						val = val[:len(val)-1]
					}
					s.files[val] = true
				}
			}
		}
	})
}

type ArgCtx struct {
	Parent *[]Arg      // GroupArg.Inner (for structs) or Call.Args containing this arg
	Base   *PointerArg // pointer to the base of the heap object containing this arg
	Offset uint64      // offset of this arg from the base
	Stop   bool        // if set by the callback, subargs of this arg are not visited
}

func ForeachSubArg(arg Arg, f func(Arg, *ArgCtx)) {
	foreachArgImpl(arg, ArgCtx{}, f)
}

func ForeachArg(c *Call, f func(Arg, *ArgCtx)) {
	ctx := ArgCtx{}
	if c.Ret != nil {
		foreachArgImpl(c.Ret, ctx, f)
	}
	ctx.Parent = &c.Args
	for _, arg := range c.Args {
		foreachArgImpl(arg, ctx, f)
	}
}

func foreachArgImpl(arg Arg, ctx ArgCtx, f func(Arg, *ArgCtx)) {
	f(arg, &ctx)
	if ctx.Stop {
		return
	}
	switch a := arg.(type) {
	case *GroupArg:
		if _, ok := a.Type().(*StructType); ok {
			ctx.Parent = &a.Inner
		}
		var totalSize uint64
		for _, arg1 := range a.Inner {
			foreachArgImpl(arg1, ctx, f)
			if !arg1.Type().BitfieldMiddle() {
				size := arg1.Size()
				ctx.Offset += size
				totalSize += size
			}
		}
		claimedSize := a.Size()
		varlen := a.Type().Varlen()
		if varlen && totalSize > claimedSize || !varlen && totalSize != claimedSize {
			panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %#v type %#v",
				totalSize, claimedSize, a, a.Type()))
		}
	case *PointerArg:
		if a.Res != nil {
			ctx.Base = a
			ctx.Offset = 0
			foreachArgImpl(a.Res, ctx, f)
		}
	case *UnionArg:
		foreachArgImpl(a.Option, ctx, f)
	}
}

func RequiredFeatures(p *Prog) (bitmasks, csums bool) {
	for _, c := range p.Calls {
		ForeachArg(c, func(arg Arg, _ *ArgCtx) {
			if a, ok := arg.(*ConstArg); ok {
				if a.Type().BitfieldOffset() != 0 || a.Type().BitfieldLength() != 0 {
					bitmasks = true
				}
			}
			if _, ok := arg.Type().(*CsumType); ok {
				csums = true
			}
		})
	}
	return
}

type CallFlags int

const (
	CallExecuted CallFlags = 1 << iota // was started at all
	CallFinished                       // finished executing (rather than blocked forever)
	CallBlocked                        // finished but blocked during execution
)

type CallInfo struct {
	Flags  CallFlags
	Errno  int
	Signal []uint32
}

const (
	fallbackSignalErrno = iota
	fallbackSignalErrnoBlocked
	fallbackSignalCtor
	fallbackSignalFlags
	fallbackCallMask = 0x1fff
)

func (p *Prog) FallbackSignal(info []CallInfo) {
	resources := make(map[*ResultArg]*Call)
	for i, c := range p.Calls {
		inf := &info[i]
		if inf.Flags&CallExecuted == 0 {
			continue
		}
		id := c.Meta.ID
		typ := fallbackSignalErrno
		if inf.Flags&CallFinished != 0 && inf.Flags&CallBlocked != 0 {
			typ = fallbackSignalErrnoBlocked
		}
		inf.Signal = append(inf.Signal, encodeFallbackSignal(typ, id, inf.Errno))
		if inf.Errno != 0 {
			continue
		}
		ForeachArg(c, func(arg Arg, _ *ArgCtx) {
			if a, ok := arg.(*ResultArg); ok {
				resources[a] = c
			}
		})
		// Specifically look only at top-level arguments,
		// deeper arguments can produce too much false signal.
		flags := 0
		for _, arg := range c.Args {
			switch a := arg.(type) {
			case *ResultArg:
				flags <<= 1
				if a.Res != nil {
					ctor := resources[a.Res]
					if ctor != nil {
						inf.Signal = append(inf.Signal,
							encodeFallbackSignal(fallbackSignalCtor, id, ctor.Meta.ID))
					}
				} else {
					if a.Val != a.Type().(*ResourceType).SpecialValues()[0] {
						flags |= 1
					}
				}
			case *ConstArg:
				const width = 3
				flags <<= width
				switch typ := a.Type().(type) {
				case *FlagsType:
					if typ.BitMask {
						for i, v := range typ.Vals {
							if a.Val&v != 0 {
								flags ^= 1 << (uint(i) % width)
							}
						}
					} else {
						for i, v := range typ.Vals {
							if a.Val == v {
								flags |= i % (1 << width)
								break
							}
						}
					}
				case *LenType:
					flags <<= 1
					if a.Val == 0 {
						flags |= 1
					}
				}
			case *PointerArg:
				flags <<= 1
				if a.IsNull() {
					flags |= 1
				}
			}
		}
		if flags != 0 {
			inf.Signal = append(inf.Signal,
				encodeFallbackSignal(fallbackSignalFlags, id, flags))
		}
	}
}

func DecodeFallbackSignal(s uint32) (callID, errno int) {
	typ, id, aux := decodeFallbackSignal(s)
	switch typ {
	case fallbackSignalErrno, fallbackSignalErrnoBlocked:
		return id, aux
	case fallbackSignalCtor, fallbackSignalFlags:
		return id, 0
	default:
		panic(fmt.Sprintf("bad fallback signal type %v", typ))
	}
}

func encodeFallbackSignal(typ, id, aux int) uint32 {
	if typ & ^7 != 0 {
		panic(fmt.Sprintf("bad fallback signal type %v", typ))
	}
	if id & ^fallbackCallMask != 0 {
		panic(fmt.Sprintf("bad call id in fallback signal %v", id))
	}
	return uint32(typ) | uint32(id&fallbackCallMask)<<3 | uint32(aux)<<16
}

func decodeFallbackSignal(s uint32) (typ, id, aux int) {
	return int(s & 7), int((s >> 3) & fallbackCallMask), int(s >> 16)
}