Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into PORT_WORK
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cnivera committed Oct 14, 2024
2 parents 3a1c5fe + b1c4ce7 commit f1e9b80
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 34 deletions.
8 changes: 1 addition & 7 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
* @sfc-gh-cnivera @sfc-gh-jsummer @sfc-gh-twhite @sfc-gh-yayin
/artifacts/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer
/images/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer
/journeys/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer
/partner/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer
/app.py @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer
/app_utils/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer
* @sfc-gh-cnivera @sfc-gh-jsummer @sfc-gh-twhite
/semantic_model_generator/ @sfc-gh-nsehrawat @sfc-gh-rehuang @sfc-gh-dasilva @sfc-gh-cnivera @sfc-gh-yayin
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ To find your Account locator, please execute the following sql command in your a
SELECT CURRENT_ACCOUNT_LOCATOR();
```

To find the `SNOWFLAKE_HOST` for your
account, [follow these instructions](https://docs.snowflake.com/en/user-guide/organizations-connect#connecting-with-a-url).
B. To find the `SNOWFLAKE_HOST` for your
account, [follow these instructions](https://docs.snowflake.com/en/user-guide/organizations-connect#connecting-with-a-url). The easiest way to find your account URL is to click the `Copy account URL` button from the Account panel in Snowsight:

* Currently we recommend you to look under the `Account locator (legacy)` method of connection for better compatibility
on API.
![CleanShot 2024-10-09 at 14 25 13](https://github.com/user-attachments/assets/b1715c57-9571-4c65-92fb-e5d43afa871b)

However, if you have trouble authenticating with this URL, you can try building the URL manually:
* Currently we recommend you to look under the `Account locator (legacy)` method of connection for better compatibility on API.
* It typically follows format of: `<accountlocator>.<region>.<cloud>.snowflakecomputing.com`. Ensure that you omit
the `https://` prefix.

Expand Down
6 changes: 6 additions & 0 deletions app_utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,12 @@ def init_session_states() -> None:
if "last_validated_model" not in st.session_state:
st.session_state.last_validated_model = semantic_model_pb2.SemanticModel()

# Chat display settings.
if "chat_debug" not in st.session_state:
st.session_state.chat_debug = False
if "multiturn" not in st.session_state:
st.session_state.multiturn = False

# initialize session states for the chat page.
if "messages" not in st.session_state:
# messages store all chat histories
Expand Down
98 changes: 75 additions & 23 deletions journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@
from streamlit_extras.row import row
from streamlit_extras.stylable_container import stylable_container

from journeys.joins import joins_dialog

from app_utils.shared_utils import (
GeneratorAppScreen,
SnowflakeStage,
stage_selector_container,
changed_from_last_validated_model,
download_yaml,
get_snowflake_connection,
get_yamls_from_stage,
init_session_states,
return_home_button,
stage_selector_container,
upload_yaml,
validate_and_upload_tmp_yaml,
get_yamls_from_stage,
)
from journeys.joins import joins_dialog
from semantic_model_generator.data_processing.cte_utils import (
context_to_column_format,
expand_all_logical_tables_as_ctes,
Expand Down Expand Up @@ -72,15 +71,19 @@ def pretty_print_sql(sql: str) -> str:


@st.cache_data(ttl=60, show_spinner=False)
def send_message(_conn: SnowflakeConnection, prompt: str) -> Dict[str, Any]:
"""Calls the REST API and returns the response."""
def send_message(
_conn: SnowflakeConnection, messages: list[dict[str, str]]
) -> Dict[str, Any]:
"""
Calls the REST API with a list of messages and returns the response.
Args:
_conn: SnowflakeConnection, used to grab the token for auth.
messages: list of chat messages to pass to the Analyst API.
Returns: The raw ChatMessage response from Analyst.
"""
request_body = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
},
],
"messages": messages,
"semantic_model": proto_to_yaml(st.session_state.semantic_model),
}

Expand Down Expand Up @@ -124,17 +127,26 @@ def send_message(_conn: SnowflakeConnection, prompt: str) -> Dict[str, Any]:

