diff --git a/yummy-rs/tests/llm/config.yaml b/yummy-rs/tests/llm/config.yaml deleted file mode 100644 index c3af3be..0000000 --- a/yummy-rs/tests/llm/config.yaml +++ /dev/null @@ -1,7 +0,0 @@ -kind: embeddings -metadata: - name: jina-embeddings -spec: - normalize_embeddings: true - model_type: E5 - diff --git a/yummy-rs/yummy-llm/src/lib.rs b/yummy-rs/yummy-llm/src/lib.rs index c10befa..db27ef9 100644 --- a/yummy-rs/yummy-llm/src/lib.rs +++ b/yummy-rs/yummy-llm/src/lib.rs @@ -15,13 +15,17 @@ pub async fn serve_llm( model_path: String, host: String, port: u16, - log_level: String, + log_level: Option<&String>, + workers: Option<&usize>, ) -> std::io::Result<()> { let config = LLMConfig::new(&model_path).await.unwrap(); - env_logger::init_from_env(env_logger::Env::new().default_filter_or(log_level)); + if let Some(v) = log_level { + env_logger::init_from_env(env_logger::Env::new().default_filter_or(v)); + } + println!("Yummy llm server running on http://{host}:{port}"); - HttpServer::new(move || { + let mut server = HttpServer::new(move || { let mut app = App::new() .app_data(web::Data::new(config.clone())) .wrap(Logger::default()) @@ -50,8 +54,11 @@ pub async fn serve_llm( } app - }) - .bind((host, port))? - .run() - .await + }); + + if let Some(num_workers) = workers { + server = server.workers(*num_workers); + } + + server.bind((host, port))?.run().await } diff --git a/yummy-rs/yummy-llm/src/models/mod.rs b/yummy-rs/yummy-llm/src/models/mod.rs index 1e3bb43..bc6a72e 100644 --- a/yummy-rs/yummy-llm/src/models/mod.rs +++ b/yummy-rs/yummy-llm/src/models/mod.rs @@ -1,9 +1,10 @@ pub mod e5_model; -use std::path::PathBuf; - +pub mod jinabert_model; use crate::config::{LLMEndpoint, LLMSpec, ModelType}; use e5_model::E5Model; +use jinabert_model::JinaBertModel; use serde::{Deserialize, Serialize}; +use std::path::PathBuf; use yummy_core::common::Result; #[derive(thiserror::Error, Debug)] @@ -31,10 +32,9 @@ pub struct LLMModelFactory {} impl LLMModelFactory { pub fn embedding_model(config: &LLMEndpoint) -> Result> { if let LLMEndpoint::Embeddings { metadata: _, spec } = config { - if let ModelType::E5 = spec.model_type { - Ok(Box::new(E5Model::new(spec)?)) - } else { - todo!() + match spec.model_type { + ModelType::E5 => Ok(Box::new(E5Model::new(spec)?)), + ModelType::JinaBert => Ok(Box::new(JinaBertModel::new(spec)?)), } } else { Err(Box::new(ModelFactoryError::WrongModelType)) diff --git a/yummy-rs/yummy-llm/tests/integration_models_lightgbm.rs b/yummy-rs/yummy-llm/tests/integration_models_lightgbm.rs deleted file mode 100644 index be445e3..0000000 --- a/yummy-rs/yummy-llm/tests/integration_models_lightgbm.rs +++ /dev/null @@ -1,37 +0,0 @@ -use yummy_core::common::Result; -use yummy_llm::config::{LLMConfig, LLMEndpoint}; -use yummy_llm::models::LLMModelFactory; - -#[tokio::test] -async fn load_model_and_predict() -> Result<()> { - let path = "../tests/llm/config.yaml".to_string(); - //let path = "../tests/mlflow/catboost_model/iris_my_model".to_string(); - let config = LLMConfig::new(&path).await?; - println!("{config:?}"); - - let input = vec![ - String::from("Elixir of Eternal Twilight: Grants visions of realms beyond the veil."), - String::from("Potion of Liquid Starlight: Imbues the drinker with celestial clarity."), - ]; - - let embeddings_config = config - .endpoints - .iter() - .filter(|x| { - matches!( - x, - LLMEndpoint::Embeddings { - metadata: _m, - spec: _s, - } - ) - }) - .last() - .expect("Wrong configuration"); - - let e5_model = LLMModelFactory::embedding_model(embeddings_config)?; - let embeddings = e5_model.forward(input)?; - - println!("EMBEDDINGS: {embeddings:#?}"); - Ok(()) -} diff --git a/yummy-rs/yummy/Makefile b/yummy-rs/yummy/Makefile index 8f5d785..02ce9e6 100644 --- a/yummy-rs/yummy/Makefile +++ b/yummy-rs/yummy/Makefile @@ -31,8 +31,10 @@ run-ml-serve: cargo run --features yummy-ml ml serve --model /home/jovyan/yummy/yummy-rs/tests/mlflow/lightgbm_model/lightgbm_my_model/ --port 8080 --host 0.0.0.0 --loglevel debug run-llm-serve-embeddings-e5: - cargo run --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config.yaml --port 8080 --host 0.0.0.0 --loglevel debug + cargo run --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config_embedding_e5.yaml --port 8080 --host 0.0.0.0 --loglevel debug +run-llm-serve-embeddings-jinabert: + cargo run --release --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config_embedding_jinabert.yaml --workers 1 run-delta-serve: cargo run --features yummy-delta delta serve --config /home/jovyan/yummy/yummy-rs/tests/delta/apply.yaml --port 8080 --host 0.0.0.0 --loglevel debug diff --git a/yummy-rs/yummy/src/main.rs b/yummy-rs/yummy/src/main.rs index 7902e47..6de8d58 100644 --- a/yummy-rs/yummy/src/main.rs +++ b/yummy-rs/yummy/src/main.rs @@ -52,10 +52,27 @@ fn cli() -> Command { Command::new("serve") .about("yummy llm serve") .args(vec![ - arg!(--config "config path"), - arg!(--host "host"), - arg!(--port "port"), - arg!(--loglevel "log level"), + arg!(--config ) + .required(true) + .help("config file path"), + arg!(--host "host") + .required(false) + .help("host") + .default_value("0.0.0.0"), + arg!(--port "port") + .required(false) + .help("port") + .default_value("8080") + .value_parser(clap::value_parser!(u16)), + arg!(--loglevel ) + .required(false) + .help("log level") + .default_value("error"), + arg!(--workers ) + .required(false) + .help("number of workers") + .default_value("2") + .value_parser(clap::value_parser!(usize)), ]) .arg_required_else_help(true), ), @@ -117,16 +134,12 @@ async fn main() -> Result<(), Box> { .get_one::("config") .expect("required"); let host = sub_sub_matches.get_one::("host").expect("required"); - let port = sub_sub_matches - .get_one::("port") - .expect("required") - .parse::()?; + let port = sub_sub_matches.get_one::("port").expect("required"); - let log_level = sub_sub_matches - .get_one::("loglevel") - .expect("required"); + let log_level = sub_sub_matches.get_one::("loglevel"); + let workers = sub_sub_matches.get_one::("workers"); - llm_serve(config.clone(), host.clone(), port, log_level.clone()).await? + llm_serve(config.clone(), host.clone(), *port, log_level, workers).await? } _ => unreachable!(), }, @@ -192,9 +205,10 @@ async fn llm_serve( config: String, host: String, port: u16, - log_level: String, + log_level: Option<&String>, + workers: Option<&usize>, ) -> std::io::Result<()> { - yummy_llm::serve_llm(config, host, port, log_level).await + yummy_llm::serve_llm(config, host, port, log_level, workers).await } #[cfg(not(feature = "yummy-llm"))] @@ -202,7 +216,8 @@ async fn llm_serve( _config: String, _host: String, _port: u16, - _log_level: String, + _log_level: Option<&String>, + _workers: Option<&usize>, ) -> std::io::Result<()> { unreachable!() }