__author__ = "raphtee@google.com (Travis Miller)"


import re, collections, StringIO, sys, unittest


class StubNotFoundError(Exception):
    'Raised when god is asked to unstub an attribute that was not stubbed'
    pass


class CheckPlaybackError(Exception):
    'Raised when mock playback does not match recorded calls.'
    pass


class SaveDataAfterCloseStringIO(StringIO.StringIO):
    """Saves the contents in a final_data property when close() is called.

    Useful as a mock output file object to test both that the file was
    closed and what was written.

    Properties:
      final_data: Set to the StringIO's getvalue() data when close() is
          called.  None if close() has not been called.
    """
    final_data = None

    def close(self):
        self.final_data = self.getvalue()
        StringIO.StringIO.close(self)



class argument_comparator(object):
    def is_satisfied_by(self, parameter):
        raise NotImplementedError


class equality_comparator(argument_comparator):
    def __init__(self, value):
        self.value = value


    @staticmethod
    def _types_match(arg1, arg2):
        if isinstance(arg1, basestring) and isinstance(arg2, basestring):
            return True
        return type(arg1) == type(arg2)


    @classmethod
    def _compare(cls, actual_arg, expected_arg):
        if isinstance(expected_arg, argument_comparator):
            return expected_arg.is_satisfied_by(actual_arg)
        if not cls._types_match(expected_arg, actual_arg):
            return False

        if isinstance(expected_arg, list) or isinstance(expected_arg, tuple):
            # recurse on lists/tuples
            if len(actual_arg) != len(expected_arg):
                return False
            for actual_item, expected_item in zip(actual_arg, expected_arg):
                if not cls._compare(actual_item, expected_item):
                    return False
        elif isinstance(expected_arg, dict):
            # recurse on dicts
            if not cls._compare(sorted(actual_arg.keys()),
                                sorted(expected_arg.keys())):
                return False
            for key, value in actual_arg.iteritems():
                if not cls._compare(value, expected_arg[key]):
                    return False
        elif actual_arg != expected_arg:
            return False

        return True


    def is_satisfied_by(self, parameter):
        return self._compare(parameter, self.value)


    def __str__(self):
        if isinstance(self.value, argument_comparator):
            return str(self.value)
        return repr(self.value)


class regex_comparator(argument_comparator):
    def __init__(self, pattern, flags=0):
        self.regex = re.compile(pattern, flags)


    def is_satisfied_by(self, parameter):
        return self.regex.search(parameter) is not None


    def __str__(self):
        return self.regex.pattern


class is_string_comparator(argument_comparator):
    def is_satisfied_by(self, parameter):
        return isinstance(parameter, basestring)


    def __str__(self):
        return "a string"


class is_instance_comparator(argument_comparator):
    def __init__(self, cls):
        self.cls = cls


    def is_satisfied_by(self, parameter):
        return isinstance(parameter, self.cls)


    def __str__(self):
        return "is a %s" % self.cls


class anything_comparator(argument_comparator):
    def is_satisfied_by(self, parameter):
        return True


    def __str__(self):
        return 'anything'


class base_mapping(object):
    def __init__(self, symbol, return_obj, *args, **dargs):
        self.return_obj = return_obj
        self.symbol = symbol
        self.args = [equality_comparator(arg) for arg in args]
        self.dargs = dict((key, equality_comparator(value))
                          for key, value in dargs.iteritems())
        self.error = None


    def match(self, *args, **dargs):
        if len(args) != len(self.args) or len(dargs) != len(self.dargs):
            return False

        for i, expected_arg in enumerate(self.args):
            if not expected_arg.is_satisfied_by(args[i]):
                return False

        # check for incorrect dargs
        for key, value in dargs.iteritems():
            if key not in self.dargs:
                return False
            if not self.dargs[key].is_satisfied_by(value):
                return False

        # check for missing dargs
        for key in self.dargs.iterkeys():
            if key not in dargs:
                return False

        return True


    def __str__(self):
        return _dump_function_call(self.symbol, self.args, self.dargs)


class function_mapping(base_mapping):
    def __init__(self, symbol, return_val, *args, **dargs):
        super(function_mapping, self).__init__(symbol, return_val, *args,
                                               **dargs)


    def and_return(self, return_obj):
        self.return_obj = return_obj


    def and_raises(self, error):
        self.error = error


class function_any_args_mapping(function_mapping):
    """A mock function mapping that doesn't verify its arguments."""
    def match(self, *args, **dargs):
        return True


class mock_function(object):
    def __init__(self, symbol, default_return_val=None,
                 record=None, playback=None):
        self.default_return_val = default_return_val
        self.num_calls = 0
        self.args = []
        self.dargs = []
        self.symbol = symbol
        self.record = record
        self.playback = playback
        self.__name__ = symbol


    def __call__(self, *args, **dargs):
        self.num_calls += 1
        self.args.append(args)
        self.dargs.append(dargs)
        if self.playback:
            return self.playback(self.symbol, *args, **dargs)
        else:
            return self.default_return_val


    def expect_call(self, *args, **dargs):
        mapping = function_mapping(self.symbol, None, *args, **dargs)
        if self.record:
            self.record(mapping)

        return mapping


    def expect_any_call(self):
        """Like expect_call but don't give a hoot what arguments are passed."""
        mapping = function_any_args_mapping(self.symbol, None)
        if self.record:
            self.record(mapping)

        return mapping


