diff --git a/usertest/src/main.rs b/usertest/src/main.rs index 2454ac9..d86a94c 100644 --- a/usertest/src/main.rs +++ b/usertest/src/main.rs @@ -1,4 +1,7 @@ -use std::thread; +use std::{ + sync::{Arc, Barrier, Mutex}, + thread, +}; fn test_sync() { print!("Testing sync syscall ..."); @@ -252,6 +255,42 @@ fn test_rust_thread() { println!(" OK"); } +fn test_rust_mutex() { + const THREADS: usize = 32; + const ITERS: usize = 1_000; + + print!("Testing rust mutex ..."); + + let mtx = Arc::new(Mutex::new(0usize)); + let barrier = Arc::new(Barrier::new(THREADS)); + + let mut handles = Vec::with_capacity(THREADS); + + for _ in 0..THREADS { + let mtx = Arc::clone(&mtx); + let barrier = Arc::clone(&barrier); + + handles.push(thread::spawn(move || { + barrier.wait(); + + for _ in 0..ITERS { + let mut guard = mtx.lock().unwrap(); + *guard += 1; + } + })); + } + + for h in handles { + h.join().unwrap(); + } + + let final_val = *mtx.lock().unwrap(); + + assert_eq!(final_val, THREADS * ITERS); + + println!(" OK"); +} + fn run_test(test_fn: fn()) { // Fork a new process to run the test unsafe { @@ -289,6 +328,7 @@ fn main() { run_test(test_rust_file); run_test(test_rust_dir); run_test(test_rust_thread); + run_test(test_rust_mutex); let end = std::time::Instant::now(); println!("All tests passed in {} ms", (end - start).as_millis()); }