// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package aes import ( "crypto/cipher" "unsafe" ) // Assert that aesCipherAsm implements the ctrAble interface. var _ ctrAble = (*aesCipherAsm)(nil) // xorBytes xors the contents of a and b and places the resulting values into // dst. If a and b are not the same length then the number of bytes processed // will be equal to the length of shorter of the two. Returns the number // of bytes processed. //go:noescape func xorBytes(dst, a, b []byte) int // streamBufferSize is the number of bytes of encrypted counter values to cache. const streamBufferSize = 32 * BlockSize type aesctr struct { block *aesCipherAsm // block cipher ctr [2]uint64 // next value of the counter (big endian) buffer []byte // buffer for the encrypted counter values storage [streamBufferSize]byte // array backing buffer slice } // NewCTR returns a Stream which encrypts/decrypts using the AES block // cipher in counter mode. The length of iv must be the same as BlockSize. func (c *aesCipherAsm) NewCTR(iv []byte) cipher.Stream { if len(iv) != BlockSize { panic("cipher.NewCTR: IV length must equal block size") } var ac aesctr ac.block = c ac.ctr[0] = *(*uint64)(unsafe.Pointer((&iv[0]))) // high bits ac.ctr[1] = *(*uint64)(unsafe.Pointer((&iv[8]))) // low bits ac.buffer = ac.storage[:0] return &ac } func (c *aesctr) refill() { // Fill up the buffer with an incrementing count. c.buffer = c.storage[:streamBufferSize] c0, c1 := c.ctr[0], c.ctr[1] for i := 0; i < streamBufferSize; i += BlockSize { b0 := (*uint64)(unsafe.Pointer(&c.buffer[i])) b1 := (*uint64)(unsafe.Pointer(&c.buffer[i+BlockSize/2])) *b0, *b1 = c0, c1 // Increment in big endian: c0 is high, c1 is low. c1++ if c1 == 0 { // add carry c0++ } } c.ctr[0], c.ctr[1] = c0, c1 // Encrypt the buffer using AES in ECB mode. cryptBlocks(c.block.function, &c.block.key[0], &c.buffer[0], &c.buffer[0], streamBufferSize) } func (c *aesctr) XORKeyStream(dst, src []byte) { if len(src) > 0 { // Assert len(dst) >= len(src) _ = dst[len(src)-1] } for len(src) > 0 { if len(c.buffer) == 0 { c.refill() } n := xorBytes(dst, src, c.buffer) c.buffer = c.buffer[n:] src = src[n:] dst = dst[n:] } }