diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 68b73d48..e52a42e7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,2 @@ -* @sfc-gh-cnivera @sfc-gh-jsummer @sfc-gh-twhite @sfc-gh-yayin -/artifacts/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer -/images/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer -/journeys/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer -/partner/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer -/app.py @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer -/app_utils/ @sfc-gh-cnivera @sfc-gh-yayin @sfc-gh-jsummer +* @sfc-gh-cnivera @sfc-gh-jsummer @sfc-gh-twhite /semantic_model_generator/ @sfc-gh-nsehrawat @sfc-gh-rehuang @sfc-gh-dasilva @sfc-gh-cnivera @sfc-gh-yayin diff --git a/README.md b/README.md index 4f3f810e..cb37c3cf 100644 --- a/README.md +++ b/README.md @@ -86,11 +86,13 @@ To find your Account locator, please execute the following sql command in your a SELECT CURRENT_ACCOUNT_LOCATOR(); ``` -To find the `SNOWFLAKE_HOST` for your -account, [follow these instructions](https://docs.snowflake.com/en/user-guide/organizations-connect#connecting-with-a-url). +B. To find the `SNOWFLAKE_HOST` for your +account, [follow these instructions](https://docs.snowflake.com/en/user-guide/organizations-connect#connecting-with-a-url). The easiest way to find your account URL is to click the `Copy account URL` button from the Account panel in Snowsight: -* Currently we recommend you to look under the `Account locator (legacy)` method of connection for better compatibility - on API. +![CleanShot 2024-10-09 at 14 25 13](https://github.com/user-attachments/assets/b1715c57-9571-4c65-92fb-e5d43afa871b) + +However, if you have trouble authenticating with this URL, you can try building the URL manually: +* Currently we recommend you to look under the `Account locator (legacy)` method of connection for better compatibility on API. * It typically follows format of: `...snowflakecomputing.com`. Ensure that you omit the `https://` prefix. diff --git a/app_utils/shared_utils.py b/app_utils/shared_utils.py index bb70784f..cb751e56 100644 --- a/app_utils/shared_utils.py +++ b/app_utils/shared_utils.py @@ -373,6 +373,12 @@ def init_session_states() -> None: if "last_validated_model" not in st.session_state: st.session_state.last_validated_model = semantic_model_pb2.SemanticModel() + # Chat display settings. + if "chat_debug" not in st.session_state: + st.session_state.chat_debug = False + if "multiturn" not in st.session_state: + st.session_state.multiturn = False + # initialize session states for the chat page. if "messages" not in st.session_state: # messages store all chat histories diff --git a/journeys/iteration.py b/journeys/iteration.py index e1c10fa5..1a3439b5 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -11,21 +11,20 @@ from streamlit_extras.row import row from streamlit_extras.stylable_container import stylable_container -from journeys.joins import joins_dialog - from app_utils.shared_utils import ( GeneratorAppScreen, SnowflakeStage, - stage_selector_container, changed_from_last_validated_model, download_yaml, get_snowflake_connection, + get_yamls_from_stage, init_session_states, return_home_button, + stage_selector_container, upload_yaml, validate_and_upload_tmp_yaml, - get_yamls_from_stage, ) +from journeys.joins import joins_dialog from semantic_model_generator.data_processing.cte_utils import ( context_to_column_format, expand_all_logical_tables_as_ctes, @@ -72,15 +71,19 @@ def pretty_print_sql(sql: str) -> str: @st.cache_data(ttl=60, show_spinner=False) -def send_message(_conn: SnowflakeConnection, prompt: str) -> Dict[str, Any]: - """Calls the REST API and returns the response.""" +def send_message( + _conn: SnowflakeConnection, messages: list[dict[str, str]] +) -> Dict[str, Any]: + """ + Calls the REST API with a list of messages and returns the response. + Args: + _conn: SnowflakeConnection, used to grab the token for auth. + messages: list of chat messages to pass to the Analyst API. + + Returns: The raw ChatMessage response from Analyst. + """ request_body = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": prompt}], - }, - ], + "messages": messages, "semantic_model": proto_to_yaml(st.session_state.semantic_model), } @@ -124,17 +127,26 @@ def send_message(_conn: SnowflakeConnection, prompt: str) -> Dict[str, Any]: def process_message(_conn: SnowflakeConnection, prompt: str) -> None: """Processes a message and adds the response to the chat.""" - st.session_state.messages.append( - {"role": "user", "content": [{"type": "text", "text": prompt}]} - ) + user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]} + st.session_state.messages.append(user_message) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Generating response..."): - response = send_message(_conn=_conn, prompt=prompt) + # Depending on whether multiturn is enabled, we either send just the user message or the entire chat history. + request_messages = ( + st.session_state.messages[1:] # Skip the welcome message + if st.session_state.multiturn + else [user_message] + ) + response = send_message(_conn=_conn, messages=request_messages) content = response["message"]["content"] - display_content(conn=_conn, content=content) - st.session_state.messages.append({"role": "assistant", "content": content}) + # Grab the request ID from the response and stash it in the chat message object. + request_id = response["request_id"] + display_content(conn=_conn, content=content, request_id=request_id) + st.session_state.messages.append( + {"role": "analyst", "content": content, "request_id": request_id} + ) def show_expr_for_ref(message_index: int) -> None: @@ -244,11 +256,11 @@ def add_verified_query(question: str, sql: str) -> None: def display_content( conn: SnowflakeConnection, content: List[Dict[str, Any]], + request_id: Optional[str], message_index: Optional[int] = None, ) -> None: """Displays a content item for a message. For generated SQL, allow user to add to verified queries directly or edit then add.""" message_index = message_index or len(st.session_state.messages) - sql = "" question = "" for item in content: if item["type"] == "text": @@ -302,6 +314,11 @@ def display_content( ): edit_verified_query(conn, sql, question, message_index) + # If debug mode is enabled, we render the request ID. Note that request IDs are currently only plumbed + # through for assistant messages, as we obtain the request ID as part of the Analyst response. + if request_id and st.session_state.chat_debug: + st.caption(f"Request ID: {request_id}") + def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None: messages = st.container(height=600, border=False) @@ -321,7 +338,7 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None: if "messages" not in st.session_state or len(st.session_state.messages) == 0: st.session_state.messages = [ { - "role": "assistant", + "role": "analyst", "content": [ { "type": "text", @@ -333,9 +350,17 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None: for message_index, message in enumerate(st.session_state.messages): with messages: - with st.chat_message(message["role"]): + # To get the handy robot icon on assistant messages, the role needs to be "assistant" or "ai". + # However, the Analyst API uses "analyst" as the role, so we need to convert it at render time. + render_role = "assistant" if message["role"] == "analyst" else "user" + with st.chat_message(render_role): display_content( - conn=_conn, content=message["content"], message_index=message_index + conn=_conn, + content=message["content"], + message_index=message_index, + request_id=message.get( + "request_id" + ), # Safe get since user messages have no request IDs ) chat_placeholder = ( @@ -616,6 +641,30 @@ def set_up_requirements() -> None: st.rerun() +@st.dialog("Chat Settings", width="small") +def chat_settings_dialog() -> None: + """ + Dialog that allows user to toggle on/off certain settings about the chat experience. + """ + + debug = st.toggle( + "Debug mode", + value=st.session_state.chat_debug, + help="Enable debug mode to see additional information (e.g. request ID).", + ) + + multiturn = st.toggle( + "Multiturn", + value=st.session_state.multiturn, + help="Enable multiturn mode to allow the chat to remember context. Note that your account must have the correct parameters enabled to use this feature.", + ) + + if st.button("Save"): + st.session_state.chat_debug = debug + st.session_state.multiturn = multiturn + st.rerun() + + VALIDATE_HELP = """Save and validate changes to the active semantic model in this app. This is useful so you can then play with it in the chat panel on the right side.""" @@ -671,6 +720,9 @@ def show() -> None: st.session_state.working_yml, language="yaml", line_numbers=True ) else: - st.markdown("**Chat**") + header_row = row([0.85, 0.15], vertical_align="center") + header_row.markdown("**Chat**") + if header_row.button("Settings"): + chat_settings_dialog() # We still initialize an empty connector and pass it down in order to propagate the connector auth token. chat_and_edit_vqr(get_snowflake_connection())