From 3a8c7b1a08b2766a7f8a388eee14442281b4e295 Mon Sep 17 00:00:00 2001
From: Adam Langley <agl@chromium.org>
Date: Thu, 24 Jan 2013 16:27:14 -0500
Subject: [PATCH 19/36] tls12_digests

Fixes a bug with handling TLS 1.2 and digest functions for DSA and ECDSA
keys.
---
 ssl/s3_clnt.c  |  26 +++++++++++++--
 ssl/ssl3.h     |  11 +++++-
 ssl/ssl_cert.c |  20 -----------
 ssl/ssl_lib.c  |  35 +++++++++++--------
 ssl/ssl_locl.h |   4 +--
 ssl/t1_lib.c   | 104 ++++++++++++++++++++-------------------------------------
 6 files changed, 94 insertions(+), 106 deletions(-)

diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index c9196b3..1f3b376 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -1990,12 +1990,13 @@ int ssl3_get_certificate_request(SSL *s)
 			SSLerr(SSL_F_SSL3_GET_CERTIFICATE_REQUEST,SSL_R_DATA_LENGTH_TOO_LONG);
 			goto err;
 			}
-		if ((llen & 1) || !tls1_process_sigalgs(s, p, llen))
+		if (llen & 1)
 			{
 			ssl3_send_alert(s,SSL3_AL_FATAL,SSL_AD_DECODE_ERROR);
 			SSLerr(SSL_F_SSL3_GET_CERTIFICATE_REQUEST,SSL_R_SIGNATURE_ALGORITHMS_ERROR);
 			goto err;
 			}
+		tls1_process_sigalgs(s, p, llen);
 		p += llen;
 		}
 
@@ -3017,7 +3018,28 @@ int ssl3_send_client_verify(SSL *s)
 			{
 			long hdatalen = 0;
 			void *hdata;
-			const EVP_MD *md = s->cert->key->digest;
+			const EVP_MD *md;
+			switch (ssl_cert_type(NULL, pkey))
+				{
+			case SSL_PKEY_RSA_ENC:
+				md = s->s3->digest_rsa;
+				break;
+			case SSL_PKEY_DSA_SIGN:
+				md = s->s3->digest_dsa;
+				break;
+			case SSL_PKEY_ECC:
+				md = s->s3->digest_ecdsa;
+				break;
+			default:
+				md = NULL;
+				}
+			if (!md)
+				/* Unlike with the SignatureAlgorithm extension (sent by clients),
+				 * there are no default algorithms for the CertificateRequest message
+				 * (sent by servers). However, now that we've sent a certificate
+				 * for which we don't really know what hash to use for signing, the
+				 * best we can do is try a default algorithm. */
+				md = EVP_sha1();
 			hdatalen = BIO_get_mem_data(s->s3->handshake_buffer,
 								&hdata);
 			if (hdatalen <= 0 || !tls12_get_sigandhash(p, pkey, md))
diff --git a/ssl/ssl3.h b/ssl/ssl3.h
index 29098e4..3229995 100644
--- a/ssl/ssl3.h
+++ b/ssl/ssl3.h
@@ -550,6 +550,16 @@ typedef struct ssl3_state_st
 	 *     verified Channel ID from the client: a P256 point, (x,y), where
 	 *     each are big-endian values. */
 	unsigned char tlsext_channel_id[64];
+
+	/* These point to the digest function to use for signatures made with
+	 * each type of public key. A NULL value indicates that the default
+	 * digest should be used, which is SHA1 as of TLS 1.2.
+	 *
+	 * (These should be in the tmp member, but we have to put them here to
+	 * ensure binary compatibility with earlier OpenSSL 1.0.* releases.) */
+	const EVP_MD *digest_rsa;
+	const EVP_MD *digest_dsa;
+	const EVP_MD *digest_ecdsa;
 	} SSL3_STATE;
 
 #endif
@@ -700,4 +710,3 @@ typedef struct ssl3_state_st
 }
 #endif
 #endif
-
diff --git a/ssl/ssl_cert.c b/ssl/ssl_cert.c
index 5123a89..bc4150b 100644
--- a/ssl/ssl_cert.c
+++ b/ssl/ssl_cert.c
@@ -160,21 +160,6 @@ int SSL_get_ex_data_X509_STORE_CTX_idx(void)
 	return ssl_x509_store_ctx_idx;
 	}
 
