// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DTLS implementation.
//
// NOTE: This is a not even a remotely production-quality DTLS
// implementation. It is the bare minimum necessary to be able to
// achieve coverage on BoringSSL's implementation. Of note is that
// this implementation assumes the underlying net.PacketConn is not
// only reliable but also ordered. BoringSSL will be expected to deal
// with simulated loss, but there is no point in forcing the test
// driver to.
package main
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"net"
)
func versionToWire(vers uint16, isDTLS bool) uint16 {
if isDTLS {
return ^(vers - 0x0201)
}
return vers
}
func wireToVersion(vers uint16, isDTLS bool) uint16 {
if isDTLS {
return ^vers + 0x0201
}
return vers
}
func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
recordHeaderLen := dtlsRecordHeaderLen
if c.rawInput == nil {
c.rawInput = c.in.newBlock()
}
b := c.rawInput
// Read a new packet only if the current one is empty.
if len(b.data) == 0 {
// Pick some absurdly large buffer size.
b.resize(maxCiphertext + recordHeaderLen)
n, err := c.conn.Read(c.rawInput.data)
if err != nil {
return 0, nil, err
}
if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
}
c.rawInput.resize(n)
}
// Read out one record.
//
// A real DTLS implementation should be tolerant of errors,
// but this is test code. We should not be tolerant of our
// peer sending garbage.
if len(b.data) < recordHeaderLen {
return 0, nil, errors.New("dtls: failed to read record header")
}
typ := recordType(b.data[0])
vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
if c.haveVers {
if vers != c.vers {
c.sendAlert(alertProtocolVersion)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
}
} else {
if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
c.sendAlert(alertProtocolVersion)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
}
}
seq := b.data[3:11]
// For test purposes, we assume a reliable channel. Require
// that the explicit sequence number matches the incrementing
// one we maintain. A real implementation would maintain a
// replay window and such.
if !bytes.Equal(seq, c.in.seq[:]) {
c.sendAlert(alertIllegalParameter)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
}
n := int(b.data[11])<<8 | int(b.data[12])
if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
c.sendAlert(alertRecordOverflow)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
}
// Process message.
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, err := c.in.decrypt(b)
if !ok {
c.in.setErrorLocked(c.sendAlert(err))
}
b.off = off
return typ, b, nil
}
func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
fragment := make([]byte, 0, 12+fragLen)
fragment = append(fragment, header...)
fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
return fragment
}
func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
if typ != recordTypeHandshake {
// Only handshake messages are fragmented.
return c.dtlsWriteRawRecord(typ, data)
}
maxLen := c.config.Bugs.MaxHandshakeRecordLength
if maxLen <= 0 {
maxLen = 1024
}
// Handshake messages have to be modified to include fragment
// offset and length and with the header replicated. Save the
// TLS header here.
//
// TODO(davidben): This assumes that data contains exactly one
// handshake message. This is incompatible with
// FragmentAcrossChangeCipherSpec. (Which is unfortunate
// because OpenSSL's DTLS implementation will probably accept
// such fragmentation and could do with a fix + tests.)
header := data[:4]
data = data[4:]
isFinished := header[0] == typeFinished
if c.config.Bugs.SendEmptyFragments {
fragment := c.makeFragment(header, data, 0, 0)
c.pendingFragments = append(c.pendingFragments, fragment)
}
firstRun := true
fragOffset := 0
for firstRun || fragOffset < len(data) {
firstRun = false
fragLen := len(data) - fragOffset
if fragLen > maxLen {
fragLen = maxLen
}
fragment := c.makeFragment(header, data, fragOffset, fragLen)
if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
fragment[0]++
}
if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
fragment[3]++
}
// Buffer the fragment for later. They will be sent (and
// reordered) on flush.
c.pendingFragments = append(c.pendingFragments, fragment)
if c.config.Bugs.ReorderHandshakeFragments {
// Don't duplicate Finished to avoid the peer
// interpreting it as a retransmit request.
if !isFinished {
c.pendingFragments = append(c.pendingFragments, fragment)
}
if fragLen > (maxLen+1)/2 {
// Overlap each fragment by half.
fragLen = (maxLen + 1) / 2
}
}
fragOffset += fragLen
n += fragLen
}
if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments {
fragment := c.makeFragment(header, data, 0, len(data))
c.pendingFragments = append(c.pendingFragments, fragment)
}
// Increment the handshake sequence number for the next
// handshake message.
c.sendHandshakeSeq++
return
}
func (c *Conn) dtlsFlushHandshake() error {
if !c.isDTLS {
return nil
}
// This is a test-only DTLS implementation, so there is no need to
// retain |c.pendingFragments| for a future retransmit.
var fragments [][]byte
fragments, c.pendingFragments = c.pendingFragments, fragments
if c.config.Bugs.ReorderHandshakeFragments {
perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
tmp := make([][]byte, len(fragments))
for i := range tmp {
tmp[i] = fragments[perm[i]]
}
fragments = tmp
}
maxRecordLen := c.config.Bugs.PackHandshakeFragments
maxPacketLen := c.config.Bugs.PackHandshakeRecords
// Pack handshake fragments into records.
var records [][]byte
for _, fragment := range fragments {
if c.config.Bugs.SplitFragmentHeader {
records = append(records, fragment[:2])
records = append(records, fragment[2:])
} else if c.config.Bugs.SplitFragmentBody {
if len(fragment) > 12 {
records = append(records, fragment[:13])
records = append(records, fragment[13:])
} else {
records = append(records, fragment)
}
} else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen {
records[i] = append(records[i], fragment...)
} else {
// The fragment will be appended to, so copy it.
records = append(records, append([]byte{}, fragment...))
}
}
// Format them into packets.
var packets [][]byte
for _, record := range records {
b, err := c.dtlsSealRecord(recordTypeHandshake, record)
if err != nil {
return err
}
if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen {
packets[i] = append(packets[i], b.data...)
} else {
// The sealed record will be appended to and reused by
// |c.out|, so copy it.
packets = append(packets, append([]byte{}, b.data...))
}
c.out.freeBlock(b)
}
// Send all the packets.
for _, packet := range packets {
if _, err := c.conn.Write(packet); err != nil {
return err
}
}
return nil
}
// dtlsSealRecord seals a record into a block from |c.out|'s pool.
func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) {
recordHeaderLen := dtlsRecordHeaderLen
maxLen := c.config.Bugs.MaxHandshakeRecordLength
if maxLen <= 0 {
maxLen = 1024
}
b = c.out.newBlock()
explicitIVLen := 0
explicitIVIsSeq := false
if cbc, ok := c.out.cipher.(cbcMode); ok {
// Block cipher modes have an explicit IV.
explicitIVLen = cbc.BlockSize()
} else if aead, ok := c.out.cipher.(*tlsAead); ok {
if aead.explicitNonce {
explicitIVLen = 8
// The AES-GCM construction in TLS has an explicit nonce so that
// the nonce can be random. However, the nonce is only 8 bytes
// which is too small for a secure, random nonce. Therefore we
// use the sequence number as the nonce.
explicitIVIsSeq = true
}
} else if c.out.cipher != nil {
panic("Unknown cipher")
}
b.resize(recordHeaderLen + explicitIVLen + len(data))
b.data[0] = byte(typ)
vers := c.vers
if vers == 0 {
// Some TLS servers fail if the record version is greater than
// TLS 1.0 for the initial ClientHello.
vers = VersionTLS10
}
vers = versionToWire(vers, c.isDTLS)
b.data[1] = byte(vers >> 8)
b.data[2] = byte(vers)
// DTLS records include an explicit sequence number.
copy(b.data[3:11], c.out.seq[0:])
b.data[11] = byte(len(data) >> 8)
b.data[12] = byte(len(data))
if explicitIVLen > 0 {
explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
if explicitIVIsSeq {
copy(explicitIV, c.out.seq[:])
} else {
if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
return
}
}
}
copy(b.data[recordHeaderLen+explicitIVLen:], data)
c.out.encrypt(b, explicitIVLen)
return
}
func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) {
b, err := c.dtlsSealRecord(typ, data)
if err != nil {
return
}
_, err = c.conn.Write(b.data)
if err != nil {
return
}
n = len(data)
c.out.freeBlock(b)
if typ == recordTypeChangeCipherSpec {
err = c.out.changeCipherSpec(c.config)
if err != nil {
// Cannot call sendAlert directly,
// because we already hold c.out.Mutex.
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
}
return
}
func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
// Assemble a full handshake message. For test purposes, this
// implementation assumes fragments arrive in order. It may
// need to be cleverer if we ever test BoringSSL's retransmit
// behavior.
for len(c.handMsg) < 4+c.handMsgLen {
// Get a new handshake record if the previous has been
// exhausted.
if c.hand.Len() == 0 {
if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
}
}
// Read the next fragment. It must fit entirely within
// the record.
if c.hand.Len() < 12 {
return nil, errors.New("dtls: bad handshake record")
}
header := c.hand.Next(12)
fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
fragSeq := uint16(header[4])<<8 | uint16(header[5])
fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
if c.hand.Len() < fragLen {
return nil, errors.New("dtls: fragment length too long")
}
fragment := c.hand.Next(fragLen)
// Check it's a fragment for the right message.
if fragSeq != c.recvHandshakeSeq {
return nil, errors.New("dtls: bad handshake sequence number")
}
// Check that the length is consistent.
if c.handMsg == nil {
c.handMsgLen = fragN
if c.handMsgLen > maxHandshake {
return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
// Start with the TLS handshake header,
// without the DTLS bits.
c.handMsg = append([]byte{}, header[:4]...)
} else if fragN != c.handMsgLen {
return nil, errors.New("dtls: bad handshake length")
}
// Add the fragment to the pending message.
if 4+fragOff != len(c.handMsg) {
return nil, errors.New("dtls: bad fragment offset")
}
if fragOff+fragLen > c.handMsgLen {
return nil, errors.New("dtls: bad fragment length")
}
c.handMsg = append(c.handMsg, fragment...)
}
c.recvHandshakeSeq++
ret := c.handMsg
c.handMsg, c.handMsgLen = nil, 0
return ret, nil
}
// DTLSServer returns a new DTLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must have
// at least one certificate.
func DTLSServer(conn net.Conn, config *Config) *Conn {
c := &Conn{config: config, isDTLS: true, conn: conn}
c.init()
return c
}
// DTLSClient returns a new DTLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerHostname or
// InsecureSkipVerify in the config.
func DTLSClient(conn net.Conn, config *Config) *Conn {
c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
c.init()
return c
}