// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package runner

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/hmac"
	"crypto/sha256"
	"crypto/subtle"
	"encoding/binary"
	"errors"
	"io"
	"time"
)

// sessionState contains the information that is serialized into a session
// ticket in order to later resume a connection.
type sessionState struct {
	vers                 uint16
	cipherSuite          uint16
	masterSecret         []byte
	handshakeHash        []byte
	certificates         [][]byte
	extendedMasterSecret bool
	earlyALPN            []byte
	ticketCreationTime   time.Time
	ticketExpiration     time.Time
	ticketFlags          uint32
	ticketAgeAdd         uint32
}

func (s *sessionState) marshal() []byte {
	msg := newByteBuilder()
	msg.addU16(s.vers)
	msg.addU16(s.cipherSuite)
	masterSecret := msg.addU16LengthPrefixed()
	masterSecret.addBytes(s.masterSecret)
	handshakeHash := msg.addU16LengthPrefixed()
	handshakeHash.addBytes(s.handshakeHash)
	msg.addU16(uint16(len(s.certificates)))
	for _, cert := range s.certificates {
		certMsg := msg.addU32LengthPrefixed()
		certMsg.addBytes(cert)
	}

	if s.extendedMasterSecret {
		msg.addU8(1)
	} else {
		msg.addU8(0)
	}

	if s.vers >= VersionTLS13 {
		msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
		msg.addU64(uint64(s.ticketExpiration.UnixNano()))
		msg.addU32(s.ticketFlags)
		msg.addU32(s.ticketAgeAdd)
	}

	earlyALPN := msg.addU16LengthPrefixed()
	earlyALPN.addBytes(s.earlyALPN)

	return msg.finish()
}

func (s *sessionState) unmarshal(data []byte) bool {
	if len(data) < 8 {
		return false
	}

	s.vers = uint16(data[0])<<8 | uint16(data[1])
	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
	masterSecretLen := int(data[4])<<8 | int(data[5])
	data = data[6:]
	if len(data) < masterSecretLen {
		return false
	}

	s.masterSecret = data[:masterSecretLen]
	data = data[masterSecretLen:]

	if len(data) < 2 {
		return false
	}

	handshakeHashLen := int(data[0])<<8 | int(data[1])
	data = data[2:]
	if len(data) < handshakeHashLen {
		return false
	}

	s.handshakeHash = data[:handshakeHashLen]
	data = data[handshakeHashLen:]

	if len(data) < 2 {
		return false
	}

	numCerts := int(data[0])<<8 | int(data[1])
	data = data[2:]

	s.certificates = make([][]byte, numCerts)
	for i := range s.certificates {
		if len(data) < 4 {
			return false
		}
		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
		data = data[4:]
		if certLen < 0 {
			return false
		}
		if len(data) < certLen {
			return false
		}
		s.certificates[i] = data[:certLen]
		data = data[certLen:]
	}

	if len(data) < 1 {
		return false
	}

	s.extendedMasterSecret = false
	if data[0] == 1 {
		s.extendedMasterSecret = true
	}
	data = data[1:]

	if s.vers >= VersionTLS13 {
		if len(data) < 24 {
			return false
		}
		s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
		data = data[8:]
		s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
		data = data[8:]
		s.ticketFlags = binary.BigEndian.Uint32(data)
		data = data[4:]
		s.ticketAgeAdd = binary.BigEndian.Uint32(data)
		data = data[4:]
	}

	earlyALPNLen := int(data[0])<<8 | int(data[1])
	data = data[2:]
	if len(data) < earlyALPNLen {
		return false
	}
	s.earlyALPN = data[:earlyALPNLen]
	data = data[earlyALPNLen:]

	if len(data) > 0 {
		return false
	}

	return true
}

func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
	serialized := state.marshal()
	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
	iv := encrypted[:aes.BlockSize]
	macBytes := encrypted[len(encrypted)-sha256.Size:]

	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
		return nil, err
	}
	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
	if err != nil {
		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
	}
	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)

	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
	mac.Write(encrypted[:len(encrypted)-sha256.Size])
	mac.Sum(macBytes[:0])

	return encrypted, nil
}

func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
	if len(encrypted) < aes.BlockSize+sha256.Size {
		return nil, false
	}

	iv := encrypted[:aes.BlockSize]
	macBytes := encrypted[len(encrypted)-sha256.Size:]

	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
	mac.Write(encrypted[:len(encrypted)-sha256.Size])
	expected := mac.Sum(nil)

	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
		return nil, false
	}

	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
	if err != nil {
		return nil, false
	}
	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
	plaintext := make([]byte, len(ciphertext))
	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)

	state := new(sessionState)
	ok := state.unmarshal(plaintext)
	return state, ok
}