// Copyright (c) 2016, Google Inc.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package runner
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/asn1"
"errors"
)
// TestShimTicketKey is the testing key assumed for the shim.
var TestShimTicketKey = make([]byte, 48)
func DecryptShimTicket(in []byte) ([]byte, error) {
name := TestShimTicketKey[:16]
macKey := TestShimTicketKey[16:32]
encKey := TestShimTicketKey[32:48]
h := hmac.New(sha256.New, macKey)
block, err := aes.NewCipher(encKey)
if err != nil {
panic(err)
}
if len(in) < len(name)+block.BlockSize()+1+h.Size() {
return nil, errors.New("tls: shim ticket too short")
}
// Check the key name.
if !bytes.Equal(name, in[:len(name)]) {
return nil, errors.New("tls: shim ticket name mismatch")
}
// Check the MAC at the end of the ticket.
mac := in[len(in)-h.Size():]
in = in[:len(in)-h.Size()]
h.Write(in)
if !hmac.Equal(mac, h.Sum(nil)) {
return nil, errors.New("tls: shim ticket MAC mismatch")
}
// The MAC covers the key name, but the encryption does not.
in = in[len(name):]
// Decrypt in-place.
iv := in[:block.BlockSize()]
in = in[block.BlockSize():]
if l := len(in); l == 0 || l%block.BlockSize() != 0 {
return nil, errors.New("tls: ticket ciphertext not a multiple of the block size")
}
out := make([]byte, len(in))
cbc := cipher.NewCBCDecrypter(block, iv)
cbc.CryptBlocks(out, in)
// Remove the padding.
pad := int(out[len(out)-1])
if pad == 0 || pad > block.BlockSize() || pad > len(in) {
return nil, errors.New("tls: bad shim ticket CBC pad")
}
for i := 0; i < pad; i++ {
if out[len(out)-1-i] != byte(pad) {
return nil, errors.New("tls: bad shim ticket CBC pad")
}
}
return out[:len(out)-pad], nil
}
func EncryptShimTicket(in []byte) []byte {
name := TestShimTicketKey[:16]
macKey := TestShimTicketKey[16:32]
encKey := TestShimTicketKey[32:48]
h := hmac.New(sha256.New, macKey)
block, err := aes.NewCipher(encKey)
if err != nil {
panic(err)
}
// Use the zero IV for rewritten tickets.
iv := make([]byte, block.BlockSize())
cbc := cipher.NewCBCEncrypter(block, iv)
pad := block.BlockSize() - (len(in) % block.BlockSize())
out := make([]byte, 0, len(name)+len(iv)+len(in)+pad+h.Size())
out = append(out, name...)
out = append(out, iv...)
out = append(out, in...)
for i := 0; i < pad; i++ {
out = append(out, byte(pad))
}
ciphertext := out[len(name)+len(iv):]
cbc.CryptBlocks(ciphertext, ciphertext)
h.Write(out)
return h.Sum(out)
}
const asn1Constructed = 0x20
func parseDERElement(in []byte) (tag byte, body, rest []byte, ok bool) {
rest = in
if len(rest) < 1 {
return
}
tag = rest[0]
rest = rest[1:]
if tag&0x1f == 0x1f {
// Long-form tags not supported.
return
}
if len(rest) < 1 {
return
}
length := int(rest[0])
rest = rest[1:]
if length > 0x7f {
lengthLength := length & 0x7f
length = 0
if lengthLength == 0 {
// No indefinite-length encoding.
return
}
// Decode long-form lengths.
for lengthLength > 0 {
if len(rest) < 1 || (length<<8)>>8 != length {
return
}
if length == 0 && rest[0] == 0 {
// Length not minimally-encoded.
return
}
length <<= 8
length |= int(rest[0])
rest = rest[1:]
lengthLength--
}
if length < 0x80 {
// Length not minimally-encoded.
return
}
}
if len(rest) < length {
return
}
body = rest[:length]
rest = rest[length:]
ok = true
return
}
func SetShimTicketVersion(in []byte, vers uint16) ([]byte, error) {
plaintext, err := DecryptShimTicket(in)
if err != nil {
return nil, err
}
tag, session, _, ok := parseDERElement(plaintext)
if !ok || tag != asn1.TagSequence|asn1Constructed {
return nil, errors.New("tls: could not decode shim session")
}
// Skip the session version.
tag, _, session, ok = parseDERElement(session)
if !ok || tag != asn1.TagInteger {
return nil, errors.New("tls: could not decode shim session")
}
// Next field is the protocol version.
tag, version, _, ok := parseDERElement(session)
if !ok || tag != asn1.TagInteger {
return nil, errors.New("tls: could not decode shim session")
}
// This code assumes both old and new versions are encoded in two
// bytes. This isn't quite right as INTEGERs are minimally-encoded, but
// we do not need to support other caess for now.
if len(version) != 2 || vers < 0x80 || vers >= 0x8000 {
return nil, errors.New("tls: unsupported version in shim session")
}
version[0] = byte(vers >> 8)
version[1] = byte(vers)
return EncryptShimTicket(plaintext), nil
}
func SetShimTicketCipherSuite(in []byte, id uint16) ([]byte, error) {
plaintext, err := DecryptShimTicket(in)
if err != nil {
return nil, err
}
tag, session, _, ok := parseDERElement(plaintext)
if !ok || tag != asn1.TagSequence|asn1Constructed {
return nil, errors.New("tls: could not decode shim session")
}
// Skip the session version.
tag, _, session, ok = parseDERElement(session)
if !ok || tag != asn1.TagInteger {
return nil, errors.New("tls: could not decode shim session")
}
// Skip the protocol version.
tag, _, session, ok = parseDERElement(session)
if !ok || tag != asn1.TagInteger {
return nil, errors.New("tls: could not decode shim session")
}
// Next field is the cipher suite.
tag, cipherSuite, _, ok := parseDERElement(session)
if !ok || tag != asn1.TagOctetString || len(cipherSuite) != 2 {
return nil, errors.New("tls: could not decode shim session")
}
cipherSuite[0] = byte(id >> 8)
cipherSuite[1] = byte(id)
return EncryptShimTicket(plaintext), nil
}