# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Spins up a trivial HTTP cgi form listener in a thread.

   This HTTPThread class is a utility for use with test cases that
   need to call back to the Autotest test case with some form value, e.g.
   http://localhost:nnnn/?status="Browser started!"
"""

import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys
import threading, urllib, urlparse
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from SocketServer import BaseServer, ThreadingMixIn


def _handle_http_errors(func):
    """Decorator function for cleaner presentation of certain exceptions."""
    def wrapper(self):
        try:
            func(self)
        except IOError, e:
            if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET:
                # Instead of dumping a stack trace, a single line is sufficient.
                self.log_error(str(e))
            else:
                raise

    return wrapper


class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
    """Implements a form handler (for POST requests only) which simply
    echoes the key=value parameters back in the response.

    If the form submission is a file upload, the file will be written
    to disk with the name contained in the 'filename' field.
    """

    SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({
        '.webm': 'video/webm',
    })

    # Override the default logging methods to use the logging module directly.
    def log_error(self, format, *args):
        logging.warning("(httpd error) %s - - [%s] %s\n" %
                     (self.address_string(), self.log_date_time_string(),
                      format%args))

    def log_message(self, format, *args):
        logging.debug("%s - - [%s] %s\n" %
                     (self.address_string(), self.log_date_time_string(),
                      format%args))

    @_handle_http_errors
    def do_POST(self):
        form = cgi.FieldStorage(
            fp=self.rfile,
            headers=self.headers,
            environ={'REQUEST_METHOD': 'POST',
                     'CONTENT_TYPE': self.headers['Content-Type']})
        # You'd think form.keys() would just return [], like it does for empty
        # python dicts; you'd be wrong. It raises TypeError if called when it
        # has no keys.
        if form:
            for field in form.keys():
                field_item = form[field]
                self.server._form_entries[field] = field_item.value
        path = urlparse.urlparse(self.path)[2]
        if path in self.server._url_handlers:
            self.server._url_handlers[path](self, form)
        else:
            # Echo back information about what was posted in the form.
            self.write_post_response(form)
        self._fire_event()


    def write_post_response(self, form):
        """Called to fill out the response to an HTTP POST.

        Override this class to give custom responses.
        """
        # Send response boilerplate
        self.send_response(200)
        self.end_headers()
        self.wfile.write('Hello from Autotest!\nClient: %s\n' %
                         str(self.client_address))
        self.wfile.write('Request for path: %s\n' % self.path)
        self.wfile.write('Got form data:\n')

        # See the note in do_POST about form.keys().
        if form:
            for field in form.keys():
                field_item = form[field]
                if field_item.filename:
                    # The field contains an uploaded file
                    upload = field_item.file.read()
                    self.wfile.write('\tUploaded %s (%d bytes)<br>' %
                                     (field, len(upload)))
                    # Write submitted file to specified filename.
                    file(field_item.filename, 'w').write(upload)
                    del upload
                else:
                    self.wfile.write('\t%s=%s<br>' % (field, form[field].value))


    def translate_path(self, path):
        """Override SimpleHTTPRequestHandler's translate_path to serve
        from arbitrary docroot
        """
        # abandon query parameters
        path = urlparse.urlparse(path)[2]
        path = posixpath.normpath(urllib.unquote(path))
        words = path.split('/')
        words = filter(None, words)
        path = self.server.docroot
        for word in words:
            drive, word = os.path.splitdrive(word)
            head, word = os.path.split(word)
            if word in (os.curdir, os.pardir): continue
            path = os.path.join(path, word)
        logging.debug('Translated path: %s', path)
        return path


    def _fire_event(self):
        wait_urls = self.server._wait_urls
        if self.path in wait_urls:
            _, e = wait_urls[self.path]
            e.set()
            del wait_urls[self.path]
        else:
            logging.debug('URL %s not in watch list' % self.path)


    @_handle_http_errors
    def do_GET(self):
        form = cgi.FieldStorage(
            fp=self.rfile,
            headers=self.headers,
            environ={'REQUEST_METHOD': 'GET'})
        split_url = urlparse.urlsplit(self.path)
        path = split_url[2]
        # Strip off query parameters to ensure that the url path
        # matches any registered events.
        self.path = path
        args = urlparse.parse_qs(split_url[3])
        if path in self.server._url_handlers:
            self.server._url_handlers[path](self, args)
        else:
            SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
        self._fire_event()


    @_handle_http_errors
    def do_HEAD(self):
        SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self)


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    def __init__(self, server_address, HandlerClass):
        HTTPServer.__init__(self, server_address, HandlerClass)


class HTTPListener(object):
    # Point default docroot to a non-existent directory (instead of None) to
    # avoid exceptions when page content is served through handlers only.
    def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}):
        self._server = ThreadedHTTPServer(('', port), FormHandler)
        self.config_server(self._server, docroot, wait_urls, url_handlers)

    def config_server(self, server, docroot, wait_urls, url_handlers):
        # Stuff some convenient data fields into the server object.
        self._server.docroot = docroot
        self._server._wait_urls = wait_urls
        self._server._url_handlers = url_handlers
        self._server._form_entries = {}
        self._server_thread = threading.Thread(
            target=self._server.serve_forever)


    def add_wait_url(self, url='/', matchParams={}):
        e = threading.Event()
        self._server._wait_urls[url] = (matchParams, e)
        return e


    def add_url_handler(self, url, handler_func):
        self._server._url_handlers[url] = handler_func


    def clear_form_entries(self):
        self._server._form_entries = {}


    def get_form_entries(self):
        """Returns a dictionary of all field=values recieved by the server.
        """
        return self._server._form_entries


    def run(self):
        logging.debug('http server on %s:%d' %
                      (self._server.server_name, self._server.server_port))
        self._server_thread.start()


    def stop(self):
        self._server.shutdown()
        self._server.socket.close()
        self._server_thread.join()


class SecureHTTPServer(ThreadingMixIn, HTTPServer):
    def __init__(self, server_address, HandlerClass, cert_path, key_path):
        _socket = socket.socket(self.address_family, self.socket_type)
        self.socket = ssl.wrap_socket(_socket,
                                      server_side=True,
                                      ssl_version=ssl.PROTOCOL_TLSv1,
                                      certfile=cert_path,
                                      keyfile=key_path)
        BaseServer.__init__(self, server_address, HandlerClass)
        self.server_bind()
        self.server_activate()


class SecureHTTPRequestHandler(FormHandler):
    def setup(self):
        self.connection = self.request
        self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize)
        self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize)

    # Override the default logging methods to use the logging module directly.
    def log_error(self, format, *args):
        logging.warning("(httpd error) %s - - [%s] %s\n" %
                     (self.address_string(), self.log_date_time_string(),
                      format%args))

    def log_message(self, format, *args):
        logging.debug("%s - - [%s] %s\n" %
                     (self.address_string(), self.log_date_time_string(),
                      format%args))


class SecureHTTPListener(HTTPListener):
    def __init__(self,
                 cert_path='/etc/login_trust_root.pem',
                 key_path='/etc/mock_server.key',
                 port=0,
                 docroot='/_',
                 wait_urls={},
                 url_handlers={}):
        self._server = SecureHTTPServer(('', port),
                                        SecureHTTPRequestHandler,
                                        cert_path,
                                        key_path)
        self.config_server(self._server, docroot, wait_urls, url_handlers)


    def getsockname(self):
        return self._server.socket.getsockname()