// Copyright 2015 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "test.h" #include <pthread.h> #include <vector> #include "../internal/multi_thread_gemm.h" namespace gemmlowp { class Thread { public: Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) : blocking_counter_(blocking_counter), number_of_times_to_decrement_(number_of_times_to_decrement), made_the_last_decrement_(false) { pthread_create(&thread_, nullptr, ThreadFunc, this); } ~Thread() { Join(); } bool Join() const { pthread_join(thread_, nullptr); return made_the_last_decrement_; } private: Thread(const Thread& other) = delete; void ThreadFunc() { for (int i = 0; i < number_of_times_to_decrement_; i++) { Check(!made_the_last_decrement_); made_the_last_decrement_ = blocking_counter_->DecrementCount(); } } static void* ThreadFunc(void* ptr) { static_cast<Thread*>(ptr)->ThreadFunc(); return nullptr; } BlockingCounter* const blocking_counter_; const int number_of_times_to_decrement_; pthread_t thread_; bool made_the_last_decrement_; }; void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, int num_decrements_per_thread, int num_decrements_to_wait_for) { std::vector<Thread*> threads; blocking_counter->Reset(num_decrements_to_wait_for); for (int i = 0; i < num_threads; i++) { threads.push_back(new Thread(blocking_counter, num_decrements_per_thread)); } blocking_counter->Wait(); int num_threads_that_made_the_last_decrement = 0; for (int i = 0; i < num_threads; i++) { if (threads[i]->Join()) { num_threads_that_made_the_last_decrement++; } delete threads[i]; } Check(num_threads_that_made_the_last_decrement == 1); } void test_blocking_counter() { BlockingCounter* blocking_counter = new BlockingCounter; // repeating the entire test sequence ensures that we test // non-monotonic changes. for (int repeat = 1; repeat <= 2; repeat++) { for (int num_threads = 1; num_threads <= 16; num_threads++) { for (int num_decrements_per_thread = 1; num_decrements_per_thread <= 64 * 1024; num_decrements_per_thread *= 4) { test_blocking_counter(blocking_counter, num_threads, num_decrements_per_thread, num_threads * num_decrements_per_thread); } } } delete blocking_counter; } } // end namespace gemmlowp int main() { gemmlowp::test_blocking_counter(); }