-static void ssl_cert_set_default_md(CERT *cert)
-	{
-	/* Set digest values to defaults */
-#ifndef OPENSSL_NO_DSA
-	cert->pkeys[SSL_PKEY_DSA_SIGN].digest = EVP_sha1();
-#endif
-#ifndef OPENSSL_NO_RSA
-	cert->pkeys[SSL_PKEY_RSA_SIGN].digest = EVP_sha1();
-	cert->pkeys[SSL_PKEY_RSA_ENC].digest = EVP_sha1();
-#endif
-#ifndef OPENSSL_NO_ECDSA
-	cert->pkeys[SSL_PKEY_ECC].digest = EVP_sha1();
-#endif
-	}
-
 CERT *ssl_cert_new(void)
 	{
 	CERT *ret;
@@ -189,7 +174,6 @@ CERT *ssl_cert_new(void)
 
 	ret->key= &(ret->pkeys[SSL_PKEY_RSA_ENC]);
 	ret->references=1;
-	ssl_cert_set_default_md(ret);
 	return(ret);
 	}
 
@@ -322,10 +306,6 @@ CERT *ssl_cert_dup(CERT *cert)
 	 * chain is held inside SSL_CTX */
 
 	ret->references=1;
-	/* Set digests to defaults. NB: we don't copy existing values as they
-	 * will be set during handshake.
-	 */
-	ssl_cert_set_default_md(ret);
 
 	return(ret);
 	
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 5f8b0b0..e360550 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -2345,32 +2345,41 @@ EVP_PKEY *ssl_get_sign_pkey(SSL *s,const SSL_CIPHER *cipher, const EVP_MD **pmd)
 	{
 	unsigned long alg_a;
 	CERT *c;
-	int idx = -1;
 
 	alg_a = cipher->algorithm_auth;
 	c=s->cert;
 
+	/* SHA1 is the default for all signature algorithms up to TLS 1.2,
+	 * except RSA which is handled specially in s3_srvr.c */
+	if (pmd)
+		*pmd = EVP_sha1();
+
 	if ((alg_a & SSL_aDSS) &&
-		(c->pkeys[SSL_PKEY_DSA_SIGN].privatekey != NULL))
-		idx = SSL_PKEY_DSA_SIGN;
+	    (c->pkeys[SSL_PKEY_DSA_SIGN].privatekey != NULL))
+		{
+		if (pmd && s->s3 && s->s3->digest_dsa)
+			*pmd = s->s3->digest_dsa;
+		return c->pkeys[SSL_PKEY_DSA_SIGN].privatekey;
+		}
 	else if (alg_a & SSL_aRSA)
 		{
+		if (pmd && s->s3 && s->s3->digest_rsa)
+			*pmd = s->s3->digest_rsa;
 		if (c->pkeys[SSL_PKEY_RSA_SIGN].privatekey != NULL)
-			idx = SSL_PKEY_RSA_SIGN;
-		else if (c->pkeys[SSL_PKEY_RSA_ENC].privatekey != NULL)
-			idx = SSL_PKEY_RSA_ENC;
+			return c->pkeys[SSL_PKEY_RSA_SIGN].privatekey;
+		if (c->pkeys[SSL_PKEY_RSA_ENC].privatekey != NULL)
+			return c->pkeys[SSL_PKEY_RSA_ENC].privatekey;
 		}
 	else if ((alg_a & SSL_aECDSA) &&
 	         (c->pkeys[SSL_PKEY_ECC].privatekey != NULL))
-		idx = SSL_PKEY_ECC;
-	if (idx == -1)
 		{
-		SSLerr(SSL_F_SSL_GET_SIGN_PKEY,ERR_R_INTERNAL_ERROR);
-		return(NULL);
+		if (pmd && s->s3 && s->s3->digest_ecdsa)
+			*pmd = s->s3->digest_ecdsa;
+		return c->pkeys[SSL_PKEY_ECC].privatekey;
 		}
-	if (pmd)
-		*pmd = c->pkeys[idx].digest;
-	return c->pkeys[idx].privatekey;
+
+	SSLerr(SSL_F_SSL_GET_SIGN_PKEY,ERR_R_INTERNAL_ERROR);
+	return(NULL);
 	}
 
 void ssl_update_cache(SSL *s,int mode)
