task executor
This commit is contained in:
parent
8bbb487c4b
commit
d2724677d9
@ -13,3 +13,4 @@ uuid = { version = "1.8.0", features = ["v4"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
|
||||
tower-http = { version = "0.5.2", features = ["trace"] }
|
||||
num_cpus = "1.16.0"
|
||||
|
@ -8,18 +8,19 @@ pub struct Error {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TaskResponse {
|
||||
pub struct ConvertResponse {
|
||||
pub id: Option<String>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(TryFromMultipart)]
|
||||
#[try_from_multipart(rename_all = "camelCase")]
|
||||
pub struct Task {
|
||||
pub struct ConvertRequest {
|
||||
pub codec: String,
|
||||
pub bit_rate: usize,
|
||||
pub max_bit_rate: usize,
|
||||
pub channel_layout: String,
|
||||
pub upload_url: String,
|
||||
|
||||
#[form_data(limit = "25MiB")]
|
||||
pub file: FieldData<NamedTempFile>,
|
||||
|
32
src/main.rs
32
src/main.rs
@ -1,9 +1,14 @@
|
||||
use crate::server::serve;
|
||||
use std::env;
|
||||
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::server::Server;
|
||||
use crate::thread_pool::ThreadPool;
|
||||
|
||||
mod dto;
|
||||
mod server;
|
||||
mod task;
|
||||
mod thread_pool;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@ -12,6 +17,27 @@ async fn main() {
|
||||
.init();
|
||||
|
||||
let addr = env::var("LISTEN").unwrap_or_else(|_| "0.0.0.0:8090".to_string());
|
||||
|
||||
serve(&addr).await.expect("Cannot bind the addr")
|
||||
let pool = ThreadPool::new(match env::var("NUM_WORKERS") {
|
||||
Ok(val) => match val.parse::<usize>() {
|
||||
Ok(val) => {
|
||||
if val > 0 {
|
||||
Some(val);
|
||||
}
|
||||
None
|
||||
}
|
||||
Err(_) => None,
|
||||
},
|
||||
Err(_) => None,
|
||||
});
|
||||
let temp_dir = env::var("TEMP_DIR").unwrap_or_else(|_| {
|
||||
env::temp_dir()
|
||||
.to_str()
|
||||
.expect("Cannot get system temp directory")
|
||||
.parse()
|
||||
.unwrap()
|
||||
});
|
||||
Server::new(pool, temp_dir)
|
||||
.serve(&addr)
|
||||
.await
|
||||
.expect("Cannot bind the addr")
|
||||
}
|
||||
|
138
src/server.rs
138
src/server.rs
@ -1,67 +1,109 @@
|
||||
use axum::extract::DefaultBodyLimit;
|
||||
use axum::http::StatusCode;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{debug_handler, Json, Router};
|
||||
use axum::extract::{DefaultBodyLimit, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::routing::post;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::dto;
|
||||
use crate::dto::{Task, TaskResponse};
|
||||
use crate::dto::{ConvertRequest, ConvertResponse};
|
||||
use crate::task::Task;
|
||||
use crate::thread_pool::ThreadPool;
|
||||
|
||||
const CONTENT_LENGTH_LIMIT: usize = 30 * 1024 * 1024;
|
||||
|
||||
pub async fn serve(addr: &str) -> std::io::Result<()> {
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/enqueue",
|
||||
post(enqueue_file).layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)),
|
||||
)
|
||||
.route("/download", get(download_file))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
pub struct Server {
|
||||
thread_pool: Arc<ThreadPool>,
|
||||
work_dir: String,
|
||||
}
|
||||
|
||||
tracing::info!("listening on {addr}");
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
Ok(listen) => listen,
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
axum::serve(listener, app).await
|
||||
impl Server {
|
||||
pub(crate) fn new(thread_pool: ThreadPool, work_dir: String) -> Server {
|
||||
Server {
|
||||
thread_pool: Arc::new(thread_pool),
|
||||
work_dir,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve(self, addr: &str) -> std::io::Result<()> {
|
||||
let this = Arc::new(self);
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/enqueue",
|
||||
post(enqueue_file)
|
||||
.layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)),
|
||||
)
|
||||
.with_state(this)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
tracing::info!("listening on {addr}");
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
Ok(listen) => listen,
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
axum::serve(listener, app).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_file(
|
||||
TypedMultipart(Task { file, .. }): TypedMultipart<Task>,
|
||||
) -> (StatusCode, Json<TaskResponse>) {
|
||||
let task_id = Uuid::new_v4().to_string();
|
||||
let path = Path::new(
|
||||
std::env::temp_dir()
|
||||
.to_str()
|
||||
.expect("Cannot get temporary directory"),
|
||||
)
|
||||
.join(format!("{}.bin", task_id));
|
||||
State(server): State<Arc<Server>>,
|
||||
TypedMultipart(req): TypedMultipart<ConvertRequest>,
|
||||
) -> (StatusCode, Json<ConvertResponse>) {
|
||||
let task_id = Uuid::new_v4();
|
||||
let input =
|
||||
Path::new(&server.work_dir).join(format!("{}.in.atranscoder", task_id.to_string()));
|
||||
let output =
|
||||
Path::new(&server.work_dir).join(format!("{}.out.atranscoder", task_id.to_string()));
|
||||
|
||||
match file.contents.persist(path) {
|
||||
Ok(_) => (
|
||||
StatusCode::CREATED,
|
||||
Json::from(TaskResponse {
|
||||
id: Option::from(task_id),
|
||||
error: None,
|
||||
}),
|
||||
),
|
||||
let file = req.file;
|
||||
|
||||
match file.contents.persist(input.clone()) {
|
||||
Ok(_) => {
|
||||
let input_path = input.to_str();
|
||||
let output_path = output.to_str();
|
||||
|
||||
if input_path.is_none() || output_path.is_none() {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json::from(ConvertResponse {
|
||||
id: None,
|
||||
error: Some(String::from("Input or output paths are not correct")),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
let task = Task::new(
|
||||
task_id,
|
||||
req.codec,
|
||||
req.bit_rate,
|
||||
req.max_bit_rate,
|
||||
req.channel_layout,
|
||||
req.upload_url,
|
||||
input_path.unwrap().to_string(),
|
||||
output_path.unwrap().to_string(),
|
||||
);
|
||||
|
||||
// Enqueue the task to the thread pool
|
||||
server.thread_pool.enqueue(task);
|
||||
|
||||
(
|
||||
StatusCode::CREATED,
|
||||
Json::from(ConvertResponse {
|
||||
id: Some(task_id.to_string()),
|
||||
error: None,
|
||||
}),
|
||||
)
|
||||
}
|
||||
Err(_) => (
|
||||
StatusCode::CREATED,
|
||||
Json::from(TaskResponse {
|
||||
id: Option::from(task_id),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json::from(ConvertResponse {
|
||||
id: Some(task_id.to_string()),
|
||||
error: Some(String::from("Cannot save the file")),
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_file() -> (StatusCode, Json<dto::Error>) {
|
||||
let resp = dto::Error {
|
||||
error: String::from("Not implemented yet."),
|
||||
};
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(resp))
|
||||
}
|
||||
|
40
src/task.rs
Normal file
40
src/task.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use tracing::debug;
|
||||
|
||||
pub struct Task {
|
||||
id: uuid::Uuid,
|
||||
codec: String,
|
||||
bit_rate: usize,
|
||||
max_bit_rate: usize,
|
||||
channel_layout: String,
|
||||
input_path: String,
|
||||
output_path: String,
|
||||
upload_url: String,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
pub fn new(
|
||||
id: uuid::Uuid,
|
||||
codec: String,
|
||||
bit_rate: usize,
|
||||
max_bit_rate: usize,
|
||||
channel_layout: String,
|
||||
upload_url: String,
|
||||
input_path: String,
|
||||
output_path: String,
|
||||
) -> Self {
|
||||
Task {
|
||||
id,
|
||||
codec,
|
||||
bit_rate,
|
||||
max_bit_rate,
|
||||
channel_layout,
|
||||
input_path,
|
||||
output_path,
|
||||
upload_url,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute(&self) {
|
||||
debug!("Executing task with id: {}", self.id.to_string());
|
||||
}
|
||||
}
|
65
src/thread_pool.rs
Normal file
65
src/thread_pool.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use tracing::{error, debug};
|
||||
|
||||
use crate::task::Task;
|
||||
|
||||
pub struct ThreadPool {
|
||||
workers: Vec<Worker>,
|
||||
sender: Sender<Task>,
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
pub(crate) fn new(num_threads: Option<usize>) -> Self {
|
||||
let num_threads = num_threads.unwrap_or_else(num_cpus::get);
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let receiver = Arc::new(Mutex::new(receiver));
|
||||
|
||||
let workers = (0..num_threads)
|
||||
.map(|id| Worker::new(id, Arc::clone(&receiver)))
|
||||
.collect();
|
||||
|
||||
ThreadPool { workers, sender }
|
||||
}
|
||||
|
||||
pub fn enqueue(&self, task: Task) {
|
||||
if let Err(e) = self.sender.send(task) {
|
||||
error!("failed to send task to the queue: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Worker {
|
||||
id: usize,
|
||||
thread: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
fn new(id: usize, receiver: Arc<Mutex<Receiver<Task>>>) -> Self {
|
||||
let thread = thread::spawn(move || loop {
|
||||
let task = {
|
||||
let lock = receiver.lock().unwrap();
|
||||
lock.recv()
|
||||
};
|
||||
|
||||
match task {
|
||||
Ok(task) => {
|
||||
debug!("worker {} got a task; executing.", id);
|
||||
task.execute();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("worker {} failed to receive task: {:?}", id, e);
|
||||
thread::sleep(Duration::from_secs(1)); // sleep to avoid busy-looping
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Worker {
|
||||
id,
|
||||
thread: Some(thread),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user