diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py index d52e596..79ad145 100755 --- a/third_party/tlslite/tlslite/constants.py +++ b/third_party/tlslite/tlslite/constants.py @@ -31,6 +31,7 @@ class HandshakeType: client_key_exchange = 16 finished = 20 next_protocol = 67 + encrypted_extensions = 203 class ContentType: change_cipher_spec = 20 @@ -45,6 +46,7 @@ class ExtensionType: # RFC 6066 / 4366 cert_type = 9 # RFC 6091 tack = 0xF300 supports_npn = 13172 + channel_id = 30032 class NameType: host_name = 0 diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py index 7ef4e3f..246082e 100755 --- a/third_party/tlslite/tlslite/messages.py +++ b/third_party/tlslite/tlslite/messages.py @@ -112,6 +112,7 @@ class ClientHello(HandshakeMsg): self.tack = False self.supports_npn = False self.server_name = bytearray(0) + self.channel_id = False def create(self, version, random, session_id, cipher_suites, certificate_types=None, srpUsername=None, @@ -179,6 +180,8 @@ class ClientHello(HandshakeMsg): if name_type == NameType.host_name: self.server_name = hostNameBytes break + elif extType == ExtensionType.channel_id: + self.channel_id = True else: _ = p.getFixBytes(extLength) index2 = p.index @@ -243,6 +246,7 @@ class ServerHello(HandshakeMsg): self.tackExt = None self.next_protos_advertised = None self.next_protos = None + self.channel_id = False def create(self, version, random, session_id, cipher_suite, certificate_type, tackExt, next_protos_advertised): @@ -329,6 +333,9 @@ class ServerHello(HandshakeMsg): w2.add(ExtensionType.supports_npn, 2) w2.add(len(encoded_next_protos_advertised), 2) w2.addFixSeq(encoded_next_protos_advertised, 1) + if self.channel_id: + w2.add(ExtensionType.channel_id, 2) + w2.add(0, 2) if len(w2.bytes): w.add(len(w2.bytes), 2) w.bytes += w2.bytes @@ -656,6 +663,28 @@ class Finished(HandshakeMsg): w.addFixSeq(self.verify_data, 1) return self.postWrite(w) +class EncryptedExtensions(HandshakeMsg): + def __init__(self): + self.channel_id_key = None + self.channel_id_proof = None + + def parse(self, p): + p.startLengthCheck(3) + soFar = 0 + while soFar != p.lengthCheck: + extType = p.get(2) + extLength = p.get(2) + if extType == ExtensionType.channel_id: + if extLength != 32*4: + raise SyntaxError() + self.channel_id_key = p.getFixBytes(64) + self.channel_id_proof = p.getFixBytes(64) + else: + p.getFixBytes(extLength) + soFar += 4 + extLength + p.stopLengthCheck() + return self + class ApplicationData(object): def __init__(self): self.contentType = ContentType.application_data diff --git a/third_party/tlslite/tlslite/tlsconnection.py b/third_party/tlslite/tlslite/tlsconnection.py index 8415592..e7c5140 100755 --- a/third_party/tlslite/tlslite/tlsconnection.py +++ b/third_party/tlslite/tlslite/tlsconnection.py @@ -1155,6 +1155,7 @@ class TLSConnection(TLSRecordLayer): serverHello.create(self.version, getRandomBytes(32), sessionID, \ cipherSuite, CertificateType.x509, tackExt, nextProtos) + serverHello.channel_id = clientHello.channel_id # Perform the SRP key exchange clientCertChain = None @@ -1191,7 +1192,7 @@ class TLSConnection(TLSRecordLayer): for result in self._serverFinished(premasterSecret, clientHello.random, serverHello.random, cipherSuite, settings.cipherImplementations, - nextProtos): + nextProtos, clientHello.channel_id): if result in (0,1): yield result else: break masterSecret = result @@ -1609,7 +1610,8 @@ class TLSConnection(TLSRecordLayer): def _serverFinished(self, premasterSecret, clientRandom, serverRandom, - cipherSuite, cipherImplementations, nextProtos): + cipherSuite, cipherImplementations, nextProtos, + doingChannelID): masterSecret = calcMasterSecret(self.version, premasterSecret, clientRandom, serverRandom) @@ -1620,7 +1622,8 @@ class TLSConnection(TLSRecordLayer): #Exchange ChangeCipherSpec and Finished messages for result in self._getFinished(masterSecret, - expect_next_protocol=nextProtos is not None): + expect_next_protocol=nextProtos is not None, + expect_channel_id=doingChannelID): yield result for result in self._sendFinished(masterSecret): @@ -1657,7 +1660,8 @@ class TLSConnection(TLSRecordLayer): for result in self._sendMsg(finished): yield result - def _getFinished(self, masterSecret, expect_next_protocol=False, nextProto=None): + def _getFinished(self, masterSecret, expect_next_protocol=False, nextProto=None, + expect_channel_id=False): #Get and check ChangeCipherSpec for result in self._getMsg(ContentType.change_cipher_spec): if result in (0,1): @@ -1690,6 +1694,20 @@ class TLSConnection(TLSRecordLayer): if nextProto: self.next_proto = nextProto + #Server Finish - Are we waiting for a EncryptedExtensions? + if expect_channel_id: + for result in self._getMsg(ContentType.handshake, HandshakeType.encrypted_extensions): + if result in (0,1): + yield result + if result is None: + for result in self._sendError(AlertDescription.unexpected_message, + "Didn't get EncryptedExtensions message"): + yield result + encrypted_extensions = result + self.channel_id = result.channel_id_key + else: + self.channel_id = None + #Calculate verification data verifyData = self._calcFinished(masterSecret, False) diff --git a/third_party/tlslite/tlslite/tlsrecordlayer.py b/third_party/tlslite/tlslite/tlsrecordlayer.py index b0833fe..ff08cbf 100755 --- a/third_party/tlslite/tlslite/tlsrecordlayer.py +++ b/third_party/tlslite/tlslite/tlsrecordlayer.py @@ -800,6 +800,8 @@ class TLSRecordLayer(object): yield Finished(self.version).parse(p) elif subType == HandshakeType.next_protocol: yield NextProtocol().parse(p) + elif subType == HandshakeType.encrypted_extensions: + yield EncryptedExtensions().parse(p) else: raise AssertionError()