// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package socktest provides utilities for socket testing.
package socktest

import (
	"fmt"
	"sync"
)

// A Switch represents a callpath point switch for socket system
// calls.
type Switch struct {
	once sync.Once

	fmu   sync.RWMutex
	fltab map[FilterType]Filter

	smu   sync.RWMutex
	sotab Sockets
	stats stats
}

func (sw *Switch) init() {
	sw.fltab = make(map[FilterType]Filter)
	sw.sotab = make(Sockets)
	sw.stats = make(stats)
}

// Stats returns a list of per-cookie socket statistics.
func (sw *Switch) Stats() []Stat {
	var st []Stat
	sw.smu.RLock()
	for _, s := range sw.stats {
		ns := *s
		st = append(st, ns)
	}
	sw.smu.RUnlock()
	return st
}

// Sockets returns mappings of socket descriptor to socket status.
func (sw *Switch) Sockets() Sockets {
	sw.smu.RLock()
	tab := make(Sockets, len(sw.sotab))
	for i, s := range sw.sotab {
		tab[i] = s
	}
	sw.smu.RUnlock()
	return tab
}

// A Cookie represents a 3-tuple of a socket; address family, socket
// type and protocol number.
type Cookie uint64

// Family returns an address family.
func (c Cookie) Family() int { return int(c >> 48) }

// Type returns a socket type.
func (c Cookie) Type() int { return int(c << 16 >> 32) }

// Protocol returns a protocol number.
func (c Cookie) Protocol() int { return int(c & 0xff) }

func cookie(family, sotype, proto int) Cookie {
	return Cookie(family)<<48 | Cookie(sotype)&0xffffffff<<16 | Cookie(proto)&0xff
}

// A Status represents the status of a socket.
type Status struct {
	Cookie    Cookie
	Err       error // error status of socket system call
	SocketErr error // error status of socket by SO_ERROR
}

func (so Status) String() string {
	return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
}

// A Stat represents a per-cookie socket statistics.
type Stat struct {
	Family   int // address family
	Type     int // socket type
	Protocol int // protocol number

	Opened    uint64 // number of sockets opened
	Connected uint64 // number of sockets connected
	Listened  uint64 // number of sockets listened
	Accepted  uint64 // number of sockets accepted
	Closed    uint64 // number of sockets closed

	OpenFailed    uint64 // number of sockets open failed
	ConnectFailed uint64 // number of sockets connect failed
	ListenFailed  uint64 // number of sockets listen failed
	AcceptFailed  uint64 // number of sockets accept failed
	CloseFailed   uint64 // number of sockets close failed
}

func (st Stat) String() string {
	return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
}

type stats map[Cookie]*Stat

func (st stats) getLocked(c Cookie) *Stat {
	s, ok := st[c]
	if !ok {
		s = &Stat{Family: c.Family(), Type: c.Type(), Protocol: c.Protocol()}
		st[c] = s
	}
	return s
}

// A FilterType represents a filter type.
type FilterType int

const (
	FilterSocket        FilterType = iota // for Socket
	FilterConnect                         // for Connect or ConnectEx
	FilterListen                          // for Listen
	FilterAccept                          // for Accept, Accept4 or AcceptEx
	FilterGetsockoptInt                   // for GetsockoptInt
	FilterClose                           // for Close or Closesocket
)

// A Filter represents a socket system call filter.
//
// It will only be executed before a system call for a socket that has
// an entry in internal table.
// If the filter returns a non-nil error, the execution of system call
// will be canceled and the system call function returns the non-nil
// error.
// It can return a non-nil AfterFilter for filtering after the
// execution of the system call.
type Filter func(*Status) (AfterFilter, error)

func (f Filter) apply(st *Status) (AfterFilter, error) {
	if f == nil {
		return nil, nil
	}
	return f(st)
}

// An AfterFilter represents a socket system call filter after an
// execution of a system call.
//
// It will only be executed after a system call for a socket that has
// an entry in internal table.
// If the filter returns a non-nil error, the system call function
// returns the non-nil error.
type AfterFilter func(*Status) error

func (f AfterFilter) apply(st *Status) error {
	if f == nil {
		return nil
	}
	return f(st)
}

// Set deploys the socket system call filter f for the filter type t.
func (sw *Switch) Set(t FilterType, f Filter) {
	sw.once.Do(sw.init)
	sw.fmu.Lock()
	sw.fltab[t] = f
	sw.fmu.Unlock()
}