普通文本  |  508行  |  13.8 KB

# (c) 2005 Ian Bicking and contributors; written for Paste
# (http://pythonpaste.org) Licensed under the MIT license:
# http://www.opensource.org/licenses/mit-license.php
"""
Gives a multi-value dictionary object (MultiDict) plus several wrappers
"""
from collections import MutableMapping

import binascii
import warnings

from webob.compat import (
    PY3,
    iteritems_,
    itervalues_,
    url_encode,
    )

__all__ = ['MultiDict', 'NestedMultiDict', 'NoVars', 'GetDict']

class MultiDict(MutableMapping):
    """
        An ordered dictionary that can have multiple values for each key.
        Adds the methods getall, getone, mixed and extend and add to the normal
        dictionary interface.
    """

    def __init__(self, *args, **kw):
        if len(args) > 1:
            raise TypeError("MultiDict can only be called with one positional "
                            "argument")
        if args:
            if hasattr(args[0], 'iteritems'):
                items = list(args[0].iteritems())
            elif hasattr(args[0], 'items'):
                items = list(args[0].items())
            else:
                items = list(args[0])
            self._items = items
        else:
            self._items = []
        if kw:
            self._items.extend(kw.items())

    @classmethod
    def view_list(cls, lst):
        """
        Create a dict that is a view on the given list
        """
        if not isinstance(lst, list):
            raise TypeError(
                "%s.view_list(obj) takes only actual list objects, not %r"
                % (cls.__name__, lst))
        obj = cls()
        obj._items = lst
        return obj

    @classmethod
    def from_fieldstorage(cls, fs):
        """
        Create a dict from a cgi.FieldStorage instance
        """
        obj = cls()
        # fs.list can be None when there's nothing to parse
        for field in fs.list or ():
            charset = field.type_options.get('charset', 'utf8')
            transfer_encoding = field.headers.get('Content-Transfer-Encoding', None)
            supported_transfer_encoding = {
                'base64' : binascii.a2b_base64,
                'quoted-printable' : binascii.a2b_qp
                }
            if PY3: # pragma: no cover
                if charset == 'utf8':
                    decode = lambda b: b
                else:
                    decode = lambda b: b.encode('utf8').decode(charset)
            else:
                decode = lambda b: b.decode(charset)
            if field.filename:
                field.filename = decode(field.filename)
                obj.add(field.name, field)
            else:
                value = field.value
                if transfer_encoding in supported_transfer_encoding:
                    if PY3: # pragma: no cover
                        # binascii accepts bytes
                        value = value.encode('utf8')
                    value = supported_transfer_encoding[transfer_encoding](value)
                    if PY3: # pragma: no cover
                        # binascii returns bytes
                        value = value.decode('utf8')
                obj.add(field.name, decode(value))
        return obj

    def __getitem__(self, key):
        for k, v in reversed(self._items):
            if k == key:
                return v
        raise KeyError(key)

    def __setitem__(self, key, value):
        try:
            del self[key]
        except KeyError:
            pass
        self._items.append((key, value))

    def add(self, key, value):
        """
        Add the key and value, not overwriting any previous value.
        """
        self._items.append((key, value))

    def getall(self, key):
        """
        Return a list of all values matching the key (may be an empty list)
        """
        return [v for k, v in self._items if k == key]

    def getone(self, key):
        """
        Get one value matching the key, raising a KeyError if multiple
        values were found.
        """
        v = self.getall(key)
        if not v:
            raise KeyError('Key not found: %r' % key)
        if len(v) > 1:
            raise KeyError('Multiple values match %r: %r' % (key, v))
        return v[0]

    def mixed(self):
        """
        Returns a dictionary where the values are either single
        values, or a list of values when a key/value appears more than
        once in this dictionary.  This is similar to the kind of
        dictionary often used to represent the variables in a web
        request.
        """
        result = {}
        multi = {}
        for key, value in self.items():
            if key in result:
                # We do this to not clobber any lists that are
                # *actual* values in this dictionary:
                if key in multi:
                    result[key].append(value)
                else:
                    result[key] = [result[key], value]
                    multi[key] = None
            else:
                result[key] = value
        return result

    def dict_of_lists(self):
        """
        Returns a dictionary where each key is associated with a list of values.
        """
        r = {}
        for key, val in self.items():
            r.setdefault(key, []).append(val)
        return r

    def __delitem__(self, key):
        items = self._items
        found = False
        for i in range(len(items)-1, -1, -1):
            if items[i][0] == key:
                del items[i]
                found = True
        if not found:
            raise KeyError(key)

    def __contains__(self, key):
        for k, v in self._items:
            if k == key:
                return True
        return False

    has_key = __contains__

    def clear(self):
        del self._items[:]

    def copy(self):
        return self.__class__(self)

    def setdefault(self, key, default=None):
        for k, v in self._items:
            if key == k:
                return v
        self._items.append((key, default))
        return default

    def pop(self, key, *args):
        if len(args) > 1:
            raise TypeError("pop expected at most 2 arguments, got %s"
                             % repr(1 + len(args)))
        for i in range(len(self._items)):
            if self._items[i][0] == key:
                v = self._items[i][1]
                del self._items[i]
                return v
        if args:
            return args[0]
        else:
            raise KeyError(key)

    def popitem(self):
        return self._items.pop()

    def update(self, *args, **kw):
        if args:
            lst = args[0]
            if len(lst) != len(dict(lst)):
                # this does not catch the cases where we overwrite existing
                # keys, but those would produce too many warning
                msg = ("Behavior of MultiDict.update() has changed "
                    "and overwrites duplicate keys. Consider using .extend()"
                )
                warnings.warn(msg, UserWarning, stacklevel=2)
        MutableMapping.update(self, *args, **kw)

    def extend(self, other=None, **kwargs):
        if other is None:
            pass
        elif hasattr(other, 'items'):
            self._items.extend(other.items())
        elif hasattr(other, 'keys'):
            for k in other.keys():
                self._items.append((k, other[k]))
        else:
            for k, v in other:
                self._items.append((k, v))
        if kwargs:
            self.update(kwargs)

    def __repr__(self):
        items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
        return '%s([%s])' % (self.__class__.__name__, ', '.join(items))

    def __len__(self):
        return len(self._items)

    ##
    ## All the iteration:
    ##

    def iterkeys(self):
        for k, v in self._items:
            yield k
    if PY3: # pragma: no cover
        keys = iterkeys
    else:
        def keys(self):
            return [k for k, v in self._items]

    __iter__ = iterkeys

    def iteritems(self):
        return iter(self._items)

    if PY3: # pragma: no cover
        items = iteritems
    else:
        def items(self):
            return self._items[:]

    def itervalues(self):
        for k, v in self._items:
            yield v

    if PY3: # pragma: no cover
        values = itervalues
    else:
        def values(self):
            return [v for k, v in self._items]

