| // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <atomic> // NOLINT |
| #include <vector> |
| #include <iostream> |
| #include <cstdlib> |
| |
| #include "../internal/multi_thread_gemm.h" |
| #include "../profiling/pthread_everywhere.h" |
| #include "test.h" |
| |
| namespace gemmlowp { |
| |
| class Thread { |
| public: |
| Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) |
| : blocking_counter_(blocking_counter), |
| number_of_times_to_decrement_(number_of_times_to_decrement), |
| made_the_last_decrement_(false), |
| finished_(false) { |
| #if defined GEMMLOWP_USE_PTHREAD |
| // Limit the stack size so as not to deplete memory when creating |
| // many threads. |
| pthread_attr_t attr; |
| int err = pthread_attr_init(&attr); |
| if (!err) { |
| size_t stack_size; |
| err = pthread_attr_getstacksize(&attr, &stack_size); |
| if (!err && stack_size > max_stack_size_) { |
| err = pthread_attr_setstacksize(&attr, max_stack_size_); |
| } |
| if (!err) { |
| err = pthread_create(&thread_, &attr, ThreadFunc, this); |
| } |
| } |
| if (err) { |
| std::cerr << "Failed to create a thread.\n"; |
| std::abort(); |
| } |
| #else |
| pthread_create(&thread_, nullptr, ThreadFunc, this); |
| #endif |
| } |
| |
| ~Thread() { Join(); } |
| |
| bool Join() { |
| while (!finished_.load()) { |
| } |
| return made_the_last_decrement_; |
| } |
| |
| private: |
| Thread(const Thread& other) = delete; |
| |
| void ThreadFunc() { |
| for (int i = 0; i < number_of_times_to_decrement_; i++) { |
| Check(!made_the_last_decrement_); |
| made_the_last_decrement_ = blocking_counter_->DecrementCount(); |
| } |
| finished_.store(true); |
| } |
| |
| static void* ThreadFunc(void* ptr) { |
| static_cast<Thread*>(ptr)->ThreadFunc(); |
| return nullptr; |
| } |
| |
| static constexpr size_t max_stack_size_ = 256 * 1024; |
| BlockingCounter* const blocking_counter_; |
| const int number_of_times_to_decrement_; |
| pthread_t thread_; |
| bool made_the_last_decrement_; |
| // finished_ is used to manually implement Join() by busy-waiting. |
| // I wanted to use pthread_join / std::thread::join, but the behavior |
| // observed on Android was that pthread_join aborts when the thread has |
| // already joined before calling pthread_join, making that hard to use. |
| // It appeared simplest to just implement this simple spinlock, and that |
| // is good enough as this is just a test. |
| std::atomic<bool> finished_; |
| }; |
| |
| void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, |
| int num_decrements_per_thread, |
| int num_decrements_to_wait_for) { |
| std::vector<Thread*> threads; |
| blocking_counter->Reset(num_decrements_to_wait_for); |
| for (int i = 0; i < num_threads; i++) { |
| threads.push_back(new Thread(blocking_counter, num_decrements_per_thread)); |
| } |
| blocking_counter->Wait(); |
| |
| int num_threads_that_made_the_last_decrement = 0; |
| for (int i = 0; i < num_threads; i++) { |
| if (threads[i]->Join()) { |
| num_threads_that_made_the_last_decrement++; |
| } |
| delete threads[i]; |
| } |
| Check(num_threads_that_made_the_last_decrement == 1); |
| } |
| |
| void test_blocking_counter() { |
| BlockingCounter* blocking_counter = new BlockingCounter; |
| |
| // repeating the entire test sequence ensures that we test |
| // non-monotonic changes. |
| for (int repeat = 1; repeat <= 2; repeat++) { |
| for (int num_threads = 1; num_threads <= 5; num_threads++) { |
| for (int num_decrements_per_thread = 1; |
| num_decrements_per_thread <= 4 * 1024; |
| num_decrements_per_thread *= 16) { |
| test_blocking_counter(blocking_counter, num_threads, |
| num_decrements_per_thread, |
| num_threads * num_decrements_per_thread); |
| } |
| } |
| } |
| delete blocking_counter; |
| } |
| |
| } // end namespace gemmlowp |
| |
| int main() { gemmlowp::test_blocking_counter(); } |