| #![cfg(test)] |
| |
| use std::sync::atomic::{AtomicUsize, Ordering}; |
| use std::sync::{Arc, Barrier}; |
| #[allow(deprecated)] |
| use Configuration; |
| use {ThreadPoolBuildError, ThreadPoolBuilder}; |
| |
| #[test] |
| fn worker_thread_index() { |
| let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap(); |
| assert_eq!(pool.current_num_threads(), 22); |
| assert_eq!(pool.current_thread_index(), None); |
| let index = pool.install(|| pool.current_thread_index().unwrap()); |
| assert!(index < 22); |
| } |
| |
| #[test] |
| fn start_callback_called() { |
| let n_threads = 16; |
| let n_called = Arc::new(AtomicUsize::new(0)); |
| // Wait for all the threads in the pool plus the one running tests. |
| let barrier = Arc::new(Barrier::new(n_threads + 1)); |
| |
| let b = barrier.clone(); |
| let nc = n_called.clone(); |
| let start_handler = move |_| { |
| nc.fetch_add(1, Ordering::SeqCst); |
| b.wait(); |
| }; |
| |
| let conf = ThreadPoolBuilder::new() |
| .num_threads(n_threads) |
| .start_handler(start_handler); |
| let _ = conf.build().unwrap(); |
| |
| // Wait for all the threads to have been scheduled to run. |
| barrier.wait(); |
| |
| // The handler must have been called on every started thread. |
| assert_eq!(n_called.load(Ordering::SeqCst), n_threads); |
| } |
| |
| #[test] |
| fn exit_callback_called() { |
| let n_threads = 16; |
| let n_called = Arc::new(AtomicUsize::new(0)); |
| // Wait for all the threads in the pool plus the one running tests. |
| let barrier = Arc::new(Barrier::new(n_threads + 1)); |
| |
| let b = barrier.clone(); |
| let nc = n_called.clone(); |
| let exit_handler = move |_| { |
| nc.fetch_add(1, Ordering::SeqCst); |
| b.wait(); |
| }; |
| |
| let conf = ThreadPoolBuilder::new() |
| .num_threads(n_threads) |
| .exit_handler(exit_handler); |
| { |
| let _ = conf.build().unwrap(); |
| // Drop the pool so it stops the running threads. |
| } |
| |
| // Wait for all the threads to have been scheduled to run. |
| barrier.wait(); |
| |
| // The handler must have been called on every exiting thread. |
| assert_eq!(n_called.load(Ordering::SeqCst), n_threads); |
| } |
| |
| #[test] |
| fn handler_panics_handled_correctly() { |
| let n_threads = 16; |
| let n_called = Arc::new(AtomicUsize::new(0)); |
| // Wait for all the threads in the pool plus the one running tests. |
| let start_barrier = Arc::new(Barrier::new(n_threads + 1)); |
| let exit_barrier = Arc::new(Barrier::new(n_threads + 1)); |
| |
| let start_handler = move |_| { |
| panic!("ensure panic handler is called when starting"); |
| }; |
| let exit_handler = move |_| { |
| panic!("ensure panic handler is called when exiting"); |
| }; |
| |
| let sb = start_barrier.clone(); |
| let eb = exit_barrier.clone(); |
| let nc = n_called.clone(); |
| let panic_handler = move |_| { |
| let val = nc.fetch_add(1, Ordering::SeqCst); |
| if val < n_threads { |
| sb.wait(); |
| } else { |
| eb.wait(); |
| } |
| }; |
| |
| let conf = ThreadPoolBuilder::new() |
| .num_threads(n_threads) |
| .start_handler(start_handler) |
| .exit_handler(exit_handler) |
| .panic_handler(panic_handler); |
| { |
| let _ = conf.build().unwrap(); |
| |
| // Wait for all the threads to start, panic in the start handler, |
| // and been taken care of by the panic handler. |
| start_barrier.wait(); |
| |
| // Drop the pool so it stops the running threads. |
| } |
| |
| // Wait for all the threads to exit, panic in the exit handler, |
| // and been taken care of by the panic handler. |
| exit_barrier.wait(); |
| |
| // The panic handler must have been called twice on every thread. |
| assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads); |
| } |
| |
| #[test] |
| #[allow(deprecated)] |
| fn check_config_build() { |
| let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap(); |
| assert_eq!(pool.current_num_threads(), 22); |
| } |
| |
| /// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync |
| fn _send_sync<T: Send + Sync>() {} |
| |
| #[test] |
| fn check_error_send_sync() { |
| _send_sync::<ThreadPoolBuildError>(); |
| } |
| |
| #[allow(deprecated)] |
| #[test] |
| fn configuration() { |
| let start_handler = move |_| {}; |
| let exit_handler = move |_| {}; |
| let panic_handler = move |_| {}; |
| let thread_name = move |i| format!("thread_name_{}", i); |
| |
| // Ensure we can call all public methods on Configuration |
| Configuration::new() |
| .thread_name(thread_name) |
| .num_threads(5) |
| .panic_handler(panic_handler) |
| .stack_size(4e6 as usize) |
| .breadth_first() |
| .start_handler(start_handler) |
| .exit_handler(exit_handler) |
| .build() |
| .unwrap(); |
| } |
| |
| #[test] |
| fn default_pool() { |
| ThreadPoolBuilder::default().build().unwrap(); |
| } |
| |
| /// Test that custom spawned threads get their `WorkerThread` cleared once |
| /// the pool is done with them, allowing them to be used with rayon again |
| /// later. e.g. WebAssembly want to have their own pool of available threads. |
| #[test] |
| fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> { |
| let n_threads = 5; |
| let mut handles = vec![]; |
| let pool = ThreadPoolBuilder::new() |
| .num_threads(n_threads) |
| .spawn_handler(|thread| { |
| let handle = std::thread::spawn(move || { |
| thread.run(); |
| |
| // Afterward, the current thread shouldn't be set anymore. |
| assert_eq!(crate::current_thread_index(), None); |
| }); |
| handles.push(handle); |
| Ok(()) |
| }) |
| .build()?; |
| assert_eq!(handles.len(), n_threads); |
| |
| pool.install(|| assert!(crate::current_thread_index().is_some())); |
| drop(pool); |
| |
| // Wait for all threads to make their assertions and exit |
| for handle in handles { |
| handle.join().unwrap(); |
| } |
| |
| Ok(()) |
| } |