_dummy = object()

class GetDict(MultiDict):
#     def __init__(self, data, tracker, encoding, errors):
#         d = lambda b: b.decode(encoding, errors)
#         data = [(d(k), d(v)) for k,v in data]
    def __init__(self, data, env):
        self.env = env
        MultiDict.__init__(self, data)
    def on_change(self):
        e = lambda t: t.encode('utf8')
        data = [(e(k), e(v)) for k,v in self.items()]
        qs = url_encode(data)
        self.env['QUERY_STRING'] = qs
        self.env['webob._parsed_query_vars'] = (self, qs)
    def __setitem__(self, key, value):
        MultiDict.__setitem__(self, key, value)
        self.on_change()
    def add(self, key, value):
        MultiDict.add(self, key, value)
        self.on_change()
    def __delitem__(self, key):
        MultiDict.__delitem__(self, key)
        self.on_change()
    def clear(self):
        MultiDict.clear(self)
        self.on_change()
    def setdefault(self, key, default=None):
        result = MultiDict.setdefault(self, key, default)
        self.on_change()
        return result
    def pop(self, key, *args):
        result = MultiDict.pop(self, key, *args)
        self.on_change()
        return result
    def popitem(self):
        result = MultiDict.popitem(self)
        self.on_change()
        return result
    def update(self, *args, **kwargs):
        MultiDict.update(self, *args, **kwargs)
        self.on_change()
    def extend(self, *args, **kwargs):
        MultiDict.extend(self, *args, **kwargs)
        self.on_change()
    def __repr__(self):
        items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
        # TODO: GET -> GetDict
        return 'GET([%s])' % (', '.join(items))
    def copy(self):
        # Copies shouldn't be tracked
        return MultiDict(self)