@@ -3138,26 +3160,15 @@ SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX* ctx)

 SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX* ctx)
	{
-	CERT *ocert = ssl->cert;
	if (ssl->ctx == ctx)
		return ssl->ctx;
 #ifndef OPENSSL_NO_TLSEXT
	if (ctx == NULL)
		ctx = ssl->initial_ctx;
 #endif
+	if (ssl->cert != NULL)
+		ssl_cert_free(ssl->cert);
	ssl->cert = ssl_cert_dup(ctx->cert);
-	if (ocert != NULL)
-		{
-		int i;
-		/* Copy negotiated digests from original */
-		for (i = 0; i < SSL_PKEY_NUM; i++)
-			{
-			CERT_PKEY *cpk = ocert->pkeys + i;
-			CERT_PKEY *rpk = ssl->cert->pkeys + i;
-			rpk->digest = cpk->digest;
-			}
-		ssl_cert_free(ocert);
-		}
	CRYPTO_add(&ctx->references,1,CRYPTO_LOCK_SSL_CTX);
	if (ssl->ctx != NULL)
		SSL_CTX_free(ssl->ctx); /* decrement reference count */
diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h
index 6d38f0f..3e89fcb 100644
--- a/ssl/ssl_locl.h
+++ b/ssl/ssl_locl.h
@@ -485,8 +485,6 @@ typedef struct cert_pkey_st
 	{
 	X509 *x509;
 	EVP_PKEY *privatekey;
-	/* Digest to use when signing */
-	const EVP_MD *digest;
 	} CERT_PKEY;
 
 typedef struct cert_st
@@ -1142,7 +1140,7 @@ int ssl_add_clienthello_renegotiate_ext(SSL *s, unsigned char *p, int *len,
 int ssl_parse_clienthello_renegotiate_ext(SSL *s, unsigned char *d, int len,
 					  int *al);
 long ssl_get_algorithm2(SSL *s);
-int tls1_process_sigalgs(SSL *s, const unsigned char *data, int dsize);
+void tls1_process_sigalgs(SSL *s, const unsigned char *data, int dsize);
 int tls12_get_req_sig_algs(SSL *s, unsigned char *p);
 
 int ssl_add_clienthello_use_srtp_ext(SSL *s, unsigned char *p, int *len, int maxlen);
diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c
index 26805e4..6af51a9 100644
--- a/ssl/t1_lib.c
+++ b/ssl/t1_lib.c
@@ -897,6 +897,13 @@ int ssl_parse_clienthello_tlsext(SSL *s, unsigned char **p, unsigned char *d, in
 
 	s->servername_done = 0;
 	s->tlsext_status_type = -1;
+
+	/* Reset TLS 1.2 digest functions to defaults because they don't carry
+	 * over to a renegotiation. */
+	s->s3->digest_rsa = NULL;
+	s->s3->digest_dsa = NULL;
+	s->s3->digest_ecdsa = NULL;
+
 #ifndef OPENSSL_NO_NEXTPROTONEG
 	s->s3->next_proto_neg_seen = 0;
 #endif
@@ -1198,11 +1205,7 @@ int ssl_parse_clienthello_tlsext(SSL *s, unsigned char **p, unsigned char *d, in
 				*al = SSL_AD_DECODE_ERROR;
 				return 0;
 				}
-			if (!tls1_process_sigalgs(s, data, dsize))
-				{
-				*al = SSL_AD_DECODE_ERROR;
-				return 0;
-				}
+			tls1_process_sigalgs(s, data, dsize);
 			}
 		else if (type == TLSEXT_TYPE_status_request &&
 		         s->version != DTLS1_VERSION && s->ctx->tlsext_status_cb)
@@ -2354,18 +2357,6 @@ static int tls12_find_id(int nid, tls12_lookup *table, size_t tlen)
 		}
 	return -1;
 	}
-#if 0
-static int tls12_find_nid(int id, tls12_lookup *table, size_t tlen)
-	{
-	size_t i;
-	for (i = 0; i < tlen; i++)
-		{
-		if (table[i].id == id)
-			return table[i].nid;
-		}
-	return -1;
-	}
-#endif
 
 int tls12_get_sigandhash(unsigned char *p, const EVP_PKEY *pk, const EVP_MD *md)
 	{
@@ -2384,6 +2375,8 @@ int tls12_get_sigandhash(unsigned char *p, const EVP_PKEY *pk, const EVP_MD *md)
 	return 1;
 	}
 
