-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
kind: embeddings | ||
metadata: | ||
name: e5-embeddings | ||
spec: | ||
normalize_embeddings: true | ||
model_type: E5 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
kind: embeddings | ||
metadata: | ||
name: jina-embeddings | ||
spec: | ||
normalize_embeddings: true | ||
model_type: JinaBert | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
|
||
POST http://localhost:8080/embeddings HTTP/1.1 | ||
content-type: application/json | ||
|
||
{ | ||
"input":["Elixir of Eternal Twilight: Grants visions of realms beyond the veil."] | ||
} | ||
|
||
### | ||
|
||
POST http://localhost:8080/embeddings HTTP/1.1 | ||
content-type: application/json | ||
|
||
{ | ||
"input":["Elixir of Eternal Twilight: Grants visions of realms beyond the veil.", | ||
"Elixir of Eternal Twilight: Grants visions of realms beyond the veil."] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
use super::{ | ||
tokenizer_pathbuf, weights_pathbuf, EmbeddingData, EmbeddingsModel, EmbeddingsResponse, | ||
}; | ||
use crate::config::LLMSpec; | ||
use anyhow::Error as E; | ||
use candle::{DType, Device, Module, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::models::jina_bert::{BertModel, Config}; | ||
use tokenizers::{PaddingParams, Tokenizer}; | ||
use yummy_core::common::Result; | ||
|
||
#[derive(thiserror::Error, Debug)] | ||
pub enum E5Error { | ||
#[error("Wrong E5 config")] | ||
WrongConfig, | ||
} | ||
|
||
const JINABERT_HF_REPO_ID: &str = "jinaai/jina-embeddings-v2-base-en"; | ||
const JINABERT_HF_TOKENIZER_REPO_ID: &str = "sentence-transformers/all-MiniLM-L6-v2"; | ||
|
||
pub struct JinaBertModel { | ||
model: BertModel, | ||
tokenizer: Tokenizer, | ||
normalize_embeddings: Option<bool>, | ||
} | ||
|
||
impl JinaBertModel { | ||
pub fn new(config: &LLMSpec) -> Result<JinaBertModel> { | ||
let model = weights_pathbuf(config, JINABERT_HF_REPO_ID)?; | ||
let tokenizer = tokenizer_pathbuf(config, JINABERT_HF_TOKENIZER_REPO_ID)?; | ||
let candle_config = Config::v2_base(); | ||
let device = &Device::Cpu; | ||
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; | ||
|
||
if let Some(pp) = tokenizer.get_padding_mut() { | ||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest | ||
} else { | ||
let pp = PaddingParams { | ||
strategy: tokenizers::PaddingStrategy::BatchLongest, | ||
..Default::default() | ||
}; | ||
tokenizer.with_padding(Some(pp)); | ||
} | ||
|
||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; | ||
let model = BertModel::new(vb, &candle_config)?; | ||
|
||
Ok(JinaBertModel { | ||
model, | ||
tokenizer, | ||
normalize_embeddings: config.normalize_embeddings, | ||
}) | ||
} | ||
} | ||
|
||
impl EmbeddingsModel for JinaBertModel { | ||
fn forward(&self, input: Vec<String>) -> Result<EmbeddingsResponse> { | ||
let device = &Device::Cpu; | ||
let tokens = self.tokenizer.encode_batch(input, true).unwrap(); | ||
let token_ids: Vec<Tensor> = tokens | ||
.iter() | ||
.map(|tokens| { | ||
let tokens = tokens.get_ids().to_vec(); | ||
Tensor::new(tokens.as_slice(), device) | ||
}) | ||
.collect::<std::result::Result<Vec<_>, _>>()?; | ||
|
||
let token_ids = Tensor::stack(&token_ids, 0)?; | ||
let embeddings = self.model.forward(&token_ids)?; | ||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; | ||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; | ||
let embeddings = if let Some(true) = self.normalize_embeddings { | ||
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)? | ||
} else { | ||
embeddings | ||
}; | ||
let embeddings_data = embeddings | ||
.to_vec2()? | ||
.iter() | ||
.map(|x| EmbeddingData { | ||
embedding: x.to_vec(), | ||
}) | ||
.collect(); | ||
|
||
Ok(EmbeddingsResponse { | ||
data: embeddings_data, | ||
}) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use yummy_core::common::Result; | ||
use yummy_llm::config::{LLMConfig, LLMEndpoint}; | ||
use yummy_llm::models::LLMModelFactory; | ||
|
||
#[tokio::test] | ||
async fn embeddings_e5() -> Result<()> { | ||
let path = "../tests/llm/config_embedding_e5.yaml".to_string(); | ||
embeddings_test(path).await | ||
} | ||
|
||
#[tokio::test] | ||
async fn embeddings_jinabert() -> Result<()> { | ||
let path = "../tests/llm/config_embedding_jinabert.yaml".to_string(); | ||
embeddings_test(path).await | ||
} | ||
|
||
async fn embeddings_test(config_path: String) -> Result<()> { | ||
let config = LLMConfig::new(&config_path).await?; | ||
println!("{config:?}"); | ||
|
||
let input = vec![ | ||
String::from("Elixir of Eternal Twilight: Grants visions of realms beyond the veil."), | ||
String::from("Potion of Liquid Starlight: Imbues the drinker with celestial clarity."), | ||
]; | ||
|
||
let embeddings_config = config | ||
.endpoints | ||
.iter() | ||
.filter(|x| { | ||
matches!( | ||
x, | ||
LLMEndpoint::Embeddings { | ||
metadata: _m, | ||
spec: _s, | ||
} | ||
) | ||
}) | ||
.last() | ||
.expect("Wrong configuration"); | ||
|
||
let e5_model = LLMModelFactory::embedding_model(embeddings_config)?; | ||
let embeddings = e5_model.forward(input)?; | ||
|
||
println!("EMBEDDINGS: {embeddings:#?}"); | ||
Ok(()) | ||
} | ||
|
||
//cargo test --release -- --list --show-output | ||
//cargo test --release -- embeddings_e5 --show-output |