/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H
#define __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H

#include <chrono>
#include <condition_variable>
#include <iostream>
#include <mutex>
#include <queue>
#include <unordered_map>
#include <utility>

using namespace ::std;
using namespace ::std::chrono;

constexpr char kVtsHalHidlTargetCallbackDefaultName[] =
    "VtsHalHidlTargetCallbackDefaultName";
constexpr milliseconds DEFAULT_CALLBACK_WAIT_TIMEOUT_INITIAL = minutes(1);

namespace testing {

/*
 * VTS target side test template for callback.
 *
 * Providing wait and notify for callback functionality.
 *
 * A typical usage looks like this:
 *
 * class CallbackArgs {
 *   ArgType1 arg1;
 *   ArgType2 arg2;
 * }
 *
 * class MyCallback
 *     : public ::testing::VtsHalHidlTargetCallbackBase<>,
 *       public CallbackInterface {
 *  public:
 *   CallbackApi1(ArgType1 arg1) {
 *     CallbackArgs data;
 *     data.arg1 = arg1;
 *     NotifyFromCallback("CallbackApi1", data);
 *   }
 *
 *   CallbackApi2(ArgType2 arg2) {
 *     CallbackArgs data;
 *     data.arg1 = arg1;
 *     NotifyFromCallback("CallbackApi2", data);
 *   }
 * }
 *
 * Test(MyTest) {
 *   CallApi1();
 *   CallApi2();
 *   auto result = cb_.WaitForCallback("CallbackApi1");
 *   // cb_ as an instance of MyCallback, result is an instance of
 *   // ::testing::VtsHalHidlTargetCallbackBase::WaitForCallbackResult
 *   EXPECT_TRUE(result.no_timeout); // Check wait did not time out
 *   EXPECT_TRUE(result.args); // Check CallbackArgs is received (not
 *                                  nullptr). This is optional.
 *   // Here check value of args using the pointer result.args;
 *   result = cb_.WaitForCallback("CallbackApi2");
 *   EXPECT_TRUE(result.no_timeout);
 *   // Here check value of args using the pointer result.args;
 *
 *   // Additionally. a test can wait for one of multiple callbacks.
 *   // In this case, wait will return when any of the callbacks in the provided
 *   // name list is called.
 *   result = cb_.WaitForCallbackAny(<vector_of_string>)
 *   // When vector_of_string is not provided, all callback functions will
 *   // be monitored. The name of callback function that was invoked
 *   // is stored in result.name
 * }
 *
 * Note type of CallbackArgsTemplateClass is same across the class, which means
 * all WaitForCallback method will return the same data type.
 */
template <class CallbackArgsTemplateClass>
class VtsHalHidlTargetCallbackBase {
 public:
  struct WaitForCallbackResult {
    WaitForCallbackResult()
        : no_timeout(false),
          args(shared_ptr<CallbackArgsTemplateClass>(nullptr)),
          name("") {}

    // Whether the wait timed out
    bool no_timeout;
    // Arguments data from callback functions. Defaults to nullptr.
    shared_ptr<CallbackArgsTemplateClass> args;
    // Name of the callback. Defaults to empty string.
    string name;
  };

  VtsHalHidlTargetCallbackBase()
      : cb_default_wait_timeout_(DEFAULT_CALLBACK_WAIT_TIMEOUT_INITIAL) {}

  virtual ~VtsHalHidlTargetCallbackBase() {
    for (auto it : cb_lock_map_) {
      delete it.second;
    }
  }

  /*
   * Wait for a callback function in a test.
   * Returns a WaitForCallbackResult object containing wait results.
   * If callback_function_name is not provided, a default name will be used.
   * Timeout defaults to -1 milliseconds. Negative timeout means use to
   * use the time out set for the callback or default callback wait time out.
   */
  WaitForCallbackResult WaitForCallback(
      const string& callback_function_name =
          kVtsHalHidlTargetCallbackDefaultName,
      milliseconds timeout = milliseconds(-1)) {
    return GetCallbackLock(callback_function_name)->WaitForCallback(timeout);
  }

  /*
   * Wait for any of the callback functions specified.
   * Returns a WaitForCallbackResult object containing wait results.
   * If callback_function_names is not provided, all callback functions will
   * be monitored, and the list of callback functions will be updated
   * dynamically during run time.
   * If timeout_any is not provided, the shortest timeout from the function
   * list will be used.
   */
  WaitForCallbackResult WaitForCallbackAny(
      const vector<string>& callback_function_names = vector<string>(),
      milliseconds timeout_any = milliseconds(-1)) {
    unique_lock<mutex> lock(cb_wait_any_mtx_);

    auto start_time = system_clock::now();

    WaitForCallbackResult res = PeekCallbackLocks(callback_function_names);
    while (!res.no_timeout) {
      auto expiration =
          GetWaitAnyTimeout(callback_function_names, start_time, timeout_any);
      auto status = cb_wait_any_cv_.wait_until(lock, expiration);
      if (status == cv_status::timeout) {
        cerr << "Timed out waiting for callback functions." << endl;
        break;
      }
      res = PeekCallbackLocks(callback_function_names);
    }
    return res;
  }