+/* tls12_get_sigid returns the TLS 1.2 SignatureAlgorithm value corresponding
+ * to the given public key, or -1 if not known. */
 int tls12_get_sigid(const EVP_PKEY *pk)
 	{
 	return tls12_find_id(pk->type, tls12_sig,
@@ -2403,47 +2396,49 @@ const EVP_MD *tls12_get_hash(unsigned char hash_alg)
 		return EVP_md5();
 #endif
 #ifndef OPENSSL_NO_SHA
-		case TLSEXT_hash_sha1:
+	case TLSEXT_hash_sha1:
 		return EVP_sha1();
 #endif
 #ifndef OPENSSL_NO_SHA256
-		case TLSEXT_hash_sha224:
+	case TLSEXT_hash_sha224:
 		return EVP_sha224();
 
-		case TLSEXT_hash_sha256:
+	case TLSEXT_hash_sha256:
 		return EVP_sha256();
 #endif
 #ifndef OPENSSL_NO_SHA512
-		case TLSEXT_hash_sha384:
+	case TLSEXT_hash_sha384:
 		return EVP_sha384();
 
-		case TLSEXT_hash_sha512:
+	case TLSEXT_hash_sha512:
 		return EVP_sha512();
 #endif
-		default:
+	default:
 		return NULL;
 
 		}
 	}
 
-/* Set preferred digest for each key type */
-
-int tls1_process_sigalgs(SSL *s, const unsigned char *data, int dsize)
+/* tls1_process_sigalgs processes a signature_algorithms extension and sets the
+ * digest functions accordingly for each key type.
+ *
+ * See RFC 5246, section 7.4.1.4.1.
+ *
+ * data: points to the content of the extension, not including type and length
+ *     headers.
+ * dsize: the number of bytes of |data|. Must be even.
+ */
+void tls1_process_sigalgs(SSL *s, const unsigned char *data, int dsize)
 	{
-	int i, idx;
-	const EVP_MD *md;
-	CERT *c = s->cert;
+	int i;
+	const EVP_MD *md, **digest_ptr;
 	/* Extension ignored for TLS versions below 1.2 */
 	if (TLS1_get_version(s) < TLS1_2_VERSION)
-		return 1;
-	/* Should never happen */
-	if (!c)
-		return 0;
+		return;
 
-	c->pkeys[SSL_PKEY_DSA_SIGN].digest = NULL;
-	c->pkeys[SSL_PKEY_RSA_SIGN].digest = NULL;
-	c->pkeys[SSL_PKEY_RSA_ENC].digest = NULL;
-	c->pkeys[SSL_PKEY_ECC].digest = NULL;
+	s->s3->digest_rsa = NULL;
+	s->s3->digest_dsa = NULL;
+	s->s3->digest_ecdsa = NULL;
 
 	for (i = 0; i < dsize; i += 2)
 		{
@@ -2453,56 +2448,31 @@ int tls1_process_sigalgs(SSL *s, const unsigned char *data, int dsize)
 			{
 #ifndef OPENSSL_NO_RSA
 			case TLSEXT_signature_rsa:
-			idx = SSL_PKEY_RSA_SIGN;
+			digest_ptr = &s->s3->digest_rsa;
 			break;
 #endif
 #ifndef OPENSSL_NO_DSA
 			case TLSEXT_signature_dsa:
-			idx = SSL_PKEY_DSA_SIGN;
+			digest_ptr = &s->s3->digest_dsa;
 			break;
 #endif
 #ifndef OPENSSL_NO_ECDSA
 			case TLSEXT_signature_ecdsa:
-			idx = SSL_PKEY_ECC;
+			digest_ptr = &s->s3->digest_ecdsa;
 			break;
 #endif
 			default:
 			continue;
 			}
 
-		if (c->pkeys[idx].digest == NULL)
+		if (*digest_ptr == NULL)
 			{
 			md = tls12_get_hash(hash_alg);
 			if (md)
-				{
-				c->pkeys[idx].digest = md;
-				if (idx == SSL_PKEY_RSA_SIGN)
-					c->pkeys[SSL_PKEY_RSA_ENC].digest = md;
-				}
+				*digest_ptr = md;
 			}
 
 		}
-
-
-	/* Set any remaining keys to default values. NOTE: if alg is not
-	 * supported it stays as NULL.
-	 */
-#ifndef OPENSSL_NO_DSA
-	if (!c->pkeys[SSL_PKEY_DSA_SIGN].digest)
-		c->pkeys[SSL_PKEY_DSA_SIGN].digest = EVP_sha1();
-#endif
-#ifndef OPENSSL_NO_RSA
-	if (!c->pkeys[SSL_PKEY_RSA_SIGN].digest)
-		{
-		c->pkeys[SSL_PKEY_RSA_SIGN].digest = EVP_sha1();
-		c->pkeys[SSL_PKEY_RSA_ENC].digest = EVP_sha1();
-		}
-#endif
-#ifndef OPENSSL_NO_ECDSA
-	if (!c->pkeys[SSL_PKEY_ECC].digest)
-		c->pkeys[SSL_PKEY_ECC].digest = EVP_sha1();
-#endif
-	return 1;
 	}
 
 #endif
-- 
1.8.2.1