def process_message(_conn: SnowflakeConnection, prompt: str) -> None:
"""Processes a message and adds the response to the chat."""
st.session_state.messages.append(
{"role": "user", "content": [{"type": "text", "text": prompt}]}
)
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
st.session_state.messages.append(user_message)
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Generating response..."):
response = send_message(_conn=_conn, prompt=prompt)
# Depending on whether multiturn is enabled, we either send just the user message or the entire chat history.
request_messages = (
st.session_state.messages[1:] # Skip the welcome message
if st.session_state.multiturn
else [user_message]
)
response = send_message(_conn=_conn, messages=request_messages)
content = response["message"]["content"]
display_content(conn=_conn, content=content)
st.session_state.messages.append({"role": "assistant", "content": content})
# Grab the request ID from the response and stash it in the chat message object.
request_id = response["request_id"]
display_content(conn=_conn, content=content, request_id=request_id)
st.session_state.messages.append(
{"role": "analyst", "content": content, "request_id": request_id}
)


def show_expr_for_ref(message_index: int) -> None:
Expand Down Expand Up @@ -244,11 +256,11 @@ def add_verified_query(question: str, sql: str) -> None:
def display_content(
conn: SnowflakeConnection,
content: List[Dict[str, Any]],
request_id: Optional[str],
message_index: Optional[int] = None,
) -> None:
"""Displays a content item for a message. For generated SQL, allow user to add to verified queries directly or edit then add."""
message_index = message_index or len(st.session_state.messages)
sql = ""
question = ""
for item in content:
if item["type"] == "text":
Expand Down Expand Up @@ -302,6 +314,11 @@ def display_content(
):
edit_verified_query(conn, sql, question, message_index)

# If debug mode is enabled, we render the request ID. Note that request IDs are currently only plumbed
# through for assistant messages, as we obtain the request ID as part of the Analyst response.
if request_id and st.session_state.chat_debug:
st.caption(f"Request ID: {request_id}")


def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None:
messages = st.container(height=600, border=False)
Expand All @@ -321,7 +338,7 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None:
if "messages" not in st.session_state or len(st.session_state.messages) == 0:
st.session_state.messages = [
{
"role": "assistant",
"role": "analyst",
"content": [
{
"type": "text",
Expand All @@ -333,9 +350,17 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None:

for message_index, message in enumerate(st.session_state.messages):
with messages:
with st.chat_message(message["role"]):
# To get the handy robot icon on assistant messages, the role needs to be "assistant" or "ai".
# However, the Analyst API uses "analyst" as the role, so we need to convert it at render time.
render_role = "assistant" if message["role"] == "analyst" else "user"
with st.chat_message(render_role):
display_content(
conn=_conn, content=message["content"], message_index=message_index
conn=_conn,
content=message["content"],
message_index=message_index,
request_id=message.get(
"request_id"
), # Safe get since user messages have no request IDs
)

chat_placeholder = (
Expand Down Expand Up @@ -616,6 +641,30 @@ def set_up_requirements() -> None:
st.rerun()


@st.dialog("Chat Settings", width="small")
def chat_settings_dialog() -> None:
"""
Dialog that allows user to toggle on/off certain settings about the chat experience.
"""

debug = st.toggle(
"Debug mode",
value=st.session_state.chat_debug,
help="Enable debug mode to see additional information (e.g. request ID).",
)

multiturn = st.toggle(
"Multiturn",
value=st.session_state.multiturn,
help="Enable multiturn mode to allow the chat to remember context. Note that your account must have the correct parameters enabled to use this feature.",
)

if st.button("Save"):
st.session_state.chat_debug = debug
st.session_state.multiturn = multiturn
st.rerun()


VALIDATE_HELP = """Save and validate changes to the active semantic model in this app. This is
useful so you can then play with it in the chat panel on the right side."""

Expand Down Expand Up @@ -671,6 +720,9 @@ def show() -> None:
st.session_state.working_yml, language="yaml", line_numbers=True
)
else:
st.markdown("**Chat**")
header_row = row([0.85, 0.15], vertical_align="center")
header_row.markdown("**Chat**")
if header_row.button("Settings"):
chat_settings_dialog()
# We still initialize an empty connector and pass it down in order to propagate the connector auth token.
chat_and_edit_vqr(get_snowflake_connection())

0 comments on commit f1e9b80

Please sign in to comment.