diff --git a/src/bin/test_simple.rs b/src/bin/test_simple.rs index 655d8c8..b7e8236 100644 --- a/src/bin/test_simple.rs +++ b/src/bin/test_simple.rs @@ -1,6 +1,6 @@ use rand::Rng; use rustmp::par_for; -use rustmp::SystemObject; +use rustmp::sysinfo::SystemObject; use std::time; #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index e991011..d8c8198 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,8 @@ -mod sysinfo; +pub mod sysinfo; +pub mod threadpool; use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; -pub use sysinfo::SystemObject; - pub struct Capture { value: Arc>, } diff --git a/src/threadpool.rs b/src/threadpool.rs new file mode 100644 index 0000000..2ac10aa --- /dev/null +++ b/src/threadpool.rs @@ -0,0 +1,97 @@ +use crate::sysinfo::SystemObject; +use lazy_static::lazy_static; +use std::panic; +use std::process; +use std::sync::mpsc::{channel, Receiver, Sender}; +use std::sync::{Arc, Barrier, Mutex}; +use std::thread::{current, Builder, JoinHandle}; + +lazy_static! { + static ref INSTANCE: Arc> = + Arc::new(Mutex::new(ThreadPoolManager::new())); +} + +pub type Job = Arc; + +pub fn as_static_job(capture: T) -> Job +where + T: Fn() + Send + Sync + 'static, +{ + Arc::new(capture) +} + +pub struct ThreadPoolManager { + pub num_threads: usize, + task_barrier: Arc, + task_comms: Vec>, + _thread_pool: Vec>, +} + +impl ThreadPoolManager { + fn new() -> ThreadPoolManager { + let master_hook = panic::take_hook(); + // Crash the program if any of our threads panic + panic::set_hook(Box::new(move |info| { + // Only panic on our own threads, leave application programmer's threads alone + if current().name().unwrap_or_default().starts_with("RMP_PAR_THREAD_") { + master_hook(info); + process::exit(1); + } else { + master_hook(info); + } + })); + + let num_threads = SystemObject::get_instance().max_num_threads; + let task_barrier = Arc::new(Barrier::new(num_threads + 1)); + let mut _thread_pool = Vec::new(); + let mut task_comms = Vec::new(); + + for tid in 0..num_threads { + let task_barrier = task_barrier.clone(); + let builder = Builder::new() // Thread builder configuration + .name(format!("RMP_PAR_THREAD_{}", tid)) // Name: RMP_PAR_THREAD_tid + .stack_size(8 << 20); // Stack size: 8MB (Linux default) + let (sender, receiver) = channel::(); + task_comms.push(sender); + _thread_pool.push( + builder + .spawn(move || routine_wrapper(tid, task_barrier, receiver)) + .unwrap(), + ); + } + + ThreadPoolManager { + num_threads, + task_barrier, + task_comms, + _thread_pool, + } + } + + pub fn get_instance_guard() -> Arc> { + return INSTANCE.clone(); + } + + pub fn exec(&self, tasks: Vec) { + // Used to wake up threads + self.task_barrier.wait(); + assert_eq!(SystemObject::get_instance().max_num_threads, tasks.len()); + for i in 0..tasks.len() { + self.task_comms[i].send(tasks[i].clone()).unwrap(); + } + // Used to return main thread from exec + self.task_barrier.wait(); + } +} + +fn routine_wrapper(tid: usize, task_barrier: Arc, receiver: Receiver) { + SystemObject::get_instance() + .set_affinity(tid) + .unwrap_or_else(|e| eprintln!("Failed to bind process #{} to hwthread: {:?}", tid, e)); + loop { + task_barrier.wait(); + let func = receiver.recv().unwrap(); + func(); + task_barrier.wait(); + } +}