普通文本  |  164行  |  4.47 KB

#!/usr/bin/python

import unittest, tempfile, os
import common
import MySQLdb
from autotest_lib.client.common_lib import global_config
from autotest_lib.database import database_connection, migrate

# Which section of the global config to pull info from.  We won't actually use
# that DB, we'll use the corresponding test DB (test_<db name>).
CONFIG_DB = 'AUTOTEST_WEB'

NUM_MIGRATIONS = 3

class DummyMigration(object):
    """\
    Dummy migration class that records all migrations done in a class
    varaible.
    """

    migrations_done = []

    def __init__(self, version):
        self.version = version
        self.name = '%03d_test' % version


    @classmethod
    def get_migrations_done(cls):
        return cls.migrations_done


    @classmethod
    def clear_migrations_done(cls):
        cls.migrations_done = []


    @classmethod
    def do_migration(cls, version, direction):
        cls.migrations_done.append((version, direction))


    def migrate_up(self, manager):
        self.do_migration(self.version, 'up')
        if self.version == 1:
            manager.create_migrate_table()


    def migrate_down(self, manager):
        self.do_migration(self.version, 'down')


MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]


class TestableMigrationManager(migrate.MigrationManager):
    def _set_migrations_dir(self, migrations_dir=None):
        pass


    def get_migrations(self, minimum_version=None, maximum_version=None):
        minimum_version = minimum_version or 1
        maximum_version = maximum_version or len(MIGRATIONS)
        return MIGRATIONS[minimum_version-1:maximum_version]


class MigrateManagerTest(unittest.TestCase):
    def setUp(self):
        self._database = (
            database_connection.DatabaseConnection.get_test_database())
        self._database.connect()
        self.manager = TestableMigrationManager(self._database)
        DummyMigration.clear_migrations_done()


    def tearDown(self):
        self._database.disconnect()


    def test_sync(self):
        self.manager.do_sync_db()
        self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
        self.assertEquals(DummyMigration.get_migrations_done(),
                          [(1, 'up'), (2, 'up'), (3, 'up')])

        DummyMigration.clear_migrations_done()
        self.manager.do_sync_db(0)
        self.assertEquals(self.manager.get_db_version(), 0)
        self.assertEquals(DummyMigration.get_migrations_done(),
                          [(3, 'down'), (2, 'down'), (1, 'down')])


    def test_sync_one_by_one(self):
        for version in xrange(1, NUM_MIGRATIONS + 1):
            self.manager.do_sync_db(version)
            self.assertEquals(self.manager.get_db_version(),
                              version)
            self.assertEquals(
                DummyMigration.get_migrations_done()[-1],
                (version, 'up'))

        for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
            self.manager.do_sync_db(version)
            self.assertEquals(self.manager.get_db_version(),
                              version)
            self.assertEquals(
                DummyMigration.get_migrations_done()[-1],
                (version + 1, 'down'))


    def test_null_sync(self):
        self.manager.do_sync_db()
        DummyMigration.clear_migrations_done()
        self.manager.do_sync_db()
        self.assertEquals(DummyMigration.get_migrations_done(), [])


class DummyMigrationManager(object):
    def __init__(self):
        self.calls = []


    def execute_script(self, script):
        self.calls.append(script)


class MigrationTest(unittest.TestCase):
    def setUp(self):
        self.manager = DummyMigrationManager()


    def _do_migration(self, migration_module):
        migration = migrate.Migration('name', 1, migration_module)
        migration.migrate_up(self.manager)
        migration.migrate_down(self.manager)

        self.assertEquals(self.manager.calls, ['foo', 'bar'])


    def test_migration_with_methods(self):
        class DummyMigration(object):
            @staticmethod
            def migrate_up(manager):
                manager.execute_script('foo')


            @staticmethod
            def migrate_down(manager):
                manager.execute_script('bar')

        self._do_migration(DummyMigration)


    def test_migration_with_strings(self):
        class DummyMigration(object):
            UP_SQL = 'foo'
            DOWN_SQL = 'bar'

        self._do_migration(DummyMigration)


if __name__ == '__main__':
    unittest.main()