From 8c3462a27a00a57b07fa2d8ca444a56ec0039af6 Mon Sep 17 00:00:00 2001 From: "Ericson \"Fogo\" Soares" Date: Mon, 26 Feb 2024 16:45:58 -0300 Subject: [PATCH] [ENG-1513] Better integration between Jobs and processing Actors (#1974) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft on new task system * Removing save to disk from task system * Bunch of concurrency issues * Solving Future impl issue when pausing tasks * Fix cancel and abort * Bunch of fixes on pause, suspend, resume, cancel and abort Also better error handling on task completion for the user * New capabilities to return an output on a task * Introducing a simple way to linear backoff on failed steal * Sample actor where tasks can dispatch more tasks * Rustfmt * Steal test to make sure * Stale deps cleanup * Removing unused utils * Initial lib docs * Docs ok * Memory cleanup on idle --------- Co-authored-by: VĂ­tor Vasconcellos --- CONTRIBUTING.md | 2 +- Cargo.lock | Bin 263083 -> 263854 bytes Cargo.toml | 1 + core/Cargo.toml | 4 +- crates/ai/Cargo.toml | 2 +- crates/file-path-helper/Cargo.toml | 2 +- crates/task-system/Cargo.toml | 42 + crates/task-system/src/error.rs | 28 + crates/task-system/src/lib.rs | 71 + crates/task-system/src/message.rs | 63 + crates/task-system/src/system.rs | 467 ++++++ crates/task-system/src/task.rs | 484 ++++++ crates/task-system/src/worker/mod.rs | 328 ++++ crates/task-system/src/worker/run.rs | 113 ++ crates/task-system/src/worker/runner.rs | 1408 ++++++++++++++++++ crates/task-system/tests/common/actors.rs | 389 +++++ crates/task-system/tests/common/jobs.rs | 119 ++ crates/task-system/tests/common/mod.rs | 3 + crates/task-system/tests/common/tasks.rs | 278 ++++ crates/task-system/tests/integration_test.rs | 224 +++ crates/utils/Cargo.toml | 2 +- rust-toolchain.toml | 2 +- 22 files changed, 4025 insertions(+), 7 deletions(-) create mode 100644 crates/task-system/Cargo.toml create mode 100644 crates/task-system/src/error.rs create mode 100644 crates/task-system/src/lib.rs create mode 100644 crates/task-system/src/message.rs create mode 100644 crates/task-system/src/system.rs create mode 100644 crates/task-system/src/task.rs create mode 100644 crates/task-system/src/worker/mod.rs create mode 100644 crates/task-system/src/worker/run.rs create mode 100644 crates/task-system/src/worker/runner.rs create mode 100644 crates/task-system/tests/common/actors.rs create mode 100644 crates/task-system/tests/common/jobs.rs create mode 100644 crates/task-system/tests/common/mod.rs create mode 100644 crates/task-system/tests/common/tasks.rs create mode 100644 crates/task-system/tests/integration_test.rs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0f77be7c7..5aef24826 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -74,7 +74,7 @@ To run the landing page: If you encounter any issues, ensure that you are using the following versions of Rust, Node and Pnpm: -- Rust version: **1.73.0** +- Rust version: **1.75.0** - Node version: **18.17** - Pnpm version: **8.0.0** diff --git a/Cargo.lock b/Cargo.lock index f70073d49ee6a31506cf27589dda092150462f47..57150f29f2b6ca8281156511d7b7173696b816b5 100644 GIT binary patch delta 340 zcmX|+KTE?v96-rMT1myme^5F|EvN_<(_E6fTeN-yg-+sfm%E5)t0V#S3$#2Byj~NMeG+XsshS9Kfrr(z!M6r!UjRuyQzp!t?Loq)3kG!7dj%`oKu9ju*eSvYw_=IHKU zSRsl-#Vf6)&nd{IQF=z;xN|j$Ua`|6_FRIT#=NTDr{e;G#agS0^dew%9Vim>k Q{)Xo2?E&hx2l~juH_-rcp8x;= delta 34 scmV+-0NwwtjS#Dc5U`QcvsBb~vWLZo0k_460+&>WqTK?wqTK@wDL4iY-~a#s diff --git a/Cargo.toml b/Cargo.toml index f960ff040..d8739859d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ swift-rs = { version = "1.0.6" } # Third party dependencies used by one or more of our crates anyhow = "1.0.75" async-channel = "2.0.0" +async-trait = "0.1.77" axum = "0.6.20" base64 = "0.21.5" blake3 = "1.5.0" diff --git a/core/Cargo.toml b/core/Cargo.toml index 79f32bd59..08aa83ebd 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -3,7 +3,7 @@ name = "sd-core" version = "0.2.4" description = "Virtual distributed filesystem engine that powers Spacedrive." authors = ["Spacedrive Technology Inc."] -rust-version = "1.73.0" +rust-version = "1.75.0" license = { workspace = true } repository = { workspace = true } edition = { workspace = true } @@ -51,6 +51,7 @@ sd-cloud-api = { version = "0.1.0", path = "../crates/cloud-api" } # Workspace dependencies async-channel = { workspace = true } +async-trait = { workspace = true } axum = { workspace = true } base64 = { workspace = true } blake3 = { workspace = true } @@ -100,7 +101,6 @@ webp = { workspace = true } # Specific Core dependencies async-recursion = "1.0.5" async-stream = "0.3.5" -async-trait = "^0.1.74" bytes = "1.5.0" ctor = "0.2.5" directories = "5.0.1" diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index db8020ae4..dfe97c9f5 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" authors = ["Ericson Soares "] readme = "README.md" description = "A simple library to generate video thumbnails using ffmpeg with the webp format" -rust-version = "1.73.0" +rust-version = "1.75.0" license = { workspace = true } repository = { workspace = true } edition = { workspace = true } diff --git a/crates/file-path-helper/Cargo.toml b/crates/file-path-helper/Cargo.toml index 5f25281c4..3529f40cb 100644 --- a/crates/file-path-helper/Cargo.toml +++ b/crates/file-path-helper/Cargo.toml @@ -3,7 +3,7 @@ name = "sd-file-path-helper" version = "0.1.0" authors = ["Ericson Soares "] readme = "README.md" -rust-version = "1.73.0" +rust-version = "1.75.0" license = { workspace = true } repository = { workspace = true } edition = { workspace = true } diff --git a/crates/task-system/Cargo.toml b/crates/task-system/Cargo.toml new file mode 100644 index 000000000..964076683 --- /dev/null +++ b/crates/task-system/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "sd-task-system" +version = "0.1.0" +authors = ["Ericson \"Fogo\" Soares "] +rust-version = "1.75.0" +license.workspace = true +edition.workspace = true +repository.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +# Workspace deps +async-channel = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +futures-concurrency = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = [ + "sync", + "parking_lot", + "rt-multi-thread", + "time", +] } +tokio-stream = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true, features = ["v4"] } + +# External deps +downcast-rs = "1.2.0" +pin-project = "1.1.4" + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "test-util", "fs"] } +tempfile = { workspace = true } +rand = "0.8.5" +tracing-test = { version = "^0.2.4", features = ["no-env-filter"] } +thiserror = { workspace = true } +lending-stream = "1.0.0" +serde = { workspace = true, features = ["derive"] } +rmp-serde = { workspace = true } +uuid = { workspace = true, features = ["serde"] } diff --git a/crates/task-system/src/error.rs b/crates/task-system/src/error.rs new file mode 100644 index 000000000..626be73c3 --- /dev/null +++ b/crates/task-system/src/error.rs @@ -0,0 +1,28 @@ +use std::{error::Error, fmt}; + +use super::task::TaskId; + +/// Task system's error type definition, representing when internal errors occurs. +#[derive(Debug, thiserror::Error)] +pub enum SystemError { + #[error("task not found ")] + TaskNotFound(TaskId), + #[error("task aborted ")] + TaskAborted(TaskId), + #[error("task join error ")] + TaskJoin(TaskId), + #[error("forced abortion for task timed out")] + TaskForcedAbortTimeout(TaskId), +} + +/// Trait for errors that can be returned by tasks, we use this trait as a bound for the task system generic +/// error type. +/// +///With this trait, we can have a unified error type through all the tasks in the system. +pub trait RunError: Error + fmt::Debug + Send + Sync + 'static {} + +/// We provide a blanket implementation for all types that also implements +/// [`std::error::Error`](https://doc.rust-lang.org/std/error/trait.Error.html) and +/// [`std::fmt::Debug`](https://doc.rust-lang.org/std/fmt/trait.Debug.html). +/// So you will not need to implement this trait for your error type, just implement the `Error` and `Debug` +impl RunError for T {} diff --git a/crates/task-system/src/lib.rs b/crates/task-system/src/lib.rs new file mode 100644 index 000000000..c2808d3d9 --- /dev/null +++ b/crates/task-system/src/lib.rs @@ -0,0 +1,71 @@ +//! +//! # Task System +//! +//! Spacedrive's Task System is a library that provides a way to manage and execute tasks in a concurrent +//! and parallel environment. +//! +//! Just bring your own unified error type and dispatch some tasks, the system will handle enqueueing, +//! parallel execution, and error handling for you. Aside from some niceties like: +//! - Round robin scheduling between workers following the available CPU cores on the user machine; +//! - Work stealing between workers for better load balancing; +//! - Gracefully pause and cancel tasks; +//! - Forced abortion of tasks; +//! - Prioritizing tasks that will suspend running tasks without priority; +//! - When the system is shutdown, it will return all pending and running tasks to theirs dispatchers, so the user can store them on disk or any other storage to be re-dispatched later; +//! +//! +//! ## Basic example +//! +//! ``` +//! use sd_task_system::{TaskSystem, Task, TaskId, ExecStatus, TaskOutput, Interrupter, TaskStatus}; +//! use async_trait::async_trait; +//! use thiserror::Error; +//! +//! #[derive(Debug, Error)] +//! pub enum SampleError { +//! #[error("Sample error")] +//! SampleError, +//! } +//! +//! #[derive(Debug)] +//! pub struct ReadyTask { +//! id: TaskId, +//! } +//! +//! #[async_trait] +//! impl Task for ReadyTask { +//! fn id(&self) -> TaskId { +//! self.id +//! } +//! +//! async fn run(&mut self, _interrupter: &Interrupter) -> Result { +//! Ok(ExecStatus::Done(TaskOutput::Empty)) +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let system = TaskSystem::new(); +//! +//! let handle = system.dispatch(ReadyTask { id: TaskId::new_v4() }).await; +//! +//! assert!(matches!( +//! handle.await, +//! Ok(TaskStatus::Done(TaskOutput::Empty)) +//! )); +//! +//! system.shutdown().await; +//! } +//! ``` +mod error; +mod message; +mod system; +mod task; +mod worker; + +pub use error::{RunError, SystemError as TaskSystemError}; +pub use system::{Dispatcher as TaskDispatcher, System as TaskSystem}; +pub use task::{ + AnyTaskOutput, ExecStatus, Interrupter, InterrupterFuture, InterruptionKind, IntoAnyTaskOutput, + IntoTask, Task, TaskHandle, TaskId, TaskOutput, TaskStatus, +}; diff --git a/crates/task-system/src/message.rs b/crates/task-system/src/message.rs new file mode 100644 index 000000000..dfc86f1e1 --- /dev/null +++ b/crates/task-system/src/message.rs @@ -0,0 +1,63 @@ +use tokio::sync::oneshot; + +use super::{ + error::{RunError, SystemError}, + task::{TaskId, TaskWorkState}, + worker::WorkerId, +}; + +#[derive(Debug)] +pub(crate) enum SystemMessage { + IdleReport(WorkerId), + WorkingReport(WorkerId), + ResumeTask { + task_id: TaskId, + worker_id: WorkerId, + ack: oneshot::Sender>, + }, + PauseNotRunningTask { + task_id: TaskId, + worker_id: WorkerId, + ack: oneshot::Sender>, + }, + CancelNotRunningTask { + task_id: TaskId, + worker_id: WorkerId, + ack: oneshot::Sender>, + }, + ForceAbortion { + task_id: TaskId, + worker_id: WorkerId, + ack: oneshot::Sender>, + }, + NotifyIdleWorkers { + start_from: WorkerId, + task_count: usize, + }, + ShutdownRequest(oneshot::Sender>), +} + +#[derive(Debug)] +pub(crate) enum WorkerMessage { + NewTask(TaskWorkState), + TaskCountRequest(oneshot::Sender), + ResumeTask { + task_id: TaskId, + ack: oneshot::Sender>, + }, + PauseNotRunningTask { + task_id: TaskId, + ack: oneshot::Sender>, + }, + CancelNotRunningTask { + task_id: TaskId, + ack: oneshot::Sender>, + }, + ForceAbortion { + task_id: TaskId, + ack: oneshot::Sender>, + }, + ShutdownRequest(oneshot::Sender<()>), + StealRequest(oneshot::Sender>>), + WakeUp, +} diff --git a/crates/task-system/src/system.rs b/crates/task-system/src/system.rs new file mode 100644 index 000000000..96d3fc610 --- /dev/null +++ b/crates/task-system/src/system.rs @@ -0,0 +1,467 @@ +use std::{ + cell::RefCell, + collections::HashSet, + pin::pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use async_channel as chan; +use futures::StreamExt; +use futures_concurrency::future::Join; +use tokio::{spawn, sync::oneshot, task::JoinHandle}; +use tracing::{error, info, trace, warn}; + +use super::{ + error::{RunError, SystemError}, + message::SystemMessage, + task::{IntoTask, Task, TaskHandle, TaskId}, + worker::{AtomicWorkerId, WorkStealer, Worker, WorkerBuilder, WorkerId}, +}; + +/// The task system is the main entry point for the library, it is responsible for creating and managing the workers +/// and dispatching tasks to them. +/// +/// It also provides a way to shutdown the system returning all pending and running tasks. +/// It uses internal mutability so it can be shared without hassles using [`Arc`]. +pub struct System { + workers: Arc>>, + msgs_tx: chan::Sender, + dispatcher: Dispatcher, + handle: RefCell>>, +} + +impl System { + /// Created a new task system with a number of workers equal to the available parallelism in the user's machine. + pub fn new() -> Self { + let workers_count = std::thread::available_parallelism().map_or_else( + |e| { + error!("Failed to get available parallelism in the job system: {e:#?}"); + 1 + }, + |non_zero| non_zero.get(), + ); + + let (msgs_tx, msgs_rx) = chan::bounded(8); + let system_comm = SystemComm(msgs_tx.clone()); + + let (workers_builders, worker_comms) = (0..workers_count) + .map(WorkerBuilder::new) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let task_stealer = WorkStealer::new(worker_comms); + + let idle_workers = Arc::new((0..workers_count).map(|_| AtomicBool::new(true)).collect()); + + let workers = Arc::new( + workers_builders + .into_iter() + .map(|builder| builder.build(system_comm.clone(), task_stealer.clone())) + .collect::>(), + ); + + let handle = spawn({ + let workers = Arc::clone(&workers); + let msgs_rx = msgs_rx.clone(); + let idle_workers = Arc::clone(&idle_workers); + + async move { + trace!("Task System message processing task starting..."); + while let Err(e) = spawn(Self::run( + Arc::clone(&workers), + Arc::clone(&idle_workers), + msgs_rx.clone(), + )) + .await + { + if e.is_panic() { + error!("Job system panicked: {e:#?}"); + } else { + trace!("Task system received shutdown signal and will exit..."); + break; + } + trace!("Restarting task system message processing task...") + } + + info!("Task system gracefully shutdown"); + } + }); + + trace!("Task system online!"); + + Self { + workers: Arc::clone(&workers), + msgs_tx, + dispatcher: Dispatcher { + workers, + idle_workers, + last_worker_id: Arc::new(AtomicWorkerId::new(0)), + }, + + handle: RefCell::new(Some(handle)), + } + } + + /// Returns the number of workers in the system. + pub fn workers_count(&self) -> usize { + self.workers.len() + } + + /// Dispatches a task to the system, the task will be assigned to a worker and executed as soon as possible. + pub async fn dispatch(&self, into_task: impl IntoTask) -> TaskHandle { + self.dispatcher.dispatch(into_task).await + } + + /// Dispatches many tasks to the system, the tasks will be assigned to workers and executed as soon as possible. + pub async fn dispatch_many(&self, into_tasks: Vec>) -> Vec> { + self.dispatcher.dispatch_many(into_tasks).await + } + + /// Returns a dispatcher that can be used to remotely dispatch tasks to the system. + pub fn get_dispatcher(&self) -> Dispatcher { + self.dispatcher.clone() + } + + async fn run( + workers: Arc>>, + idle_workers: Arc>, + msgs_rx: chan::Receiver, + ) { + let mut msg_stream = pin!(msgs_rx); + + while let Some(msg) = msg_stream.next().await { + match msg { + SystemMessage::IdleReport(worker_id) => { + trace!("Task system received a worker idle report request: "); + idle_workers[worker_id].store(true, Ordering::Relaxed); + } + + SystemMessage::WorkingReport(worker_id) => { + trace!( + "Task system received a working report request: " + ); + idle_workers[worker_id].store(false, Ordering::Relaxed); + } + + SystemMessage::ResumeTask { + task_id, + worker_id, + ack, + } => { + trace!("Task system received a task resume request: "); + workers[worker_id].resume_task(task_id, ack).await; + } + + SystemMessage::PauseNotRunningTask { + task_id, + worker_id, + ack, + } => { + trace!("Task system received a task resume request: "); + workers[worker_id] + .pause_not_running_task(task_id, ack) + .await; + } + + SystemMessage::CancelNotRunningTask { + task_id, + worker_id, + ack, + } => { + trace!("Task system received a task resume request: "); + workers[worker_id] + .cancel_not_running_task(task_id, ack) + .await; + } + + SystemMessage::ForceAbortion { + task_id, + worker_id, + ack, + } => { + trace!( + "Task system received a task force abortion request: \ + " + ); + workers[worker_id].force_task_abortion(task_id, ack).await; + } + + SystemMessage::NotifyIdleWorkers { + start_from, + task_count, + } => { + trace!( + "Task system received a request to notify idle workers: \ + " + ); + + for idx in (0..workers.len()) + .cycle() + .skip(start_from) + .take(usize::min(task_count, workers.len())) + { + if idle_workers[idx].load(Ordering::Relaxed) { + workers[idx].wake().await; + // we don't mark the worker as not idle because we wait for it to + // successfully steal a task and then report it back as active + } + } + } + + SystemMessage::ShutdownRequest(tx) => { + trace!("Task system received a shutdown request"); + tx.send(Ok(())) + .expect("System channel closed trying to shutdown"); + return; + } + } + } + } + + /// Shuts down the system, returning all pending and running tasks to their respective handles. + pub async fn shutdown(&self) { + if let Some(handle) = self + .handle + .try_borrow_mut() + .ok() + .and_then(|mut maybe_handle| maybe_handle.take()) + { + self.workers + .iter() + .map(|worker| async move { worker.shutdown().await }) + .collect::>() + .join() + .await; + + let (tx, rx) = oneshot::channel(); + + self.msgs_tx + .send(SystemMessage::ShutdownRequest(tx)) + .await + .expect("Task system channel closed trying to shutdown"); + + if let Err(e) = rx + .await + .expect("Task system channel closed trying to shutdown") + { + error!("Task system failed to shutdown: {e:#?}"); + } + + if let Err(e) = handle.await { + error!("Task system failed to shutdown on handle await: {e:#?}"); + } + } else { + warn!("Trying to shutdown the tasks system that was already shutdown"); + } + } +} + +/// The default implementation of the task system will create a system with a number of workers equal to the available +/// parallelism in the user's machine. +impl Default for System { + fn default() -> Self { + Self::new() + } +} + +/// SAFETY: Due to usage of refcell we lost `Sync` impl, but we only use it to have a shutdown method +/// receiving `&self` which is called once, and we also use `try_borrow_mut` so we never panic +unsafe impl Sync for System {} + +#[derive(Clone, Debug)] +#[repr(transparent)] +pub(crate) struct SystemComm(chan::Sender); + +impl SystemComm { + pub async fn idle_report(&self, worker_id: usize) { + self.0 + .send(SystemMessage::IdleReport(worker_id)) + .await + .expect("System channel closed trying to report idle"); + } + + pub async fn working_report(&self, worker_id: usize) { + self.0 + .send(SystemMessage::WorkingReport(worker_id)) + .await + .expect("System channel closed trying to report working"); + } + + pub async fn pause_not_running_task( + &self, + task_id: TaskId, + worker_id: WorkerId, + ) -> Result<(), SystemError> { + let (tx, rx) = oneshot::channel(); + + self.0 + .send(SystemMessage::PauseNotRunningTask { + task_id, + worker_id, + ack: tx, + }) + .await + .expect("System channel closed trying to pause not running task"); + + rx.await + .expect("System channel closed trying receive pause not running task response") + } + + pub async fn cancel_not_running_task( + &self, + task_id: TaskId, + worker_id: WorkerId, + ) -> Result<(), SystemError> { + let (tx, rx) = oneshot::channel(); + + self.0 + .send(SystemMessage::CancelNotRunningTask { + task_id, + worker_id, + ack: tx, + }) + .await + .expect("System channel closed trying to cancel a not running task"); + + rx.await + .expect("System channel closed trying receive cancel a not running task response") + } + + pub async fn request_help(&self, worker_id: WorkerId, task_count: usize) { + self.0 + .send(SystemMessage::NotifyIdleWorkers { + start_from: worker_id, + task_count, + }) + .await + .expect("System channel closed trying to request help"); + } + + pub async fn resume_task( + &self, + task_id: TaskId, + worker_id: WorkerId, + ) -> Result<(), SystemError> { + let (tx, rx) = oneshot::channel(); + + self.0 + .send(SystemMessage::ResumeTask { + task_id, + worker_id, + ack: tx, + }) + .await + .expect("System channel closed trying to resume task"); + + rx.await + .expect("System channel closed trying receive resume task response") + } + + pub async fn force_abortion( + &self, + task_id: TaskId, + worker_id: WorkerId, + ) -> Result<(), SystemError> { + let (tx, rx) = oneshot::channel(); + + self.0 + .send(SystemMessage::ForceAbortion { + task_id, + worker_id, + ack: tx, + }) + .await + .expect("System channel closed trying to resume task"); + + rx.await + .expect("System channel closed trying receive resume task response") + } +} + +/// A remote dispatcher of tasks. +/// +/// It can be used to dispatch tasks to the system from other threads or tasks. +/// It uses [`Arc`] internally so it can be cheaply cloned and put inside tasks so tasks can dispatch other tasks. +#[derive(Debug)] +pub struct Dispatcher { + workers: Arc>>, + idle_workers: Arc>, + last_worker_id: Arc, +} + +impl Clone for Dispatcher { + fn clone(&self) -> Self { + Self { + workers: Arc::clone(&self.workers), + idle_workers: Arc::clone(&self.idle_workers), + last_worker_id: Arc::clone(&self.last_worker_id), + } + } +} + +impl Dispatcher { + /// Dispatches a task to the system, the task will be assigned to a worker and executed as soon as possible. + pub async fn dispatch(&self, into_task: impl IntoTask) -> TaskHandle { + let task = into_task.into_task(); + + async fn inner(this: &Dispatcher, task: Box>) -> TaskHandle { + let worker_id = this + .last_worker_id + .fetch_update(Ordering::Release, Ordering::Acquire, |last_worker_id| { + Some((last_worker_id + 1) % this.workers.len()) + }) + .expect("we hardcoded the update function to always return Some(next_worker_id) through dispatcher"); + + trace!( + "Dispatching task to worker: ", + task.id() + ); + let handle = this.workers[worker_id].add_task(task).await; + + this.idle_workers[worker_id].store(false, Ordering::Relaxed); + + handle + } + + inner(self, task).await + } + + /// Dispatches many tasks to the system, the tasks will be assigned to workers and executed as soon as possible. + pub async fn dispatch_many(&self, into_tasks: Vec>) -> Vec> { + let mut workers_task_count = self + .workers + .iter() + .map(|worker| async move { (worker.id, worker.task_count().await) }) + .collect::>() + .join() + .await; + + workers_task_count.sort_by_key(|(_id, count)| *count); + + let (handles, workers_ids_set) = into_tasks + .into_iter() + .map(IntoTask::into_task) + .zip(workers_task_count.into_iter().cycle()) + .map(|(task, (worker_id, _))| async move { + (self.workers[worker_id].add_task(task).await, worker_id) + }) + .collect::>() + .join() + .await + .into_iter() + .unzip::<_, _, Vec<_>, HashSet<_>>(); + + workers_ids_set.into_iter().for_each(|worker_id| { + self.idle_workers[worker_id].store(false, Ordering::Relaxed); + }); + + handles + } + + /// Returns the number of workers in the system. + pub fn workers_count(&self) -> usize { + self.workers.len() + } +} diff --git a/crates/task-system/src/task.rs b/crates/task-system/src/task.rs new file mode 100644 index 000000000..3962214e6 --- /dev/null +++ b/crates/task-system/src/task.rs @@ -0,0 +1,484 @@ +use std::{ + fmt, + future::{Future, IntoFuture}, + pin::Pin, + sync::{ + atomic::{AtomicBool, AtomicU8, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use async_channel as chan; +use async_trait::async_trait; +use chan::{Recv, RecvError}; +use downcast_rs::{impl_downcast, Downcast}; +use tokio::sync::oneshot; +use tracing::{trace, warn}; +use uuid::Uuid; + +use super::{ + error::{RunError, SystemError}, + system::SystemComm, + worker::{AtomicWorkerId, WorkerId}, +}; + +/// A unique identifier for a task using the [`uuid`](https://docs.rs/uuid) crate. +pub type TaskId = Uuid; + +/// A trait that represents any kind of output that a task can return. +/// +/// The user will downcast it to the concrete type that the task returns. Most of the time, +/// tasks will not return anything, so it isn't a costly abstraction, as only a heap allocation +/// is needed when the user wants to return a [`Box`]. +pub trait AnyTaskOutput: Send + fmt::Debug + Downcast + 'static {} + +impl_downcast!(AnyTaskOutput); + +/// Blanket implementation for all types that implements `std::fmt::Debug + Send + 'static` +impl AnyTaskOutput for T {} + +/// A helper trait to convert any type that implements [`AnyTaskOutput`] into a [`TaskOutput`], boxing it. +pub trait IntoAnyTaskOutput { + fn into_output(self) -> TaskOutput; +} + +/// Blanket implementation for all types that implements AnyTaskOutput +impl IntoAnyTaskOutput for T { + fn into_output(self) -> TaskOutput { + TaskOutput::Out(Box::new(self)) + } +} + +/// An enum representing whether a task returned anything or not. +#[derive(Debug)] +pub enum TaskOutput { + Out(Box), + Empty, +} + +/// An enum representing all possible outcomes for a task. +#[derive(Debug)] +pub enum TaskStatus { + /// The task has finished successfully and maybe has some output for the user. + Done(TaskOutput), + /// Task was gracefully cancelled by the user. + Canceled, + /// Task was forcefully aborted by the user. + ForcedAbortion, + /// The task system was shutdown and we give back the task to the user so they can downcast it + /// back to the original concrete type and store it on disk or any other storage to be re-dispatched later. + Shutdown(Box>), + /// Task had and error so we return it back and the user can handle it appropriately. + Error(E), +} + +/// Represents whether the current [`Task::run`] method on a task finished successfully or was interrupted. +/// +/// `Done` and `Canceled` variants can only happen once, while `Paused` can happen multiple times, +/// whenever the user wants to pause the task. +#[derive(Debug)] +pub enum ExecStatus { + Done(TaskOutput), + Paused, + Canceled, +} + +#[derive(Debug)] +pub(crate) enum InternalTaskExecStatus { + Done(TaskOutput), + Paused, + Canceled, + Suspend, + Error(E), +} + +impl From> for InternalTaskExecStatus { + fn from(result: Result) -> Self { + result + .map(|status| match status { + ExecStatus::Done(out) => Self::Done(out), + ExecStatus::Paused => Self::Paused, + ExecStatus::Canceled => Self::Canceled, + }) + .unwrap_or_else(|e| Self::Error(e)) + } +} + +/// A helper trait to convert any type that implements [`Task`] into a [`Box>`], boxing it. +pub trait IntoTask { + fn into_task(self) -> Box>; +} + +/// Blanket implementation for all types that implements [`Task`] and `'static` +impl + 'static, E: RunError> IntoTask for T { + fn into_task(self) -> Box> { + Box::new(self) + } +} + +/// The main trait that represents a task that can be dispatched to the task system. +/// +/// All traits in the task system must return the same generic error type, so we can have a unified +/// error handling. +/// +/// We're currently using the [`async_trait`](https://docs.rs/async-trait) crate to allow dyn async traits, +/// due to a limitation in the Rust language. +#[async_trait] +pub trait Task: fmt::Debug + Downcast + Send + 'static { + /// This method represent the work that should be done by the worker, it will be called by the + /// worker when there is a slot available in its internal queue. + /// We receive a `&mut self` so any internal data can be mutated on each `run` invocation. + /// + /// The [`interrupter`](Interrupter) is a helper object that can be used to check if the user requested a pause or a cancel, + /// so the user can decide the appropriated moment to pause or cancel the task. Avoiding corrupted data or + /// inconsistent states. + async fn run(&mut self, interrupter: &Interrupter) -> Result; + + /// This method defines whether a task should run with priority or not. The task system has a mechanism + /// to suspend non-priority tasks on any worker and run priority tasks ASAP. This is useful for tasks that + /// are more important than others, like a task that should be concluded and show results immediately to the user, + /// as thumbnails being generated for the current open directory or copy/paste operations. + fn with_priority(&self) -> bool { + false + } + + /// An unique identifier for the task, it will be used to identify the task on the system and also to the user. + fn id(&self) -> TaskId; +} + +impl_downcast!(Task where E: RunError); + +/// Intermediate struct to wait until a pause or a cancel commands are sent by the user. +#[must_use = "`InterrupterFuture` does nothing unless polled"] +#[pin_project::pin_project] +pub struct InterrupterFuture<'recv> { + #[pin] + fut: Recv<'recv, InterruptionRequest>, + has_interrupted: &'recv AtomicU8, +} + +impl Future for InterrupterFuture<'_> { + type Output = InterruptionKind; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.fut.poll(cx) { + Poll::Ready(Ok(InterruptionRequest { kind, ack })) => { + if ack.send(Ok(())).is_err() { + warn!("TaskInterrupter ack channel closed"); + } + this.has_interrupted.store(kind as u8, Ordering::Relaxed); + Poll::Ready(kind) + } + Poll::Ready(Err(RecvError)) => { + // In case the task handle was dropped, we can't receive any more interrupt messages + // so we will never interrupt and the task will run freely until ended + warn!("Task interrupter channel closed, will run task until it finishes!"); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } +} + +/// We use an [`IntoFuture`] implementation to allow the user to use the `await` syntax on the [`Interrupter`] object. +/// With this trait, we return an [`InterrupterFuture`] that will await until the user requests a pause or a cancel. +impl<'recv> IntoFuture for &'recv Interrupter { + type Output = InterruptionKind; + + type IntoFuture = InterrupterFuture<'recv>; + + fn into_future(self) -> Self::IntoFuture { + InterrupterFuture { + fut: self.interrupt_rx.recv(), + has_interrupted: &self.has_interrupted, + } + } +} + +/// A helper object that can be used to check if the user requested a pause or a cancel, so the task `run` +/// implementation can decide the appropriated moment to pause or cancel the task. +#[derive(Debug)] +pub struct Interrupter { + interrupt_rx: chan::Receiver, + has_interrupted: AtomicU8, +} + +impl Interrupter { + pub(crate) fn new(interrupt_tx: chan::Receiver) -> Self { + Self { + interrupt_rx: interrupt_tx, + has_interrupted: AtomicU8::new(0), + } + } + + /// Check if the user requested a pause or a cancel, returning the kind of interruption that was requested + /// in a non-blocking manner. + pub fn try_check_interrupt(&self) -> Option { + if let Some(kind) = InterruptionKind::load(&self.has_interrupted) { + Some(kind) + } else if let Ok(InterruptionRequest { kind, ack }) = self.interrupt_rx.try_recv() { + if ack.send(Ok(())).is_err() { + warn!("TaskInterrupter ack channel closed"); + } + + self.has_interrupted.store(kind as u8, Ordering::Relaxed); + + Some(kind) + } else { + None + } + } + + pub(super) fn reset(&self) { + self.has_interrupted + .compare_exchange( + InterruptionKind::Pause as u8, + 0, + Ordering::Release, + Ordering::Relaxed, + ) + .expect("we must only reset paused tasks"); + } +} + +/// The kind of interruption that can be requested by the user, a pause or a cancel +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum InterruptionKind { + Pause = 1, + Cancel = 2, +} + +impl InterruptionKind { + fn load(kind: &AtomicU8) -> Option { + match kind.load(Ordering::Relaxed) { + 1 => Some(Self::Pause), + 2 => Some(Self::Cancel), + _ => None, + } + } +} + +#[derive(Debug)] +pub(crate) struct InterruptionRequest { + kind: InterruptionKind, + ack: oneshot::Sender>, +} + +/// A handle returned when a task is dispatched to the task system, it can be used to pause, cancel, resume, or wait +/// until the task gets completed. +#[derive(Debug)] +pub struct TaskHandle { + pub(crate) worktable: Arc, + pub(crate) done_rx: oneshot::Receiver, SystemError>>, + pub(crate) system_comm: SystemComm, + pub(crate) task_id: TaskId, +} + +impl Future for TaskHandle { + type Output = Result, SystemError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.done_rx) + .poll(cx) + .map(|res| res.expect("TaskHandle done channel unexpectedly closed")) + } +} + +impl TaskHandle { + /// Get the unique identifier of the task + pub fn task_id(&self) -> TaskId { + self.task_id + } + + /// Gracefully pause the task at a safe point defined by the user using the [`Interrupter`] + pub async fn pause(&self) -> Result<(), SystemError> { + let is_paused = self.worktable.is_paused.load(Ordering::Relaxed); + let is_canceled = self.worktable.is_canceled.load(Ordering::Relaxed); + let is_done = self.worktable.is_done.load(Ordering::Relaxed); + + trace!("Received pause command task: "); + + if !is_paused && !is_canceled && !is_done { + if self.worktable.is_running.load(Ordering::Relaxed) { + let (tx, rx) = oneshot::channel(); + + trace!("Task is running, sending pause request"); + + self.worktable.pause(tx).await; + + rx.await.expect("Worker failed to ack pause request")?; + } else { + trace!("Task is not running, setting is_paused flag"); + self.worktable.is_paused.store(true, Ordering::Relaxed); + return self + .system_comm + .pause_not_running_task( + self.task_id, + self.worktable.current_worker_id.load(Ordering::Relaxed), + ) + .await; + } + } + + Ok(()) + } + + /// Gracefully cancel the task at a safe point defined by the user using the [`Interrupter`] + pub async fn cancel(&self) -> Result<(), SystemError> { + let is_canceled = self.worktable.is_canceled.load(Ordering::Relaxed); + let is_done = self.worktable.is_done.load(Ordering::Relaxed); + + trace!("Received cancel command task: "); + + if !is_canceled && !is_done { + if self.worktable.is_running.load(Ordering::Relaxed) { + let (tx, rx) = oneshot::channel(); + + trace!("Task is running, sending cancel request"); + + self.worktable.cancel(tx).await; + + rx.await.expect("Worker failed to ack cancel request")?; + } else { + trace!("Task is not running, setting is_canceled flag"); + self.worktable.is_canceled.store(true, Ordering::Relaxed); + return self + .system_comm + .cancel_not_running_task( + self.task_id, + self.worktable.current_worker_id.load(Ordering::Relaxed), + ) + .await; + } + } + + Ok(()) + } + + /// Forcefully abort the task, this can lead to corrupted data or inconsistent states, so use it with caution. + pub async fn force_abortion(&self) -> Result<(), SystemError> { + self.worktable.set_aborted(); + self.system_comm + .force_abortion( + self.task_id, + self.worktable.current_worker_id.load(Ordering::Relaxed), + ) + .await + } + + /// Marks the task to be resumed by the task system, the worker will start processing it if there is a slot + /// available or will be enqueued otherwise. + pub async fn resume(&self) -> Result<(), SystemError> { + self.system_comm + .resume_task( + self.task_id, + self.worktable.current_worker_id.load(Ordering::Relaxed), + ) + .await + } +} + +#[derive(Debug)] +pub(crate) struct TaskWorktable { + started: AtomicBool, + is_running: AtomicBool, + is_done: AtomicBool, + is_paused: AtomicBool, + is_canceled: AtomicBool, + is_aborted: AtomicBool, + interrupt_tx: chan::Sender, + current_worker_id: AtomicWorkerId, +} + +impl TaskWorktable { + pub fn new(worker_id: WorkerId, interrupt_tx: chan::Sender) -> Self { + Self { + started: AtomicBool::new(false), + is_running: AtomicBool::new(false), + is_done: AtomicBool::new(false), + is_paused: AtomicBool::new(false), + is_canceled: AtomicBool::new(false), + is_aborted: AtomicBool::new(false), + interrupt_tx, + current_worker_id: AtomicWorkerId::new(worker_id), + } + } + + pub fn set_started(&self) { + self.started.store(true, Ordering::Relaxed); + self.is_running.store(true, Ordering::Relaxed); + } + + pub fn set_completed(&self) { + self.is_done.store(true, Ordering::Relaxed); + self.is_running.store(false, Ordering::Relaxed); + } + + pub fn set_unpause(&self) { + self.is_paused.store(false, Ordering::Relaxed); + } + + pub fn set_aborted(&self) { + self.is_aborted.store(true, Ordering::Relaxed); + } + + pub async fn pause(&self, tx: oneshot::Sender>) { + self.is_paused.store(true, Ordering::Relaxed); + self.is_running.store(false, Ordering::Relaxed); + + trace!("Sending pause signal to Interrupter object on task"); + + self.interrupt_tx + .send(InterruptionRequest { + kind: InterruptionKind::Pause, + ack: tx, + }) + .await + .expect("Worker channel closed trying to pause task"); + } + + pub async fn cancel(&self, tx: oneshot::Sender>) { + self.is_canceled.store(true, Ordering::Relaxed); + self.is_running.store(false, Ordering::Relaxed); + + self.interrupt_tx + .send(InterruptionRequest { + kind: InterruptionKind::Cancel, + ack: tx, + }) + .await + .expect("Worker channel closed trying to pause task"); + } + + pub fn is_paused(&self) -> bool { + self.is_paused.load(Ordering::Relaxed) + } + + pub fn is_canceled(&self) -> bool { + self.is_canceled.load(Ordering::Relaxed) + } + + pub fn is_aborted(&self) -> bool { + self.is_aborted.load(Ordering::Relaxed) + } +} + +#[derive(Debug)] +pub(crate) struct TaskWorkState { + pub(crate) task: Box>, + pub(crate) worktable: Arc, + pub(crate) done_tx: oneshot::Sender, SystemError>>, + pub(crate) interrupter: Arc, +} + +impl TaskWorkState { + pub fn change_worker(&self, new_worker_id: WorkerId) { + self.worktable + .current_worker_id + .store(new_worker_id, Ordering::Relaxed); + } +} diff --git a/crates/task-system/src/worker/mod.rs b/crates/task-system/src/worker/mod.rs new file mode 100644 index 000000000..dc8b87b04 --- /dev/null +++ b/crates/task-system/src/worker/mod.rs @@ -0,0 +1,328 @@ +use std::{ + cell::RefCell, + sync::{atomic::AtomicUsize, Arc}, + time::Duration, +}; + +use async_channel as chan; +use tokio::{spawn, sync::oneshot, task::JoinHandle}; +use tracing::{error, info, trace, warn}; + +use super::{ + error::{RunError, SystemError}, + message::WorkerMessage, + system::SystemComm, + task::{ + InternalTaskExecStatus, Interrupter, Task, TaskHandle, TaskId, TaskWorkState, TaskWorktable, + }, +}; + +mod run; +mod runner; + +use run::run; + +const ONE_SECOND: Duration = Duration::from_secs(1); + +pub(crate) type WorkerId = usize; +pub(crate) type AtomicWorkerId = AtomicUsize; + +pub(crate) struct WorkerBuilder { + id: usize, + msgs_tx: chan::Sender>, + msgs_rx: chan::Receiver>, +} + +impl WorkerBuilder { + pub fn new(id: WorkerId) -> (Self, WorkerComm) { + let (msgs_tx, msgs_rx) = chan::bounded(8); + + let worker_comm = WorkerComm { + worker_id: id, + msgs_tx: msgs_tx.clone(), + }; + + ( + Self { + id, + msgs_tx, + msgs_rx, + }, + worker_comm, + ) + } + + pub fn build(self, system_comm: SystemComm, task_stealer: WorkStealer) -> Worker { + let Self { + id, + msgs_tx, + msgs_rx, + } = self; + + let handle = spawn({ + let msgs_rx = msgs_rx.clone(); + let system_comm = system_comm.clone(); + let task_stealer = task_stealer.clone(); + + async move { + trace!("Worker message processing task starting..."); + while let Err(e) = spawn(run( + id, + system_comm.clone(), + task_stealer.clone(), + msgs_rx.clone(), + )) + .await + { + if e.is_panic() { + error!( + "Worker critically failed and will restart: \ + {e:#?}" + ); + } else { + trace!( + "Worker received shutdown signal and will exit..." + ); + break; + } + } + + info!("Worker gracefully shutdown"); + } + }); + + Worker { + id, + system_comm, + msgs_tx, + handle: RefCell::new(Some(handle)), + } + } +} + +#[derive(Debug)] +pub(crate) struct Worker { + pub id: usize, + system_comm: SystemComm, + msgs_tx: chan::Sender>, + handle: RefCell>>, +} + +impl Worker { + pub async fn add_task(&self, new_task: Box>) -> TaskHandle { + let (done_tx, done_rx) = oneshot::channel(); + + let (interrupt_tx, interrupt_rx) = chan::bounded(1); + + let worktable = Arc::new(TaskWorktable::new(self.id, interrupt_tx)); + + let task_id = new_task.id(); + + self.msgs_tx + .send(WorkerMessage::NewTask(TaskWorkState { + task: new_task, + worktable: Arc::clone(&worktable), + interrupter: Arc::new(Interrupter::new(interrupt_rx)), + done_tx, + })) + .await + .expect("Worker channel closed trying to add task"); + + TaskHandle { + worktable, + done_rx, + system_comm: self.system_comm.clone(), + task_id, + } + } + + pub async fn task_count(&self) -> usize { + let (tx, rx) = oneshot::channel(); + + self.msgs_tx + .send(WorkerMessage::TaskCountRequest(tx)) + .await + .expect("Worker channel closed trying to get task count"); + + rx.await + .expect("Worker channel closed trying to receive task count response") + } + + pub async fn resume_task( + &self, + task_id: TaskId, + ack: oneshot::Sender>, + ) { + self.msgs_tx + .send(WorkerMessage::ResumeTask { task_id, ack }) + .await + .expect("Worker channel closed trying to resume task"); + } + + pub async fn pause_not_running_task( + &self, + task_id: TaskId, + ack: oneshot::Sender>, + ) { + self.msgs_tx + .send(WorkerMessage::PauseNotRunningTask { task_id, ack }) + .await + .expect("Worker channel closed trying to pause a not running task"); + } + + pub async fn cancel_not_running_task( + &self, + task_id: TaskId, + ack: oneshot::Sender>, + ) { + self.msgs_tx + .send(WorkerMessage::CancelNotRunningTask { task_id, ack }) + .await + .expect("Worker channel closed trying to cancel a not running task"); + } + + pub async fn force_task_abortion( + &self, + task_id: TaskId, + ack: oneshot::Sender>, + ) { + self.msgs_tx + .send(WorkerMessage::ForceAbortion { task_id, ack }) + .await + .expect("Worker channel closed trying to force task abortion"); + } + + pub async fn shutdown(&self) { + if let Some(handle) = self + .handle + .try_borrow_mut() + .ok() + .and_then(|mut maybe_handle| maybe_handle.take()) + { + let (tx, rx) = oneshot::channel(); + + self.msgs_tx + .send(WorkerMessage::ShutdownRequest(tx)) + .await + .expect("Worker channel closed trying to shutdown"); + + rx.await.expect("Worker channel closed trying to shutdown"); + + if let Err(e) = handle.await { + if e.is_panic() { + error!("Worker {} critically failed: {e:#?}", self.id); + } + } + } else { + warn!("Trying to shutdown a worker that was already shutdown"); + } + } + + pub async fn wake(&self) { + self.msgs_tx + .send(WorkerMessage::WakeUp) + .await + .expect("Worker channel closed trying to wake up"); + } +} + +/// SAFETY: Due to usage of refcell we lost `Sync` impl, but we only use it to have a shutdown method +/// receiving `&self` which is called once, and we also use `try_borrow_mut` so we never panic +unsafe impl Sync for Worker {} + +#[derive(Clone)] +pub(crate) struct WorkerComm { + worker_id: WorkerId, + msgs_tx: chan::Sender>, +} + +impl WorkerComm { + pub async fn steal_task(&self, worker_id: WorkerId) -> Option> { + let (tx, rx) = oneshot::channel(); + + self.msgs_tx + .send(WorkerMessage::StealRequest(tx)) + .await + .expect("Worker channel closed trying to steal task"); + + rx.await + .expect("Worker channel closed trying to steal task") + .map(|task_work_state| { + trace!( + "Worker stole task: \ + ", + self.worker_id, + task_work_state.task.id() + ); + task_work_state.change_worker(worker_id); + task_work_state + }) + } +} + +pub(crate) struct WorkStealer { + worker_comms: Arc>>, +} + +impl Clone for WorkStealer { + fn clone(&self) -> Self { + Self { + worker_comms: Arc::clone(&self.worker_comms), + } + } +} + +impl WorkStealer { + pub fn new(worker_comms: Vec>) -> Self { + Self { + worker_comms: Arc::new(worker_comms), + } + } + + pub async fn steal(&self, worker_id: WorkerId) -> Option> { + let total_workers = self.worker_comms.len(); + + for worker_comm in self + .worker_comms + .iter() + // Cycling over the workers + .cycle() + // Starting from the next worker id + .skip(worker_id) + // Taking the total amount of workers + .take(total_workers) + // Removing the current worker as we can't steal from ourselves + .filter(|worker_comm| worker_comm.worker_id != worker_id) + { + trace!( + "Trying to steal from worker ", + worker_comm.worker_id + ); + + if let Some(task) = worker_comm.steal_task(worker_id).await { + return Some(task); + } else { + trace!( + "Worker has no tasks to steal", + worker_comm.worker_id + ); + } + } + + None + } + + pub fn workers_count(&self) -> usize { + self.worker_comms.len() + } +} + +struct TaskRunnerOutput { + task_work_state: TaskWorkState, + status: InternalTaskExecStatus, +} + +enum RunnerMessage { + TaskOutput(TaskId, Result, ()>), + StoleTask(Option>), +} diff --git a/crates/task-system/src/worker/run.rs b/crates/task-system/src/worker/run.rs new file mode 100644 index 000000000..3f76f2070 --- /dev/null +++ b/crates/task-system/src/worker/run.rs @@ -0,0 +1,113 @@ +use std::pin::pin; + +use async_channel as chan; +use futures::StreamExt; +use futures_concurrency::stream::Merge; +use tokio::time::{interval_at, Instant}; +use tokio_stream::wrappers::IntervalStream; +use tracing::{error, warn}; + +use super::{ + super::{error::RunError, message::WorkerMessage, system::SystemComm}, + runner::Runner, + RunnerMessage, WorkStealer, WorkerId, ONE_SECOND, +}; + +pub(super) async fn run( + id: WorkerId, + system_comm: SystemComm, + work_stealer: WorkStealer, + msgs_rx: chan::Receiver>, +) { + let (mut runner, runner_rx) = Runner::new(id, work_stealer, system_comm); + + let mut idle_checker_interval = interval_at(Instant::now(), ONE_SECOND); + idle_checker_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + enum StreamMessage { + Commands(WorkerMessage), + RunnerMsg(RunnerMessage), + IdleCheck, + } + + let mut msg_stream = pin!(( + msgs_rx.map(StreamMessage::Commands), + runner_rx.map(StreamMessage::RunnerMsg), + IntervalStream::new(idle_checker_interval).map(|_| StreamMessage::IdleCheck), + ) + .merge()); + + while let Some(msg) = msg_stream.next().await { + match msg { + // Worker messages + StreamMessage::Commands(WorkerMessage::NewTask(task_work_state)) => { + runner.abort_steal_task(); + runner.new_task(task_work_state).await; + } + + StreamMessage::Commands(WorkerMessage::TaskCountRequest(tx)) => { + if tx.send(runner.total_tasks()).is_err() { + warn!("Task count request channel closed before sending task count"); + } + } + + StreamMessage::Commands(WorkerMessage::ResumeTask { task_id, ack }) => { + if ack.send(runner.resume_task(task_id).await).is_err() { + warn!("Resume task channel closed before sending ack"); + } + } + + StreamMessage::Commands(WorkerMessage::PauseNotRunningTask { task_id, ack }) => { + if ack + .send(runner.pause_not_running_task(task_id).await) + .is_err() + { + warn!("Resume task channel closed before sending ack"); + } + } + + StreamMessage::Commands(WorkerMessage::CancelNotRunningTask { task_id, ack }) => { + if ack + .send(runner.cancel_not_running_task(task_id).await) + .is_err() + { + warn!("Resume task channel closed before sending ack"); + } + } + + StreamMessage::Commands(WorkerMessage::ForceAbortion { task_id, ack }) => { + if ack.send(runner.force_task_abortion(task_id).await).is_err() { + warn!("Force abortion channel closed before sending ack"); + } + } + + StreamMessage::Commands(WorkerMessage::ShutdownRequest(tx)) => { + return runner.shutdown(tx).await; + } + + StreamMessage::Commands(WorkerMessage::StealRequest(tx)) => runner.steal_request(tx), + + StreamMessage::Commands(WorkerMessage::WakeUp) => runner.wake_up().await, + + // Runner messages + StreamMessage::RunnerMsg(RunnerMessage::TaskOutput(task_id, Ok(output))) => { + runner.process_task_output(task_id, output).await + } + + StreamMessage::RunnerMsg(RunnerMessage::TaskOutput(task_id, Err(()))) => { + error!("Task failed "); + + runner.clean_suspended_task(task_id); + + runner.dispatch_next_task(task_id).await; + } + + StreamMessage::RunnerMsg(RunnerMessage::StoleTask(maybe_new_task)) => { + runner.process_stolen_task(maybe_new_task).await; + } + + // Idle checking to steal some work + StreamMessage::IdleCheck => runner.idle_check().await, + } + } +} diff --git a/crates/task-system/src/worker/runner.rs b/crates/task-system/src/worker/runner.rs new file mode 100644 index 000000000..e42791cf4 --- /dev/null +++ b/crates/task-system/src/worker/runner.rs @@ -0,0 +1,1408 @@ +use std::{ + collections::{HashMap, VecDeque}, + future::pending, + pin::pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_channel as chan; +use futures::StreamExt; +use futures_concurrency::future::Race; +use tokio::{ + spawn, + sync::oneshot, + task::{JoinError, JoinHandle}, + time::{timeout, Instant}, +}; +use tracing::{debug, error, trace, warn}; + +use super::{ + super::{ + error::{RunError, SystemError}, + system::SystemComm, + task::{ + ExecStatus, InternalTaskExecStatus, Task, TaskId, TaskOutput, TaskStatus, TaskWorkState, + }, + }, + RunnerMessage, TaskRunnerOutput, WorkStealer, WorkerId, ONE_SECOND, +}; + +const TEN_SECONDS: Duration = Duration::from_secs(10); +const ONE_MINUTE: Duration = Duration::from_secs(60); + +const TASK_QUEUE_INITIAL_SIZE: usize = 64; +const PRIORITY_TASK_QUEUE_INITIAL_SIZE: usize = 32; +const ABORT_AND_SUSPEND_MAP_INITIAL_SIZE: usize = 8; + +pub(super) enum TaskAddStatus { + Running, + Enqueued, +} + +struct AbortAndSuspendSignalers { + abort_tx: oneshot::Sender>>, + suspend_tx: oneshot::Sender<()>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum PendingTaskKind { + Normal, + Priority, + Suspended, +} + +impl PendingTaskKind { + fn with_priority(has_priority: bool) -> Self { + if has_priority { + Self::Priority + } else { + Self::Normal + } + } +} + +struct RunningTask { + task_id: TaskId, + task_kind: PendingTaskKind, + handle: JoinHandle<()>, +} + +fn dispatch_steal_request( + worker_id: WorkerId, + work_stealer: WorkStealer, + runner_tx: chan::Sender>, +) -> JoinHandle<()> { + spawn(async move { + runner_tx + .send(RunnerMessage::StoleTask( + work_stealer.steal(worker_id).await, + )) + .await + .expect("runner channel closed before send stolen task"); + }) +} + +enum WaitingSuspendedTask { + Task(TaskId), + None, +} + +impl WaitingSuspendedTask { + fn is_waiting(&self) -> bool { + matches!(self, Self::Task(_)) + } +} + +pub(super) struct Runner { + worker_id: WorkerId, + system_comm: SystemComm, + work_stealer: WorkStealer, + task_kinds: HashMap, + tasks: VecDeque>, + paused_tasks: HashMap>, + suspended_task: Option>, + priority_tasks: VecDeque>, + last_requested_help: Instant, + is_idle: bool, + waiting_suspension: WaitingSuspendedTask, + abort_and_suspend_map: HashMap, + runner_tx: chan::Sender>, + current_task_handle: Option, + suspend_on_shutdown_rx: chan::Receiver>, + current_steal_task_handle: Option>, + last_steal_attempt_at: Instant, + steal_attempts_count: u32, +} + +impl Runner { + pub(super) fn new( + worker_id: WorkerId, + work_stealer: WorkStealer, + system_comm: SystemComm, + ) -> (Self, chan::Receiver>) { + let (runner_tx, runner_rx) = chan::bounded(8); + + ( + Self { + worker_id, + system_comm, + work_stealer, + task_kinds: HashMap::with_capacity(TASK_QUEUE_INITIAL_SIZE), + tasks: VecDeque::with_capacity(TASK_QUEUE_INITIAL_SIZE), + paused_tasks: HashMap::new(), + suspended_task: None, + priority_tasks: VecDeque::with_capacity(PRIORITY_TASK_QUEUE_INITIAL_SIZE), + last_requested_help: Instant::now(), + is_idle: true, + waiting_suspension: WaitingSuspendedTask::None, + abort_and_suspend_map: HashMap::with_capacity(ABORT_AND_SUSPEND_MAP_INITIAL_SIZE), + runner_tx, + current_task_handle: None, + suspend_on_shutdown_rx: runner_rx.clone(), + current_steal_task_handle: None, + last_steal_attempt_at: Instant::now(), + steal_attempts_count: 0, + }, + runner_rx, + ) + } + + pub(super) fn total_tasks(&self) -> usize { + let priority_tasks_count = self.priority_tasks.len(); + let current_task_count = if self.current_task_handle.is_some() { + 1 + } else { + 0 + }; + let suspended_task_count = if self.suspended_task.is_some() { 1 } else { 0 }; + let tasks_count = self.tasks.len(); + + trace!( + "Task count: \ + ", + self.worker_id + ); + + priority_tasks_count + current_task_count + suspended_task_count + tasks_count + } + + pub(super) fn spawn_task_runner( + &mut self, + task_id: TaskId, + task_work_state: TaskWorkState, + ) -> JoinHandle<()> { + let (abort_tx, abort_rx) = oneshot::channel(); + let (suspend_tx, suspend_rx) = oneshot::channel(); + + self.abort_and_suspend_map.insert( + task_id, + AbortAndSuspendSignalers { + abort_tx, + suspend_tx, + }, + ); + + let handle = spawn(run_single_task( + self.worker_id, + task_work_state, + self.runner_tx.clone(), + suspend_rx, + abort_rx, + )); + + trace!( + "Task runner spawned: ", + self.worker_id + ); + + handle + } + + pub(super) async fn new_task(&mut self, task_work_state: TaskWorkState) { + let task_id = task_work_state.task.id(); + let new_kind = PendingTaskKind::with_priority(task_work_state.task.with_priority()); + + trace!( + "Received new task: ", + self.worker_id + ); + + self.task_kinds.insert(task_id, new_kind); + + match self + .inner_add_task(task_id, new_kind, task_work_state) + .await + { + TaskAddStatus::Running => trace!( + "Task running: ", + self.worker_id + ), + TaskAddStatus::Enqueued => trace!( + "Task enqueued: ", + self.worker_id + ), + } + } + + pub(super) async fn resume_task(&mut self, task_id: TaskId) -> Result<(), SystemError> { + trace!( + "Resume task request: ", + self.worker_id + ); + if let Some(task_work_state) = self.paused_tasks.remove(&task_id) { + task_work_state.worktable.set_unpause(); + + match self + .inner_add_task( + task_id, + *self + .task_kinds + .get(&task_id) + .expect("we added the task kind before pausing it"), + task_work_state, + ) + .await + { + TaskAddStatus::Running => trace!( + "Resumed task is running: ", + self.worker_id + ), + TaskAddStatus::Enqueued => trace!( + "Resumed task was enqueued: ", + self.worker_id + ), + } + + Ok(()) + } else { + trace!( + "Task not found: ", + self.worker_id + ); + Err(SystemError::TaskNotFound(task_id)) + } + } + + pub(super) async fn pause_not_running_task( + &mut self, + task_id: TaskId, + ) -> Result<(), SystemError> { + trace!( + "Pause not running task request: ", + self.worker_id + ); + + if self.paused_tasks.contains_key(&task_id) { + trace!( + "Task is already paused: ", + self.worker_id + ); + return Ok(()); + } + + if let Some(current_task) = &self.current_task_handle { + if current_task.task_id == task_id { + trace!( + "Task began to run before we managed to pause it, run function will pause it: \ + ", + self.worker_id + ); + return Ok(()); // The task will pause itself + } + } + + if let Some(suspended_task) = &self.suspended_task { + if suspended_task.task.id() == task_id { + trace!( + "Task is already suspended but will be paused: ", + self.worker_id + ); + + self.paused_tasks.insert( + task_id, + self.suspended_task.take().expect("we just checked it"), + ); + + return Ok(()); + } + } + + if let Some(index) = self + .priority_tasks + .iter() + .position(|task_work_state| task_work_state.task.id() == task_id) + { + self.paused_tasks.insert( + task_id, + self.priority_tasks + .remove(index) + .expect("we just checked it"), + ); + + return Ok(()); + } + + if let Some(index) = self + .tasks + .iter() + .position(|task_work_state| task_work_state.task.id() == task_id) + { + self.paused_tasks.insert( + task_id, + self.tasks.remove(index).expect("we just checked it"), + ); + + return Ok(()); + } + + Err(SystemError::TaskNotFound(task_id)) + } + + pub(super) async fn cancel_not_running_task( + &mut self, + task_id: TaskId, + ) -> Result<(), SystemError> { + trace!( + "Cancel not running task request: ", + self.worker_id + ); + + if let Some(current_task) = &self.current_task_handle { + if current_task.task_id == task_id { + trace!( + "Task began to run before we managed to cancel it, run function will cancel it: \ + ", + self.worker_id + ); + return Ok(()); // The task will cancel itself + } + } + + if let Some(suspended_task) = &self.suspended_task { + if suspended_task.task.id() == task_id { + trace!( + "Task is already suspended but will be paused: ", + self.worker_id + ); + + send_cancel_task_response( + self.worker_id, + task_id, + self.suspended_task.take().expect("we just checked it"), + ); + + return Ok(()); + } + } + + if let Some(index) = self + .priority_tasks + .iter() + .position(|task_work_state| task_work_state.task.id() == task_id) + { + send_cancel_task_response( + self.worker_id, + task_id, + self.priority_tasks + .remove(index) + .expect("we just checked it"), + ); + + return Ok(()); + } + + if let Some(index) = self + .tasks + .iter() + .position(|task_work_state| task_work_state.task.id() == task_id) + { + send_cancel_task_response( + self.worker_id, + task_id, + self.tasks.remove(index).expect("we just checked it"), + ); + + return Ok(()); + } + + // If the task is not found, then it's possible that the user already canceled it but still have the handle + Ok(()) + } + + #[inline(always)] + pub(super) async fn inner_add_task( + &mut self, + task_id: TaskId, + task_kind: PendingTaskKind, + task_work_state: TaskWorkState, + ) -> TaskAddStatus { + if self.is_idle { + trace!( + "Idle worker will process the new task: ", + self.worker_id + ); + let handle = self.spawn_task_runner(task_id, task_work_state); + + self.current_task_handle = Some(RunningTask { + task_id, + task_kind, + handle, + }); + + // Doesn't need to report working back to system as it already registered + // that we're not idle anymore when it dispatched the task to this worker + self.is_idle = false; + + TaskAddStatus::Running + } else { + let RunningTask { + task_id: old_task_id, + task_kind: old_kind, + .. + } = self + .current_task_handle + .as_ref() + .expect("Worker isn't idle, but no task is running"); + + trace!( + "Worker is busy: \ + ", + self.worker_id, + ); + + let add_status = match (task_kind, old_kind) { + (PendingTaskKind::Priority, PendingTaskKind::Priority) => { + trace!( + "Old and new tasks have priority, will put new task on priority queue: \ + ", + self.worker_id + ); + self.priority_tasks.push_front(task_work_state); + + TaskAddStatus::Enqueued + } + (PendingTaskKind::Priority, PendingTaskKind::Normal) => { + if !self.waiting_suspension.is_waiting() { + trace!( + "Old task will be suspended: \ + ", + self.worker_id + ); + + // We put the query at the top of the priority queue, so it will be + // dispatched by the run function as soon as the current task is suspended + self.priority_tasks.push_front(task_work_state); + + if self + .abort_and_suspend_map + .remove(old_task_id) + .expect("we always store the abort and suspend signalers") + .suspend_tx + .send(()) + .is_err() + { + warn!( + "Task suspend channel closed before receiving suspend signal. \ + This probably happened because the task finished before we could suspend it." + ); + } + + self.waiting_suspension = WaitingSuspendedTask::Task(*old_task_id); + } else { + trace!( + "Worker is already waiting for a task to be suspended, will enqueue new task: \ + ", + self.worker_id + ); + + self.priority_tasks.push_front(task_work_state); + } + + TaskAddStatus::Running + } + (_, _) => { + trace!( + "New task doesn't have priority and will be enqueued: \ + ", + self.worker_id, + ); + + self.tasks.push_back(task_work_state); + + TaskAddStatus::Enqueued + } + }; + + let task_count = self.total_tasks(); + + trace!( + "Worker with {task_count} pending tasks: ", + self.worker_id + ); + + if task_count > self.work_stealer.workers_count() + && self.last_requested_help.elapsed() > ONE_SECOND + { + trace!( + "Worker requesting help from the system: \ + ", + self.worker_id + ); + + self.system_comm + .request_help(self.worker_id, task_count) + .await; + + self.last_requested_help = Instant::now(); + } + + add_status + } + } + + pub(super) async fn force_task_abortion( + &mut self, + task_id: uuid::Uuid, + ) -> Result<(), SystemError> { + if let Some(AbortAndSuspendSignalers { abort_tx, .. }) = + self.abort_and_suspend_map.remove(&task_id) + { + let (tx, rx) = oneshot::channel(); + + if abort_tx.send(tx).is_err() { + debug!( + "Failed to send force abortion request, the task probably finished before we could abort it: \ + ", + self.worker_id + ); + + Ok(()) + } else { + match timeout(ONE_SECOND, rx).await { + Ok(Ok(res)) => res, + // If the sender was dropped, then the task finished before we could + // abort it which is fine + Ok(Err(_)) => Ok(()), + Err(_) => Err(SystemError::TaskForcedAbortTimeout(task_id)), + } + } + } else { + trace!( + "Forced abortion of a not running task request: ", + self.worker_id + ); + + if let Some(current_task) = &self.current_task_handle { + if current_task.task_id == task_id { + trace!( + "Task began to run before we managed to abort it, run function will abort it: \ + ", + self.worker_id + ); + return Ok(()); // The task will abort itself + } + } + + if let Some(suspended_task) = &self.suspended_task { + if suspended_task.task.id() == task_id { + trace!( + "Task is already suspended but will be paused: ", + self.worker_id + ); + + send_forced_abortion_task_response( + self.worker_id, + task_id, + self.suspended_task.take().expect("we just checked it"), + ); + + return Ok(()); + } + } + + if let Some(index) = self + .priority_tasks + .iter() + .position(|task_work_state| task_work_state.task.id() == task_id) + { + send_forced_abortion_task_response( + self.worker_id, + task_id, + self.priority_tasks + .remove(index) + .expect("we just checked it"), + ); + + return Ok(()); + } + + if let Some(index) = self + .tasks + .iter() + .position(|task_work_state| task_work_state.task.id() == task_id) + { + send_forced_abortion_task_response( + self.worker_id, + task_id, + self.tasks.remove(index).expect("we just checked it"), + ); + + return Ok(()); + } + + // If the task is not found, then it's possible that the user already aborted it but still have the handle + Ok(()) + } + } + + pub(super) async fn shutdown(mut self, tx: oneshot::Sender<()>) { + trace!( + "Worker beginning shutdown process: ", + self.worker_id + ); + + trace!( + "Aborting steal task for shutdown if there is one running: ", + self.worker_id + ); + + self.abort_steal_task(); + + let Runner { + worker_id, + tasks, + paused_tasks, + priority_tasks, + is_idle, + abort_and_suspend_map, + runner_tx, + mut current_task_handle, + suspend_on_shutdown_rx, + .. + } = self; + + let mut suspend_on_shutdown_rx = pin!(suspend_on_shutdown_rx); + + if !is_idle { + trace!("Worker is busy, will shutdown tasks: "); + + if let Some(RunningTask { + task_id, handle, .. + }) = current_task_handle.take() + { + abort_and_suspend_map.into_iter().for_each( + |(task_id, AbortAndSuspendSignalers { suspend_tx, .. })| { + if suspend_tx.send(()).is_err() { + warn!( + "Shutdown request channel closed before sending abort signal: \ + " + ); + } else { + trace!( + "Sent suspend signal for task on shutdown: \ + " + ); + } + }, + ); + + if let Err(e) = handle.await { + error!("Task failed to join: {e:#?}"); + } + + runner_tx.close(); + + while let Some(runner_msg) = suspend_on_shutdown_rx.next().await { + match runner_msg { + RunnerMessage::TaskOutput(task_id, res) => match res { + Ok(TaskRunnerOutput { + task_work_state, + status, + }) => match status { + InternalTaskExecStatus::Done(out) => send_complete_task_response( + self.worker_id, + task_id, + task_work_state, + out, + ), + + InternalTaskExecStatus::Canceled => { + send_cancel_task_response(worker_id, task_id, task_work_state) + } + + InternalTaskExecStatus::Suspend + | InternalTaskExecStatus::Paused => { + send_shutdown_task_response( + worker_id, + task_id, + task_work_state, + ); + } + + InternalTaskExecStatus::Error(e) => { + send_error_task_response( + worker_id, + task_id, + task_work_state, + e, + ); + } + }, + Err(()) => { + error!( + "Task failed to suspend on shutdown" + ); + } + }, + + RunnerMessage::StoleTask(Some(task_work_state)) => { + send_shutdown_task_response(worker_id, task_id, task_work_state); + } + + RunnerMessage::StoleTask(None) => {} + } + } + } + + priority_tasks + .into_iter() + .chain(paused_tasks.into_values()) + .chain(tasks.into_iter()) + .for_each(|task_work_state| { + send_shutdown_task_response( + worker_id, + task_work_state.task.id(), + task_work_state, + ); + }) + } else { + trace!("Worker is idle, no tasks to shutdown: "); + } + + trace!("Worker shutdown process completed: "); + + if tx.send(()).is_err() { + warn!("Shutdown request channel closed before sending ack"); + } + } + + pub(super) fn get_next_task(&mut self) -> Option<(PendingTaskKind, TaskWorkState)> { + if let Some(task) = self.priority_tasks.pop_front() { + return Some((PendingTaskKind::Priority, task)); + } + + if let Some(task) = self.suspended_task.take() { + task.interrupter.reset(); + task.worktable.set_unpause(); + return Some((PendingTaskKind::Suspended, task)); + } + + self.tasks + .pop_front() + .map(|task| (PendingTaskKind::Normal, task)) + } + + pub(super) fn steal_request(&mut self, tx: oneshot::Sender>>) { + trace!("Steal request: ", self.worker_id); + if let Some((kind, task)) = self.get_next_task() { + let task_id = task.task.id(); + self.task_kinds.remove(&task_id); + + trace!( + "Stealing task: ", + self.worker_id + ); + + if let Err(Some(task)) = tx.send(Some(task)) { + warn!( + "Steal request channel closed before sending task: ", + self.worker_id + ); + match kind { + PendingTaskKind::Normal => self.tasks.push_front(task), + PendingTaskKind::Priority => self.priority_tasks.push_front(task), + PendingTaskKind::Suspended => self.suspended_task = Some(task), + } + + self.task_kinds.insert(task_id, kind); + } + } else { + trace!("No task to steal: ", self.worker_id); + if tx.send(None).is_err() { + warn!( + "Steal request channel closed before sending no task response: \ + ", + self.worker_id + ); + } + } + } + + pub(super) async fn wake_up(&mut self) { + if self.is_idle { + trace!( + "Worker is idle, waking up: ", + self.worker_id + ); + + if self.current_steal_task_handle.is_none() { + self.current_steal_task_handle = Some(dispatch_steal_request( + self.worker_id, + self.work_stealer.clone(), + self.runner_tx.clone(), + )); + } else { + trace!( + "Steal task already running, ignoring wake up request: ", + self.worker_id + ); + } + } else { + trace!( + "Worker already working, ignoring wake up request: ", + self.worker_id + ); + } + } + + #[inline(always)] + pub(super) async fn dispatch_next_task(&mut self, finished_task_id: TaskId) { + trace!( + "Task finished and will try to process a new task: \ + ", + self.worker_id + ); + + self.abort_and_suspend_map.remove(&finished_task_id); + + let RunningTask { + task_id: old_task_id, + + handle, + .. + } = self + .current_task_handle + .take() + .expect("Task handle missing, but task output received"); + + assert_eq!(finished_task_id, old_task_id, "Task output id mismatch"); + + trace!( + "Waiting task handle: ", + self.worker_id + ); + if let Err(e) = handle.await { + error!("Task failed to join: {e:#?}"); + } + trace!( + "Waited task handle: ", + self.worker_id + ); + + if let Some((task_kind, task_work_state)) = self.get_next_task() { + let task_id = task_work_state.task.id(); + + trace!( + "Dispatching next task: ", + self.worker_id + ); + + let handle = self.spawn_task_runner(task_id, task_work_state); + + self.current_task_handle = Some(RunningTask { + task_id, + task_kind, + handle, + }); + } else { + trace!( + "No task to dispatch, worker is now idle and will dispatch a steal request: ", + self.worker_id + ); + + self.is_idle = true; + self.system_comm.idle_report(self.worker_id).await; + + if self.current_steal_task_handle.is_none() { + self.current_steal_task_handle = Some(dispatch_steal_request( + self.worker_id, + self.work_stealer.clone(), + self.runner_tx.clone(), + )); + } else { + trace!( + "Steal task already running: ", + self.worker_id + ); + } + } + } + + pub(super) async fn process_task_output( + &mut self, + task_id: TaskId, + TaskRunnerOutput { + task_work_state, + status, + }: TaskRunnerOutput, + ) { + match status { + InternalTaskExecStatus::Done(out) => { + send_complete_task_response(self.worker_id, task_id, task_work_state, out) + } + + InternalTaskExecStatus::Paused => { + self.paused_tasks.insert(task_id, task_work_state); + trace!( + "Task paused: ", + self.worker_id + ); + } + + InternalTaskExecStatus::Canceled => { + send_cancel_task_response(self.worker_id, task_id, task_work_state) + } + + InternalTaskExecStatus::Error(e) => { + send_error_task_response(self.worker_id, task_id, task_work_state, e) + } + + InternalTaskExecStatus::Suspend => { + self.suspended_task = Some(task_work_state); + trace!( + "Task suspended: ", + self.worker_id + ); + + self.clean_suspended_task(task_id); + } + } + + trace!( + "Processing task output completed and will try to dispatch a new task: \ + ", + self.worker_id + ); + + self.dispatch_next_task(task_id).await; + } + + pub(super) async fn idle_check(&mut self) { + if self.is_idle { + trace!( + "Worker is idle for some time and will try to steal a task: ", + self.worker_id + ); + + if self.current_steal_task_handle.is_none() { + let elapsed = self.last_steal_attempt_at.elapsed(); + let required = (TEN_SECONDS * self.steal_attempts_count).min(ONE_MINUTE); + trace!( + "Steal attempt required cool down: ", + self.worker_id, self.steal_attempts_count); + if elapsed > required { + self.current_steal_task_handle = Some(dispatch_steal_request( + self.worker_id, + self.work_stealer.clone(), + self.runner_tx.clone(), + )); + self.last_steal_attempt_at = Instant::now(); + } else { + trace!( + "Steal attempt still cooling down: ", + self.worker_id, + self.steal_attempts_count + ); + } + } else { + trace!( + "Steal task already running, ignoring on this idle check: ", + self.worker_id + ); + } + + // As we're idle, let's check if we need to do some memory cleanup + if self.tasks.capacity() > TASK_QUEUE_INITIAL_SIZE { + assert_eq!(self.tasks.len(), 0); + self.tasks.shrink_to(TASK_QUEUE_INITIAL_SIZE); + } + + if self.task_kinds.capacity() > TASK_QUEUE_INITIAL_SIZE { + assert_eq!(self.task_kinds.len(), 0); + self.task_kinds.shrink_to(TASK_QUEUE_INITIAL_SIZE); + } + + if self.priority_tasks.capacity() > PRIORITY_TASK_QUEUE_INITIAL_SIZE { + assert_eq!(self.priority_tasks.len(), 0); + self.priority_tasks + .shrink_to(PRIORITY_TASK_QUEUE_INITIAL_SIZE); + } + + if self.paused_tasks.capacity() != self.paused_tasks.len() { + self.paused_tasks.shrink_to_fit(); + } + + if self.abort_and_suspend_map.capacity() > ABORT_AND_SUSPEND_MAP_INITIAL_SIZE { + assert!(self.abort_and_suspend_map.len() < ABORT_AND_SUSPEND_MAP_INITIAL_SIZE); + self.abort_and_suspend_map + .shrink_to(ABORT_AND_SUSPEND_MAP_INITIAL_SIZE); + } + } + } + + pub(super) fn abort_steal_task(&mut self) { + if let Some(steal_task_handle) = self.current_steal_task_handle.take() { + steal_task_handle.abort(); + trace!("Aborted steal task: ", self.worker_id); + } else { + trace!("No steal task to abort: ", self.worker_id); + } + } + + pub(super) async fn process_stolen_task(&mut self, maybe_new_task: Option>) { + if let Some(steal_task_handle) = self.current_steal_task_handle.take() { + if let Err(e) = steal_task_handle.await { + error!("Steal task failed to join: {e:#?}"); + } + } + + if let Some(task_work_state) = maybe_new_task { + self.system_comm.working_report(self.worker_id).await; + trace!( + "Stolen task: ", + self.worker_id, + task_work_state.task.id() + ); + self.steal_attempts_count = 0; + self.new_task(task_work_state).await; + } else { + self.steal_attempts_count += 1; + } + } + + pub(crate) fn clean_suspended_task(&mut self, task_id: uuid::Uuid) { + match self.waiting_suspension { + WaitingSuspendedTask::Task(waiting_task_id) if waiting_task_id == task_id => { + trace!( + "Task was suspended and will be cleaned: ", + self.worker_id + ); + self.waiting_suspension = WaitingSuspendedTask::None; + } + WaitingSuspendedTask::Task(_) => { + trace!( + "Task wasn't suspended, ignoring: ", + self.worker_id + ); + } + WaitingSuspendedTask::None => {} + } + } +} + +async fn run_single_task( + worker_id: WorkerId, + TaskWorkState { + mut task, + worktable, + interrupter, + done_tx, + }: TaskWorkState, + runner_tx: chan::Sender>, + suspend_rx: oneshot::Receiver<()>, + abort_rx: oneshot::Receiver>>, +) { + let task_id = task.id(); + + worktable.set_started(); + + trace!("Running task: "); + + let handle = spawn({ + let interrupter = Arc::clone(&interrupter); + + let already_paused = worktable.is_paused(); + let already_canceled = worktable.is_canceled(); + let already_aborted = worktable.is_aborted(); + + async move { + if already_paused { + trace!( + "Task was paused before running: " + ); + + (task, Ok(Ok(ExecStatus::Paused))) + } else if already_canceled { + trace!( + "Task was canceled before running: " + ); + + (task, Ok(Ok(ExecStatus::Canceled))) + } else if already_aborted { + trace!( + "Task was aborted before running: " + ); + + (task, Err(SystemError::TaskAborted(task_id))) + } else { + let res = task.run(&interrupter).await; + + trace!("Ran task: : {res:?}"); + + (task, Ok(res)) + } + } + }); + + let task_abort_handle = handle.abort_handle(); + + let has_suspended = Arc::new(AtomicBool::new(false)); + + let suspender_handle = spawn({ + let has_suspended = Arc::clone(&has_suspended); + let worktable = Arc::clone(&worktable); + async move { + if suspend_rx.await.is_ok() { + let (tx, rx) = oneshot::channel(); + + trace!("Suspend signal received: "); + + // The interrupter only knows about Pause and Cancel commands, we use pause as + // the suspend task feature should be invisible to the user + worktable.pause(tx).await; + + match rx.await { + Ok(Ok(())) => { + trace!("Suspending: "); + has_suspended.store(true, Ordering::Relaxed); + } + Ok(Err(e)) => { + error!( + "Task failed to suspend: {e:#?}", + ); + } + Err(_) => { + // The task probably finished before we could suspend it so the channel was dropped + trace!("Suspend channel closed: "); + } + } + } else { + trace!( + "Suspend channel closed, task probably finished before we could suspend it: \ + " + ); + } + } + }); + + type SpawnedTaskRunOutput = (Box>, Result, SystemError>); + + enum RaceOutput { + Completed(Result, JoinError>), + Abort(oneshot::Sender>), + } + + match (async { RaceOutput::Completed(handle.await) }, async move { + if let Ok(tx) = abort_rx.await { + trace!("Aborting task: "); + RaceOutput::Abort(tx) + } else { + // If the abort channel is closed, we should just ignore it and keep waiting for the task to finish + // as we're being suspended by the worker + trace!( + "Abort channel closed, will wait for task to finish: " + ); + pending().await + } + }) + .race() + .await + { + RaceOutput::Completed(Ok((task, Ok(res)))) => { + trace!( + "Task completed ok: " + ); + runner_tx + .send(RunnerMessage::TaskOutput(task_id, { + let mut internal_status = res.into(); + + if matches!(internal_status, InternalTaskExecStatus::Paused) + && has_suspended.load(Ordering::Relaxed) + { + internal_status = InternalTaskExecStatus::Suspend; + } + + Ok(TaskRunnerOutput { + task_work_state: TaskWorkState { + task, + worktable, + interrupter, + done_tx, + }, + status: internal_status, + }) + })) + .await + .expect("Task runner channel closed while sending task output"); + } + + RaceOutput::Completed(Ok((_, Err(e)))) => { + trace!("Task had an error: "); + + if done_tx + .send(if matches!(e, SystemError::TaskAborted(_)) { + Ok(TaskStatus::ForcedAbortion) + } else { + Err(e) + }) + .is_err() + { + error!("Task done channel closed while sending error response"); + } + + runner_tx + .send(RunnerMessage::TaskOutput(task_id, Err(()))) + .await + .expect("Task runner channel closed while sending task output"); + } + + RaceOutput::Completed(Err(join_error)) => { + error!("Task failed to join: {join_error:#?}",); + if done_tx.send(Err(SystemError::TaskJoin(task_id))).is_err() { + error!("Task done channel closed while sending join error response"); + } + + if runner_tx + .send(RunnerMessage::TaskOutput(task_id, Err(()))) + .await + .is_err() + { + error!("Task runner channel closed while sending join error response"); + } + } + + RaceOutput::Abort(tx) => { + task_abort_handle.abort(); + + trace!("Task aborted: "); + + if done_tx.send(Ok(TaskStatus::ForcedAbortion)).is_err() { + error!("Task done channel closed while sending abort error response"); + } + + if runner_tx + .send(RunnerMessage::TaskOutput(task_id, Err(()))) + .await + .is_err() + { + error!("Task runner channel closed while sending abort error response"); + } + + if tx.send(Ok(())).is_err() { + error!("Task abort channel closed while sending abort error response"); + } + } + } + + if !suspender_handle.is_finished() { + trace!( + "Aborting suspender handler as it isn't needed anymore: " + ); + // if we received a suspend signal this abort will do nothing, as the task finished already + suspender_handle.abort(); + } + + trace!("Run single task finished: "); +} + +fn send_complete_task_response( + worker_id: WorkerId, + task_id: TaskId, + TaskWorkState { + done_tx, worktable, .. + }: TaskWorkState, + out: TaskOutput, +) { + worktable.set_completed(); + if done_tx.send(Ok(TaskStatus::Done(out))).is_err() { + warn!( + "Task done channel closed before sending done response for task: \ + " + ); + } else { + trace!( + "Emitted task done signal on shutdown: \ + " + ); + } +} + +fn send_cancel_task_response( + worker_id: WorkerId, + task_id: TaskId, + TaskWorkState { + done_tx, worktable, .. + }: TaskWorkState, +) { + worktable.set_completed(); + if done_tx.send(Ok(TaskStatus::Canceled)).is_err() { + warn!( + "Task done channel closed before sending canceled response for task: \ + ", + ); + } else { + trace!( + "Emitted task canceled signal on cancel not running task: \ + ", + ); + } +} + +fn send_shutdown_task_response( + worker_id: WorkerId, + task_id: TaskId, + TaskWorkState { task, done_tx, .. }: TaskWorkState, +) { + if done_tx.send(Ok(TaskStatus::Shutdown(task))).is_err() { + warn!( + "Task done channel closed before sending shutdown response for task: \ + " + ); + } else { + trace!( + "Successfully suspended and sent back DynTask on worker shutdown: \ + " + ); + } +} + +fn send_error_task_response( + worker_id: usize, + task_id: uuid::Uuid, + TaskWorkState { + done_tx, worktable, .. + }: TaskWorkState, + e: E, +) { + worktable.set_completed(); + if done_tx.send(Ok(TaskStatus::Error(e))).is_err() { + warn!( + "Task done channel closed before sending error response for task: \ + " + ); + } else { + trace!( + "Emitted task error signal on shutdown: \ + " + ); + } +} + +fn send_forced_abortion_task_response( + worker_id: WorkerId, + task_id: TaskId, + TaskWorkState { + done_tx, worktable, .. + }: TaskWorkState, +) { + worktable.set_completed(); + if done_tx.send(Ok(TaskStatus::ForcedAbortion)).is_err() { + warn!( + "Task done channel closed before sending forced abortion response for task: \ + ", + ); + } else { + trace!( + "Emitted task forced abortion signal on cancel not running task: \ + ", + ); + } +} diff --git a/crates/task-system/tests/common/actors.rs b/crates/task-system/tests/common/actors.rs new file mode 100644 index 000000000..f166b70c1 --- /dev/null +++ b/crates/task-system/tests/common/actors.rs @@ -0,0 +1,389 @@ +use sd_task_system::{ + ExecStatus, Interrupter, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, +}; + +use std::{ + path::{Path, PathBuf}, + sync::Arc, + time::Duration, +}; + +use async_channel as chan; +use async_trait::async_trait; +use futures::stream::{self, FuturesUnordered, StreamExt}; +use futures_concurrency::future::Race; +use serde::{Deserialize, Serialize}; +use tokio::{fs, spawn, sync::broadcast}; +use tracing::{error, info, trace, warn}; + +use crate::common::tasks::TimedTaskOutput; + +use super::tasks::{SampleError, TimeTask}; + +const SAMPLE_ACTOR_SAVE_STATE_FILE_NAME: &str = "sample_actor_save_state.bin"; + +pub struct SampleActor { + data: Arc, // Can hold any kind of actor data, like an AI model + task_dispatcher: TaskDispatcher, + task_handles_tx: chan::Sender>, +} + +impl SampleActor { + pub async fn new( + data_directory: impl AsRef, + data: String, + task_dispatcher: TaskDispatcher, + ) -> (Self, broadcast::Receiver<()>) { + let (task_handles_tx, task_handles_rx) = chan::bounded(8); + + let (idle_tx, idle_rx) = broadcast::channel(1); + + let save_state_file_path = data_directory + .as_ref() + .join(SAMPLE_ACTOR_SAVE_STATE_FILE_NAME); + + let data = Arc::new(data); + + let pending_tasks = fs::read(&save_state_file_path) + .await + .map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + info!("No saved actor tasks found"); + } else { + error!("Failed to read saved actor tasks: {e:#?}"); + } + }) + .ok() + .and_then(|data| { + rmp_serde::from_slice::>(&data) + .map_err(|e| { + error!("Failed to deserialize saved actor tasks: {e:#?}"); + }) + .ok() + }) + .unwrap_or_default(); + + spawn(Self::run(save_state_file_path, task_handles_rx, idle_tx)); + + for SampleActorTaskSaveState { + id, + duration, + has_priority, + paused_count, + } in pending_tasks + { + task_handles_tx + .send(if has_priority { + task_dispatcher + .dispatch(SampleActorTaskWithPriority::with_id( + id, + duration, + Arc::clone(&data), + paused_count, + )) + .await + } else { + task_dispatcher + .dispatch(SampleActorTask::with_id( + id, + duration, + Arc::clone(&data), + paused_count, + )) + .await + }) + .await + .expect("Task handle receiver dropped"); + } + + ( + Self { + data, + task_dispatcher, + task_handles_tx, + }, + idle_rx, + ) + } + + pub fn new_task(&self, duration: Duration) -> SampleActorTask { + SampleActorTask::new(duration, Arc::clone(&self.data)) + } + + pub fn new_priority_task(&self, duration: Duration) -> SampleActorTaskWithPriority { + SampleActorTaskWithPriority::new(duration, Arc::clone(&self.data)) + } + + async fn inner_process(&self, duration: Duration, has_priority: bool) { + self.task_handles_tx + .send(if has_priority { + self.task_dispatcher + .dispatch(self.new_priority_task(duration)) + .await + } else { + self.task_dispatcher.dispatch(self.new_task(duration)).await + }) + .await + .expect("Task handle receiver dropped"); + } + + pub async fn process(&self, duration: Duration) { + self.inner_process(duration, false).await + } + + pub async fn process_with_priority(&self, duration: Duration) { + self.inner_process(duration, true).await + } + + async fn run( + save_state_file_path: PathBuf, + task_handles_rx: chan::Receiver>, + idle_tx: broadcast::Sender<()>, + ) { + let mut handles = FuturesUnordered::>::new(); + + enum RaceOutput { + NewHandle(TaskHandle), + CompletedHandle, + Stop(Option>>), + } + + let mut pending = 0usize; + + loop { + match ( + async { + if let Ok(handle) = task_handles_rx.recv().await { + RaceOutput::NewHandle(handle) + } else { + RaceOutput::Stop(None) + } + }, + async { + if let Some(out) = handles.next().await { + match out { + Ok(TaskStatus::Done(maybe_out)) => { + if let TaskOutput::Out(out) = maybe_out { + info!( + "Task completed: {:?}", + out.downcast::() + .expect("we know the task type") + ); + } + } + Ok(TaskStatus::Canceled) => { + trace!("Task was canceled") + } + Ok(TaskStatus::ForcedAbortion) => { + warn!("Task was forcibly aborted"); + } + Ok(TaskStatus::Shutdown(task)) => { + // If a task was shutdown, it means the task system is shutting down + // so all other tasks will also be shutdown + + return RaceOutput::Stop(Some(task)); + } + Ok(TaskStatus::Error(e)) => { + error!("Task failed: {e:#?}"); + } + Err(e) => { + error!("Task system failed: {e:#?}"); + } + } + + RaceOutput::CompletedHandle + } else { + RaceOutput::Stop(None) + } + }, + ) + .race() + .await + { + RaceOutput::NewHandle(handle) => { + pending += 1; + info!("Received new task handle, total pending tasks: {pending}"); + handles.push(handle); + } + RaceOutput::CompletedHandle => { + pending -= 1; + info!("Task completed, total pending tasks: {pending}"); + if pending == 0 { + info!("All tasks completed, sending idle report..."); + idle_tx.send(()).expect("idle receiver dropped"); + } + } + RaceOutput::Stop(maybe_task) => { + task_handles_rx.close(); + task_handles_rx + .for_each(|handle| async { handles.push(handle) }) + .await; + + let tasks = stream::iter( + maybe_task + .into_iter() + .map(SampleActorTaskSaveState::from_task), + ) + .chain(handles.filter_map(|handle| async move { + match handle { + Ok(TaskStatus::Done(maybe_out)) => { + if let TaskOutput::Out(out) = maybe_out { + info!( + "Task completed: {:?}", + out.downcast::() + .expect("we know the task type") + ); + } + + None + } + Ok(TaskStatus::Canceled) => None, + Ok(TaskStatus::ForcedAbortion) => { + warn!("Task was forcibly aborted"); + None + } + Ok(TaskStatus::Shutdown(task)) => { + Some(SampleActorTaskSaveState::from_task(task)) + } + Ok(TaskStatus::Error(e)) => { + error!("Task failed: {e:#?}"); + None + } + Err(e) => { + error!("Task system failed: {e:#?}"); + None + } + } + })) + .collect::>() + .await; + + if let Err(e) = fs::write( + &save_state_file_path, + rmp_serde::to_vec_named(&tasks).expect("failed to serialize"), + ) + .await + { + error!("Failed to save actor tasks: {e:#?}"); + } + + return; + } + } + } + } +} + +impl Drop for SampleActor { + fn drop(&mut self) { + self.task_handles_tx.close(); + } +} + +#[derive(Debug)] +pub struct SampleActorTask { + timed_task: TimeTask, + actor_data: Arc, // Can hold any kind of actor data +} + +impl SampleActorTask { + pub fn new(duration: Duration, actor_data: Arc) -> Self { + Self { + timed_task: TimeTask::new(duration, false), + actor_data, + } + } + + fn with_id(id: TaskId, duration: Duration, actor_data: Arc, paused_count: u32) -> Self { + Self { + timed_task: TimeTask::with_id(id, duration, false, paused_count), + actor_data, + } + } +} + +#[derive(Debug)] +pub struct SampleActorTaskWithPriority { + timed_task: TimeTask, + actor_data: Arc, // Can hold any kind of actor data +} +impl SampleActorTaskWithPriority { + fn new(duration: Duration, actor_data: Arc) -> SampleActorTaskWithPriority { + Self { + timed_task: TimeTask::new(duration, true), + actor_data, + } + } + + fn with_id(id: TaskId, duration: Duration, actor_data: Arc, paused_count: u32) -> Self { + Self { + timed_task: TimeTask::with_id(id, duration, true, paused_count), + actor_data, + } + } +} + +#[async_trait] +impl Task for SampleActorTask { + fn id(&self) -> TaskId { + self.timed_task.id() + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + info!("Actor data: {:#?}", self.actor_data); + self.timed_task.run(interrupter).await + } + + fn with_priority(&self) -> bool { + self.timed_task.with_priority() + } +} + +#[async_trait] +impl Task for SampleActorTaskWithPriority { + fn id(&self) -> TaskId { + self.timed_task.id() + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + info!("Actor data: {:#?}", self.actor_data); + self.timed_task.run(interrupter).await + } + + fn with_priority(&self) -> bool { + self.timed_task.with_priority() + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct SampleActorTaskSaveState { + id: TaskId, + duration: Duration, + has_priority: bool, + paused_count: u32, +} + +impl SampleActorTaskSaveState { + fn from_task(dyn_task: Box>) -> Self { + match dyn_task.downcast::() { + Ok(concrete_task) => SampleActorTaskSaveState { + id: concrete_task.timed_task.id(), + duration: concrete_task.timed_task.duration, + has_priority: false, + paused_count: concrete_task.timed_task.paused_count, + }, + Err(dyn_task) => { + let concrete_task = dyn_task + .downcast::() + .expect("we know the task type"); + + SampleActorTaskSaveState { + id: concrete_task.timed_task.id(), + duration: concrete_task.timed_task.duration, + has_priority: true, + paused_count: concrete_task.timed_task.paused_count, + } + } + } + } +} diff --git a/crates/task-system/tests/common/jobs.rs b/crates/task-system/tests/common/jobs.rs new file mode 100644 index 000000000..9792fa943 --- /dev/null +++ b/crates/task-system/tests/common/jobs.rs @@ -0,0 +1,119 @@ +use async_trait::async_trait; +use futures_concurrency::future::FutureGroup; +use lending_stream::{LendingStream, StreamExt}; +use sd_task_system::{ + ExecStatus, Interrupter, IntoAnyTaskOutput, Task, TaskDispatcher, TaskHandle, TaskId, + TaskOutput, TaskStatus, +}; +use tracing::trace; + +use super::tasks::SampleError; + +#[derive(Debug)] +pub struct SampleJob { + total_steps: u32, + task_dispatcher: TaskDispatcher, +} + +impl SampleJob { + pub fn new(total_steps: u32, task_dispatcher: TaskDispatcher) -> Self { + Self { + total_steps, + task_dispatcher, + } + } + + pub async fn run(self) -> Result<(), SampleError> { + let Self { + total_steps, + task_dispatcher, + } = self; + + let initial_steps = (0..task_dispatcher.workers_count()) + .map(|_| SampleJobTask { + id: TaskId::new_v4(), + expected_children: total_steps - 1, + task_dispatcher: task_dispatcher.clone(), + }) + .collect::>(); + + let mut group = FutureGroup::from_iter( + task_dispatcher + .dispatch_many(initial_steps) + .await + .into_iter(), + ) + .lend_mut(); + + while let Some((group, res)) = group.next().await { + match res.unwrap() { + TaskStatus::Done(TaskOutput::Out(out)) => { + group.insert( + out.downcast::() + .expect("we know the output type") + .children_handle, + ); + trace!("Received more tasks to wait for ({} left)", group.len()); + } + TaskStatus::Done(TaskOutput::Empty) => { + trace!( + "Step done, waiting for all children to finish ({} left)", + group.len() + ); + } + + TaskStatus::Canceled => { + trace!("Task was canceled"); + } + TaskStatus::ForcedAbortion => { + trace!("Aborted") + } + TaskStatus::Shutdown(task) => { + trace!("Task was shutdown: {:?}", task); + } + TaskStatus::Error(e) => return Err(e), + } + } + + Ok(()) + } +} + +#[derive(Debug)] +struct SampleJobTask { + id: TaskId, + expected_children: u32, + task_dispatcher: TaskDispatcher, +} + +#[derive(Debug)] +struct Output { + children_handle: TaskHandle, +} + +#[async_trait] +impl Task for SampleJobTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, _interrupter: &Interrupter) -> Result { + if self.expected_children > 0 { + Ok(ExecStatus::Done( + Output { + children_handle: self + .task_dispatcher + .dispatch(SampleJobTask { + id: TaskId::new_v4(), + expected_children: self.expected_children - 1, + task_dispatcher: self.task_dispatcher.clone(), + }) + .await, + } + .into_output(), + )) + } else { + Ok(ExecStatus::Done(TaskOutput::Empty)) + } + } +} diff --git a/crates/task-system/tests/common/mod.rs b/crates/task-system/tests/common/mod.rs new file mode 100644 index 000000000..c94169a48 --- /dev/null +++ b/crates/task-system/tests/common/mod.rs @@ -0,0 +1,3 @@ +pub mod actors; +pub mod jobs; +pub mod tasks; diff --git a/crates/task-system/tests/common/tasks.rs b/crates/task-system/tests/common/tasks.rs new file mode 100644 index 000000000..3d556ee07 --- /dev/null +++ b/crates/task-system/tests/common/tasks.rs @@ -0,0 +1,278 @@ +use std::{future::pending, time::Duration}; + +use sd_task_system::{ + ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, Task, TaskId, TaskOutput, +}; + +use async_trait::async_trait; +use futures_concurrency::future::Race; +use thiserror::Error; +use tokio::{ + sync::oneshot, + time::{sleep, Instant}, +}; +use tracing::{error, info}; + +#[derive(Debug, Error)] +pub enum SampleError { + #[error("Sample error")] + SampleError, +} + +#[derive(Debug)] +pub struct NeverTask { + id: TaskId, +} + +impl Default for NeverTask { + fn default() -> Self { + Self { + id: TaskId::new_v4(), + } + } +} + +#[async_trait] +impl Task for NeverTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + match interrupter.await { + InterruptionKind::Pause => { + info!("Pausing NeverTask ", self.id); + Ok(ExecStatus::Paused) + } + InterruptionKind::Cancel => { + info!("Canceling NeverTask ", self.id); + Ok(ExecStatus::Canceled) + } + } + } +} + +#[derive(Debug)] +pub struct ReadyTask { + id: TaskId, +} + +impl Default for ReadyTask { + fn default() -> Self { + Self { + id: TaskId::new_v4(), + } + } +} + +#[async_trait] +impl Task for ReadyTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, _interrupter: &Interrupter) -> Result { + Ok(ExecStatus::Done(TaskOutput::Empty)) + } +} + +#[derive(Debug)] +pub struct BogusTask { + id: TaskId, +} + +impl Default for BogusTask { + fn default() -> Self { + Self { + id: TaskId::new_v4(), + } + } +} + +#[async_trait] +impl Task for BogusTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, _interrupter: &Interrupter) -> Result { + Err(SampleError::SampleError) + } +} + +#[derive(Debug)] +pub struct TimeTask { + id: TaskId, + pub duration: Duration, + priority: bool, + pub paused_count: u32, +} + +impl TimeTask { + pub fn new(duration: Duration, priority: bool) -> Self { + Self { + id: TaskId::new_v4(), + duration, + priority, + paused_count: 0, + } + } + + pub fn with_id(id: TaskId, duration: Duration, priority: bool, paused_count: u32) -> Self { + Self { + id, + duration, + priority, + paused_count, + } + } +} + +#[derive(Debug)] +pub struct TimedTaskOutput { + pub pauses_count: u32, +} + +#[async_trait] +impl Task for TimeTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + let start = Instant::now(); + + info!("Running timed task for {:#?}", self.duration); + + enum RaceOutput { + Paused(Duration), + Canceled, + Completed, + } + + let task_work_fut = async { + sleep(self.duration).await; + RaceOutput::Completed + }; + + let interrupt_fut = async { + let elapsed = start.elapsed(); + match interrupter.await { + InterruptionKind::Pause => RaceOutput::Paused(if elapsed < self.duration { + self.duration - elapsed + } else { + Duration::ZERO + }), + InterruptionKind::Cancel => RaceOutput::Canceled, + } + }; + + Ok(match (task_work_fut, interrupt_fut).race().await { + RaceOutput::Completed | RaceOutput::Paused(Duration::ZERO) => ExecStatus::Done( + TimedTaskOutput { + pauses_count: self.paused_count, + } + .into_output(), + ), + RaceOutput::Paused(remaining_duration) => { + self.duration = remaining_duration; + self.paused_count += 1; + ExecStatus::Paused + } + RaceOutput::Canceled => ExecStatus::Canceled, + }) + } + + fn with_priority(&self) -> bool { + self.priority + } +} + +#[derive(Debug)] +pub struct PauseOnceTask { + id: TaskId, + has_paused: bool, + began_tx: Option>, +} + +impl PauseOnceTask { + pub fn new() -> (Self, oneshot::Receiver<()>) { + let (tx, rx) = oneshot::channel(); + ( + Self { + id: TaskId::new_v4(), + has_paused: false, + began_tx: Some(tx), + }, + rx, + ) + } +} + +#[async_trait] +impl Task for PauseOnceTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + if let Some(began_tx) = self.began_tx.take() { + if began_tx.send(()).is_err() { + error!("Failed to send began signal"); + } + } + + if !self.has_paused { + self.has_paused = true; + match interrupter.await { + InterruptionKind::Pause => { + info!("Pausing PauseOnceTask ", self.id); + self.has_paused = true; + Ok(ExecStatus::Paused) + } + InterruptionKind::Cancel => { + info!("Canceling PauseOnceTask ", self.id); + Ok(ExecStatus::Canceled) + } + } + } else { + Ok(ExecStatus::Done(TaskOutput::Empty)) + } + } +} + +#[derive(Debug)] +pub struct BrokenTask { + id: TaskId, + began_tx: Option>, +} + +impl BrokenTask { + pub fn new() -> (Self, oneshot::Receiver<()>) { + let (tx, rx) = oneshot::channel(); + ( + Self { + id: TaskId::new_v4(), + began_tx: Some(tx), + }, + rx, + ) + } +} + +#[async_trait] +impl Task for BrokenTask { + fn id(&self) -> TaskId { + self.id + } + + async fn run(&mut self, _: &Interrupter) -> Result { + if let Some(began_tx) = self.began_tx.take() { + if began_tx.send(()).is_err() { + error!("Failed to send began signal"); + } + } + + pending().await + } +} diff --git a/crates/task-system/tests/integration_test.rs b/crates/task-system/tests/integration_test.rs new file mode 100644 index 000000000..bf3ce697b --- /dev/null +++ b/crates/task-system/tests/integration_test.rs @@ -0,0 +1,224 @@ +use sd_task_system::{TaskOutput, TaskStatus, TaskSystem}; + +use std::{collections::VecDeque, time::Duration}; + +use futures_concurrency::future::Join; +use rand::Rng; +use tempfile::tempdir; +use tracing::info; +use tracing_test::traced_test; + +mod common; + +use common::{ + actors::SampleActor, + tasks::{BogusTask, BrokenTask, NeverTask, PauseOnceTask, ReadyTask, SampleError}, +}; + +use crate::common::jobs::SampleJob; + +#[tokio::test] +#[traced_test] +async fn test_actor() { + let data_dir = tempdir().unwrap(); + + let system = TaskSystem::new(); + + let (actor, mut actor_idle_rx) = + SampleActor::new(data_dir.path(), "test".to_string(), system.get_dispatcher()).await; + + let mut rng = rand::thread_rng(); + + for i in 0..=250 { + if rng.gen_bool(0.1) { + info!("dispatching priority task {i}"); + actor + .process_with_priority(Duration::from_millis(rng.gen_range(50..150))) + .await; + } else { + info!("dispatching task {i}"); + actor + .process(Duration::from_millis(rng.gen_range(200..500))) + .await; + } + } + + info!("all tasks dispatched, now we wait a bit..."); + + actor_idle_rx.recv().await.unwrap(); + + system.shutdown().await; + + info!("done"); +} + +#[tokio::test] +#[traced_test] +async fn shutdown_test() { + let system = TaskSystem::new(); + + let handle = system.dispatch(NeverTask::default()).await; + + system.shutdown().await; + + assert!(matches!(handle.await, Ok(TaskStatus::Shutdown(_)))); +} + +#[tokio::test] +#[traced_test] +async fn cancel_test() { + let system = TaskSystem::new(); + + let handle = system.dispatch(NeverTask::default()).await; + + info!("issuing cancel"); + handle.cancel().await.unwrap(); + + assert!(matches!(handle.await, Ok(TaskStatus::Canceled))); + + system.shutdown().await; +} + +#[tokio::test] +#[traced_test] +async fn done_test() { + let system = TaskSystem::new(); + + let handle = system.dispatch(ReadyTask::default()).await; + + assert!(matches!( + handle.await, + Ok(TaskStatus::Done(TaskOutput::Empty)) + )); + + system.shutdown().await; +} + +#[tokio::test] +#[traced_test] +async fn abort_test() { + let system = TaskSystem::new(); + + let (task, began_rx) = BrokenTask::new(); + + let handle = system.dispatch(task).await; + + began_rx.await.unwrap(); + + handle.force_abortion().await.unwrap(); + + assert!(matches!(handle.await, Ok(TaskStatus::ForcedAbortion))); + + system.shutdown().await; +} + +#[tokio::test] +#[traced_test] +async fn error_test() { + let system = TaskSystem::new(); + + let handle = system.dispatch(BogusTask::default()).await; + + assert!(matches!( + handle.await, + Ok(TaskStatus::Error(SampleError::SampleError)) + )); + + system.shutdown().await; +} + +#[tokio::test] +#[traced_test] +async fn pause_test() { + let system = TaskSystem::new(); + + let (task, began_rx) = PauseOnceTask::new(); + + let handle = system.dispatch(task).await; + + info!("Task dispatched, now we wait for it to begin..."); + + began_rx.await.unwrap(); + + handle.pause().await.unwrap(); + + info!("Paused task, now we resume it..."); + + handle.resume().await.unwrap(); + + info!("Resumed task, now we wait for it to complete..."); + + assert!(matches!( + handle.await, + Ok(TaskStatus::Done(TaskOutput::Empty)) + )); + + system.shutdown().await; +} + +#[tokio::test] +#[traced_test] +async fn jobs_test() { + let system = TaskSystem::new(); + + let task_dispatcher = system.get_dispatcher(); + + let job = SampleJob::new(256, task_dispatcher.clone()); + + job.run().await.unwrap(); + + system.shutdown().await; +} + +#[tokio::test] +#[traced_test] +async fn steal_test() { + let system = TaskSystem::new(); + + let workers_count = system.workers_count(); + + let (pause_tasks, pause_begans) = (0..workers_count) + .map(|_| PauseOnceTask::new()) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + // With this, all workers will be busy + let mut pause_handles = VecDeque::from(system.dispatch_many(pause_tasks).await); + + let ready_handles = system + .dispatch_many((0..100).map(|_| ReadyTask::default()).collect()) + .await; + + pause_begans + .into_iter() + .map(|began_rx| async move { began_rx.await.unwrap() }) + .collect::>() + .join() + .await; + + let first_paused_handle = pause_handles.pop_front().unwrap(); + + info!("All tasks dispatched, will now release the first one, so the first worker can steal everything..."); + + first_paused_handle.pause().await.unwrap(); + + first_paused_handle.resume().await.unwrap(); + + first_paused_handle.await.unwrap(); + + ready_handles.join().await.into_iter().for_each(|res| { + res.unwrap(); + }); + + pause_handles + .into_iter() + .map(|handle| async move { + handle.pause().await.unwrap(); + handle.resume().await.unwrap(); + handle.await.unwrap(); + }) + .collect::>() + .join() + .await; + + system.shutdown().await; +} diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 2d941149c..53f18aa1a 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -9,6 +9,6 @@ edition = "2021" sd-prisma = { path = "../prisma" } prisma-client-rust = { workspace = true } -rspc = { workspace = true } +rspc = { workspace = true, features = ["unstable"] } thiserror = { workspace = true } uuid = { workspace = true } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 6d833ff50..292fe499e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "1.75" +channel = "stable"