Skip to content

Commit

Permalink
add jinabert
Browse files Browse the repository at this point in the history
  • Loading branch information
qooba committed Jan 22, 2024
1 parent 89839b5 commit 806b67d
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 0 deletions.
7 changes: 7 additions & 0 deletions yummy-rs/tests/llm/config_embedding_e5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: embeddings
metadata:
name: e5-embeddings
spec:
normalize_embeddings: true
model_type: E5

7 changes: 7 additions & 0 deletions yummy-rs/tests/llm/config_embedding_jinabert.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: embeddings
metadata:
name: jina-embeddings
spec:
normalize_embeddings: true
model_type: JinaBert

17 changes: 17 additions & 0 deletions yummy-rs/tests/llm/llm.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

POST http://localhost:8080/embeddings HTTP/1.1
content-type: application/json

{
"input":["Elixir of Eternal Twilight: Grants visions of realms beyond the veil."]
}

###

POST http://localhost:8080/embeddings HTTP/1.1
content-type: application/json

{
"input":["Elixir of Eternal Twilight: Grants visions of realms beyond the veil.",
"Elixir of Eternal Twilight: Grants visions of realms beyond the veil."]
}
92 changes: 92 additions & 0 deletions yummy-rs/yummy-llm/src/models/jinabert_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use super::{
tokenizer_pathbuf, weights_pathbuf, EmbeddingData, EmbeddingsModel, EmbeddingsResponse,
};
use crate::config::LLMSpec;
use anyhow::Error as E;
use candle::{DType, Device, Module, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::jina_bert::{BertModel, Config};
use tokenizers::{PaddingParams, Tokenizer};
use yummy_core::common::Result;

#[derive(thiserror::Error, Debug)]
pub enum E5Error {
#[error("Wrong E5 config")]
WrongConfig,
}

const JINABERT_HF_REPO_ID: &str = "jinaai/jina-embeddings-v2-base-en";
const JINABERT_HF_TOKENIZER_REPO_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";

pub struct JinaBertModel {
model: BertModel,
tokenizer: Tokenizer,
normalize_embeddings: Option<bool>,
}

impl JinaBertModel {
pub fn new(config: &LLMSpec) -> Result<JinaBertModel> {
let model = weights_pathbuf(config, JINABERT_HF_REPO_ID)?;
let tokenizer = tokenizer_pathbuf(config, JINABERT_HF_TOKENIZER_REPO_ID)?;
let candle_config = Config::v2_base();
let device = &Device::Cpu;
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;

if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &candle_config)?;

Ok(JinaBertModel {
model,
tokenizer,
normalize_embeddings: config.normalize_embeddings,
})
}
}

impl EmbeddingsModel for JinaBertModel {
fn forward(&self, input: Vec<String>) -> Result<EmbeddingsResponse> {
let device = &Device::Cpu;
let tokens = self.tokenizer.encode_batch(input, true).unwrap();
let token_ids: Vec<Tensor> = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<std::result::Result<Vec<_>, _>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;
let embeddings = self.model.forward(&token_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if let Some(true) = self.normalize_embeddings {
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
} else {
embeddings
};
let embeddings_data = embeddings
.to_vec2()?
.iter()
.map(|x| EmbeddingData {
embedding: x.to_vec(),
})
.collect();

Ok(EmbeddingsResponse {
data: embeddings_data,
})
}
}

#[cfg(test)]
mod tests {}
49 changes: 49 additions & 0 deletions yummy-rs/yummy-llm/tests/integration_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use yummy_core::common::Result;
use yummy_llm::config::{LLMConfig, LLMEndpoint};
use yummy_llm::models::LLMModelFactory;

#[tokio::test]
async fn embeddings_e5() -> Result<()> {
let path = "../tests/llm/config_embedding_e5.yaml".to_string();
embeddings_test(path).await
}

#[tokio::test]
async fn embeddings_jinabert() -> Result<()> {
let path = "../tests/llm/config_embedding_jinabert.yaml".to_string();
embeddings_test(path).await
}

async fn embeddings_test(config_path: String) -> Result<()> {
let config = LLMConfig::new(&config_path).await?;
println!("{config:?}");

let input = vec![
String::from("Elixir of Eternal Twilight: Grants visions of realms beyond the veil."),
String::from("Potion of Liquid Starlight: Imbues the drinker with celestial clarity."),
];

let embeddings_config = config
.endpoints
.iter()
.filter(|x| {
matches!(
x,
LLMEndpoint::Embeddings {
metadata: _m,
spec: _s,
}
)
})
.last()
.expect("Wrong configuration");

let e5_model = LLMModelFactory::embedding_model(embeddings_config)?;
let embeddings = e5_model.forward(input)?;

println!("EMBEDDINGS: {embeddings:#?}");
Ok(())
}

//cargo test --release -- --list --show-output
//cargo test --release -- embeddings_e5 --show-output

0 comments on commit 806b67d

Please sign in to comment.