diff --git a/yummy-rs/yummy-delta-py/src/lib.rs b/yummy-rs/yummy-delta-py/src/lib.rs index 38754f8..f911364 100644 --- a/yummy-rs/yummy-delta-py/src/lib.rs +++ b/yummy-rs/yummy-delta-py/src/lib.rs @@ -50,12 +50,24 @@ fn run_apply(config_path: String) -> PyResult { } #[pyfunction] -fn run(config_path: String, host: String, port: u16, log_level: String) -> PyResult { +fn run( + config_path: String, + host: String, + port: u16, + log_level: String, + workers: usize, +) -> PyResult { tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap() - .block_on(run_delta_server(config_path, host, port, log_level)) + .block_on(run_delta_server( + config_path, + host, + port, + Some(&log_level), + Some(&workers), + )) .unwrap(); Ok("Ok".to_string()) } diff --git a/yummy-rs/yummy-delta/Cargo.toml b/yummy-rs/yummy-delta/Cargo.toml index 4ac8772..ee069b8 100644 --- a/yummy-rs/yummy-delta/Cargo.toml +++ b/yummy-rs/yummy-delta/Cargo.toml @@ -36,6 +36,8 @@ anyhow = "1.0.69" url = "2.3" prettytable-rs = "^0.10" once_cell = "1.17.1" +num-traits = "0.2.15" + # Datafusion #datafusion = { version = "14" } #datafusion-expr = { version = "14" } diff --git a/yummy-rs/yummy-delta/src/lib.rs b/yummy-rs/yummy-delta/src/lib.rs index dbb81a3..9258d2b 100644 --- a/yummy-rs/yummy-delta/src/lib.rs +++ b/yummy-rs/yummy-delta/src/lib.rs @@ -11,6 +11,7 @@ use crate::output::PrettyOutput; use actix_web::middleware::Logger; use actix_web::{web, App, HttpServer}; use apply::DeltaApply; +use num_traits::Zero; use server::{ append, create_table, details, health, list_stores, list_tables, optimize, overwrite, query_stream, vacuum, @@ -73,9 +74,13 @@ pub async fn run_delta_server( config_path: String, host: String, port: u16, - log_level: String, + log_level: Option<&String>, + workers: Option<&usize>, ) -> std::io::Result<()> { - env_logger::init_from_env(env_logger::Env::new().default_filter_or(log_level)); + if let Some(v) = log_level { + env_logger::init_from_env(env_logger::Env::new().default_filter_or(v)); + } + println!("Yummy delta server running on http://{host}:{port}"); let delta_manager = DeltaApply::new(&config_path.clone()) .await @@ -83,7 +88,7 @@ pub async fn run_delta_server( .delta_manager() .unwrap(); - let _ = HttpServer::new(move || { + let mut server = HttpServer::new(move || { App::new() .app_data(web::Data::new(delta_manager.clone())) .route("/health", web::get().to(health)) @@ -107,10 +112,13 @@ pub async fn run_delta_server( ), ) .wrap(Logger::default()) - }) - .bind((host, port))? - .run() - .await; + }); - Ok(()) + if let Some(num_workers) = workers { + if !num_workers.is_zero() { + server = server.workers(*num_workers); + } + } + + server.bind((host, port))?.run().await } diff --git a/yummy-rs/yummy-llm/src/lib.rs b/yummy-rs/yummy-llm/src/lib.rs index db27ef9..9029529 100644 --- a/yummy-rs/yummy-llm/src/lib.rs +++ b/yummy-rs/yummy-llm/src/lib.rs @@ -2,12 +2,13 @@ pub mod common; pub mod config; pub mod models; pub mod server; - use actix_web::middleware::Logger; use actix_web::{web, App, HttpServer}; -use config::LLMConfig; +use config::{LLMConfig, LLMSpec, ModelType}; use models::LLMModelFactory; +use num_traits::Zero; use server::{embeddings, health}; +use yummy_core::config::Metadata; use crate::config::LLMEndpoint; @@ -20,6 +21,52 @@ pub async fn serve_llm( ) -> std::io::Result<()> { let config = LLMConfig::new(&model_path).await.unwrap(); + serve(config, host, port, log_level, workers).await +} + +pub async fn serve_embeddings( + model: String, + normalize_embeddings: bool, + host: String, + port: u16, + log_level: Option<&String>, + workers: Option<&usize>, +) -> std::io::Result<()> { + let config = LLMConfig { + endpoints: vec![LLMEndpoint::Embeddings { + metadata: Metadata { + name: "embedding".to_string(), + store: None, + table: None, + }, + spec: LLMSpec { + model: None, + tokenizer: None, + config: None, + hf_model_repo_id: None, + hf_tokenizer_repo_id: None, + hf_config_repo_id: None, + normalize_embeddings: Some(normalize_embeddings), + model_type: if model.to_uppercase() == "E5" { + ModelType::E5 + } else { + ModelType::JinaBert + }, + use_pth: None, + }, + }], + }; + + serve(config, host, port, log_level, workers).await +} + +async fn serve( + config: LLMConfig, + host: String, + port: u16, + log_level: Option<&String>, + workers: Option<&usize>, +) -> std::io::Result<()> { if let Some(v) = log_level { env_logger::init_from_env(env_logger::Env::new().default_filter_or(v)); } @@ -57,7 +104,9 @@ pub async fn serve_llm( }); if let Some(num_workers) = workers { - server = server.workers(*num_workers); + if !num_workers.is_zero() { + server = server.workers(*num_workers); + } } server.bind((host, port))?.run().await diff --git a/yummy-rs/yummy-ml-py/src/lib.rs b/yummy-rs/yummy-ml-py/src/lib.rs index da334b1..9227808 100644 --- a/yummy-rs/yummy-ml-py/src/lib.rs +++ b/yummy-rs/yummy-ml-py/src/lib.rs @@ -2,12 +2,24 @@ use pyo3::prelude::*; use yummy_ml::serve_ml_model; #[pyfunction] -fn serve(model_path: String, host: String, port: u16, log_level: String) -> PyResult { +fn serve( + model_path: String, + host: String, + port: u16, + log_level: String, + workers: usize, +) -> PyResult { tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap() - .block_on(serve_ml_model(model_path, host, port, log_level)) + .block_on(serve_ml_model( + model_path, + host, + port, + Some(&log_level), + Some(&workers), + )) .unwrap(); Ok("Ok".to_string()) } diff --git a/yummy-rs/yummy-ml/Cargo.toml b/yummy-rs/yummy-ml/Cargo.toml index 7511d00..374fce2 100644 --- a/yummy-rs/yummy-ml/Cargo.toml +++ b/yummy-rs/yummy-ml/Cargo.toml @@ -19,6 +19,7 @@ derive_more = "0.99.17" #catboost = { git = "https://github.com/catboost/catboost", optional = true } lightgbm = "0.2.3" thiserror = "1.0" +num-traits = "0.2.15" #reqwest = "0.11.13" [features] diff --git a/yummy-rs/yummy-ml/src/lib.rs b/yummy-rs/yummy-ml/src/lib.rs index 8125708..6cf064a 100644 --- a/yummy-rs/yummy-ml/src/lib.rs +++ b/yummy-rs/yummy-ml/src/lib.rs @@ -7,27 +7,36 @@ use actix_web::middleware::Logger; use actix_web::{web, App, HttpServer}; use config::MLConfig; use models::MLModelFactory; +use num_traits::Zero; use server::{health, invocations}; pub async fn serve_ml_model( model_path: String, host: String, port: u16, - log_level: String, + log_level: Option<&String>, + workers: Option<&usize>, ) -> std::io::Result<()> { let config = MLConfig::new(&model_path).await.unwrap(); + if let Some(v) = log_level { + env_logger::init_from_env(env_logger::Env::new().default_filter_or(v)); + } - env_logger::init_from_env(env_logger::Env::new().default_filter_or(log_level)); println!("Yummy ml server running on http://{host}:{port}"); - HttpServer::new(move || { + let mut server = HttpServer::new(move || { App::new() .app_data(web::Data::new(MLModelFactory::new(config.clone()).unwrap())) .app_data(web::Data::new(config.clone())) .wrap(Logger::default()) .route("/health", web::get().to(health)) .route("/invocations", web::post().to(invocations)) - }) - .bind((host, port))? - .run() - .await + }); + + if let Some(num_workers) = workers { + if !num_workers.is_zero() { + server = server.workers(*num_workers); + } + } + + server.bind((host, port))?.run().await } diff --git a/yummy-rs/yummy/Makefile b/yummy-rs/yummy/Makefile index 02ce9e6..d1c1438 100644 --- a/yummy-rs/yummy/Makefile +++ b/yummy-rs/yummy/Makefile @@ -36,6 +36,9 @@ run-llm-serve-embeddings-e5: run-llm-serve-embeddings-jinabert: cargo run --release --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config_embedding_jinabert.yaml --workers 1 +run-llm-serve-embeddings-jinabert1: + cargo run --release --features yummy-llm llm serve-embeddings --model jinabert --workers 1 + run-delta-serve: cargo run --features yummy-delta delta serve --config /home/jovyan/yummy/yummy-rs/tests/delta/apply.yaml --port 8080 --host 0.0.0.0 --loglevel debug diff --git a/yummy-rs/yummy/src/main.rs b/yummy-rs/yummy/src/main.rs index 6de8d58..6b9b3bf 100644 --- a/yummy-rs/yummy/src/main.rs +++ b/yummy-rs/yummy/src/main.rs @@ -13,17 +13,36 @@ fn cli() -> Command { .subcommand( Command::new("apply") .about("yummy delta apply") - .args(vec![arg!(-f --filename "Apply config file")]) + .args(vec![arg!(-f --filename ) + .required(true) + .help("apply config file")]) .arg_required_else_help(true), ) .subcommand( Command::new("serve") .about("yummy delta serve") .args(vec![ - arg!(--config "config file"), - arg!(--host "host"), - arg!(--port "port"), - arg!(--loglevel "log level"), + arg!(--config ) + .required(true) + .help("config file path"), + arg!(--host "host") + .required(false) + .help("host") + .default_value("0.0.0.0"), + arg!(--port "port") + .required(false) + .help("port") + .default_value("8080") + .value_parser(clap::value_parser!(u16)), + arg!(--loglevel ) + .required(false) + .help("log level") + .default_value("error"), + arg!(--workers ) + .required(false) + .help("number of workers") + .default_value("0") + .value_parser(clap::value_parser!(usize)), ]) .arg_required_else_help(true), ), @@ -36,10 +55,25 @@ fn cli() -> Command { Command::new("serve") .about("yummy ml serve") .args(vec![ - arg!(--model "model path"), - arg!(--host "host"), - arg!(--port "port"), - arg!(--loglevel "log level"), + arg!(--model ).required(true).help("model path"), + arg!(--host "host") + .required(false) + .help("host") + .default_value("0.0.0.0"), + arg!(--port "port") + .required(false) + .help("port") + .default_value("8080") + .value_parser(clap::value_parser!(u16)), + arg!(--loglevel ) + .required(false) + .help("log level") + .default_value("error"), + arg!(--workers ) + .required(false) + .help("number of workers") + .default_value("0") + .value_parser(clap::value_parser!(usize)), ]) .arg_required_else_help(true), ), @@ -71,7 +105,40 @@ fn cli() -> Command { arg!(--workers ) .required(false) .help("number of workers") - .default_value("2") + .default_value("0") + .value_parser(clap::value_parser!(usize)), + ]) + .arg_required_else_help(true), + ) + .subcommand( + Command::new("serve-embeddings") + .about("yummy llm serve embeddings") + .args(vec![ + arg!(--model ) + .required(true) + .help("model type: E5, JinaBert"), + arg!(--normalize ) + .required(false) + .help("normalize embeddings") + .default_value("true") + .value_parser(clap::value_parser!(bool)), + arg!(--host "host") + .required(false) + .help("host") + .default_value("0.0.0.0"), + arg!(--port "port") + .required(false) + .help("port") + .default_value("8080") + .value_parser(clap::value_parser!(u16)), + arg!(--loglevel ) + .required(false) + .help("log level") + .default_value("error"), + arg!(--workers ) + .required(false) + .help("number of workers") + .default_value("0") .value_parser(clap::value_parser!(usize)), ]) .arg_required_else_help(true), @@ -96,16 +163,12 @@ async fn main() -> Result<(), Box> { .get_one::("config") .expect("required"); let host = sub_sub_matches.get_one::("host").expect("required"); - let port = sub_sub_matches - .get_one::("port") - .expect("required") - .parse::()?; + let port = sub_sub_matches.get_one::("port").expect("required"); - let log_level = sub_sub_matches - .get_one::("loglevel") - .expect("required"); + let log_level = sub_sub_matches.get_one::("loglevel"); + let workers = sub_sub_matches.get_one::("workers"); - delta_serve(config.clone(), host.clone(), port, log_level.clone()).await? + delta_serve(config.clone(), host.clone(), *port, log_level, workers).await? } _ => unreachable!(), }, @@ -115,16 +178,12 @@ async fn main() -> Result<(), Box> { .get_one::("model") .expect("required"); let host = sub_sub_matches.get_one::("host").expect("required"); - let port = sub_sub_matches - .get_one::("port") - .expect("required") - .parse::()?; + let port = sub_sub_matches.get_one::("port").expect("required"); - let log_level = sub_sub_matches - .get_one::("loglevel") - .expect("required"); + let log_level = sub_sub_matches.get_one::("loglevel"); + let workers = sub_sub_matches.get_one::("workers"); - ml_serve(model.clone(), host.clone(), port, log_level.clone()).await? + ml_serve(model.clone(), host.clone(), *port, log_level, workers).await? } _ => unreachable!(), }, @@ -141,6 +200,30 @@ async fn main() -> Result<(), Box> { llm_serve(config.clone(), host.clone(), *port, log_level, workers).await? } + Some(("serve-embeddings", sub_sub_matches)) => { + let model = sub_sub_matches + .get_one::("model") + .expect("required"); + let normalize = sub_sub_matches + .get_one::("normalize") + .expect("required"); + let host = sub_sub_matches.get_one::("host").expect("required"); + let port = sub_sub_matches.get_one::("port").expect("required"); + + let log_level = sub_sub_matches.get_one::("loglevel"); + let workers = sub_sub_matches.get_one::("workers"); + + llm_serve_embeddings( + model.clone(), + *normalize, + host.clone(), + *port, + log_level, + workers, + ) + .await? + } + _ => unreachable!(), }, @@ -165,9 +248,10 @@ async fn delta_serve( config: String, host: String, port: u16, - log_level: String, + log_level: Option<&String>, + workers: Option<&usize>, ) -> std::io::Result<()> { - yummy_delta::run_delta_server(config, host, port, log_level).await + yummy_delta::run_delta_server(config, host, port, log_level, workers).await } #[cfg(not(feature = "yummy-delta"))] @@ -175,7 +259,8 @@ async fn delta_serve( _config: String, _host: String, _port: u16, - _log_level: String, + _log_level: Option<&String>, + _workers: Option<&usize>, ) -> std::io::Result<()> { unreachable!() } @@ -185,9 +270,10 @@ async fn ml_serve( model: String, host: String, port: u16, - log_level: String, + log_level: Option<&String>, + workers: Option<&usize>, ) -> std::io::Result<()> { - yummy_ml::serve_ml_model(model, host, port, log_level).await + yummy_ml::serve_ml_model(model, host, port, log_level, workers).await } #[cfg(not(feature = "yummy-ml"))] @@ -195,7 +281,8 @@ async fn ml_serve( _model: String, _host: String, _port: u16, - _log_level: String, + _log_level: Option<&String>, + _workers: Option<&usize>, ) -> std::io::Result<()> { unreachable!() } @@ -221,3 +308,27 @@ async fn llm_serve( ) -> std::io::Result<()> { unreachable!() } + +#[cfg(feature = "yummy-llm")] +async fn llm_serve_embeddings( + model: String, + normalize: bool, + host: String, + port: u16, + log_level: Option<&String>, + workers: Option<&usize>, +) -> std::io::Result<()> { + yummy_llm::serve_embeddings(model, normalize, host, port, log_level, workers).await +} + +#[cfg(not(feature = "yummy-llm"))] +async fn llm_serve_embeddings( + _model: String, + _normalize: bool, + _host: String, + _port: u16, + _log_level: Option<&String>, + _workers: Option<&usize>, +) -> std::io::Result<()> { + unreachable!() +}