  /*
   * Notify a waiting test when a callback is invoked.
   * If callback_function_name is not provided, a default name will be used.
   */
  void NotifyFromCallback(const string& callback_function_name =
                              kVtsHalHidlTargetCallbackDefaultName) {
    unique_lock<mutex> lock(cb_wait_any_mtx_);
    GetCallbackLock(callback_function_name)->NotifyFromCallback();
    cb_wait_any_cv_.notify_one();
  }

  /*
   * Notify a waiting test with data when a callback is invoked.
   */
  void NotifyFromCallback(const CallbackArgsTemplateClass& data) {
    NotifyFromCallback(kVtsHalHidlTargetCallbackDefaultName, data);
  }

  /*
   * Notify a waiting test with data when a callback is invoked.
   * If callback_function_name is not provided, a default name will be used.
   */
  void NotifyFromCallback(const string& callback_function_name,
                          const CallbackArgsTemplateClass& data) {
    unique_lock<mutex> lock(cb_wait_any_mtx_);
    GetCallbackLock(callback_function_name)->NotifyFromCallback(data);
    cb_wait_any_cv_.notify_one();
  }

  /*
   * Clear lock and data for a callback function.
   * This function is optional.
   */
  void ClearForCallback(const string& callback_function_name =
                            kVtsHalHidlTargetCallbackDefaultName) {
    GetCallbackLock(callback_function_name, true);
  }

  /*
   * Get wait timeout for a specific callback function.
   * If callback_function_name is not provided, a default name will be used.
   */
  milliseconds GetWaitTimeout(const string& callback_function_name =
                                  kVtsHalHidlTargetCallbackDefaultName) {
    return GetCallbackLock(callback_function_name)->GetWaitTimeout();
  }

  /*
   * Set wait timeout for a specific callback function.
   * To set a default timeout (not for the default function name),
   * use SetWaitTimeoutDefault. default function name callback timeout will
   * also be set by SetWaitTimeoutDefault.
   */
  void SetWaitTimeout(const string& callback_function_name,
                      milliseconds timeout) {
    GetCallbackLock(callback_function_name)->SetWaitTimeout(timeout);
  }

  /*
   * Get default wait timeout for a callback function.
   * The default timeout is valid for all callback function names that
   * have not been specified a timeout value, including default function name.
   */
  milliseconds GetWaitTimeoutDefault() { return cb_default_wait_timeout_; }

  /*
   * Set default wait timeout for a callback function.
   * The default timeout is valid for all callback function names that
   * have not been specified a timeout value, including default function name.
   */
  void SetWaitTimeoutDefault(milliseconds timeout) {
    cb_default_wait_timeout_ = timeout;
  }

 private:
  /*
   * A utility class to store semaphore and data for a callback name.
   */
  class CallbackLock {
   public:
    CallbackLock(VtsHalHidlTargetCallbackBase& parent, const string& name)
        : wait_count_(0),
          parent_(parent),
          timeout_(milliseconds(-1)),
          name_(name) {}

    /*
     * Wait for represented callback function.
     * Timeout defaults to -1 milliseconds. Negative timeout means use to
     * use the time out set for the callback or default callback wait time out.
     */
    WaitForCallbackResult WaitForCallback(
        milliseconds timeout = milliseconds(-1),
        bool no_wait_blocking = false) {
      return Wait(timeout, no_wait_blocking);
    }

    /*
     * Wait for represented callback function.
     * Timeout defaults to -1 milliseconds. Negative timeout means use to
     * use the time out set for the callback or default callback wait time out.
     */
    WaitForCallbackResult WaitForCallback(bool no_wait_blocking) {
      return Wait(milliseconds(-1), no_wait_blocking);
    }

    /* Notify from represented callback function. */
    void NotifyFromCallback() {
      unique_lock<mutex> lock(wait_mtx_);
      Notify();
    }

    /* Notify from represented callback function with data. */
    void NotifyFromCallback(const CallbackArgsTemplateClass& data) {
      unique_lock<mutex> wait_lock(wait_mtx_);
      arg_data_.push(make_shared<CallbackArgsTemplateClass>(data));
      Notify();
    }

    /* Set wait timeout for represented callback function. */
    void SetWaitTimeout(milliseconds timeout) { timeout_ = timeout; }

    /* Get wait timeout for represented callback function. */
    milliseconds GetWaitTimeout() {
      if (timeout_ < milliseconds(0)) {
        return parent_.GetWaitTimeoutDefault();
      }
      return timeout_;
    }