class mask_function(mock_function):
    def __init__(self, symbol, original_function, default_return_val=None,
                 record=None, playback=None):
        super(mask_function, self).__init__(symbol,
                                            default_return_val,
                                            record, playback)
        self.original_function = original_function


    def run_original_function(self, *args, **dargs):
        return self.original_function(*args, **dargs)


class mock_class(object):
    def __init__(self, cls, name, default_ret_val=None,
                 record=None, playback=None):
        self.__name = name
        self.__record = record
        self.__playback = playback

        for symbol in dir(cls):
            if symbol.startswith("_"):
                continue

            orig_symbol = getattr(cls, symbol)
            if callable(orig_symbol):
                f_name = "%s.%s" % (self.__name, symbol)
                func = mock_function(f_name, default_ret_val,
                                     self.__record, self.__playback)
                setattr(self, symbol, func)
            else:
                setattr(self, symbol, orig_symbol)


    def __repr__(self):
        return '<mock_class: %s>' % self.__name


class mock_god(object):
    NONEXISTENT_ATTRIBUTE = object()

    def __init__(self, debug=False, fail_fast=True, ut=None):
        """
        With debug=True, all recorded method calls will be printed as
        they happen.
        With fail_fast=True, unexpected calls will immediately cause an
        exception to be raised.  With False, they will be silently recorded and
        only reported when check_playback() is called.
        """
        self.recording = collections.deque()
        self.errors = []
        self._stubs = []
        self._debug = debug
        self._fail_fast = fail_fast
        self._ut = ut


    def set_fail_fast(self, fail_fast):
        self._fail_fast = fail_fast


    def create_mock_class_obj(self, cls, name, default_ret_val=None):
        record = self.__record_call
        playback = self.__method_playback
        errors = self.errors

        class cls_sub(cls):
            cls_count = 0

            # overwrite the initializer
            def __init__(self, *args, **dargs):
                pass


            @classmethod
            def expect_new(typ, *args, **dargs):
                obj = typ.make_new(*args, **dargs)
                mapping = base_mapping(name, obj, *args, **dargs)
                record(mapping)
                return obj


            def __new__(typ, *args, **dargs):
                return playback(name, *args, **dargs)


            @classmethod
            def make_new(typ, *args, **dargs):
                obj = super(cls_sub, typ).__new__(typ, *args,
                                                  **dargs)

                typ.cls_count += 1
                obj_name = "%s_%s" % (name, typ.cls_count)
                for symbol in dir(obj):
                    if (symbol.startswith("__") and
                        symbol.endswith("__")):
                        continue

                    if isinstance(getattr(typ, symbol, None), property):
                        continue

                    orig_symbol = getattr(obj, symbol)
                    if callable(orig_symbol):
                        f_name = ("%s.%s" %
                                  (obj_name, symbol))
                        func = mock_function(f_name,
                                        default_ret_val,
                                        record,
                                        playback)
                        setattr(obj, symbol, func)
                    else:
                        setattr(obj, symbol,
                                orig_symbol)

                return obj

        return cls_sub


    def create_mock_class(self, cls, name, default_ret_val=None):
        """
        Given something that defines a namespace cls (class, object,
        module), and a (hopefully unique) name, will create a
        mock_class object with that name and that possessess all
        the public attributes of cls.  default_ret_val sets the
        default_ret_val on all methods of the cls mock.
        """
        return mock_class(cls, name, default_ret_val,
                          self.__record_call, self.__method_playback)


    def create_mock_function(self, symbol, default_return_val=None):
        """
        create a mock_function with name symbol and default return
        value of default_ret_val.
        """
        return mock_function(symbol, default_return_val,
                             self.__record_call, self.__method_playback)


    def mock_up(self, obj, name, default_ret_val=None):
        """
        Given an object (class instance or module) and a registration
        name, then replace all its methods with mock function objects
        (passing the orignal functions to the mock functions).
        """
        for symbol in dir(obj):
            if symbol.startswith("__"):
                continue

            orig_symbol = getattr(obj, symbol)
            if callable(orig_symbol):
                f_name = "%s.%s" % (name, symbol)
                func = mask_function(f_name, orig_symbol,
                                     default_ret_val,
                                     self.__record_call,
                                     self.__method_playback)
                setattr(obj, symbol, func)


    def stub_with(self, namespace, symbol, new_attribute):
        original_attribute = getattr(namespace, symbol,
                                     self.NONEXISTENT_ATTRIBUTE)

        # You only want to save the original attribute in cases where it is
        # directly associated with the object in question. In cases where
        # the attribute is actually inherited via some sort of hierarchy
        # you want to delete the stub (restoring the original structure)
        attribute_is_inherited = (hasattr(namespace, '__dict__') and
                                  symbol not in namespace.__dict__)
        if attribute_is_inherited:
            original_attribute = self.NONEXISTENT_ATTRIBUTE

        newstub = (namespace, symbol, original_attribute, new_attribute)
        self._stubs.append(newstub)
        setattr(namespace, symbol, new_attribute)


    def stub_function(self, namespace, symbol):
        mock_attribute = self.create_mock_function(symbol)
        self.stub_with(namespace, symbol, mock_attribute)


    def stub_class_method(self, cls, symbol):
        mock_attribute = self.create_mock_function(symbol)
        self.stub_with(cls, symbol, staticmethod(mock_attribute))


    def stub_class(self, namespace, symbol):
        attr = getattr(namespace, symbol)
        mock_class = self.create_mock_class_obj(attr, symbol)
        self.stub_with(namespace, symbol, mock_class)


    def stub_function_to_return(self, namespace, symbol, object_to_return):
        """Stub out a function with one that always returns a fixed value.

        @param namespace The namespace containing the function to stub out.
        @param symbol The attribute within the namespace to stub out.
        @param object_to_return The value that the stub should return whenever
            it is called.
        """
        self.stub_with(namespace, symbol,
                       lambda *args, **dargs: object_to_return)


    def _perform_unstub(self, stub):
        namespace, symbol, orig_attr, new_attr = stub
        if orig_attr == self.NONEXISTENT_ATTRIBUTE:
            delattr(namespace, symbol)
        else:
            setattr(namespace, symbol, orig_attr)


    def unstub(self, namespace, symbol):
        for stub in reversed(self._stubs):
            if (namespace, symbol) == (stub[0], stub[1]):
                self._perform_unstub(stub)
                self._stubs.remove(stub)
                return

        raise StubNotFoundError()


    def unstub_all(self):
        self._stubs.reverse()
        for stub in self._stubs:
            self._perform_unstub(stub)
        self._stubs = []


    def __method_playback(self, symbol, *args, **dargs):
        if self._debug:
            print >> sys.__stdout__, (' * Mock call: ' +
                                      _dump_function_call(symbol, args, dargs))

        if len(self.recording) != 0:
            func_call = self.recording[0]
            if func_call.symbol != symbol:
                msg = ("Unexpected call: %s\nExpected: %s"
                    % (_dump_function_call(symbol, args, dargs),
                       func_call))
                self._append_error(msg)
                return None

            if not func_call.match(*args, **dargs):
                msg = ("Incorrect call: %s\nExpected: %s"
                    % (_dump_function_call(symbol, args, dargs),
                      func_call))
                self._append_error(msg)
                return None

            # this is the expected call so pop it and return
            self.recording.popleft()
            if func_call.error:
                raise func_call.error
            else:
                return func_call.return_obj
        else:
            msg = ("unexpected call: %s"
                   % (_dump_function_call(symbol, args, dargs)))
            self._append_error(msg)
            return None


    def __record_call(self, mapping):
        self.recording.append(mapping)


    def _append_error(self, error):
        if self._debug:
            print >> sys.__stdout__, ' *** ' + error
        if self._fail_fast:
            raise CheckPlaybackError(error)
        self.errors.append(error)


    def check_playback(self):
        """
        Report any errors that were encounterd during calls
        to __method_playback().
        """
        if len(self.errors) > 0:
            if self._debug:
                print '\nPlayback errors:'
            for error in self.errors:
                print >> sys.__stdout__, error

            if self._ut:
                self._ut.fail('\n'.join(self.errors))

            raise CheckPlaybackError
        elif len(self.recording) != 0:
            errors = []
            for func_call in self.recording:
                error = "%s not called" % (func_call,)
                errors.append(error)
                print >> sys.__stdout__, error

            if self._ut:
                self._ut.fail('\n'.join(errors))

            raise CheckPlaybackError
        self.recording.clear()


    def mock_io(self):
        """Mocks and saves the stdout & stderr output"""
        self.orig_stdout = sys.stdout
        self.orig_stderr = sys.stderr

        self.mock_streams_stdout = StringIO.StringIO('')
        self.mock_streams_stderr = StringIO.StringIO('')

        sys.stdout = self.mock_streams_stdout
        sys.stderr = self.mock_streams_stderr


    def unmock_io(self):
        """Restores the stdout & stderr, and returns both
        output strings"""
        sys.stdout = self.orig_stdout
        sys.stderr = self.orig_stderr
        values = (self.mock_streams_stdout.getvalue(),
                  self.mock_streams_stderr.getvalue())

        self.mock_streams_stdout.close()
        self.mock_streams_stderr.close()
        return values


def _arg_to_str(arg):
    if isinstance(arg, argument_comparator):
        return str(arg)
    return repr(arg)


def _dump_function_call(symbol, args, dargs):
    arg_vec = []
    for arg in args:
        arg_vec.append(_arg_to_str(arg))
    for key, val in dargs.iteritems():
        arg_vec.append("%s=%s" % (key, _arg_to_str(val)))
    return "%s(%s)" % (symbol, ', '.join(arg_vec))