# Copyright 2012, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from mod_pywebsocket import common
from mod_pywebsocket import util
from mod_pywebsocket.http_header_util import quote_if_necessary
_available_processors = {}
class ExtensionProcessorInterface(object):
def name(self):
return None
def get_extension_response(self):
return None
def setup_stream_options(self, stream_options):
pass
class DeflateStreamExtensionProcessor(ExtensionProcessorInterface):
"""WebSocket DEFLATE stream extension processor."""
def __init__(self, request):
self._logger = util.get_class_logger(self)
self._request = request
def name(self):
return common.DEFLATE_STREAM_EXTENSION
def get_extension_response(self):
if len(self._request.get_parameter_names()) != 0:
return None
self._logger.debug(
'Enable %s extension', common.DEFLATE_STREAM_EXTENSION)
return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION)
def setup_stream_options(self, stream_options):
stream_options.deflate_stream = True
_available_processors[common.DEFLATE_STREAM_EXTENSION] = (
DeflateStreamExtensionProcessor)
def _log_compression_ratio(logger, original_bytes, total_original_bytes,
filtered_bytes, total_filtered_bytes):
# Print inf when ratio is not available.
ratio = float('inf')
average_ratio = float('inf')
if original_bytes != 0:
ratio = float(filtered_bytes) / original_bytes
if total_original_bytes != 0:
average_ratio = (
float(total_filtered_bytes) / total_original_bytes)
logger.debug('Outgoing compress ratio: %f (average: %f)' %
(ratio, average_ratio))
def _log_decompression_ratio(logger, received_bytes, total_received_bytes,
filtered_bytes, total_filtered_bytes):
# Print inf when ratio is not available.
ratio = float('inf')
average_ratio = float('inf')
if received_bytes != 0:
ratio = float(received_bytes) / filtered_bytes
if total_filtered_bytes != 0:
average_ratio = (
float(total_received_bytes) / total_filtered_bytes)
logger.debug('Incoming compress ratio: %f (average: %f)' %
(ratio, average_ratio))
class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
"""WebSocket Per-frame DEFLATE extension processor."""
_WINDOW_BITS_PARAM = 'max_window_bits'
_NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover'
def __init__(self, request):
self._logger = util.get_class_logger(self)
self._request = request
self._response_window_bits = None
self._response_no_context_takeover = False
# Counters for statistics.
# Total number of outgoing bytes supplied to this filter.
self._total_outgoing_payload_bytes = 0
# Total number of bytes sent to the network after applying this filter.
self._total_filtered_outgoing_payload_bytes = 0
# Total number of bytes received from the network.
self._total_incoming_payload_bytes = 0
# Total number of incoming bytes obtained after applying this filter.
self._total_filtered_incoming_payload_bytes = 0
def name(self):
return common.DEFLATE_FRAME_EXTENSION
def get_extension_response(self):
# Any unknown parameter will be just ignored.
window_bits = self._request.get_parameter_value(
self._WINDOW_BITS_PARAM)
no_context_takeover = self._request.has_parameter(
self._NO_CONTEXT_TAKEOVER_PARAM)
if (no_context_takeover and
self._request.get_parameter_value(
self._NO_CONTEXT_TAKEOVER_PARAM) is not None):
return None
if window_bits is not None:
try:
window_bits = int(window_bits)
except ValueError, e:
return None
if window_bits < 8 or window_bits > 15:
return None
self._deflater = util._RFC1979Deflater(
window_bits, no_context_takeover)
self._inflater = util._RFC1979Inflater()
self._compress_outgoing = True
response = common.ExtensionParameter(self._request.name())
if self._response_window_bits is not None:
response.add_parameter(
self._WINDOW_BITS_PARAM, str(self._response_window_bits))
if self._response_no_context_takeover:
response.add_parameter(
self._NO_CONTEXT_TAKEOVER_PARAM, None)
self._logger.debug(
'Enable %s extension ('
'request: window_bits=%s; no_context_takeover=%r, '
'response: window_wbits=%s; no_context_takeover=%r)' %
(self._request.name(),
window_bits,
no_context_takeover,
self._response_window_bits,
self._response_no_context_takeover))
return response
def setup_stream_options(self, stream_options):
class _OutgoingFilter(object):
def __init__(self, parent):
self._parent = parent
def filter(self, frame):
self._parent._outgoing_filter(frame)
class _IncomingFilter(object):
def __init__(self, parent):
self._parent = parent
def filter(self, frame):
self._parent._incoming_filter(frame)
stream_options.outgoing_frame_filters.append(
_OutgoingFilter(self))
stream_options.incoming_frame_filters.insert(
0, _IncomingFilter(self))
def set_response_window_bits(self, value):
self._response_window_bits = value
def set_response_no_context_takeover(self, value):
self._response_no_context_takeover = value
def enable_outgoing_compression(self):
self._compress_outgoing = True
def disable_outgoing_compression(self):
self._compress_outgoing = False
def _outgoing_filter(self, frame):
"""Transform outgoing frames. This method is called only by
an _OutgoingFilter instance.
"""
original_payload_size = len(frame.payload)
self._total_outgoing_payload_bytes += original_payload_size
if (not self._compress_outgoing or
common.is_control_opcode(frame.opcode)):
self._total_filtered_outgoing_payload_bytes += (
original_payload_size)
return
frame.payload = self._deflater.filter(frame.payload)
frame.rsv1 = 1
filtered_payload_size = len(frame.payload)
self._total_filtered_outgoing_payload_bytes += filtered_payload_size
_log_compression_ratio(self._logger, original_payload_size,
self._total_outgoing_payload_bytes,
filtered_payload_size,
self._total_filtered_outgoing_payload_bytes)
def _incoming_filter(self, frame):
"""Transform incoming frames. This method is called only by
an _IncomingFilter instance.
"""
received_payload_size = len(frame.payload)
self._total_incoming_payload_bytes += received_payload_size
if frame.rsv1 != 1 or common.is_control_opcode(frame.opcode):
self._total_filtered_incoming_payload_bytes += (
received_payload_size)
return
frame.payload = self._inflater.filter(frame.payload)
frame.rsv1 = 0
filtered_payload_size = len(frame.payload)
self._total_filtered_incoming_payload_bytes += filtered_payload_size
_log_decompression_ratio(self._logger, received_payload_size,
self._total_incoming_payload_bytes,
filtered_payload_size,
self._total_filtered_incoming_payload_bytes)
_available_processors[common.DEFLATE_FRAME_EXTENSION] = (
DeflateFrameExtensionProcessor)
# Adding vendor-prefixed deflate-frame extension.
# TODO(bashi): Remove this after WebKit stops using vender prefix.
_available_processors[common.X_WEBKIT_DEFLATE_FRAME_EXTENSION] = (
DeflateFrameExtensionProcessor)
def _parse_compression_method(data):
"""Parses the value of "method" extension parameter."""
return common.parse_extensions(data, allow_quoted_string=True)
def _create_accepted_method_desc(method_name, method_params):
"""Creates accepted-method-desc from given method name and parameters"""
extension = common.ExtensionParameter(method_name)
for name, value in method_params:
extension.add_parameter(name, value)
return common.format_extension(extension)
class CompressionExtensionProcessorBase(ExtensionProcessorInterface):
"""Base class for Per-frame and Per-message compression extension."""
_METHOD_PARAM = 'method'
def __init__(self, request):
self._logger = util.get_class_logger(self)
self._request = request
self._compression_method_name = None
self._compression_processor = None
def name(self):
return ''
def _lookup_compression_processor(self, method_desc):
return None
def _get_compression_processor_response(self):
"""Looks up the compression processor based on the self._request and
returns the compression processor's response.
"""
method_list = self._request.get_parameter_value(self._METHOD_PARAM)
if method_list is None:
return None
methods = _parse_compression_method(method_list)
if methods is None:
return None
comression_processor = None
# The current implementation tries only the first method that matches
# supported algorithm. Following methods aren't tried even if the
# first one is rejected.
# TODO(bashi): Need to clarify this behavior.
for method_desc in methods:
compression_processor = self._lookup_compression_processor(
method_desc)
if compression_processor is not None:
self._compression_method_name = method_desc.name()
break
if compression_processor is None:
return None
processor_response = compression_processor.get_extension_response()
if processor_response is None:
return None
self._compression_processor = compression_processor
return processor_response
def get_extension_response(self):
processor_response = self._get_compression_processor_response()
if processor_response is None:
return None
response = common.ExtensionParameter(self._request.name())
accepted_method_desc = _create_accepted_method_desc(
self._compression_method_name,
processor_response.get_parameters())
response.add_parameter(self._METHOD_PARAM, accepted_method_desc)
self._logger.debug(
'Enable %s extension (method: %s)' %
(self._request.name(), self._compression_method_name))
return response
def setup_stream_options(self, stream_options):
if self._compression_processor is None:
return
self._compression_processor.setup_stream_options(stream_options)
def get_compression_processor(self):
return self._compression_processor
class PerFrameCompressionExtensionProcessor(CompressionExtensionProcessorBase):
"""WebSocket Per-frame compression extension processor."""
_DEFLATE_METHOD = 'deflate'
def __init__(self, request):
CompressionExtensionProcessorBase.__init__(self, request)
def name(self):
return common.PERFRAME_COMPRESSION_EXTENSION
def _lookup_compression_processor(self, method_desc):
if method_desc.name() == self._DEFLATE_METHOD:
return DeflateFrameExtensionProcessor(method_desc)
_available_processors[common.PERFRAME_COMPRESSION_EXTENSION] = (
PerFrameCompressionExtensionProcessor)
class DeflateMessageProcessor(ExtensionProcessorInterface):
"""Per-message deflate processor."""
_S2C_MAX_WINDOW_BITS_PARAM = 's2c_max_window_bits'
_S2C_NO_CONTEXT_TAKEOVER_PARAM = 's2c_no_context_takeover'
_C2S_MAX_WINDOW_BITS_PARAM = 'c2s_max_window_bits'
_C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover'
def __init__(self, request):
self._request = request
self._logger = util.get_class_logger(self)
self._c2s_max_window_bits = None
self._c2s_no_context_takeover = False
self._compress_outgoing = False
# Counters for statistics.
# Total number of outgoing bytes supplied to this filter.
self._total_outgoing_payload_bytes = 0
# Total number of bytes sent to the network after applying this filter.
self._total_filtered_outgoing_payload_bytes = 0
# Total number of bytes received from the network.
self._total_incoming_payload_bytes = 0
# Total number of incoming bytes obtained after applying this filter.
self._total_filtered_incoming_payload_bytes = 0
def name(self):
return 'deflate'
def get_extension_response(self):
# Any unknown parameter will be just ignored.
s2c_max_window_bits = self._request.get_parameter_value(
self._S2C_MAX_WINDOW_BITS_PARAM)
if s2c_max_window_bits is not None:
try:
s2c_max_window_bits = int(s2c_max_window_bits)
except ValueError, e:
return None
if s2c_max_window_bits < 8 or s2c_max_window_bits > 15:
return None
s2c_no_context_takeover = self._request.has_parameter(
self._S2C_NO_CONTEXT_TAKEOVER_PARAM)
if (s2c_no_context_takeover and
self._request.get_parameter_value(
self._S2C_NO_CONTEXT_TAKEOVER_PARAM) is not None):
return None
self._deflater = util._RFC1979Deflater(
s2c_max_window_bits, s2c_no_context_takeover)
self._inflater = util._RFC1979Inflater()
self._compress_outgoing = True
response = common.ExtensionParameter(self._request.name())
if s2c_max_window_bits is not None:
response.add_parameter(
self._S2C_MAX_WINDOW_BITS_PARAM, str(s2c_max_window_bits))
if s2c_no_context_takeover is not None:
response.add_parameter(
self._S2C_NO_CONTEXT_TAKEOVER_PARAM, None)
if self._c2s_max_window_bits is not None:
response.add_parameter(
self._C2S_MAX_WINDOW_BITS_PARAM,
str(self._c2s_response_window_bits))
if self._c2s_no_context_takeover:
response.add_parameter(
self._C2S_NO_CONTEXT_TAKEOVER_PARAM, None)
self._logger.debug(
'Enable %s extension ('
'request: s2c_max_window_bits=%s; s2c_no_context_takeover=%r, '
'response: c2s_max_window_bits=%s; c2s_no_context_takeover=%r)' %
(self._request.name(),
s2c_max_window_bits,
s2c_no_context_takeover,
self._c2s_max_window_bits,
self._c2s_no_context_takeover))
return response
def setup_stream_options(self, stream_options):
class _OutgoingMessageFilter(object):
def __init__(self, parent):
self._parent = parent
def filter(self, message, end=True, binary=False):
return self._parent._process_outgoing_message(
message, end, binary)
class _IncomingMessageFilter(object):
def __init__(self, parent):
self._parent = parent
self._decompress_next_message = False
def decompress_next_message(self):
self._decompress_next_message = True
def filter(self, message):
message = self._parent._process_incoming_message(
message, self._decompress_next_message)
self._decompress_next_message = False
return message
self._outgoing_message_filter = _OutgoingMessageFilter(self)
self._incoming_message_filter = _IncomingMessageFilter(self)
stream_options.outgoing_message_filters.append(
self._outgoing_message_filter)
stream_options.incoming_message_filters.append(
self._incoming_message_filter)
class _OutgoingFrameFilter(object):
def __init__(self, parent):
self._parent = parent
self._set_compression_bit = False
def set_compression_bit(self):
self._set_compression_bit = True
def filter(self, frame):
self._parent._process_outgoing_frame(
frame, self._set_compression_bit)
self._set_compression_bit = False
class _IncomingFrameFilter(object):
def __init__(self, parent):
self._parent = parent
def filter(self, frame):
self._parent._process_incoming_frame(frame)
self._outgoing_frame_filter = _OutgoingFrameFilter(self)
self._incoming_frame_filter = _IncomingFrameFilter(self)
stream_options.outgoing_frame_filters.append(
self._outgoing_frame_filter)
stream_options.incoming_frame_filters.append(
self._incoming_frame_filter)
stream_options.encode_text_message_to_utf8 = False
def set_c2s_window_bits(self, value):
self._c2s_max_window_bits = value
def set_c2s_no_context_takeover(self, value):
self._c2s_no_context_takeover = value
def enable_outgoing_compression(self):
self._compress_outgoing = True
def disable_outgoing_compression(self):
self._compress_outgoing = False
def _process_incoming_message(self, message, decompress):
if not decompress:
return message
received_payload_size = len(message)
self._total_incoming_payload_bytes += received_payload_size
message = self._inflater.filter(message)
filtered_payload_size = len(message)
self._total_filtered_incoming_payload_bytes += filtered_payload_size
_log_decompression_ratio(self._logger, received_payload_size,
self._total_incoming_payload_bytes,
filtered_payload_size,
self._total_filtered_incoming_payload_bytes)
return message
def _process_outgoing_message(self, message, end, binary):
if not binary:
message = message.encode('utf-8')
if not self._compress_outgoing:
return message
original_payload_size = len(message)
self._total_outgoing_payload_bytes += original_payload_size
message = self._deflater.filter(message)
filtered_payload_size = len(message)
self._total_filtered_outgoing_payload_bytes += filtered_payload_size
_log_compression_ratio(self._logger, original_payload_size,
self._total_outgoing_payload_bytes,
filtered_payload_size,
self._total_filtered_outgoing_payload_bytes)
self._outgoing_frame_filter.set_compression_bit()
return message
def _process_incoming_frame(self, frame):
if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode):
self._incoming_message_filter.decompress_next_message()
frame.rsv1 = 0
def _process_outgoing_frame(self, frame, compression_bit):
if (not compression_bit or
common.is_control_opcode(frame.opcode)):
return
frame.rsv1 = 1
class PerMessageCompressionExtensionProcessor(
CompressionExtensionProcessorBase):
"""WebSocket Per-message compression extension processor."""
_DEFLATE_METHOD = 'deflate'
def __init__(self, request):
CompressionExtensionProcessorBase.__init__(self, request)
def name(self):
return common.PERMESSAGE_COMPRESSION_EXTENSION
def _lookup_compression_processor(self, method_desc):
if method_desc.name() == self._DEFLATE_METHOD:
return DeflateMessageProcessor(method_desc)
_available_processors[common.PERMESSAGE_COMPRESSION_EXTENSION] = (
PerFrameCompressionExtensionProcessor)
class MuxExtensionProcessor(ExtensionProcessorInterface):
"""WebSocket multiplexing extension processor."""
_QUOTA_PARAM = 'quota'
def __init__(self, request):
self._request = request
def name(self):
return common.MUX_EXTENSION
def get_extension_response(self, ws_request,
logical_channel_extensions):
# Mux extension cannot be used after extensions that depend on
# frame boundary, extension data field, or any reserved bits
# which are attributed to each frame.
for extension in logical_channel_extensions:
name = extension.name()
if (name == common.PERFRAME_COMPRESSION_EXTENSION or
name == common.DEFLATE_FRAME_EXTENSION or
name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION):
return None
quota = self._request.get_parameter_value(self._QUOTA_PARAM)
if quota is None:
ws_request.mux_quota = 0
else:
try:
quota = int(quota)
except ValueError, e:
return None
if quota < 0 or quota >= 2 ** 32:
return None
ws_request.mux_quota = quota
ws_request.mux = True
ws_request.mux_extensions = logical_channel_extensions
return common.ExtensionParameter(common.MUX_EXTENSION)
def setup_stream_options(self, stream_options):
pass
_available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor
def get_extension_processor(extension_request):
global _available_processors
processor_class = _available_processors.get(extension_request.name())
if processor_class is None:
return None
return processor_class(extension_request)
# vi:sts=4 sw=4 et