   private:
    /*
     * Wait for represented callback function in a test.
     * Returns a WaitForCallbackResult object containing wait results.
     * Timeout defaults to -1 milliseconds. Negative timeout means use to
     * use the time out set for the callback or default callback wait time out.
     */
    WaitForCallbackResult Wait(milliseconds timeout, bool no_wait_blocking) {
      unique_lock<mutex> lock(wait_mtx_);
      WaitForCallbackResult res;
      res.name = name_;
      if (!no_wait_blocking) {
        if (timeout < milliseconds(0)) {
          timeout = GetWaitTimeout();
        }
        auto expiration = system_clock::now() + timeout;
        while (wait_count_ == 0) {
          auto status = wait_cv_.wait_until(lock, expiration);
          if (status == cv_status::timeout) {
            cerr << "Timed out waiting for callback" << endl;
            return res;
          }
        }
      } else if (!wait_count_) {
        return res;
      }

      wait_count_--;
      res.no_timeout = true;
      if (!arg_data_.empty()) {
        res.args = arg_data_.front();
        arg_data_.pop();
      }
      return res;
    }

    /* Notify from represented callback function. */
    void Notify() {
      wait_count_++;
      wait_cv_.notify_one();
    }

    // Mutex for protecting operations on wait count and conditional variable
    mutex wait_mtx_;
    // Conditional variable for callback wait and notify
    condition_variable wait_cv_;
    // Count for callback conditional variable
    unsigned int wait_count_;
    // A queue of callback arg data
    queue<shared_ptr<CallbackArgsTemplateClass>> arg_data_;
    // Pointer to parent class
    VtsHalHidlTargetCallbackBase& parent_;
    // Wait time out
    milliseconds timeout_;
    // Name of the represented callback function
    string name_;
  };

  /*
   * Get CallbackLock object using callback function name.
   * If callback_function_name is not provided, a default name will be used.
   * If callback_function_name does not exists in map yet, a new CallbackLock
   * object will be created.
   * If auto_clear is true, the old CallbackLock will be deleted.
   */
  CallbackLock* GetCallbackLock(const string& callback_function_name,
                                bool auto_clear = false) {
    unique_lock<mutex> lock(cb_lock_map_mtx_);
    auto found = cb_lock_map_.find(callback_function_name);
    if (found == cb_lock_map_.end()) {
      CallbackLock* result = new CallbackLock(*this, callback_function_name);
      cb_lock_map_.insert({callback_function_name, result});
      return result;
    } else {
      if (auto_clear) {
        delete (found->second);
        found->second = new CallbackLock(*this, callback_function_name);
      }
      return found->second;
    }
  }

  /*
   * Get wait timeout for a list of function names.
   * If timeout_any is not negative, start_time + timeout_any will be returned.
   * Otherwise, the shortest timeout from the list will be returned.
   */
  system_clock::time_point GetWaitAnyTimeout(
      const vector<string>& callback_function_names,
      system_clock::time_point start_time, milliseconds timeout_any) {
    if (timeout_any >= milliseconds(0)) {
      return start_time + timeout_any;
    }

    auto locks = GetWaitAnyCallbackLocks(callback_function_names);

    auto timeout_min = system_clock::duration::max();
    for (auto lock : locks) {
      auto timeout = lock->GetWaitTimeout();
      if (timeout < timeout_min) {
        timeout_min = timeout;
      }
    }

    return start_time + timeout_min;
  }

  /*
   * Get a list of CallbackLock pointers from provided function name list.
   */
  vector<CallbackLock*> GetWaitAnyCallbackLocks(
      const vector<string>& callback_function_names) {
    vector<CallbackLock*> res;
    if (callback_function_names.empty()) {
      for (auto const& it : cb_lock_map_) {
        res.push_back(it.second);
      }
    } else {
      for (auto const& name : callback_function_names) {
        res.push_back(GetCallbackLock(name));
      }
    }
    return res;
  }

  /*
   * Peek into the list of callback locks to check whether any of the
   * callback functions has been called.
   */
  WaitForCallbackResult PeekCallbackLocks(
      const vector<string>& callback_function_names) {
    auto locks = GetWaitAnyCallbackLocks(callback_function_names);
    for (auto lock : locks) {
      auto test = lock->WaitForCallback(true);
      if (test.no_timeout) {
        return test;
      }
    }
    WaitForCallbackResult res;
    return res;
  }

  // A map of function name and CallbackLock object pointers
  unordered_map<string, CallbackLock*> cb_lock_map_;
  // Mutex for protecting operations on lock map
  mutex cb_lock_map_mtx_;
  // Mutex for protecting waiting any callback
  mutex cb_wait_any_mtx_;
  // Default wait timeout
  milliseconds cb_default_wait_timeout_;
  // Conditional variable for any callback notify
  condition_variable cb_wait_any_cv_;
};

}  // namespace testing

#endif  // __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H