#!/usr/bin/python #pylint: disable-msg=C0111 import datetime import common from autotest_lib.frontend import setup_django_environment from autotest_lib.frontend.afe import frontend_test_utils from autotest_lib.client.common_lib import host_queue_entry_states from autotest_lib.client.common_lib.test_utils import mock from autotest_lib.client.common_lib.test_utils import unittest from autotest_lib.database import database_connection from autotest_lib.frontend.afe import models, model_attributes from autotest_lib.scheduler import monitor_db from autotest_lib.scheduler import scheduler_lib from autotest_lib.scheduler import scheduler_models _DEBUG = False class BaseSchedulerModelsTest(unittest.TestCase, frontend_test_utils.FrontendTestMixin): _config_section = 'AUTOTEST_WEB' def _do_query(self, sql): self._database.execute(sql) def _set_monitor_stubs(self): # Clear the instance cache as this is a brand new database. scheduler_models.DBObject._clear_instance_cache() self._database = ( database_connection.TranslatingDatabase.get_test_database( translators=scheduler_lib._DB_TRANSLATORS)) self._database.connect(db_type='django') self._database.debug = _DEBUG self.god.stub_with(scheduler_models, '_db', self._database) def setUp(self): self._frontend_common_setup() self._set_monitor_stubs() def tearDown(self): self._database.disconnect() self._frontend_common_teardown() def _update_hqe(self, set, where=''): query = 'UPDATE afe_host_queue_entries SET ' + set if where: query += ' WHERE ' + where self._do_query(query) class DelayedCallTaskTest(unittest.TestCase): def setUp(self): self.god = mock.mock_god() def tearDown(self): self.god.unstub_all() def test_delayed_call(self): test_time = self.god.create_mock_function('time') test_time.expect_call().and_return(33) test_time.expect_call().and_return(34.01) test_time.expect_call().and_return(34.99) test_time.expect_call().and_return(35.01) def test_callback(): test_callback.calls += 1 test_callback.calls = 0 delay_task = scheduler_models.DelayedCallTask( delay_seconds=2, callback=test_callback, now_func=test_time) # time 33 self.assertEqual(35, delay_task.end_time) delay_task.poll() # activates the task and polls it once, time 34.01 self.assertEqual(0, test_callback.calls, "callback called early") delay_task.poll() # time 34.99 self.assertEqual(0, test_callback.calls, "callback called early") delay_task.poll() # time 35.01 self.assertEqual(1, test_callback.calls) self.assert_(delay_task.is_done()) self.assert_(delay_task.success) self.assert_(not delay_task.aborted) self.god.check_playback() def test_delayed_call_abort(self): delay_task = scheduler_models.DelayedCallTask( delay_seconds=987654, callback=lambda : None) delay_task.abort() self.assert_(delay_task.aborted) self.assert_(delay_task.is_done()) self.assert_(not delay_task.success) self.god.check_playback() class DBObjectTest(BaseSchedulerModelsTest): def test_compare_fields_in_row(self): host = scheduler_models.Host(id=1) fields = list(host._fields) row_data = [getattr(host, fieldname) for fieldname in fields] self.assertEqual({}, host._compare_fields_in_row(row_data)) row_data[fields.index('hostname')] = 'spam' self.assertEqual({'hostname': ('host1', 'spam')}, host._compare_fields_in_row(row_data)) row_data[fields.index('id')] = 23 self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)}, host._compare_fields_in_row(row_data)) def test_compare_fields_in_row_datetime_ignores_microseconds(self): datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890) datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0) class TestTable(scheduler_models.DBObject): _table_name = 'test_table' _fields = ('id', 'test_datetime') tt = TestTable(row=[1, datetime_without_us]) self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us])) def test_always_query(self): host_a = scheduler_models.Host(id=2) self.assertEqual(host_a.hostname, 'host2') self._do_query('UPDATE afe_hosts SET hostname="host2-updated" ' 'WHERE id=2') host_b = scheduler_models.Host(id=2, always_query=True) self.assert_(host_a is host_b, 'Cached instance not returned.') self.assertEqual(host_a.hostname, 'host2-updated', 'Database was not re-queried') # If either of these are called, a query was made when it shouldn't be. host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!') host_a._update_fields_from_row = host_a._compare_fields_in_row host_c = scheduler_models.Host(id=2, always_query=False) self.assert_(host_a is host_c, 'Cached instance not returned') def test_delete(self): host = scheduler_models.Host(id=3) host.delete() host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, always_query=False) host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, always_query=True) def test_save(self): # Dummy Job to avoid creating a one in the HostQueueEntry __init__. class MockJob(object): def __init__(self, id): pass def tag(self): return 'MockJob' self.god.stub_with(scheduler_models, 'Job', MockJob) hqe = scheduler_models.HostQueueEntry( new_record=True, row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None, None]) hqe.save() new_id = hqe.id # Force a re-query and verify that the correct data was stored. scheduler_models.DBObject._clear_instance_cache() hqe = scheduler_models.HostQueueEntry(id=new_id) self.assertEqual(hqe.id, new_id) self.assertEqual(hqe.job_id, 1) self.assertEqual(hqe.host_id, 2) self.assertEqual(hqe.status, 'Queued') self.assertEqual(hqe.meta_host, None) self.assertEqual(hqe.active, False) self.assertEqual(hqe.complete, False) self.assertEqual(hqe.deleted, False) self.assertEqual(hqe.execution_subdir, '.') self.assertEqual(hqe.atomic_group_id, None) self.assertEqual(hqe.started_on, None) self.assertEqual(hqe.finished_on, None) class HostTest(BaseSchedulerModelsTest): def test_cmp_for_sort(self): expected_order = [ 'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010', 'host10', 'host11', 'yolkfolk'] hostname_idx = list(scheduler_models.Host._fields).index('hostname') row = [None] * len(scheduler_models.Host._fields) hosts = [] for hostname in expected_order: row[hostname_idx] = hostname hosts.append(scheduler_models.Host(row=row, new_record=True)) host1 = hosts[expected_order.index('Host1')] host010 = hosts[expected_order.index('HOST010')] host10 = hosts[expected_order.index('host10')] host3 = hosts[expected_order.index('host3')] alice = hosts[expected_order.index('alice')] self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10)) self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010)) self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3)) self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3)) self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice)) self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice)) hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) self.assertEqual(expected_order, [h.hostname for h in hosts]) hosts.reverse() hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) self.assertEqual(expected_order, [h.hostname for h in hosts]) class HostQueueEntryTest(BaseSchedulerModelsTest): def _create_hqe(self, dependency_labels=(), **create_job_kwargs): job = self._create_job(**create_job_kwargs) for label in dependency_labels: job.dependency_labels.add(label) hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id)) self.assertEqual(1, len(hqes)) return hqes[0] def _check_hqe_labels(self, hqe, expected_labels): expected_labels = set(expected_labels) label_names = set(label.name for label in hqe.get_labels()) self.assertEqual(expected_labels, label_names) def test_get_labels_empty(self): hqe = self._create_hqe(hosts=[1]) labels = list(hqe.get_labels()) self.assertEqual([], labels) def test_get_labels_metahost(self): hqe = self._create_hqe(metahosts=[2]) self._check_hqe_labels(hqe, ['label2']) def test_get_labels_dependancies(self): hqe = self._create_hqe(dependency_labels=(self.label3, self.label4), metahosts=[1]) self._check_hqe_labels(hqe, ['label1', 'label3', 'label4']) def setup_abort_test(self, agent_finished=True): """Setup the variables for testing abort method. @param agent_finished: True to mock agent is finished before aborting the hqe. @return hqe, dispatcher: Mock object of hqe and dispatcher to be used to test abort method. """ hqe = self._create_hqe(hosts=[1]) hqe.aborted = True hqe.complete = False hqe.status = models.HostQueueEntry.Status.STARTING hqe.started_on = datetime.datetime.now() dispatcher = self.god.create_mock_class(monitor_db.BaseDispatcher, 'BaseDispatcher') agent = self.god.create_mock_class(monitor_db.Agent, 'Agent') dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent]) agent.is_done.expect_call().and_return(agent_finished) return hqe, dispatcher def test_abort_fail_with_unfinished_agent(self): """abort should fail if the hqe still has agent not finished. """ hqe, dispatcher = self.setup_abort_test(agent_finished=False) self.assertIsNone(hqe.finished_on) with self.assertRaises(AssertionError): hqe.abort(dispatcher) self.god.check_playback() # abort failed, finished_on should not be set self.assertIsNone(hqe.finished_on) def test_abort_success(self): """abort should succeed if all agents for the hqe are finished. """ hqe, dispatcher = self.setup_abort_test(agent_finished=True) self.assertIsNone(hqe.finished_on) hqe.abort(dispatcher) self.god.check_playback() self.assertIsNotNone(hqe.finished_on) def test_set_finished_on(self): """Test that finished_on is set when hqe completes.""" for status in host_queue_entry_states.Status.values: hqe = self._create_hqe(hosts=[1]) hqe.started_on = datetime.datetime.now() hqe.job.update_field('shard_id', 3) self.assertIsNone(hqe.finished_on) hqe.set_status(status) if status in host_queue_entry_states.COMPLETE_STATUSES: self.assertIsNotNone(hqe.finished_on) self.assertIsNone(hqe.job.shard_id) else: self.assertIsNone(hqe.finished_on) self.assertEquals(hqe.job.shard_id, 3) class JobTest(BaseSchedulerModelsTest): def setUp(self): super(JobTest, self).setUp() def _mock_create(**kwargs): task = models.SpecialTask(**kwargs) task.save() self._tasks.append(task) self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create) def _test_pre_job_tasks_helper(self, reboot_before=model_attributes.RebootBefore.ALWAYS): """ Calls HQE._do_schedule_pre_job_tasks() and returns the created special task """ self._tasks = [] queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0] queue_entry.job.reboot_before = reboot_before queue_entry._do_schedule_pre_job_tasks() return self._tasks def test_job_request_abort(self): django_job = self._create_job(hosts=[5, 6], atomic_group=1) job = scheduler_models.Job(django_job.id) job.request_abort() django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id)) for hqe in django_hqes: self.assertTrue(hqe.aborted) def test__atomic_and_has_started__on_atomic(self): self._create_job(hosts=[5, 6], atomic_group=1) job = scheduler_models.Job.fetch('id = 1')[0] self.assertFalse(job._atomic_and_has_started()) self._update_hqe("status='Pending'") self.assertFalse(job._atomic_and_has_started()) self._update_hqe("status='Verifying'") self.assertFalse(job._atomic_and_has_started()) self.assertFalse(job._atomic_and_has_started()) self._update_hqe("status='Failed'") self.assertFalse(job._atomic_and_has_started()) self._update_hqe("status='Stopped'") self.assertFalse(job._atomic_and_has_started()) self._update_hqe("status='Starting'") self.assertTrue(job._atomic_and_has_started()) self._update_hqe("status='Completed'") self.assertTrue(job._atomic_and_has_started()) self._update_hqe("status='Aborted'") def test__atomic_and_has_started__not_atomic(self): self._create_job(hosts=[1, 2]) job = scheduler_models.Job.fetch('id = 1')[0] self.assertFalse(job._atomic_and_has_started()) self._update_hqe("status='Starting'") self.assertFalse(job._atomic_and_has_started()) def _check_special_tasks(self, tasks, task_types): self.assertEquals(len(tasks), len(task_types)) for task, (task_type, queue_entry_id) in zip(tasks, task_types): self.assertEquals(task.task, task_type) self.assertEquals(task.host.id, 1) if queue_entry_id: self.assertEquals(task.queue_entry.id, queue_entry_id) def test_run_asynchronous(self): self._create_job(hosts=[1, 2]) tasks = self._test_pre_job_tasks_helper() self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) def test_run_asynchronous_skip_verify(self): job = self._create_job(hosts=[1, 2]) job.run_verify = False job.save() tasks = self._test_pre_job_tasks_helper() self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) def test_run_synchronous_verify(self): self._create_job(hosts=[1, 2], synchronous=True) tasks = self._test_pre_job_tasks_helper() self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) def test_run_synchronous_skip_verify(self): job = self._create_job(hosts=[1, 2], synchronous=True) job.run_verify = False job.save() tasks = self._test_pre_job_tasks_helper() self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) def test_run_asynchronous_do_not_reset(self): job = self._create_job(hosts=[1, 2]) job.run_reset = False job.run_verify = False job.save() tasks = self._test_pre_job_tasks_helper() self.assertEquals(tasks, []) def test_run_synchronous_do_not_reset_no_RebootBefore(self): job = self._create_job(hosts=[1, 2], synchronous=True) job.reboot_before = model_attributes.RebootBefore.NEVER job.save() tasks = self._test_pre_job_tasks_helper( reboot_before=model_attributes.RebootBefore.NEVER) self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) def test_run_asynchronous_do_not_reset(self): job = self._create_job(hosts=[1, 2], synchronous=False) job.reboot_before = model_attributes.RebootBefore.NEVER job.save() tasks = self._test_pre_job_tasks_helper( reboot_before=model_attributes.RebootBefore.NEVER) self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) def test_run_atomic_group_already_started(self): self._create_job(hosts=[5, 6], atomic_group=1, synchronous=True) self._update_hqe("status='Starting', execution_subdir=''") job = scheduler_models.Job.fetch('id = 1')[0] queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0] assert queue_entry.job is job self.assertEqual(None, job.run(queue_entry)) self.god.check_playback() def test_reboot_before_always(self): job = self._create_job(hosts=[1]) job.reboot_before = model_attributes.RebootBefore.ALWAYS job.save() tasks = self._test_pre_job_tasks_helper() self._check_special_tasks(tasks, [ (models.SpecialTask.Task.RESET, None) ]) def _test_reboot_before_if_dirty_helper(self): job = self._create_job(hosts=[1]) job.reboot_before = model_attributes.RebootBefore.IF_DIRTY job.save() tasks = self._test_pre_job_tasks_helper() task_types = [(models.SpecialTask.Task.RESET, None)] self._check_special_tasks(tasks, task_types) def test_reboot_before_if_dirty(self): models.Host.smart_get(1).update_object(dirty=True) self._test_reboot_before_if_dirty_helper() def test_reboot_before_not_dirty(self): models.Host.smart_get(1).update_object(dirty=False) self._test_reboot_before_if_dirty_helper() def test_next_group_name(self): django_job = self._create_job(metahosts=[1]) job = scheduler_models.Job(id=django_job.id) self.assertEqual('group0', job._next_group_name()) for hqe in django_job.hostqueueentry_set.filter(): hqe.execution_subdir = 'my_rack.group0' hqe.save() self.assertEqual('my_rack.group1', job._next_group_name('my/rack')) if __name__ == '__main__': unittest.main()