Skip to content

Commit

Permalink
Add Arctic
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed May 7, 2024
1 parent 75cef85 commit 8392eed
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"titleBar.activeBackground": "#51103e",
"titleBar.activeForeground": "#e7e7e7",
"titleBar.inactiveBackground": "#51103e99",
"titleBar.inactiveForeground": "#e7e7e799"
"titleBar.inactiveForeground": "#e7e7e799",
"tab.activeBorder": "#7c185f"
},
"peacock.color": "#51103e"
}
21 changes: 20 additions & 1 deletion chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "llama", "claude", "mixtral8x7b"]:
if v not in ["gpt", "llama", "claude", "mixtral8x7b", "arctic"]:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -58,6 +58,8 @@ def setup(self):
self.setup_mixtral_8x7b()
elif self.model_type == "llama":
self.setup_llama()
elif self.model_type == "arctic":
self.setup_arctic()

def setup_gpt(self):
self.llm = ChatOpenAI(
Expand Down Expand Up @@ -111,6 +113,21 @@ def setup_llama(self):
},
)

def setup_arctic(self):
self.llm = ChatOpenAI(
model_name="snowflake/snowflake-arctic-instruct",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
},
)

def get_chain(self, vectorstore):
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
Expand Down Expand Up @@ -156,6 +173,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
model_type = "claude"
elif "llama" in model_name.lower():
model_type = "llama"
elif "arctic" in model_name.lower():
model_type = "arctic"
else:
raise ValueError(f"Unsupported model name: {model_name}")

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5"],
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5", "Snowflake Arctic"],
index=0,
horizontal=True,
)
Expand Down
18 changes: 10 additions & 8 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler


image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
gemini_url = image_url + "google-gemini-icon.png?t=2024-03-01T07%3A25%3A59.637Z"
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png"
gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z"
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z"
openai_url = (
image_url
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z"
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-05-07T21%3A18%3A44.079Z"
)
user_url = image_url + "cat-with-sunglasses.png"
claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z"
meta_url = image_url + "meta-logo.webp?t=2024-04-18T22%3A43%3A17.775Z"

user_url = image_url + "cat-with-sunglasses.png?t=2024-05-07T21%3A17%3A21.951Z"
claude_url = image_url + "Claude.png?t=2024-05-07T21%3A16%3A17.252Z"
meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z"
snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z"

def get_model_url(model_name):
if "gpt" in model_name.lower():
Expand All @@ -25,6 +26,8 @@ def get_model_url(model_name):
return meta_url
elif "gemini" in model_name.lower():
return gemini_url
elif "arctic" in model_name.lower():
return snow_url
return mistral_url


Expand Down Expand Up @@ -121,7 +124,6 @@ def start_loading_message(self):
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)

def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
print("on llm bnew token ", token)
if not self.has_streaming_started:
self.has_streaming_started = True

Expand Down

0 comments on commit 8392eed

Please sign in to comment.