// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package aes
import (
"crypto/cipher"
"unsafe"
)
// Assert that aesCipherAsm implements the ctrAble interface.
var _ ctrAble = (*aesCipherAsm)(nil)
// xorBytes xors the contents of a and b and places the resulting values into
// dst. If a and b are not the same length then the number of bytes processed
// will be equal to the length of shorter of the two. Returns the number
// of bytes processed.
//go:noescape
func xorBytes(dst, a, b []byte) int
// streamBufferSize is the number of bytes of encrypted counter values to cache.
const streamBufferSize = 32 * BlockSize
type aesctr struct {
block *aesCipherAsm // block cipher
ctr [2]uint64 // next value of the counter (big endian)
buffer []byte // buffer for the encrypted counter values
storage [streamBufferSize]byte // array backing buffer slice
}
// NewCTR returns a Stream which encrypts/decrypts using the AES block
// cipher in counter mode. The length of iv must be the same as BlockSize.
func (c *aesCipherAsm) NewCTR(iv []byte) cipher.Stream {
if len(iv) != BlockSize {
panic("cipher.NewCTR: IV length must equal block size")
}
var ac aesctr
ac.block = c
ac.ctr[0] = *(*uint64)(unsafe.Pointer((&iv[0]))) // high bits
ac.ctr[1] = *(*uint64)(unsafe.Pointer((&iv[8]))) // low bits
ac.buffer = ac.storage[:0]
return &ac
}
func (c *aesctr) refill() {
// Fill up the buffer with an incrementing count.
c.buffer = c.storage[:streamBufferSize]
c0, c1 := c.ctr[0], c.ctr[1]
for i := 0; i < streamBufferSize; i += BlockSize {
b0 := (*uint64)(unsafe.Pointer(&c.buffer[i]))
b1 := (*uint64)(unsafe.Pointer(&c.buffer[i+BlockSize/2]))
*b0, *b1 = c0, c1
// Increment in big endian: c0 is high, c1 is low.
c1++
if c1 == 0 {
// add carry
c0++
}
}
c.ctr[0], c.ctr[1] = c0, c1
// Encrypt the buffer using AES in ECB mode.
cryptBlocks(c.block.function, &c.block.key[0], &c.buffer[0], &c.buffer[0], streamBufferSize)
}
func (c *aesctr) XORKeyStream(dst, src []byte) {
if len(src) > 0 {
// Assert len(dst) >= len(src)
_ = dst[len(src)-1]
}
for len(src) > 0 {
if len(c.buffer) == 0 {
c.refill()
}
n := xorBytes(dst, src, c.buffer)
c.buffer = c.buffer[n:]
src = src[n:]
dst = dst[n:]
}
}