# Copyright 2015 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""
All of the MBIM messages are created using the MBIMControlMessageMeta metaclass.
The metaclass supports a hierarchy of message definitions so that each message
definition extends the structure of the base class it inherits.
(mbim_message.py)
MBIMControlMessage| (mbim_message_request.py)
|>MBIMControlMessageRequest |
| |>MBIMOpen
| |>MBIMClose
| |>MBIMCommand |
| | |>MBIMSetConnect
| | |>...
| |
| |>MBIMHostError
|
| (mbim_message_response.py)
|>MBIMControlMessageResponse|
|>MBIMOpenDone
|>MBIMCloseDone
|>MBIMCommandDone|
| |>MBIMConnectInfo
| |>...
|
|>MBIMHostError
"""
import array
import logging
import struct
import sys
from collections import namedtuple
from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
# Type of message classes. The values of each field in the message is stored
# as an attribute of the object created.
# Request message classes accepts values for the attributes of the object.
MESSAGE_TYPE_REQUEST = 1
# Response message classes accepts raw_data which is parsed into attributes of
# the object.
MESSAGE_TYPE_RESPONSE = 2
# Message field types.
# Just a normal field type. No special properties.
FIELD_TYPE_NORMAL = 1
# Identify the payload ID for a message. This is used in parsing of
# response messages to help in identifying the child message class.
FIELD_TYPE_PAYLOAD_ID = 2
# Total length of the message including any payload_buffer it may contain.
FIELD_TYPE_TOTAL_LEN = 3
# Length of the payload contained in the payload_buffer.
FIELD_TYPE_PAYLOAD_LEN = 4
# Number of fragments of this message.
FIELD_TYPE_NUM_FRAGMENTS = 5
# Transaction ID of this message
FIELD_TYPE_TRANSACTION_ID = 6
def message_class_new(cls, **kwargs):
"""
Creates a message instance with either the given field name/value
pairs or raw data buffer.
The total_length and transaction_id fields are automatically calculated
if not explicitly provided in the message args.
@param kwargs: Dictionary of (field_name, field_value) pairs or
raw_data=Packed binary array.
@returns New message object created.
"""
if 'raw_data' in kwargs and kwargs['raw_data']:
# We unpack the raw data received into the appropriate fields
# for this class. If there is some additional data present in
# |raw_data| that does not fit the format of the structure,
# they're stored in the variable sized |payload_buffer| field.
raw_data = kwargs['raw_data']
data_format = cls.get_field_format_string(get_all=True)
unpack_length = cls.get_struct_len(get_all=True)
data_length = len(raw_data)
if data_length < unpack_length:
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
'Length of Data (%d) to be parsed less than message'
' structure length (%d)' %
(data_length, unpack_length))
obj = super(cls, cls).__new__(cls, *struct.unpack_from(data_format,
raw_data))
if data_length > unpack_length:
setattr(obj, 'payload_buffer', raw_data[unpack_length:])
else:
setattr(obj, 'payload_buffer', None)
return obj
else:
# Check if all the fields have been populated for this message
# except for transaction ID and message length since these
# are generated during init.
field_values = []
fields = cls.get_fields(get_all=True)
defaults = cls.get_defaults(get_all=True)
for _, field_name, field_type in fields:
if field_name not in kwargs:
if field_type == FIELD_TYPE_TOTAL_LEN:
field_value = cls.get_struct_len(get_all=True)
if 'payload_buffer' in kwargs:
field_value += len(kwargs.get('payload_buffer'))
elif field_type == FIELD_TYPE_TRANSACTION_ID:
field_value = cls.get_next_transaction_id()
else:
field_value = defaults.get(field_name, None)
if field_value is None:
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
'Missing field value (%s) in %s' % (
field_name, cls.__name__))
field_values.append(field_value)
else:
field_values.append(kwargs.pop(field_name))
obj = super(cls, cls).__new__(cls, *field_values)
# We need to account for optional variable sized payload_buffer
# in some messages which are not explicitly mentioned in the
# |cls._FIELDS| attribute.
if 'payload_buffer' in kwargs:
setattr(obj, 'payload_buffer', kwargs.pop('payload_buffer'))
else:
setattr(obj, 'payload_buffer', None)
if kwargs:
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
'Unexpected fields (%s) in %s' % (
kwargs.keys(), cls.__name__))
return obj
class MBIMControlMessageMeta(type):
"""
Metaclass for all the control message parsing/generation.
The metaclass creates each class by concatenating all the message fields
from it's base classes to create a hierarchy of messages.
Thus the payload class of each message class becomes the subclass of that
message.
Message definition attributes->
_FIELDS(optional): Used to define structure elements. The fields of a
message is the concatenation of the _FIELDS attribute
along with all the _FIELDS attribute from it's parent
classes.
_DEFAULTS(optional): Field name/value pairs to be assigned to some
of the fields if they are fixed for a message type.
These are generally used to assign values to fields in
the parent class.
_IDENTIFIERS(optional): Field name/value pairs to be used to idenitfy this
message during parsing from raw_data.
_SECONDARY_FRAGMENTS(optional): Used to identify if this class can be
fragmented and name of secondary class
definition.
MESSAGE_TYPE: Used to identify request/repsonse classes.
Message internal attributes->
_CONSOLIDATED_FIELDS: Consolidated list of all the fields defining this
message.
_CONSOLIDATED_DEFAULTS: Consolidated list of all the default field
name/value pairs for this message.
"""
def __new__(mcs, name, bases, attrs):
# The MBIMControlMessage base class, which inherits from 'object',
# is merely used to establish the class hierarchy and is never
# constructed on it's own.
if object in bases:
return super(MBIMControlMessageMeta, mcs).__new__(
mcs, name, bases, attrs)
# Append the current class fields, defaults to any base parent class
# fields.
fields = []
defaults = {}
for base_class in bases:
if hasattr(base_class, '_CONSOLIDATED_FIELDS'):
fields = getattr(base_class, '_CONSOLIDATED_FIELDS')
if hasattr(base_class, '_CONSOLIDATED_DEFAULTS'):
defaults = getattr(base_class, '_CONSOLIDATED_DEFAULTS').copy()
if '_FIELDS' in attrs:
fields = fields + map(list, attrs['_FIELDS'])
if '_DEFAULTS' in attrs:
defaults.update(attrs['_DEFAULTS'])
attrs['_CONSOLIDATED_FIELDS'] = fields
attrs['_CONSOLIDATED_DEFAULTS'] = defaults
if not fields:
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
'%s message must have some fields defined' % name)
attrs['__new__'] = message_class_new
_, field_names, _ = zip(*fields)
message_class = namedtuple(name, field_names)
# Prepend the class created via namedtuple to |bases| in order to
# correctly resolve the __new__ method while preserving the class
# hierarchy.
cls = super(MBIMControlMessageMeta, mcs).__new__(
mcs, name, (message_class,) + bases, attrs)
return cls
class MBIMControlMessage(object):
"""
MBIMControlMessage base class.
This class should not be instantiated or used directly.
"""
__metaclass__ = MBIMControlMessageMeta
_NEXT_TRANSACTION_ID = 0X00000000
@classmethod
def _find_subclasses(cls):
"""
Helper function to find all the derived payload classes of this
class.
"""
return [c for c in cls.__subclasses__()]
@classmethod
def get_fields(cls, get_all=False):
"""
Helper function to find all the fields of this class.
Returns either the total message fields or only the current
substructure fields in the nested message.
@param get_all: Whether to return the total struct fields or sub struct
fields.
@returns Fields of the structure.
"""
if get_all:
return cls._CONSOLIDATED_FIELDS
else:
return cls._FIELDS
@classmethod
def get_defaults(cls, get_all=False):
"""
Helper function to find all the default field values of this class.
Returns either the total message default field name/value pairs or only
the current substructure defaults in the nested message.
@param get_all: Whether to return the total struct defaults or sub
struct defaults.
@returns Defaults of the structure.
"""
if get_all:
return cls._CONSOLIDATED_DEFAULTS
else:
return cls._DEFAULTS
@classmethod
def _get_identifiers(cls):
"""
Helper function to find all the identifier field name/value pairs of
this class.
@returns All the idenitifiers of this class.
"""
return getattr(cls, '_IDENTIFIERS', None)
@classmethod
def _find_field_names_of_type(cls, find_type, get_all=False):
"""
Helper function to find all the field names which matches the field_type
specified.
params find_type: One of the FIELD_TYPE_* enum values specified above.
@returns Corresponding field names if found, else None.
"""
fields = cls.get_fields(get_all=get_all)
field_names = []
for _, field_name, field_type in fields:
if field_type == find_type:
field_names.append(field_name)
return field_names
@classmethod
def get_secondary_fragment(cls):
"""
Helper function to retrieve the associated secondary fragment class.
@returns |_SECONDARY_FRAGMENT| attribute of the class
"""
return getattr(cls, '_SECONDARY_FRAGMENT', None)
@classmethod
def get_field_names(cls, get_all=True):
"""
Helper function to return the field names of the message.
@returns The field names of the message structure.
"""
_, field_names, _ = zip(*cls.get_fields(get_all=get_all))
return field_names
@classmethod
def get_field_formats(cls, get_all=True):
"""
Helper function to return the field formats of the message.
@returns The format of fields of the message structure.
"""
field_formats, _, _ = zip(*cls.get_fields(get_all=get_all))
return field_formats
@classmethod
def get_field_format_string(cls, get_all=True):
"""
Helper function to return the field format string of the message.
@returns The format string of the message structure.
"""
format_string = '<' + ''.join(cls.get_field_formats(get_all=get_all))
return format_string
@classmethod
def get_struct_len(cls, get_all=False):
"""
Returns the length of the structure representing the message.
Returns the length of either the total message or only the current
substructure in the nested message.
@param get_all: Whether to return the total struct length or sub struct
length.
@returns Length of the structure.
"""
return struct.calcsize(cls.get_field_format_string(get_all=get_all))
@classmethod
def find_primary_parent_fragment(cls):
"""
Traverses up the message tree to find the primary fragment class
at the same tree level as the secondary frag class associated with this
message class. This should only be called on primary fragment derived
classes!
@returns Primary frag class associated with the message.
"""
secondary_frag_cls = cls.get_secondary_fragment()
secondary_frag_parent_cls = secondary_frag_cls.__bases__[1]
message_cls = cls
message_parent_cls = message_cls.__bases__[1]
while message_parent_cls != secondary_frag_parent_cls:
message_cls = message_parent_cls
message_parent_cls = message_cls.__bases__[1]
return message_cls
@classmethod
def get_next_transaction_id(cls):
"""
Returns incrementing transaction ids on successive calls.
@returns The tracsaction id for control message delivery.
"""
if MBIMControlMessage._NEXT_TRANSACTION_ID > (sys.maxint - 2):
MBIMControlMessage._NEXT_TRANSACTION_ID = 0x00000000
MBIMControlMessage._NEXT_TRANSACTION_ID += 1
return MBIMControlMessage._NEXT_TRANSACTION_ID
def _get_fields_of_type(self, field_type, get_all=False):
"""
Helper function to find all the field name/value of the specified type
in the given object.
@returns Corresponding map of field name/value pairs extracted from the
object.
"""
field_names = self.__class__._find_field_names_of_type(field_type,
get_all=get_all)
return {f: getattr(self, f) for f in field_names}
def _get_payload_id_fields(self):
"""
Helper function to find all the payload id field name/value in the given
object.
@returns Corresponding field name/value pairs extracted from the object.
"""
return self._get_fields_of_type(FIELD_TYPE_PAYLOAD_ID)
def get_payload_len(self):
"""
Helper function to find the payload len field value in the given
object.
@returns Corresponding field value extracted from the object.
"""
payload_len_fields = self._get_fields_of_type(FIELD_TYPE_PAYLOAD_LEN)
if ((not payload_len_fields) or (len(payload_len_fields) > 1)):
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
"Erorr in finding payload len field in message: %s" %
self.__class__.__name__)
return payload_len_fields.values()[0]
def get_total_len(self):
"""
Helper function to find the total len field value in the given
object.
@returns Corresponding field value extracted from the object.
"""
total_len_fields = self._get_fields_of_type(FIELD_TYPE_TOTAL_LEN,
get_all=True)
if ((not total_len_fields) or (len(total_len_fields) > 1)):
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
"Erorr in finding total len field in message: %s" %
self.__class__.__name__)
return total_len_fields.values()[0]
def get_num_fragments(self):
"""
Helper function to find the fragment num field value in the given
object.
@returns Corresponding field value extracted from the object.
"""
num_fragment_fields = self._get_fields_of_type(FIELD_TYPE_NUM_FRAGMENTS)
if ((not num_fragment_fields) or (len(num_fragment_fields) > 1)):
mbim_errors.log_and_raise(
mbim_errors.MBIMComplianceControlMessageError,
"Erorr in finding num fragments field in message: %s" %
self.__class__.__name__)
return num_fragment_fields.values()[0]
def find_payload_class(self):
"""
Helper function to find the derived class which has the default
|payload_id| fields matching the current message contents.
@returns Corresponding class if found, else None.
"""
cls = self.__class__
for payload_cls in cls._find_subclasses():
message_ids = self._get_payload_id_fields()
subclass_ids = payload_cls._get_identifiers()
if message_ids == subclass_ids:
return payload_cls
return None
def calculate_total_len(self):
"""
Helper function to calculate the total len of a given message
object.
@returns Total length of the message.
"""
message_class = self.__class__
total_len = message_class.get_struct_len(get_all=True)
if self.payload_buffer:
total_len += len(self.payload_buffer)
return total_len
def pack(self, format_string, field_names):
"""
Packs a list of fields based on their formats.
@param format_string: The concatenated formats for the fields given in
|field_names|.
@param field_names: The name of the fields to be packed.
@returns The packet in binary array form.
"""
field_values = [getattr(self, name) for name in field_names]
return array.array('B', struct.pack(format_string, *field_values))
def print_all_fields(self):
"""Prints all the field name, value pair of this message."""
logging.debug('Class Name: %s', self.__class__.__name__)
for field_name in self.__class__.get_field_names(get_all=True):
logging.debug('Field Name: %s, Field Value: %s',
field_name, str(getattr(self, field_name)))
if self.payload_buffer:
logging.debug('Payload: %s', str(getattr(self, 'payload_buffer')))
def create_raw_data(self):
"""
Creates the raw binary data corresponding to the message struct.
@param payload_buffer: Variable sized paylaod buffer to attach at the
end of the msg.
@returns Packed byte array of the message.
"""
message = self
message_class = message.__class__
format_string = message_class.get_field_format_string()
field_names = message_class.get_field_names()
packet = message.pack(format_string, field_names)
if self.payload_buffer:
packet.extend(self.payload_buffer)
return packet
def copy(self, **fields_to_alter):
"""
Replaces the message tuple with updated field values.
@param fields_to_alter: Field name/value pairs to be changed.
@returns Updated message with the field values updated.
"""
message = self._replace(**fields_to_alter)
# Copy the associated payload_buffer field to the new tuple.
message.payload_buffer = self.payload_buffer
return message