task executor

This commit is contained in:
Pavel 2024-05-25 19:22:08 +03:00
parent 8bbb487c4b
commit d2724677d9
6 changed files with 229 additions and 54 deletions

View File

@ -13,3 +13,4 @@ uuid = { version = "1.8.0", features = ["v4"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
tower-http = { version = "0.5.2", features = ["trace"] }
num_cpus = "1.16.0"

View File

@ -8,18 +8,19 @@ pub struct Error {
}
#[derive(Serialize, Deserialize)]
pub struct TaskResponse {
pub struct ConvertResponse {
pub id: Option<String>,
pub error: Option<String>,
}
#[derive(TryFromMultipart)]
#[try_from_multipart(rename_all = "camelCase")]
pub struct Task {
pub struct ConvertRequest {
pub codec: String,
pub bit_rate: usize,
pub max_bit_rate: usize,
pub channel_layout: String,
pub upload_url: String,
#[form_data(limit = "25MiB")]
pub file: FieldData<NamedTempFile>,

View File

@ -1,9 +1,14 @@
use crate::server::serve;
use std::env;
use tracing_subscriber::EnvFilter;
use crate::server::Server;
use crate::thread_pool::ThreadPool;
mod dto;
mod server;
mod task;
mod thread_pool;
#[tokio::main]
async fn main() {
@ -12,6 +17,27 @@ async fn main() {
.init();
let addr = env::var("LISTEN").unwrap_or_else(|_| "0.0.0.0:8090".to_string());
serve(&addr).await.expect("Cannot bind the addr")
let pool = ThreadPool::new(match env::var("NUM_WORKERS") {
Ok(val) => match val.parse::<usize>() {
Ok(val) => {
if val > 0 {
Some(val);
}
None
}
Err(_) => None,
},
Err(_) => None,
});
let temp_dir = env::var("TEMP_DIR").unwrap_or_else(|_| {
env::temp_dir()
.to_str()
.expect("Cannot get system temp directory")
.parse()
.unwrap()
});
Server::new(pool, temp_dir)
.serve(&addr)
.await
.expect("Cannot bind the addr")
}

View File

@ -1,67 +1,109 @@
use axum::extract::DefaultBodyLimit;
use axum::http::StatusCode;
use axum::routing::{get, post};
use axum::{Json, Router};
use axum_typed_multipart::TypedMultipart;
use std::path::Path;
use std::sync::Arc;
use axum::{debug_handler, Json, Router};
use axum::extract::{DefaultBodyLimit, State};
use axum::http::StatusCode;
use axum::routing::post;
use axum_typed_multipart::TypedMultipart;
use tokio::net::TcpListener;
use tower_http::trace::TraceLayer;
use uuid::Uuid;
use crate::dto;
use crate::dto::{Task, TaskResponse};
use crate::dto::{ConvertRequest, ConvertResponse};
use crate::task::Task;
use crate::thread_pool::ThreadPool;
const CONTENT_LENGTH_LIMIT: usize = 30 * 1024 * 1024;
pub async fn serve(addr: &str) -> std::io::Result<()> {
let app = Router::new()
.route(
"/enqueue",
post(enqueue_file).layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)),
)
.route("/download", get(download_file))
.layer(TraceLayer::new_for_http());
pub struct Server {
thread_pool: Arc<ThreadPool>,
work_dir: String,
}
tracing::info!("listening on {addr}");
let listener = match TcpListener::bind(addr).await {
Ok(listen) => listen,
Err(err) => return Err(err),
};
axum::serve(listener, app).await
impl Server {
pub(crate) fn new(thread_pool: ThreadPool, work_dir: String) -> Server {
Server {
thread_pool: Arc::new(thread_pool),
work_dir,
}
}
pub async fn serve(self, addr: &str) -> std::io::Result<()> {
let this = Arc::new(self);
let app = Router::new()
.route(
"/enqueue",
post(enqueue_file)
.layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)),
)
.with_state(this)
.layer(TraceLayer::new_for_http());
tracing::info!("listening on {addr}");
let listener = match TcpListener::bind(addr).await {
Ok(listen) => listen,
Err(err) => return Err(err),
};
axum::serve(listener, app).await
}
}
async fn enqueue_file(
TypedMultipart(Task { file, .. }): TypedMultipart<Task>,
) -> (StatusCode, Json<TaskResponse>) {
let task_id = Uuid::new_v4().to_string();
let path = Path::new(
std::env::temp_dir()
.to_str()
.expect("Cannot get temporary directory"),
)
.join(format!("{}.bin", task_id));
State(server): State<Arc<Server>>,
TypedMultipart(req): TypedMultipart<ConvertRequest>,
) -> (StatusCode, Json<ConvertResponse>) {
let task_id = Uuid::new_v4();
let input =
Path::new(&server.work_dir).join(format!("{}.in.atranscoder", task_id.to_string()));
let output =
Path::new(&server.work_dir).join(format!("{}.out.atranscoder", task_id.to_string()));
match file.contents.persist(path) {
Ok(_) => (
StatusCode::CREATED,
Json::from(TaskResponse {
id: Option::from(task_id),
error: None,
}),
),
let file = req.file;
match file.contents.persist(input.clone()) {
Ok(_) => {
let input_path = input.to_str();
let output_path = output.to_str();
if input_path.is_none() || output_path.is_none() {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(ConvertResponse {
id: None,
error: Some(String::from("Input or output paths are not correct")),
}),
);
}
let task = Task::new(
task_id,
req.codec,
req.bit_rate,
req.max_bit_rate,
req.channel_layout,
req.upload_url,
input_path.unwrap().to_string(),
output_path.unwrap().to_string(),
);
// Enqueue the task to the thread pool
server.thread_pool.enqueue(task);
(
StatusCode::CREATED,
Json::from(ConvertResponse {
id: Some(task_id.to_string()),
error: None,
}),
)
}
Err(_) => (
StatusCode::CREATED,
Json::from(TaskResponse {
id: Option::from(task_id),
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(ConvertResponse {
id: Some(task_id.to_string()),
error: Some(String::from("Cannot save the file")),
}),
),
}
}
async fn download_file() -> (StatusCode, Json<dto::Error>) {
let resp = dto::Error {
error: String::from("Not implemented yet."),
};
(StatusCode::INTERNAL_SERVER_ERROR, Json(resp))
}
}

40
src/task.rs Normal file
View File

@ -0,0 +1,40 @@
use tracing::debug;
pub struct Task {
id: uuid::Uuid,
codec: String,
bit_rate: usize,
max_bit_rate: usize,
channel_layout: String,
input_path: String,
output_path: String,
upload_url: String,
}
impl Task {
pub fn new(
id: uuid::Uuid,
codec: String,
bit_rate: usize,
max_bit_rate: usize,
channel_layout: String,
upload_url: String,
input_path: String,
output_path: String,
) -> Self {
Task {
id,
codec,
bit_rate,
max_bit_rate,
channel_layout,
input_path,
output_path,
upload_url,
}
}
pub fn execute(&self) {
debug!("Executing task with id: {}", self.id.to_string());
}
}

65
src/thread_pool.rs Normal file
View File

@ -0,0 +1,65 @@
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use std::time::Duration;
use tracing::{error, debug};
use crate::task::Task;
pub struct ThreadPool {
workers: Vec<Worker>,
sender: Sender<Task>,
}
impl ThreadPool {
pub(crate) fn new(num_threads: Option<usize>) -> Self {
let num_threads = num_threads.unwrap_or_else(num_cpus::get);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let workers = (0..num_threads)
.map(|id| Worker::new(id, Arc::clone(&receiver)))
.collect();
ThreadPool { workers, sender }
}
pub fn enqueue(&self, task: Task) {
if let Err(e) = self.sender.send(task) {
error!("failed to send task to the queue: {:?}", e);
}
}
}
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<Receiver<Task>>>) -> Self {
let thread = thread::spawn(move || loop {
let task = {
let lock = receiver.lock().unwrap();
lock.recv()
};
match task {
Ok(task) => {
debug!("worker {} got a task; executing.", id);
task.execute();
}
Err(e) => {
error!("worker {} failed to receive task: {:?}", id, e);
thread::sleep(Duration::from_secs(1)); // sleep to avoid busy-looping
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}