diff --git a/Cargo.toml b/Cargo.toml index 8adfcac..2a547c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ uuid = { version = "1.8.0", features = ["v4"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } tower-http = { version = "0.5.2", features = ["trace"] } +num_cpus = "1.16.0" diff --git a/src/dto.rs b/src/dto.rs index 963c2bf..6e2cbd4 100644 --- a/src/dto.rs +++ b/src/dto.rs @@ -8,18 +8,19 @@ pub struct Error { } #[derive(Serialize, Deserialize)] -pub struct TaskResponse { +pub struct ConvertResponse { pub id: Option, pub error: Option, } #[derive(TryFromMultipart)] #[try_from_multipart(rename_all = "camelCase")] -pub struct Task { +pub struct ConvertRequest { pub codec: String, pub bit_rate: usize, pub max_bit_rate: usize, pub channel_layout: String, + pub upload_url: String, #[form_data(limit = "25MiB")] pub file: FieldData, diff --git a/src/main.rs b/src/main.rs index bd4e847..4156540 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,14 @@ -use crate::server::serve; use std::env; + use tracing_subscriber::EnvFilter; +use crate::server::Server; +use crate::thread_pool::ThreadPool; + mod dto; mod server; +mod task; +mod thread_pool; #[tokio::main] async fn main() { @@ -12,6 +17,27 @@ async fn main() { .init(); let addr = env::var("LISTEN").unwrap_or_else(|_| "0.0.0.0:8090".to_string()); - - serve(&addr).await.expect("Cannot bind the addr") + let pool = ThreadPool::new(match env::var("NUM_WORKERS") { + Ok(val) => match val.parse::() { + Ok(val) => { + if val > 0 { + Some(val); + } + None + } + Err(_) => None, + }, + Err(_) => None, + }); + let temp_dir = env::var("TEMP_DIR").unwrap_or_else(|_| { + env::temp_dir() + .to_str() + .expect("Cannot get system temp directory") + .parse() + .unwrap() + }); + Server::new(pool, temp_dir) + .serve(&addr) + .await + .expect("Cannot bind the addr") } diff --git a/src/server.rs b/src/server.rs index 168c529..d9049ab 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,67 +1,109 @@ -use axum::extract::DefaultBodyLimit; -use axum::http::StatusCode; -use axum::routing::{get, post}; -use axum::{Json, Router}; -use axum_typed_multipart::TypedMultipart; use std::path::Path; +use std::sync::Arc; + +use axum::{debug_handler, Json, Router}; +use axum::extract::{DefaultBodyLimit, State}; +use axum::http::StatusCode; +use axum::routing::post; +use axum_typed_multipart::TypedMultipart; use tokio::net::TcpListener; use tower_http::trace::TraceLayer; use uuid::Uuid; -use crate::dto; -use crate::dto::{Task, TaskResponse}; +use crate::dto::{ConvertRequest, ConvertResponse}; +use crate::task::Task; +use crate::thread_pool::ThreadPool; const CONTENT_LENGTH_LIMIT: usize = 30 * 1024 * 1024; -pub async fn serve(addr: &str) -> std::io::Result<()> { - let app = Router::new() - .route( - "/enqueue", - post(enqueue_file).layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)), - ) - .route("/download", get(download_file)) - .layer(TraceLayer::new_for_http()); +pub struct Server { + thread_pool: Arc, + work_dir: String, +} - tracing::info!("listening on {addr}"); - let listener = match TcpListener::bind(addr).await { - Ok(listen) => listen, - Err(err) => return Err(err), - }; - axum::serve(listener, app).await +impl Server { + pub(crate) fn new(thread_pool: ThreadPool, work_dir: String) -> Server { + Server { + thread_pool: Arc::new(thread_pool), + work_dir, + } + } + + pub async fn serve(self, addr: &str) -> std::io::Result<()> { + let this = Arc::new(self); + let app = Router::new() + .route( + "/enqueue", + post(enqueue_file) + .layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)), + ) + .with_state(this) + .layer(TraceLayer::new_for_http()); + + tracing::info!("listening on {addr}"); + let listener = match TcpListener::bind(addr).await { + Ok(listen) => listen, + Err(err) => return Err(err), + }; + axum::serve(listener, app).await + } } async fn enqueue_file( - TypedMultipart(Task { file, .. }): TypedMultipart, -) -> (StatusCode, Json) { - let task_id = Uuid::new_v4().to_string(); - let path = Path::new( - std::env::temp_dir() - .to_str() - .expect("Cannot get temporary directory"), - ) - .join(format!("{}.bin", task_id)); + State(server): State>, + TypedMultipart(req): TypedMultipart, +) -> (StatusCode, Json) { + let task_id = Uuid::new_v4(); + let input = + Path::new(&server.work_dir).join(format!("{}.in.atranscoder", task_id.to_string())); + let output = + Path::new(&server.work_dir).join(format!("{}.out.atranscoder", task_id.to_string())); - match file.contents.persist(path) { - Ok(_) => ( - StatusCode::CREATED, - Json::from(TaskResponse { - id: Option::from(task_id), - error: None, - }), - ), + let file = req.file; + + match file.contents.persist(input.clone()) { + Ok(_) => { + let input_path = input.to_str(); + let output_path = output.to_str(); + + if input_path.is_none() || output_path.is_none() { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json::from(ConvertResponse { + id: None, + error: Some(String::from("Input or output paths are not correct")), + }), + ); + } + + let task = Task::new( + task_id, + req.codec, + req.bit_rate, + req.max_bit_rate, + req.channel_layout, + req.upload_url, + input_path.unwrap().to_string(), + output_path.unwrap().to_string(), + ); + + // Enqueue the task to the thread pool + server.thread_pool.enqueue(task); + + ( + StatusCode::CREATED, + Json::from(ConvertResponse { + id: Some(task_id.to_string()), + error: None, + }), + ) + } Err(_) => ( - StatusCode::CREATED, - Json::from(TaskResponse { - id: Option::from(task_id), + StatusCode::INTERNAL_SERVER_ERROR, + Json::from(ConvertResponse { + id: Some(task_id.to_string()), error: Some(String::from("Cannot save the file")), }), ), } -} - -async fn download_file() -> (StatusCode, Json) { - let resp = dto::Error { - error: String::from("Not implemented yet."), - }; - (StatusCode::INTERNAL_SERVER_ERROR, Json(resp)) -} +} \ No newline at end of file diff --git a/src/task.rs b/src/task.rs new file mode 100644 index 0000000..0003ecd --- /dev/null +++ b/src/task.rs @@ -0,0 +1,40 @@ +use tracing::debug; + +pub struct Task { + id: uuid::Uuid, + codec: String, + bit_rate: usize, + max_bit_rate: usize, + channel_layout: String, + input_path: String, + output_path: String, + upload_url: String, +} + +impl Task { + pub fn new( + id: uuid::Uuid, + codec: String, + bit_rate: usize, + max_bit_rate: usize, + channel_layout: String, + upload_url: String, + input_path: String, + output_path: String, + ) -> Self { + Task { + id, + codec, + bit_rate, + max_bit_rate, + channel_layout, + input_path, + output_path, + upload_url, + } + } + + pub fn execute(&self) { + debug!("Executing task with id: {}", self.id.to_string()); + } +} diff --git a/src/thread_pool.rs b/src/thread_pool.rs new file mode 100644 index 0000000..a645304 --- /dev/null +++ b/src/thread_pool.rs @@ -0,0 +1,65 @@ +use std::sync::{Arc, Mutex}; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::thread; +use std::time::Duration; + +use tracing::{error, debug}; + +use crate::task::Task; + +pub struct ThreadPool { + workers: Vec, + sender: Sender, +} + +impl ThreadPool { + pub(crate) fn new(num_threads: Option) -> Self { + let num_threads = num_threads.unwrap_or_else(num_cpus::get); + let (sender, receiver) = mpsc::channel(); + let receiver = Arc::new(Mutex::new(receiver)); + + let workers = (0..num_threads) + .map(|id| Worker::new(id, Arc::clone(&receiver))) + .collect(); + + ThreadPool { workers, sender } + } + + pub fn enqueue(&self, task: Task) { + if let Err(e) = self.sender.send(task) { + error!("failed to send task to the queue: {:?}", e); + } + } +} + +struct Worker { + id: usize, + thread: Option>, +} + +impl Worker { + fn new(id: usize, receiver: Arc>>) -> Self { + let thread = thread::spawn(move || loop { + let task = { + let lock = receiver.lock().unwrap(); + lock.recv() + }; + + match task { + Ok(task) => { + debug!("worker {} got a task; executing.", id); + task.execute(); + } + Err(e) => { + error!("worker {} failed to receive task: {:?}", id, e); + thread::sleep(Duration::from_secs(1)); // sleep to avoid busy-looping + } + } + }); + + Worker { + id, + thread: Some(thread), + } + } +}