Skip to content

Commit

Permalink
add rest api embeddings - jinabert
Browse files Browse the repository at this point in the history
  • Loading branch information
qooba committed Jan 22, 2024
1 parent 324465e commit 89839b5
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 73 deletions.
7 changes: 0 additions & 7 deletions yummy-rs/tests/llm/config.yaml

This file was deleted.

21 changes: 14 additions & 7 deletions yummy-rs/yummy-llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ pub async fn serve_llm(
model_path: String,
host: String,
port: u16,
log_level: String,
log_level: Option<&String>,
workers: Option<&usize>,
) -> std::io::Result<()> {
let config = LLMConfig::new(&model_path).await.unwrap();

env_logger::init_from_env(env_logger::Env::new().default_filter_or(log_level));
if let Some(v) = log_level {
env_logger::init_from_env(env_logger::Env::new().default_filter_or(v));
}

println!("Yummy llm server running on http://{host}:{port}");
HttpServer::new(move || {
let mut server = HttpServer::new(move || {
let mut app = App::new()
.app_data(web::Data::new(config.clone()))
.wrap(Logger::default())
Expand Down Expand Up @@ -50,8 +54,11 @@ pub async fn serve_llm(
}

app
})
.bind((host, port))?
.run()
.await
});

if let Some(num_workers) = workers {
server = server.workers(*num_workers);
}

server.bind((host, port))?.run().await
}
12 changes: 6 additions & 6 deletions yummy-rs/yummy-llm/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
pub mod e5_model;
use std::path::PathBuf;

pub mod jinabert_model;
use crate::config::{LLMEndpoint, LLMSpec, ModelType};
use e5_model::E5Model;
use jinabert_model::JinaBertModel;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use yummy_core::common::Result;

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -31,10 +32,9 @@ pub struct LLMModelFactory {}
impl LLMModelFactory {
pub fn embedding_model(config: &LLMEndpoint) -> Result<Box<dyn EmbeddingsModel>> {
if let LLMEndpoint::Embeddings { metadata: _, spec } = config {
if let ModelType::E5 = spec.model_type {
Ok(Box::new(E5Model::new(spec)?))
} else {
todo!()
match spec.model_type {
ModelType::E5 => Ok(Box::new(E5Model::new(spec)?)),
ModelType::JinaBert => Ok(Box::new(JinaBertModel::new(spec)?)),
}
} else {
Err(Box::new(ModelFactoryError::WrongModelType))
Expand Down
37 changes: 0 additions & 37 deletions yummy-rs/yummy-llm/tests/integration_models_lightgbm.rs

This file was deleted.

4 changes: 3 additions & 1 deletion yummy-rs/yummy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ run-ml-serve:
cargo run --features yummy-ml ml serve --model /home/jovyan/yummy/yummy-rs/tests/mlflow/lightgbm_model/lightgbm_my_model/ --port 8080 --host 0.0.0.0 --loglevel debug

run-llm-serve-embeddings-e5:
cargo run --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config.yaml --port 8080 --host 0.0.0.0 --loglevel debug
cargo run --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config_embedding_e5.yaml --port 8080 --host 0.0.0.0 --loglevel debug

run-llm-serve-embeddings-jinabert:
cargo run --release --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config_embedding_jinabert.yaml --workers 1

run-delta-serve:
cargo run --features yummy-delta delta serve --config /home/jovyan/yummy/yummy-rs/tests/delta/apply.yaml --port 8080 --host 0.0.0.0 --loglevel debug
Expand Down
45 changes: 30 additions & 15 deletions yummy-rs/yummy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,27 @@ fn cli() -> Command {
Command::new("serve")
.about("yummy llm serve")
.args(vec![
arg!(--config <FILE> "config path"),
arg!(--host <HOST> "host"),
arg!(--port <PORT> "port"),
arg!(--loglevel <LOGLEVEL> "log level"),
arg!(--config <FILE>)
.required(true)
.help("config file path"),
arg!(--host <HOST> "host")
.required(false)
.help("host")
.default_value("0.0.0.0"),
arg!(--port <PORT> "port")
.required(false)
.help("port")
.default_value("8080")
.value_parser(clap::value_parser!(u16)),
arg!(--loglevel <LOGLEVEL>)
.required(false)
.help("log level")
.default_value("error"),
arg!(--workers <WORKERS>)
.required(false)
.help("number of workers")
.default_value("2")
.value_parser(clap::value_parser!(usize)),
])
.arg_required_else_help(true),
),
Expand Down Expand Up @@ -117,16 +134,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.get_one::<String>("config")
.expect("required");
let host = sub_sub_matches.get_one::<String>("host").expect("required");
let port = sub_sub_matches
.get_one::<String>("port")
.expect("required")
.parse::<u16>()?;
let port = sub_sub_matches.get_one::<u16>("port").expect("required");

let log_level = sub_sub_matches
.get_one::<String>("loglevel")
.expect("required");
let log_level = sub_sub_matches.get_one::<String>("loglevel");
let workers = sub_sub_matches.get_one::<usize>("workers");

llm_serve(config.clone(), host.clone(), port, log_level.clone()).await?
llm_serve(config.clone(), host.clone(), *port, log_level, workers).await?
}
_ => unreachable!(),
},
Expand Down Expand Up @@ -192,17 +205,19 @@ async fn llm_serve(
config: String,
host: String,
port: u16,
log_level: String,
log_level: Option<&String>,
workers: Option<&usize>,
) -> std::io::Result<()> {
yummy_llm::serve_llm(config, host, port, log_level).await
yummy_llm::serve_llm(config, host, port, log_level, workers).await
}

#[cfg(not(feature = "yummy-llm"))]
async fn llm_serve(
_config: String,
_host: String,
_port: u16,
_log_level: String,
_log_level: Option<&String>,
_workers: Option<&usize>,
) -> std::io::Result<()> {
unreachable!()
}

0 comments on commit 89839b5

Please sign in to comment.