class NestedMultiDict(MultiDict):
    """
    Wraps several MultiDict objects, treating it as one large MultiDict
    """

    def __init__(self, *dicts):
        self.dicts = dicts

    def __getitem__(self, key):
        for d in self.dicts:
            value = d.get(key, _dummy)
            if value is not _dummy:
                return value
        raise KeyError(key)

    def _readonly(self, *args, **kw):
        raise KeyError("NestedMultiDict objects are read-only")
    __setitem__ = _readonly
    add = _readonly
    __delitem__ = _readonly
    clear = _readonly
    setdefault = _readonly
    pop = _readonly
    popitem = _readonly
    update = _readonly

    def getall(self, key):
        result = []
        for d in self.dicts:
            result.extend(d.getall(key))
        return result

    # Inherited:
    # getone
    # mixed
    # dict_of_lists

    def copy(self):
        return MultiDict(self)

    def __contains__(self, key):
        for d in self.dicts:
            if key in d:
                return True
        return False

    has_key = __contains__

    def __len__(self):
        v = 0
        for d in self.dicts:
            v += len(d)
        return v

    def __nonzero__(self):
        for d in self.dicts:
            if d:
                return True
        return False

    def iteritems(self):
        for d in self.dicts:
            for item in iteritems_(d):
                yield item
    if PY3: # pragma: no cover
        items = iteritems
    else:
        def items(self):
            return list(self.iteritems())

    def itervalues(self):
        for d in self.dicts:
            for value in itervalues_(d):
                yield value
    if PY3: # pragma: no cover
        values = itervalues
    else:
        def values(self):
            return list(self.itervalues())

    def __iter__(self):
        for d in self.dicts:
            for key in d:
                yield key

    iterkeys = __iter__

    if PY3: # pragma: no cover
        keys = iterkeys
    else:
        def keys(self):
            return list(self.iterkeys())

class NoVars(object):
    """
    Represents no variables; used when no variables
    are applicable.

    This is read-only
    """

    def __init__(self, reason=None):
        self.reason = reason or 'N/A'

    def __getitem__(self, key):
        raise KeyError("No key %r: %s" % (key, self.reason))

    def __setitem__(self, *args, **kw):
        raise KeyError("Cannot add variables: %s" % self.reason)

    add = __setitem__
    setdefault = __setitem__
    update = __setitem__

    def __delitem__(self, *args, **kw):
        raise KeyError("No keys to delete: %s" % self.reason)
    clear = __delitem__
    pop = __delitem__
    popitem = __delitem__

    def get(self, key, default=None):
        return default

    def getall(self, key):
        return []

    def getone(self, key):
        return self[key]

    def mixed(self):
        return {}
    dict_of_lists = mixed

    def __contains__(self, key):
        return False
    has_key = __contains__

    def copy(self):
        return self

    def __repr__(self):
        return '<%s: %s>' % (self.__class__.__name__,
                             self.reason)

    def __len__(self):
        return 0

    def __cmp__(self, other):
        return cmp({}, other)

    def iterkeys(self):
        return iter([])

    if PY3: # pragma: no cover
        keys = iterkeys
        items = iterkeys
        values = iterkeys
    else:
        def keys(self):
            return []
        items = keys
        values = keys
        itervalues = iterkeys
        iteritems = iterkeys

    __iter__ = iterkeys

def _hide_passwd(items):
    for k, v in items:
        if ('password' in k
            or 'passwd' in k
            or 'pwd' in k
        ):
            yield k, '******'
        else:
            yield k, v