Skip to content

Commit

Permalink
add jinabert
Browse files Browse the repository at this point in the history
  • Loading branch information
qooba committed Jan 22, 2024
1 parent 806b67d commit 9cb132f
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 54 deletions.
16 changes: 14 additions & 2 deletions yummy-rs/yummy-delta-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,24 @@ fn run_apply(config_path: String) -> PyResult<String> {
}

#[pyfunction]
fn run(config_path: String, host: String, port: u16, log_level: String) -> PyResult<String> {
fn run(
config_path: String,
host: String,
port: u16,
log_level: String,
workers: usize,
) -> PyResult<String> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(run_delta_server(config_path, host, port, log_level))
.block_on(run_delta_server(
config_path,
host,
port,
Some(&log_level),
Some(&workers),
))
.unwrap();
Ok("Ok".to_string())
}
Expand Down
2 changes: 2 additions & 0 deletions yummy-rs/yummy-delta/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ anyhow = "1.0.69"
url = "2.3"
prettytable-rs = "^0.10"
once_cell = "1.17.1"
num-traits = "0.2.15"

# Datafusion
#datafusion = { version = "14" }
#datafusion-expr = { version = "14" }
Expand Down
24 changes: 16 additions & 8 deletions yummy-rs/yummy-delta/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::output::PrettyOutput;
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
use apply::DeltaApply;
use num_traits::Zero;
use server::{
append, create_table, details, health, list_stores, list_tables, optimize, overwrite,
query_stream, vacuum,
Expand Down Expand Up @@ -73,17 +74,21 @@ pub async fn run_delta_server(
config_path: String,
host: String,
port: u16,
log_level: String,
log_level: Option<&String>,
workers: Option<&usize>,
) -> std::io::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or(log_level));
if let Some(v) = log_level {
env_logger::init_from_env(env_logger::Env::new().default_filter_or(v));
}

println!("Yummy delta server running on http://{host}:{port}");
let delta_manager = DeltaApply::new(&config_path.clone())
.await
.unwrap()
.delta_manager()
.unwrap();

let _ = HttpServer::new(move || {
let mut server = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(delta_manager.clone()))
.route("/health", web::get().to(health))
Expand All @@ -107,10 +112,13 @@ pub async fn run_delta_server(
),
)
.wrap(Logger::default())
})
.bind((host, port))?
.run()
.await;
});

Ok(())
if let Some(num_workers) = workers {
if !num_workers.is_zero() {
server = server.workers(*num_workers);
}
}

server.bind((host, port))?.run().await
}
55 changes: 52 additions & 3 deletions yummy-rs/yummy-llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ pub mod common;
pub mod config;
pub mod models;
pub mod server;

use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
use config::LLMConfig;
use config::{LLMConfig, LLMSpec, ModelType};
use models::LLMModelFactory;
use num_traits::Zero;
use server::{embeddings, health};
use yummy_core::config::Metadata;

use crate::config::LLMEndpoint;

Expand All @@ -20,6 +21,52 @@ pub async fn serve_llm(
) -> std::io::Result<()> {
let config = LLMConfig::new(&model_path).await.unwrap();

serve(config, host, port, log_level, workers).await
}

pub async fn serve_embeddings(
model: String,
normalize_embeddings: bool,
host: String,
port: u16,
log_level: Option<&String>,
workers: Option<&usize>,
) -> std::io::Result<()> {
let config = LLMConfig {
endpoints: vec![LLMEndpoint::Embeddings {
metadata: Metadata {
name: "embedding".to_string(),
store: None,
table: None,
},
spec: LLMSpec {
model: None,
tokenizer: None,
config: None,
hf_model_repo_id: None,
hf_tokenizer_repo_id: None,
hf_config_repo_id: None,
normalize_embeddings: Some(normalize_embeddings),
model_type: if model.to_uppercase() == "E5" {
ModelType::E5
} else {
ModelType::JinaBert
},
use_pth: None,
},
}],
};

serve(config, host, port, log_level, workers).await
}

async fn serve(
config: LLMConfig,
host: String,
port: u16,
log_level: Option<&String>,
workers: Option<&usize>,
) -> std::io::Result<()> {
if let Some(v) = log_level {
env_logger::init_from_env(env_logger::Env::new().default_filter_or(v));
}
Expand Down Expand Up @@ -57,7 +104,9 @@ pub async fn serve_llm(
});

if let Some(num_workers) = workers {
server = server.workers(*num_workers);
if !num_workers.is_zero() {
server = server.workers(*num_workers);
}
}

server.bind((host, port))?.run().await
Expand Down
16 changes: 14 additions & 2 deletions yummy-rs/yummy-ml-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,24 @@ use pyo3::prelude::*;
use yummy_ml::serve_ml_model;

#[pyfunction]
fn serve(model_path: String, host: String, port: u16, log_level: String) -> PyResult<String> {
fn serve(
model_path: String,
host: String,
port: u16,
log_level: String,
workers: usize,
) -> PyResult<String> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(serve_ml_model(model_path, host, port, log_level))
.block_on(serve_ml_model(
model_path,
host,
port,
Some(&log_level),
Some(&workers),
))
.unwrap();
Ok("Ok".to_string())
}
Expand Down
1 change: 1 addition & 0 deletions yummy-rs/yummy-ml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ derive_more = "0.99.17"
#catboost = { git = "https://github.com/catboost/catboost", optional = true }
lightgbm = "0.2.3"
thiserror = "1.0"
num-traits = "0.2.15"
#reqwest = "0.11.13"

[features]
Expand Down
23 changes: 16 additions & 7 deletions yummy-rs/yummy-ml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,36 @@ use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
use config::MLConfig;
use models::MLModelFactory;
use num_traits::Zero;
use server::{health, invocations};

pub async fn serve_ml_model(
model_path: String,
host: String,
port: u16,
log_level: String,
log_level: Option<&String>,
workers: Option<&usize>,
) -> std::io::Result<()> {
let config = MLConfig::new(&model_path).await.unwrap();
if let Some(v) = log_level {
env_logger::init_from_env(env_logger::Env::new().default_filter_or(v));
}

env_logger::init_from_env(env_logger::Env::new().default_filter_or(log_level));
println!("Yummy ml server running on http://{host}:{port}");
HttpServer::new(move || {
let mut server = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(MLModelFactory::new(config.clone()).unwrap()))
.app_data(web::Data::new(config.clone()))
.wrap(Logger::default())
.route("/health", web::get().to(health))
.route("/invocations", web::post().to(invocations))
})
.bind((host, port))?
.run()
.await
});

if let Some(num_workers) = workers {
if !num_workers.is_zero() {
server = server.workers(*num_workers);
}
}

server.bind((host, port))?.run().await
}
3 changes: 3 additions & 0 deletions yummy-rs/yummy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ run-llm-serve-embeddings-e5:
run-llm-serve-embeddings-jinabert:
cargo run --release --features yummy-llm llm serve --config /home/jovyan/yummy/yummy-rs/tests/llm/config_embedding_jinabert.yaml --workers 1

run-llm-serve-embeddings-jinabert1:
cargo run --release --features yummy-llm llm serve-embeddings --model jinabert --workers 1

run-delta-serve:
cargo run --features yummy-delta delta serve --config /home/jovyan/yummy/yummy-rs/tests/delta/apply.yaml --port 8080 --host 0.0.0.0 --loglevel debug

Expand Down
Loading

0 comments on commit 9cb132f

Please sign in to comment.