# Copyright 2014 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test routines to generate dummy certificates."""

import BaseHTTPServer
import os
import shutil
import ssl
import tempfile
import threading
import unittest

import certutils


class Server(BaseHTTPServer.HTTPServer):

  def __init__(self, https_root_ca_cert_path):
    BaseHTTPServer.HTTPServer.__init__(
        self, ('localhost', 0), BaseHTTPServer.BaseHTTPRequestHandler)
    self.socket = ssl.wrap_socket(
        self.socket, certfile=https_root_ca_cert_path, server_side=True,
        do_handshake_on_connect=False)

  def __enter__(self):
    thread = threading.Thread(target=self.serve_forever)
    thread.daemon = True
    thread.start()
    return self

  def cleanup(self):
    try:
      self.shutdown()
    except KeyboardInterrupt:
      pass

  def __exit__(self, type_, value_, traceback_):
    self.cleanup()


class CertutilsTest(unittest.TestCase):

  def _check_cert_file(self, cert_file_path, cert_str, key_str=None):
    cert_load = open(cert_file_path, 'r').read()
    if key_str:
      expected_cert = key_str + cert_str
    else:
      expected_cert = cert_str
    self.assertEqual(expected_cert, cert_load)

  def setUp(self):
    self._temp_dir = tempfile.mkdtemp(prefix='certutils_', dir='/tmp')

  def tearDown(self):
    if self._temp_dir:
      shutil.rmtree(self._temp_dir)

  def test_generate_dummy_ca_cert(self):
    subject = 'testSubject'
    c, _ = certutils.generate_dummy_ca_cert(subject)
    c = certutils.load_cert(c)
    self.assertEqual(c.get_subject().commonName, subject)

  def test_get_host_cert(self):
    ca_cert_path = os.path.join(self._temp_dir, 'rootCA.pem')
    issuer = 'testCA'
    certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(issuer),
                                  cert_path=ca_cert_path)

    with Server(ca_cert_path) as server:
      cert_str = certutils.get_host_cert('localhost', server.server_port)
      cert = certutils.load_cert(cert_str)
      self.assertEqual(issuer, cert.get_subject().commonName)

  def test_get_host_cert_gives_empty_for_bad_host(self):
    cert_str = certutils.get_host_cert('not_a_valid_host_name_2472341234234234')
    self.assertEqual('', cert_str)

  def test_write_dummy_ca_cert(self):
    base_path = os.path.join(self._temp_dir, 'testCA')
    ca_cert_path = base_path + '.pem'
    cert_path = base_path + '-cert.pem'
    ca_cert_android = base_path + '-cert.cer'
    ca_cert_windows = base_path + '-cert.p12'

    self.assertFalse(os.path.exists(ca_cert_path))
    self.assertFalse(os.path.exists(cert_path))
    self.assertFalse(os.path.exists(ca_cert_android))
    self.assertFalse(os.path.exists(ca_cert_windows))
    c, k = certutils.generate_dummy_ca_cert()
    certutils.write_dummy_ca_cert(c, k, ca_cert_path)

    self._check_cert_file(ca_cert_path, c, k)
    self._check_cert_file(cert_path, c)
    self._check_cert_file(ca_cert_android, c)
    self.assertTrue(os.path.exists(ca_cert_windows))

  def test_generate_cert(self):
    ca_cert_path = os.path.join(self._temp_dir, 'testCA.pem')
    issuer = 'testIssuer'
    certutils.write_dummy_ca_cert(
        *certutils.generate_dummy_ca_cert(issuer), cert_path=ca_cert_path)

    with open(ca_cert_path, 'r') as root_file:
      root_string = root_file.read()
    subject = 'testSubject'
    cert_string = certutils.generate_cert(
        root_string, '', subject)
    cert = certutils.load_cert(cert_string)
    self.assertEqual(issuer, cert.get_issuer().commonName)
    self.assertEqual(subject, cert.get_subject().commonName)

    with open(ca_cert_path, 'r') as ca_cert_file:
      ca_cert_str = ca_cert_file.read()
    cert_string = certutils.generate_cert(ca_cert_str, cert_string,
                                          'host')
    cert = certutils.load_cert(cert_string)
    self.assertEqual(issuer, cert.get_issuer().commonName)
    self.assertEqual(subject, cert.get_subject().commonName)


if __name__ == '__main__':
  unittest.main()