task executor
This commit is contained in:
parent
8bbb487c4b
commit
d2724677d9
@ -13,3 +13,4 @@ uuid = { version = "1.8.0", features = ["v4"] }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
|
||||||
tower-http = { version = "0.5.2", features = ["trace"] }
|
tower-http = { version = "0.5.2", features = ["trace"] }
|
||||||
|
num_cpus = "1.16.0"
|
||||||
|
@ -8,18 +8,19 @@ pub struct Error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct TaskResponse {
|
pub struct ConvertResponse {
|
||||||
pub id: Option<String>,
|
pub id: Option<String>,
|
||||||
pub error: Option<String>,
|
pub error: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(TryFromMultipart)]
|
#[derive(TryFromMultipart)]
|
||||||
#[try_from_multipart(rename_all = "camelCase")]
|
#[try_from_multipart(rename_all = "camelCase")]
|
||||||
pub struct Task {
|
pub struct ConvertRequest {
|
||||||
pub codec: String,
|
pub codec: String,
|
||||||
pub bit_rate: usize,
|
pub bit_rate: usize,
|
||||||
pub max_bit_rate: usize,
|
pub max_bit_rate: usize,
|
||||||
pub channel_layout: String,
|
pub channel_layout: String,
|
||||||
|
pub upload_url: String,
|
||||||
|
|
||||||
#[form_data(limit = "25MiB")]
|
#[form_data(limit = "25MiB")]
|
||||||
pub file: FieldData<NamedTempFile>,
|
pub file: FieldData<NamedTempFile>,
|
||||||
|
32
src/main.rs
32
src/main.rs
@ -1,9 +1,14 @@
|
|||||||
use crate::server::serve;
|
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
use crate::server::Server;
|
||||||
|
use crate::thread_pool::ThreadPool;
|
||||||
|
|
||||||
mod dto;
|
mod dto;
|
||||||
mod server;
|
mod server;
|
||||||
|
mod task;
|
||||||
|
mod thread_pool;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
@ -12,6 +17,27 @@ async fn main() {
|
|||||||
.init();
|
.init();
|
||||||
|
|
||||||
let addr = env::var("LISTEN").unwrap_or_else(|_| "0.0.0.0:8090".to_string());
|
let addr = env::var("LISTEN").unwrap_or_else(|_| "0.0.0.0:8090".to_string());
|
||||||
|
let pool = ThreadPool::new(match env::var("NUM_WORKERS") {
|
||||||
serve(&addr).await.expect("Cannot bind the addr")
|
Ok(val) => match val.parse::<usize>() {
|
||||||
|
Ok(val) => {
|
||||||
|
if val > 0 {
|
||||||
|
Some(val);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Err(_) => None,
|
||||||
|
},
|
||||||
|
Err(_) => None,
|
||||||
|
});
|
||||||
|
let temp_dir = env::var("TEMP_DIR").unwrap_or_else(|_| {
|
||||||
|
env::temp_dir()
|
||||||
|
.to_str()
|
||||||
|
.expect("Cannot get system temp directory")
|
||||||
|
.parse()
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
Server::new(pool, temp_dir)
|
||||||
|
.serve(&addr)
|
||||||
|
.await
|
||||||
|
.expect("Cannot bind the addr")
|
||||||
}
|
}
|
||||||
|
140
src/server.rs
140
src/server.rs
@ -1,67 +1,109 @@
|
|||||||
use axum::extract::DefaultBodyLimit;
|
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::routing::{get, post};
|
|
||||||
use axum::{Json, Router};
|
|
||||||
use axum_typed_multipart::TypedMultipart;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::{debug_handler, Json, Router};
|
||||||
|
use axum::extract::{DefaultBodyLimit, State};
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::routing::post;
|
||||||
|
use axum_typed_multipart::TypedMultipart;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::dto;
|
use crate::dto::{ConvertRequest, ConvertResponse};
|
||||||
use crate::dto::{Task, TaskResponse};
|
use crate::task::Task;
|
||||||
|
use crate::thread_pool::ThreadPool;
|
||||||
|
|
||||||
const CONTENT_LENGTH_LIMIT: usize = 30 * 1024 * 1024;
|
const CONTENT_LENGTH_LIMIT: usize = 30 * 1024 * 1024;
|
||||||
|
|
||||||
pub async fn serve(addr: &str) -> std::io::Result<()> {
|
pub struct Server {
|
||||||
let app = Router::new()
|
thread_pool: Arc<ThreadPool>,
|
||||||
.route(
|
work_dir: String,
|
||||||
"/enqueue",
|
}
|
||||||
post(enqueue_file).layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)),
|
|
||||||
)
|
|
||||||
.route("/download", get(download_file))
|
|
||||||
.layer(TraceLayer::new_for_http());
|
|
||||||
|
|
||||||
tracing::info!("listening on {addr}");
|
impl Server {
|
||||||
let listener = match TcpListener::bind(addr).await {
|
pub(crate) fn new(thread_pool: ThreadPool, work_dir: String) -> Server {
|
||||||
Ok(listen) => listen,
|
Server {
|
||||||
Err(err) => return Err(err),
|
thread_pool: Arc::new(thread_pool),
|
||||||
};
|
work_dir,
|
||||||
axum::serve(listener, app).await
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn serve(self, addr: &str) -> std::io::Result<()> {
|
||||||
|
let this = Arc::new(self);
|
||||||
|
let app = Router::new()
|
||||||
|
.route(
|
||||||
|
"/enqueue",
|
||||||
|
post(enqueue_file)
|
||||||
|
.layer(DefaultBodyLimit::max(CONTENT_LENGTH_LIMIT)),
|
||||||
|
)
|
||||||
|
.with_state(this)
|
||||||
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
||||||
|
tracing::info!("listening on {addr}");
|
||||||
|
let listener = match TcpListener::bind(addr).await {
|
||||||
|
Ok(listen) => listen,
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
};
|
||||||
|
axum::serve(listener, app).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn enqueue_file(
|
async fn enqueue_file(
|
||||||
TypedMultipart(Task { file, .. }): TypedMultipart<Task>,
|
State(server): State<Arc<Server>>,
|
||||||
) -> (StatusCode, Json<TaskResponse>) {
|
TypedMultipart(req): TypedMultipart<ConvertRequest>,
|
||||||
let task_id = Uuid::new_v4().to_string();
|
) -> (StatusCode, Json<ConvertResponse>) {
|
||||||
let path = Path::new(
|
let task_id = Uuid::new_v4();
|
||||||
std::env::temp_dir()
|
let input =
|
||||||
.to_str()
|
Path::new(&server.work_dir).join(format!("{}.in.atranscoder", task_id.to_string()));
|
||||||
.expect("Cannot get temporary directory"),
|
let output =
|
||||||
)
|
Path::new(&server.work_dir).join(format!("{}.out.atranscoder", task_id.to_string()));
|
||||||
.join(format!("{}.bin", task_id));
|
|
||||||
|
|
||||||
match file.contents.persist(path) {
|
let file = req.file;
|
||||||
Ok(_) => (
|
|
||||||
StatusCode::CREATED,
|
match file.contents.persist(input.clone()) {
|
||||||
Json::from(TaskResponse {
|
Ok(_) => {
|
||||||
id: Option::from(task_id),
|
let input_path = input.to_str();
|
||||||
error: None,
|
let output_path = output.to_str();
|
||||||
}),
|
|
||||||
),
|
if input_path.is_none() || output_path.is_none() {
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json::from(ConvertResponse {
|
||||||
|
id: None,
|
||||||
|
error: Some(String::from("Input or output paths are not correct")),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let task = Task::new(
|
||||||
|
task_id,
|
||||||
|
req.codec,
|
||||||
|
req.bit_rate,
|
||||||
|
req.max_bit_rate,
|
||||||
|
req.channel_layout,
|
||||||
|
req.upload_url,
|
||||||
|
input_path.unwrap().to_string(),
|
||||||
|
output_path.unwrap().to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Enqueue the task to the thread pool
|
||||||
|
server.thread_pool.enqueue(task);
|
||||||
|
|
||||||
|
(
|
||||||
|
StatusCode::CREATED,
|
||||||
|
Json::from(ConvertResponse {
|
||||||
|
id: Some(task_id.to_string()),
|
||||||
|
error: None,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
Err(_) => (
|
Err(_) => (
|
||||||
StatusCode::CREATED,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json::from(TaskResponse {
|
Json::from(ConvertResponse {
|
||||||
id: Option::from(task_id),
|
id: Some(task_id.to_string()),
|
||||||
error: Some(String::from("Cannot save the file")),
|
error: Some(String::from("Cannot save the file")),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file() -> (StatusCode, Json<dto::Error>) {
|
|
||||||
let resp = dto::Error {
|
|
||||||
error: String::from("Not implemented yet."),
|
|
||||||
};
|
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(resp))
|
|
||||||
}
|
|
40
src/task.rs
Normal file
40
src/task.rs
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
pub struct Task {
|
||||||
|
id: uuid::Uuid,
|
||||||
|
codec: String,
|
||||||
|
bit_rate: usize,
|
||||||
|
max_bit_rate: usize,
|
||||||
|
channel_layout: String,
|
||||||
|
input_path: String,
|
||||||
|
output_path: String,
|
||||||
|
upload_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Task {
|
||||||
|
pub fn new(
|
||||||
|
id: uuid::Uuid,
|
||||||
|
codec: String,
|
||||||
|
bit_rate: usize,
|
||||||
|
max_bit_rate: usize,
|
||||||
|
channel_layout: String,
|
||||||
|
upload_url: String,
|
||||||
|
input_path: String,
|
||||||
|
output_path: String,
|
||||||
|
) -> Self {
|
||||||
|
Task {
|
||||||
|
id,
|
||||||
|
codec,
|
||||||
|
bit_rate,
|
||||||
|
max_bit_rate,
|
||||||
|
channel_layout,
|
||||||
|
input_path,
|
||||||
|
output_path,
|
||||||
|
upload_url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(&self) {
|
||||||
|
debug!("Executing task with id: {}", self.id.to_string());
|
||||||
|
}
|
||||||
|
}
|
65
src/thread_pool.rs
Normal file
65
src/thread_pool.rs
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::sync::mpsc::{self, Receiver, Sender};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use tracing::{error, debug};
|
||||||
|
|
||||||
|
use crate::task::Task;
|
||||||
|
|
||||||
|
pub struct ThreadPool {
|
||||||
|
workers: Vec<Worker>,
|
||||||
|
sender: Sender<Task>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadPool {
|
||||||
|
pub(crate) fn new(num_threads: Option<usize>) -> Self {
|
||||||
|
let num_threads = num_threads.unwrap_or_else(num_cpus::get);
|
||||||
|
let (sender, receiver) = mpsc::channel();
|
||||||
|
let receiver = Arc::new(Mutex::new(receiver));
|
||||||
|
|
||||||
|
let workers = (0..num_threads)
|
||||||
|
.map(|id| Worker::new(id, Arc::clone(&receiver)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
ThreadPool { workers, sender }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enqueue(&self, task: Task) {
|
||||||
|
if let Err(e) = self.sender.send(task) {
|
||||||
|
error!("failed to send task to the queue: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Worker {
|
||||||
|
id: usize,
|
||||||
|
thread: Option<thread::JoinHandle<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Worker {
|
||||||
|
fn new(id: usize, receiver: Arc<Mutex<Receiver<Task>>>) -> Self {
|
||||||
|
let thread = thread::spawn(move || loop {
|
||||||
|
let task = {
|
||||||
|
let lock = receiver.lock().unwrap();
|
||||||
|
lock.recv()
|
||||||
|
};
|
||||||
|
|
||||||
|
match task {
|
||||||
|
Ok(task) => {
|
||||||
|
debug!("worker {} got a task; executing.", id);
|
||||||
|
task.execute();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("worker {} failed to receive task: {:?}", id, e);
|
||||||
|
thread::sleep(Duration::from_secs(1)); // sleep to avoid busy-looping
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Worker {
|
||||||
|
id,
|
||||||
|
thread: Some(thread),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user