#!/usr/bin/python

import unittest, time
import common
from autotest_lib.client.common_lib import global_config
from autotest_lib.client.common_lib.test_utils import mock
from autotest_lib.database import database_connection

_CONFIG_SECTION = 'AUTOTEST_WEB'
_HOST = 'myhost'
_USER = 'myuser'
_PASS = 'mypass'
_DB_NAME = 'mydb'
_DB_TYPE = 'mydbtype'

_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
                       db_name=_DB_NAME)
_RECONNECT_DELAY = 10

class FakeDatabaseError(Exception):
    pass


class DatabaseConnectionTest(unittest.TestCase):
    def setUp(self):
        self.god = mock.mock_god()
        self.god.stub_function(time, 'sleep')


    def tearDown(self):
        global_config.global_config.reset_config_values()
        self.god.unstub_all()


    def _get_database_connection(self, config_section=_CONFIG_SECTION):
        if config_section == _CONFIG_SECTION:
            self._override_config()
        db = database_connection.DatabaseConnection(config_section)

        self._fake_backend = self.god.create_mock_class(
            database_connection._GenericBackend, 'fake_backend')
        for exception in database_connection._DB_EXCEPTIONS:
            setattr(self._fake_backend, exception, FakeDatabaseError)
        self._fake_backend.rowcount = 0

        def get_fake_backend(db_type):
            self._db_type = db_type
            return self._fake_backend
        self.god.stub_with(db, '_get_backend', get_fake_backend)

        db.reconnect_delay_sec = _RECONNECT_DELAY
        return db


    def _override_config(self):
        c = global_config.global_config
        c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
        c.override_config_value(_CONFIG_SECTION, 'user', _USER)
        c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
        c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
        c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)


    def test_connect(self):
        db = self._get_database_connection(config_section=None)
        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)

        db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
                   password=_PASS, db_name=_DB_NAME)

        self.assertEquals(self._db_type, _DB_TYPE)
        self.god.check_playback()


    def test_global_config(self):
        db = self._get_database_connection()
        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)

        db.connect()

        self.assertEquals(self._db_type, _DB_TYPE)
        self.god.check_playback()


    def _expect_reconnect(self, fail=False):
        self._fake_backend.disconnect.expect_call()
        call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
        if fail:
            call.and_raises(FakeDatabaseError())


    def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
            FakeDatabaseError())
        for i in xrange(num_reconnects):
            time.sleep.expect_call(_RECONNECT_DELAY)
            if i < num_reconnects - 1:
                self._expect_reconnect(fail=True)
            else:
                self._expect_reconnect(fail=fail_last)


    def test_connect_retry(self):
        db = self._get_database_connection()
        self._expect_fail_and_reconnect(1)

        db.connect()
        self.god.check_playback()

        self._fake_backend.disconnect.expect_call()
        self._expect_fail_and_reconnect(0)
        self.assertRaises(FakeDatabaseError, db.connect,
                          try_reconnecting=False)
        self.god.check_playback()

        db.reconnect_enabled = False
        self._fake_backend.disconnect.expect_call()
        self._expect_fail_and_reconnect(0)
        self.assertRaises(FakeDatabaseError, db.connect)
        self.god.check_playback()


    def test_max_reconnect(self):
        db = self._get_database_connection()
        db.max_reconnect_attempts = 5
        self._expect_fail_and_reconnect(5, fail_last=True)

        self.assertRaises(FakeDatabaseError, db.connect)
        self.god.check_playback()


    def test_reconnect_forever(self):
        db = self._get_database_connection()
        db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
        self._expect_fail_and_reconnect(30)

        db.connect()
        self.god.check_playback()


    def _simple_connect(self, db):
        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
        db.connect()
        self.god.check_playback()


    def test_disconnect(self):
        db = self._get_database_connection()
        self._simple_connect(db)
        self._fake_backend.disconnect.expect_call()

        db.disconnect()
        self.god.check_playback()


    def test_execute(self):
        db = self._get_database_connection()
        self._simple_connect(db)
        params = object()
        self._fake_backend.execute.expect_call('query', params)

        db.execute('query', params)
        self.god.check_playback()


    def test_execute_retry(self):
        db = self._get_database_connection()
        self._simple_connect(db)
        self._fake_backend.execute.expect_call('query', None).and_raises(
            FakeDatabaseError())
        self._expect_reconnect()
        self._fake_backend.execute.expect_call('query', None)

        db.execute('query')
        self.god.check_playback()

        self._fake_backend.execute.expect_call('query', None).and_raises(
            FakeDatabaseError())
        self.assertRaises(FakeDatabaseError, db.execute, 'query',
                          try_reconnecting=False)


if __name__ == '__main__':
    unittest.main()