From 8392eedf6d75f0a886a7d9fad7a86d4ffd853931 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Wed, 8 May 2024 09:26:58 +1200 Subject: [PATCH] Add Arctic --- .vscode/settings.json | 3 ++- chain.py | 21 ++++++++++++++++++++- main.py | 2 +- utils/snowchat_ui.py | 18 ++++++++++-------- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 48d8756..9068edb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -20,7 +20,8 @@ "titleBar.activeBackground": "#51103e", "titleBar.activeForeground": "#e7e7e7", "titleBar.inactiveBackground": "#51103e99", - "titleBar.inactiveForeground": "#e7e7e799" + "titleBar.inactiveForeground": "#e7e7e799", + "tab.activeBorder": "#7c185f" }, "peacock.color": "#51103e" } \ No newline at end of file diff --git a/chain.py b/chain.py index 178c380..8addf90 100644 --- a/chain.py +++ b/chain.py @@ -33,7 +33,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "llama", "claude", "mixtral8x7b"]: + if v not in ["gpt", "llama", "claude", "mixtral8x7b", "arctic"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -58,6 +58,8 @@ def setup(self): self.setup_mixtral_8x7b() elif self.model_type == "llama": self.setup_llama() + elif self.model_type == "arctic": + self.setup_arctic() def setup_gpt(self): self.llm = ChatOpenAI( @@ -111,6 +113,21 @@ def setup_llama(self): }, ) + def setup_arctic(self): + self.llm = ChatOpenAI( + model_name="snowflake/snowflake-arctic-instruct", + temperature=0.1, + api_key=self.secrets["OPENROUTER_API_KEY"], + max_tokens=700, + callbacks=[self.callback_handler], + streaming=True, + base_url="https://openrouter.ai/api/v1", + default_headers={ + "HTTP-Referer": "https://snowchat.streamlit.app/", + "X-Title": "Snowchat", + }, + ) + def get_chain(self, vectorstore): def _combine_documents( docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" @@ -156,6 +173,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): model_type = "claude" elif "llama" in model_name.lower(): model_type = "llama" + elif "arctic" in model_name.lower(): + model_type = "arctic" else: raise ValueError(f"Unsupported model name: {model_name}") diff --git a/main.py b/main.py index b527704..275f9c2 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5"], + options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5", "Snowflake Arctic"], index=0, horizontal=True, ) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 72aaff8..46d8f8e 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -4,17 +4,18 @@ import streamlit as st from langchain.callbacks.base import BaseCallbackHandler + image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/" -gemini_url = image_url + "google-gemini-icon.png?t=2024-03-01T07%3A25%3A59.637Z" -mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" +gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z" +mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z" openai_url = ( image_url - + "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z" + + "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-05-07T21%3A18%3A44.079Z" ) -user_url = image_url + "cat-with-sunglasses.png" -claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z" -meta_url = image_url + "meta-logo.webp?t=2024-04-18T22%3A43%3A17.775Z" - +user_url = image_url + "cat-with-sunglasses.png?t=2024-05-07T21%3A17%3A21.951Z" +claude_url = image_url + "Claude.png?t=2024-05-07T21%3A16%3A17.252Z" +meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z" +snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z" def get_model_url(model_name): if "gpt" in model_name.lower(): @@ -25,6 +26,8 @@ def get_model_url(model_name): return meta_url elif "gemini" in model_name.lower(): return gemini_url + elif "arctic" in model_name.lower(): + return snow_url return mistral_url @@ -121,7 +124,6 @@ def start_loading_message(self): self.placeholder.markdown(loading_message_content, unsafe_allow_html=True) def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs): - print("on llm bnew token ", token) if not self.has_streaming_started: self.has_streaming_started = True