/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H #define __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H #include <chrono> #include <condition_variable> #include <iostream> #include <mutex> #include <queue> #include <unordered_map> #include <utility> using namespace ::std; using namespace ::std::chrono; constexpr char kVtsHalHidlTargetCallbackDefaultName[] = "VtsHalHidlTargetCallbackDefaultName"; constexpr milliseconds DEFAULT_CALLBACK_WAIT_TIMEOUT_INITIAL = minutes(1); namespace testing { /* * VTS target side test template for callback. * * Providing wait and notify for callback functionality. * * A typical usage looks like this: * * class CallbackArgs { * ArgType1 arg1; * ArgType2 arg2; * } * * class MyCallback * : public ::testing::VtsHalHidlTargetCallbackBase<>, * public CallbackInterface { * public: * CallbackApi1(ArgType1 arg1) { * CallbackArgs data; * data.arg1 = arg1; * NotifyFromCallback("CallbackApi1", data); * } * * CallbackApi2(ArgType2 arg2) { * CallbackArgs data; * data.arg1 = arg1; * NotifyFromCallback("CallbackApi2", data); * } * } * * Test(MyTest) { * CallApi1(); * CallApi2(); * auto result = cb_.WaitForCallback("CallbackApi1"); * // cb_ as an instance of MyCallback, result is an instance of * // ::testing::VtsHalHidlTargetCallbackBase::WaitForCallbackResult * EXPECT_TRUE(result.no_timeout); // Check wait did not time out * EXPECT_TRUE(result.args); // Check CallbackArgs is received (not * nullptr). This is optional. * // Here check value of args using the pointer result.args; * result = cb_.WaitForCallback("CallbackApi2"); * EXPECT_TRUE(result.no_timeout); * // Here check value of args using the pointer result.args; * * // Additionally. a test can wait for one of multiple callbacks. * // In this case, wait will return when any of the callbacks in the provided * // name list is called. * result = cb_.WaitForCallbackAny(<vector_of_string>) * // When vector_of_string is not provided, all callback functions will * // be monitored. The name of callback function that was invoked * // is stored in result.name * } * * Note type of CallbackArgsTemplateClass is same across the class, which means * all WaitForCallback method will return the same data type. */ template <class CallbackArgsTemplateClass> class VtsHalHidlTargetCallbackBase { public: struct WaitForCallbackResult { WaitForCallbackResult() : no_timeout(false), args(shared_ptr<CallbackArgsTemplateClass>(nullptr)), name("") {} // Whether the wait timed out bool no_timeout; // Arguments data from callback functions. Defaults to nullptr. shared_ptr<CallbackArgsTemplateClass> args; // Name of the callback. Defaults to empty string. string name; }; VtsHalHidlTargetCallbackBase() : cb_default_wait_timeout_(DEFAULT_CALLBACK_WAIT_TIMEOUT_INITIAL) {} virtual ~VtsHalHidlTargetCallbackBase() { for (auto it : cb_lock_map_) { delete it.second; } } /* * Wait for a callback function in a test. * Returns a WaitForCallbackResult object containing wait results. * If callback_function_name is not provided, a default name will be used. * Timeout defaults to -1 milliseconds. Negative timeout means use to * use the time out set for the callback or default callback wait time out. */ WaitForCallbackResult WaitForCallback( const string& callback_function_name = kVtsHalHidlTargetCallbackDefaultName, milliseconds timeout = milliseconds(-1)) { return GetCallbackLock(callback_function_name)->WaitForCallback(timeout); } /* * Wait for any of the callback functions specified. * Returns a WaitForCallbackResult object containing wait results. * If callback_function_names is not provided, all callback functions will * be monitored, and the list of callback functions will be updated * dynamically during run time. * If timeout_any is not provided, the shortest timeout from the function * list will be used. */ WaitForCallbackResult WaitForCallbackAny( const vector<string>& callback_function_names = vector<string>(), milliseconds timeout_any = milliseconds(-1)) { unique_lock<mutex> lock(cb_wait_any_mtx_); auto start_time = system_clock::now(); WaitForCallbackResult res = PeekCallbackLocks(callback_function_names); while (!res.no_timeout) { auto expiration = GetWaitAnyTimeout(callback_function_names, start_time, timeout_any); auto status = cb_wait_any_cv_.wait_until(lock, expiration); if (status == cv_status::timeout) { cerr << "Timed out waiting for callback functions." << endl; break; } res = PeekCallbackLocks(callback_function_names); } return res; } /* * Notify a waiting test when a callback is invoked. * If callback_function_name is not provided, a default name will be used. */ void NotifyFromCallback(const string& callback_function_name = kVtsHalHidlTargetCallbackDefaultName) { unique_lock<mutex> lock(cb_wait_any_mtx_); GetCallbackLock(callback_function_name)->NotifyFromCallback(); cb_wait_any_cv_.notify_one(); } /* * Notify a waiting test with data when a callback is invoked. */ void NotifyFromCallback(const CallbackArgsTemplateClass& data) { NotifyFromCallback(kVtsHalHidlTargetCallbackDefaultName, data); } /* * Notify a waiting test with data when a callback is invoked. * If callback_function_name is not provided, a default name will be used. */ void NotifyFromCallback(const string& callback_function_name, const CallbackArgsTemplateClass& data) { unique_lock<mutex> lock(cb_wait_any_mtx_); GetCallbackLock(callback_function_name)->NotifyFromCallback(data); cb_wait_any_cv_.notify_one(); } /* * Clear lock and data for a callback function. * This function is optional. */ void ClearForCallback(const string& callback_function_name = kVtsHalHidlTargetCallbackDefaultName) { GetCallbackLock(callback_function_name, true); } /* * Get wait timeout for a specific callback function. * If callback_function_name is not provided, a default name will be used. */ milliseconds GetWaitTimeout(const string& callback_function_name = kVtsHalHidlTargetCallbackDefaultName) { return GetCallbackLock(callback_function_name)->GetWaitTimeout(); } /* * Set wait timeout for a specific callback function. * To set a default timeout (not for the default function name), * use SetWaitTimeoutDefault. default function name callback timeout will * also be set by SetWaitTimeoutDefault. */ void SetWaitTimeout(const string& callback_function_name, milliseconds timeout) { GetCallbackLock(callback_function_name)->SetWaitTimeout(timeout); } /* * Get default wait timeout for a callback function. * The default timeout is valid for all callback function names that * have not been specified a timeout value, including default function name. */ milliseconds GetWaitTimeoutDefault() { return cb_default_wait_timeout_; } /* * Set default wait timeout for a callback function. * The default timeout is valid for all callback function names that * have not been specified a timeout value, including default function name. */ void SetWaitTimeoutDefault(milliseconds timeout) { cb_default_wait_timeout_ = timeout; } private: /* * A utility class to store semaphore and data for a callback name. */ class CallbackLock { public: CallbackLock(VtsHalHidlTargetCallbackBase& parent, const string& name) : wait_count_(0), parent_(parent), timeout_(milliseconds(-1)), name_(name) {} /* * Wait for represented callback function. * Timeout defaults to -1 milliseconds. Negative timeout means use to * use the time out set for the callback or default callback wait time out. */ WaitForCallbackResult WaitForCallback( milliseconds timeout = milliseconds(-1), bool no_wait_blocking = false) { return Wait(timeout, no_wait_blocking); } /* * Wait for represented callback function. * Timeout defaults to -1 milliseconds. Negative timeout means use to * use the time out set for the callback or default callback wait time out. */ WaitForCallbackResult WaitForCallback(bool no_wait_blocking) { return Wait(milliseconds(-1), no_wait_blocking); } /* Notify from represented callback function. */ void NotifyFromCallback() { unique_lock<mutex> lock(wait_mtx_); Notify(); } /* Notify from represented callback function with data. */ void NotifyFromCallback(const CallbackArgsTemplateClass& data) { unique_lock<mutex> wait_lock(wait_mtx_); arg_data_.push(make_shared<CallbackArgsTemplateClass>(data)); Notify(); } /* Set wait timeout for represented callback function. */ void SetWaitTimeout(milliseconds timeout) { timeout_ = timeout; } /* Get wait timeout for represented callback function. */ milliseconds GetWaitTimeout() { if (timeout_ < milliseconds(0)) { return parent_.GetWaitTimeoutDefault(); } return timeout_; } private: /* * Wait for represented callback function in a test. * Returns a WaitForCallbackResult object containing wait results. * Timeout defaults to -1 milliseconds. Negative timeout means use to * use the time out set for the callback or default callback wait time out. */ WaitForCallbackResult Wait(milliseconds timeout, bool no_wait_blocking) { unique_lock<mutex> lock(wait_mtx_); WaitForCallbackResult res; res.name = name_; if (!no_wait_blocking) { if (timeout < milliseconds(0)) { timeout = GetWaitTimeout(); } auto expiration = system_clock::now() + timeout; while (wait_count_ == 0) { auto status = wait_cv_.wait_until(lock, expiration); if (status == cv_status::timeout) { cerr << "Timed out waiting for callback" << endl; return res; } } } else if (!wait_count_) { return res; } wait_count_--; res.no_timeout = true; if (!arg_data_.empty()) { res.args = arg_data_.front(); arg_data_.pop(); } return res; } /* Notify from represented callback function. */ void Notify() { wait_count_++; wait_cv_.notify_one(); } // Mutex for protecting operations on wait count and conditional variable mutex wait_mtx_; // Conditional variable for callback wait and notify condition_variable wait_cv_; // Count for callback conditional variable unsigned int wait_count_; // A queue of callback arg data queue<shared_ptr<CallbackArgsTemplateClass>> arg_data_; // Pointer to parent class VtsHalHidlTargetCallbackBase& parent_; // Wait time out milliseconds timeout_; // Name of the represented callback function string name_; }; /* * Get CallbackLock object using callback function name. * If callback_function_name is not provided, a default name will be used. * If callback_function_name does not exists in map yet, a new CallbackLock * object will be created. * If auto_clear is true, the old CallbackLock will be deleted. */ CallbackLock* GetCallbackLock(const string& callback_function_name, bool auto_clear = false) { unique_lock<mutex> lock(cb_lock_map_mtx_); auto found = cb_lock_map_.find(callback_function_name); if (found == cb_lock_map_.end()) { CallbackLock* result = new CallbackLock(*this, callback_function_name); cb_lock_map_.insert({callback_function_name, result}); return result; } else { if (auto_clear) { delete (found->second); found->second = new CallbackLock(*this, callback_function_name); } return found->second; } } /* * Get wait timeout for a list of function names. * If timeout_any is not negative, start_time + timeout_any will be returned. * Otherwise, the shortest timeout from the list will be returned. */ system_clock::time_point GetWaitAnyTimeout( const vector<string>& callback_function_names, system_clock::time_point start_time, milliseconds timeout_any) { if (timeout_any >= milliseconds(0)) { return start_time + timeout_any; } auto locks = GetWaitAnyCallbackLocks(callback_function_names); auto timeout_min = system_clock::duration::max(); for (auto lock : locks) { auto timeout = lock->GetWaitTimeout(); if (timeout < timeout_min) { timeout_min = timeout; } } return start_time + timeout_min; } /* * Get a list of CallbackLock pointers from provided function name list. */ vector<CallbackLock*> GetWaitAnyCallbackLocks( const vector<string>& callback_function_names) { vector<CallbackLock*> res; if (callback_function_names.empty()) { for (auto const& it : cb_lock_map_) { res.push_back(it.second); } } else { for (auto const& name : callback_function_names) { res.push_back(GetCallbackLock(name)); } } return res; } /* * Peek into the list of callback locks to check whether any of the * callback functions has been called. */ WaitForCallbackResult PeekCallbackLocks( const vector<string>& callback_function_names) { auto locks = GetWaitAnyCallbackLocks(callback_function_names); for (auto lock : locks) { auto test = lock->WaitForCallback(true); if (test.no_timeout) { return test; } } WaitForCallbackResult res; return res; } // A map of function name and CallbackLock object pointers unordered_map<string, CallbackLock*> cb_lock_map_; // Mutex for protecting operations on lock map mutex cb_lock_map_mtx_; // Mutex for protecting waiting any callback mutex cb_wait_any_mtx_; // Default wait timeout milliseconds cb_default_wait_timeout_; // Conditional variable for any callback notify condition_variable cb_wait_any_cv_; }; } // namespace testing #endif // __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H