# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Routines to generate root and server certificates. Certificate Naming Conventions: ca_cert: crypto.X509 for the certificate authority (w/ both the pub & priv keys) cert: a crypto.X509 certificate (w/ just the pub key) cert_str: a certificate string (w/ just the pub cert) key: a private crypto.PKey (from ca or pem) ca_cert_str: a certificae authority string (w/ both the pub & priv certs) """ import logging import os import platform import socket import subprocess import time openssl_import_error = None Error = None SSL_METHOD = None SysCallError = None VERIFY_PEER = None ZeroReturnError = None FILETYPE_PEM = None try: from OpenSSL import crypto, SSL Error = SSL.Error SSL_METHOD = SSL.SSLv23_METHOD SysCallError = SSL.SysCallError VERIFY_PEER = SSL.VERIFY_PEER ZeroReturnError = SSL.ZeroReturnError FILETYPE_PEM = crypto.FILETYPE_PEM except ImportError, e: openssl_import_error = e def get_ssl_context(method=SSL_METHOD): # One of: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD if openssl_import_error: raise openssl_import_error # pylint: disable=raising-bad-type return SSL.Context(method) class WrappedConnection(object): def __init__(self, obj): self._wrapped_obj = obj def __getattr__(self, attr): if attr in self.__dict__: return getattr(self, attr) return getattr(self._wrapped_obj, attr) def recv(self, buflen=1024, flags=0): try: return self._wrapped_obj.recv(buflen, flags) except SSL.SysCallError, e: if e.args[1] == 'Unexpected EOF': return '' raise except SSL.ZeroReturnError: return '' def get_ssl_connection(context, connection): return WrappedConnection(SSL.Connection(context, connection)) def load_privatekey(key, filetype=FILETYPE_PEM): """Loads obj private key object from string.""" return crypto.load_privatekey(filetype, key) def load_cert(cert_str, filetype=FILETYPE_PEM): """Loads obj cert object from string.""" return crypto.load_certificate(filetype, cert_str) def _dump_privatekey(key, filetype=FILETYPE_PEM): """Dumps obj private key object to string.""" return crypto.dump_privatekey(filetype, key) def _dump_cert(cert, filetype=FILETYPE_PEM): """Dumps obj cert object to string.""" return crypto.dump_certificate(filetype, cert) def generate_dummy_ca_cert(subject='_WebPageReplayCert'): """Generates dummy certificate authority. Args: subject: a string representing the desired root cert issuer Returns: A tuple of the public key and the private key strings for the root certificate """ if openssl_import_error: raise openssl_import_error # pylint: disable=raising-bad-type key = crypto.PKey() key.generate_key(crypto.TYPE_RSA, 1024) ca_cert = crypto.X509() ca_cert.set_serial_number(int(time.time()*10000)) ca_cert.set_version(2) ca_cert.get_subject().CN = subject ca_cert.get_subject().O = subject ca_cert.gmtime_adj_notBefore(-60 * 60 * 24 * 365 * 2) ca_cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 2) ca_cert.set_issuer(ca_cert.get_subject()) ca_cert.set_pubkey(key) ca_cert.add_extensions([ crypto.X509Extension('basicConstraints', True, 'CA:TRUE'), crypto.X509Extension('subjectAltName', False, 'DNS:' + subject), crypto.X509Extension('nsCertType', True, 'sslCA'), crypto.X509Extension('extendedKeyUsage', True, ('serverAuth,clientAuth,emailProtection,' 'timeStamping,msCodeInd,msCodeCom,msCTLSign,' 'msSGC,msEFS,nsSGC')), crypto.X509Extension('keyUsage', False, 'keyCertSign, cRLSign'), crypto.X509Extension('subjectKeyIdentifier', False, 'hash', subject=ca_cert), ]) ca_cert.sign(key, 'sha256') key_str = _dump_privatekey(key) ca_cert_str = _dump_cert(ca_cert) return ca_cert_str, key_str def get_host_cert(host, port=443): """Contacts the host and returns its certificate.""" host_certs = [] def verify_cb(conn, cert, errnum, depth, ok): host_certs.append(cert) # Return True to indicates that the certificate was ok. return True context = SSL.Context(SSL.SSLv23_METHOD) context.set_verify(SSL.VERIFY_PEER, verify_cb) # Demand a certificate s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection = SSL.Connection(context, s) try: connection.connect((host, port)) connection.send('') except SSL.SysCallError: pass except socket.gaierror: logging.debug('Host name is not valid') finally: connection.shutdown() connection.close() if not host_certs: logging.warning('Unable to get host certificate from %s:%s', host, port) return '' return _dump_cert(host_certs[-1]) def write_dummy_ca_cert(ca_cert_str, key_str, cert_path): """Writes four certificate files. For example, if cert_path is "mycert.pem": mycert.pem - CA plus private key mycert-cert.pem - CA in PEM format mycert-cert.cer - CA for Android mycert-cert.p12 - CA in PKCS12 format for Windows devices Args: cert_path: path string such as "mycert.pem" ca_cert_str: certificate string key_str: private key string """ dirname = os.path.dirname(cert_path) if dirname and not os.path.exists(dirname): os.makedirs(dirname) root_path = os.path.splitext(cert_path)[0] ca_cert_path = root_path + '-cert.pem' android_cer_path = root_path + '-cert.cer' windows_p12_path = root_path + '-cert.p12' # Dump the CA plus private key with open(cert_path, 'w') as f: f.write(key_str) f.write(ca_cert_str) # Dump the certificate in PEM format with open(ca_cert_path, 'w') as f: f.write(ca_cert_str) # Create a .cer file with the same contents for Android with open(android_cer_path, 'w') as f: f.write(ca_cert_str) ca_cert = load_cert(ca_cert_str) key = load_privatekey(key_str) # Dump the certificate in PKCS12 format for Windows devices with open(windows_p12_path, 'w') as f: p12 = crypto.PKCS12() p12.set_certificate(ca_cert) p12.set_privatekey(key) f.write(p12.export()) def generate_cert(root_ca_cert_str, server_cert_str, server_host): """Generates a cert_str with the sni field in server_cert_str signed by the root_ca_cert_str. Args: root_ca_cert_str: PEM formatted string representing the root cert server_cert_str: PEM formatted string representing cert server_host: host name to use if there is no server_cert_str Returns: a PEM formatted certificate string """ EXTENSION_WHITELIST = set(['subjectAltName']) if openssl_import_error: raise openssl_import_error # pylint: disable=raising-bad-type common_name = server_host reused_extensions = [] if server_cert_str: original_cert = load_cert(server_cert_str) common_name = original_cert.get_subject().commonName for i in xrange(original_cert.get_extension_count()): original_cert_extension = original_cert.get_extension(i) if original_cert_extension.get_short_name() in EXTENSION_WHITELIST: reused_extensions.append(original_cert_extension) ca_cert = load_cert(root_ca_cert_str) ca_key = load_privatekey(root_ca_cert_str) cert = crypto.X509() cert.get_subject().CN = common_name cert.gmtime_adj_notBefore(-60 * 60) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca_cert.get_subject()) cert.set_serial_number(int(time.time()*10000)) cert.set_pubkey(ca_key) cert.add_extensions(reused_extensions) cert.sign(ca_key, 'sha256') return _dump_cert(cert) def install_cert_in_nssdb(home_directory_path, certificate_path): """Installs a certificate into the ~/.pki/nssdb database. Args: home_directory_path: Path of the home directory where to install certificate_path: Path of a CA in PEM format """ assert os.path.isdir(home_directory_path) assert platform.system() == 'Linux', \ 'SSL certification authority has only been tested for linux.' if (os.path.abspath(home_directory_path) == os.path.abspath(os.environ['HOME'])): raise Exception('Modifying $HOME/.pki/nssdb compromises your machine.') cert_database_path = os.path.join(home_directory_path, '.pki', 'nssdb') def certutil(args): cmd = ['certutil', '--empty-password', '-d', 'sql:' + cert_database_path] cmd.extend(args) logging.info(subprocess.list2cmdline(cmd)) subprocess.check_call(cmd) if not os.path.isdir(cert_database_path): os.makedirs(cert_database_path) certutil(['-N']) certutil(['-A', '-t', 'PC,,', '-n', certificate_path, '-i', certificate_path])