# Copyright (c) 2012 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Spins up a trivial HTTP cgi form listener in a thread. This HTTPThread class is a utility for use with test cases that need to call back to the Autotest test case with some form value, e.g. http://localhost:nnnn/?status="Browser started!" """ import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys import threading, urllib, urlparse from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from SocketServer import BaseServer, ThreadingMixIn def _handle_http_errors(func): """Decorator function for cleaner presentation of certain exceptions.""" def wrapper(self): try: func(self) except IOError, e: if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET: # Instead of dumping a stack trace, a single line is sufficient. self.log_error(str(e)) else: raise return wrapper class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): """Implements a form handler (for POST requests only) which simply echoes the key=value parameters back in the response. If the form submission is a file upload, the file will be written to disk with the name contained in the 'filename' field. """ SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({ '.webm': 'video/webm', }) # Override the default logging methods to use the logging module directly. def log_error(self, format, *args): logging.warning("(httpd error) %s - - [%s] %s\n" % (self.address_string(), self.log_date_time_string(), format%args)) def log_message(self, format, *args): logging.debug("%s - - [%s] %s\n" % (self.address_string(), self.log_date_time_string(), format%args)) @_handle_http_errors def do_POST(self): form = cgi.FieldStorage( fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}) # You'd think form.keys() would just return [], like it does for empty # python dicts; you'd be wrong. It raises TypeError if called when it # has no keys. if form: for field in form.keys(): field_item = form[field] self.server._form_entries[field] = field_item.value path = urlparse.urlparse(self.path)[2] if path in self.server._url_handlers: self.server._url_handlers[path](self, form) else: # Echo back information about what was posted in the form. self.write_post_response(form) self._fire_event() def write_post_response(self, form): """Called to fill out the response to an HTTP POST. Override this class to give custom responses. """ # Send response boilerplate self.send_response(200) self.end_headers() self.wfile.write('Hello from Autotest!\nClient: %s\n' % str(self.client_address)) self.wfile.write('Request for path: %s\n' % self.path) self.wfile.write('Got form data:\n') # See the note in do_POST about form.keys(). if form: for field in form.keys(): field_item = form[field] if field_item.filename: # The field contains an uploaded file upload = field_item.file.read() self.wfile.write('\tUploaded %s (%d bytes)<br>' % (field, len(upload))) # Write submitted file to specified filename. file(field_item.filename, 'w').write(upload) del upload else: self.wfile.write('\t%s=%s<br>' % (field, form[field].value)) def translate_path(self, path): """Override SimpleHTTPRequestHandler's translate_path to serve from arbitrary docroot """ # abandon query parameters path = urlparse.urlparse(path)[2] path = posixpath.normpath(urllib.unquote(path)) words = path.split('/') words = filter(None, words) path = self.server.docroot for word in words: drive, word = os.path.splitdrive(word) head, word = os.path.split(word) if word in (os.curdir, os.pardir): continue path = os.path.join(path, word) logging.debug('Translated path: %s', path) return path def _fire_event(self): wait_urls = self.server._wait_urls if self.path in wait_urls: _, e = wait_urls[self.path] e.set() del wait_urls[self.path] else: logging.debug('URL %s not in watch list' % self.path) @_handle_http_errors def do_GET(self): form = cgi.FieldStorage( fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'GET'}) split_url = urlparse.urlsplit(self.path) path = split_url[2] # Strip off query parameters to ensure that the url path # matches any registered events. self.path = path args = urlparse.parse_qs(split_url[3]) if path in self.server._url_handlers: self.server._url_handlers[path](self, args) else: SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self) self._fire_event() @_handle_http_errors def do_HEAD(self): SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self) class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): def __init__(self, server_address, HandlerClass): HTTPServer.__init__(self, server_address, HandlerClass) class HTTPListener(object): # Point default docroot to a non-existent directory (instead of None) to # avoid exceptions when page content is served through handlers only. def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}): self._server = ThreadedHTTPServer(('', port), FormHandler) self.config_server(self._server, docroot, wait_urls, url_handlers) def config_server(self, server, docroot, wait_urls, url_handlers): # Stuff some convenient data fields into the server object. self._server.docroot = docroot self._server._wait_urls = wait_urls self._server._url_handlers = url_handlers self._server._form_entries = {} self._server_thread = threading.Thread( target=self._server.serve_forever) def add_wait_url(self, url='/', matchParams={}): e = threading.Event() self._server._wait_urls[url] = (matchParams, e) return e def add_url_handler(self, url, handler_func): self._server._url_handlers[url] = handler_func def clear_form_entries(self): self._server._form_entries = {} def get_form_entries(self): """Returns a dictionary of all field=values recieved by the server. """ return self._server._form_entries def run(self): logging.debug('http server on %s:%d' % (self._server.server_name, self._server.server_port)) self._server_thread.start() def stop(self): self._server.shutdown() self._server.socket.close() self._server_thread.join() class SecureHTTPServer(ThreadingMixIn, HTTPServer): def __init__(self, server_address, HandlerClass, cert_path, key_path): _socket = socket.socket(self.address_family, self.socket_type) self.socket = ssl.wrap_socket(_socket, server_side=True, ssl_version=ssl.PROTOCOL_TLSv1, certfile=cert_path, keyfile=key_path) BaseServer.__init__(self, server_address, HandlerClass) self.server_bind() self.server_activate() class SecureHTTPRequestHandler(FormHandler): def setup(self): self.connection = self.request self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize) self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize) # Override the default logging methods to use the logging module directly. def log_error(self, format, *args): logging.warning("(httpd error) %s - - [%s] %s\n" % (self.address_string(), self.log_date_time_string(), format%args)) def log_message(self, format, *args): logging.debug("%s - - [%s] %s\n" % (self.address_string(), self.log_date_time_string(), format%args)) class SecureHTTPListener(HTTPListener): def __init__(self, cert_path='/etc/login_trust_root.pem', key_path='/etc/mock_server.key', port=0, docroot='/_', wait_urls={}, url_handlers={}): self._server = SecureHTTPServer(('', port), SecureHTTPRequestHandler, cert_path, key_path) self.config_server(self._server, docroot, wait_urls, url_handlers) def getsockname(self): return self._server.socket.getsockname()