diff --git a/requirements.txt b/requirements.txt index 3b0668e..8929d73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +baidu-aip +chardet +ffmpeg-python langchain langchain_community langchain-openai @@ -5,6 +8,8 @@ pytz qianfan requests streamlit +streamlit-chat-widget +streamlit_extras tiktoken toml xata \ No newline at end of file diff --git a/src/Chat.py b/src/Chat.py index 0810a30..5535c07 100644 --- a/src/Chat.py +++ b/src/Chat.py @@ -7,6 +7,8 @@ import streamlit as st from langchain.schema import AIMessage, HumanMessage from streamlit.web.server.websocket_headers import _get_websocket_headers +from streamlit_extras.bottom_container import bottom +from streamlit_chat_widget import chat_input_widget import ui_config import utils @@ -86,7 +88,10 @@ st.image(ui.sidebar_image) with col_text: st.title(ui.sidebar_title) - st.subheader("环境生态领域智能助手",help="Environment and Ecology Intelligent Assistant") + st.subheader( + "环境生态领域智能助手", + help="Environment and Ecology Intelligent Assistant", + ) if "subsription" in st.session_state: st.markdown( @@ -210,8 +215,13 @@ # search_arxiv_top_k = top_k_values.get("search_arxiv_top_k", 0) # search_docs_top_k = top_k_values.get("search_docs_top_k", 0) - st.markdown("🔥 限时限量免费开放", help="Limited time and quantity free access") - st.markdown("🏹 如需更佳体验,请前往 [Kaiwu](https://www.kaiwu.info)", help="ChatGPT 4o and chat history archives") + st.markdown( + "🔥 限时限量免费开放", help="Limited time and quantity free access" + ) + st.markdown( + "🏹 如需更佳体验,请前往 [Kaiwu](https://www.kaiwu.info)", + help="ChatGPT 4o and chat history archives", + ) col_newchat, col_delete = st.columns([1, 1]) with col_newchat: @@ -325,191 +335,212 @@ def main(): st.session_state["chat_disabled"] = False if "xata_history_refresh" not in st.session_state: - user_query = st.chat_input( - placeholder=ui.chat_human_placeholder, - disabled=st.session_state["chat_disabled"], - ) - if user_query: - beginDatetime = get_begin_datetime() - if ( - "count_chat_history" not in st.session_state - or "begin_hour" not in st.session_state - ): - st.session_state["begin_hour"] = beginDatetime.hour - st.session_state["count_chat_history"] = count_chat_history( - st.session_state["username"], beginDatetime - ) - else: + # user_query = st.chat_input( + # placeholder=ui.chat_human_placeholder, + # disabled=st.session_state["chat_disabled"], + # ) + with bottom(): + user_input = chat_input_widget() + if user_input: + if "text" in user_input: + user_query = user_input["text"] + elif "audioFile" in user_input: + audio_bytes = bytes(user_input["audioFile"]) + # st.audio(audio_bytes, format="audio/wav") + voice_result = utils.voice_to_text(audio_bytes)["result"] + user_query = "".join(voice_result) + + if user_query: + beginDatetime = get_begin_datetime() if ( - st.session_state["begin_hour"] != beginDatetime.hour - or st.session_state["count_chat_history"] % 10 == 0 + "count_chat_history" not in st.session_state + or "begin_hour" not in st.session_state ): st.session_state["begin_hour"] = beginDatetime.hour st.session_state["count_chat_history"] = count_chat_history( st.session_state["username"], beginDatetime ) - - if ( - not ( - "subsription" in st.session_state - and st.session_state["subsription"] == "Elite" - ) - ) and st.session_state["count_chat_history"] > 39: - time_range_str = ( - str(beginDatetime.hour) - + ":00 - " - + str(beginDatetime.hour + 3) - + ":00" - ) - st.chat_message("ai", avatar=ui.chat_ai_avatar).markdown( - "You have reached the usage limit for this time range (UTC " - + time_range_str - + "). Please try again later. (您已达到 UTC " - + time_range_str - + " 时间范围的使用限制,请稍后再试。)" - ) - - else: - st.chat_message("human", avatar=ui.chat_user_avatar).markdown( - user_query - ) - st.session_state["messages"].append( - {"role": "human", "content": user_query} - ) - human_message = HumanMessage( - content=user_query, - additional_kwargs={"id": st.session_state["username"]}, - ) - st.session_state["xata_history"].add_message(human_message) - - # check text sensitivity - answer = check_text_sensitivity(user_query)["answer"] - if answer is not None: - with st.chat_message("ai", avatar=ui.chat_ai_avatar): - st.markdown(answer) - st.session_state["messages"].append( - { - "role": "ai", - "content": answer, - } - ) - ai_message = AIMessage( - content=answer, - additional_kwargs={ - "id": st.session_state["username"] - }, - ) - st.session_state["xata_history"].add_message(ai_message) - st.session_state["count_chat_history"] += 1 else: - current_message = st.session_state["messages"][-8:][1:][:-1] - for item in current_message: - item.pop("avatar", None) - - chat_history_recent = str(current_message) - if ( - search_sci - or search_online - or search_report - or search_patent - or search_standard + st.session_state["begin_hour"] != beginDatetime.hour + or st.session_state["count_chat_history"] % 10 == 0 ): - formatted_messages = str( - [ - (msg["role"], msg["content"]) - for msg in st.session_state["messages"][1:] - ] - ) - - func_calling_response = func_calling_chain( - api_key, llm_model, openai_api_base - ).invoke({"input": formatted_messages}) - - query = func_calling_response.get("query") - - # try: - # created_at = json.loads( - # func_calling_response.get("created_at", None) - # ) - # except TypeError: - # created_at = None - - # source = func_calling_response.get("source", None) - - # filters = {} - # if created_at: - # filters["created_at"] = created_at - # if source: - # filters["source"] = source - - # docs_response = [] - # docs_response.extend( - # search_sci_service( - # query=query, - # filters=filters, - # top_k=3, - # ) - # ) - # docs_response.extend( - # search_internet(query, top_k=3) - # ) - docs_response = asyncio.run( - concurrent_search_service( - urls=search_list, query=query + st.session_state["begin_hour"] = beginDatetime.hour + st.session_state["count_chat_history"] = ( + count_chat_history( + st.session_state["username"], beginDatetime ) ) - input = f"""必须遵循: -- 使用“{docs_response}”(如果有)和您自己的知识回应“{user_query}”,以用户相同的语言提供逻辑清晰、经过批判性分析的回复。 -- 如果有“{chat_history_recent}”,请利用聊天上下文调整回复的详细程度。 -- 如果没有提供参考或没有上下文的情况,不要要求用户提供,直接回应用户的问题。 -- 有选择地使用项目符号,以提高清晰度或组织性。 -- 在适用情况下,使用 作者-日期 的引用风格在正文中引用来源。 -- 在末尾以Markdown格式提供一个参考文献列表,格式为[标题.期刊.作者.日期.](链接)(或仅文件名),仅包括文本中提到的参考文献。 -- 在Markdown中使用 '$' 或 '$$' 引用LaTeX以渲染数学公式。 - -必须避免: -- 重复用户的查询。 -- 将引用的参考文献翻译成用户查询的语言。 -- 在回复前加上任何标识,如“AI:”。 -""" + if ( + not ( + "subsription" in st.session_state + and st.session_state["subsription"] == "Elite" + ) + ) and st.session_state["count_chat_history"] > 39: + time_range_str = ( + str(beginDatetime.hour) + + ":00 - " + + str(beginDatetime.hour + 3) + + ":00" + ) + st.chat_message("ai", avatar=ui.chat_ai_avatar).markdown( + "You have reached the usage limit for this time range (UTC " + + time_range_str + + "). Please try again later. (您已达到 UTC " + + time_range_str + + " 时间范围的使用限制,请稍后再试。)" + ) + else: + st.chat_message( + "human", avatar=ui.chat_user_avatar + ).markdown(user_query) + st.session_state["messages"].append( + {"role": "human", "content": user_query} + ) + human_message = HumanMessage( + content=user_query, + additional_kwargs={"id": st.session_state["username"]}, + ) + st.session_state["xata_history"].add_message(human_message) + + # check text sensitivity + answer = check_text_sensitivity(user_query)["answer"] + if answer is not None: + with st.chat_message("ai", avatar=ui.chat_ai_avatar): + st.markdown(answer) + st.session_state["messages"].append( + { + "role": "ai", + "content": answer, + } + ) + ai_message = AIMessage( + content=answer, + additional_kwargs={ + "id": st.session_state["username"] + }, + ) + st.session_state["xata_history"].add_message( + ai_message + ) + st.session_state["count_chat_history"] += 1 else: - input = f"""回应“{user_query}”。如果“{chat_history_recent}”不为空,请使用其作为聊天上下文。""" - - with st.chat_message("ai", avatar=ui.chat_ai_avatar): - st_callback = StreamHandler(st.empty()) - response = main_chain( - api_key, llm_model, openai_api_base, baidu_llm - ).invoke( - {"input": input}, - {"callbacks": [st_callback]}, - ) + current_message = st.session_state["messages"][-8:][1:][ + :-1 + ] + for item in current_message: + item.pop("avatar", None) + + chat_history_recent = str(current_message) + + if ( + search_sci + or search_online + or search_report + or search_patent + or search_standard + ): + formatted_messages = str( + [ + (msg["role"], msg["content"]) + for msg in st.session_state["messages"][1:] + ] + ) - st.session_state["messages"].append( - { - "role": "ai", - "content": response, - } - ) - ai_message = AIMessage( - content=response, - additional_kwargs={ - "id": st.session_state["username"] - }, - ) - st.session_state["xata_history"].add_message(ai_message) - st.session_state["count_chat_history"] += 1 + func_calling_response = func_calling_chain( + api_key, llm_model, openai_api_base + ).invoke({"input": formatted_messages}) + + query = func_calling_response.get("query") + + # try: + # created_at = json.loads( + # func_calling_response.get("created_at", None) + # ) + # except TypeError: + # created_at = None + + # source = func_calling_response.get("source", None) + + # filters = {} + # if created_at: + # filters["created_at"] = created_at + # if source: + # filters["source"] = source + + # docs_response = [] + # docs_response.extend( + # search_sci_service( + # query=query, + # filters=filters, + # top_k=3, + # ) + # ) + # docs_response.extend( + # search_internet(query, top_k=3) + # ) + docs_response = asyncio.run( + concurrent_search_service( + urls=search_list, query=query + ) + ) + + input = f"""必须遵循: + - 使用“{docs_response}”(如果有)和您自己的知识回应“{user_query}”,以用户相同的语言提供逻辑清晰、经过批判性分析的回复。 + - 如果有“{chat_history_recent}”,请利用聊天上下文调整回复的详细程度。 + - 如果没有提供参考或没有上下文的情况,不要要求用户提供,直接回应用户的问题。 + - 有选择地使用项目符号,以提高清晰度或组织性。 + - 在适用情况下,使用 作者-日期 的引用风格在正文中引用来源。 + - 在末尾以Markdown格式提供一个参考文献列表,格式为[标题.期刊.作者.日期.](链接)(或仅文件名),仅包括文本中提到的参考文献。 + - 在Markdown中使用 '$' 或 '$$' 引用LaTeX以渲染数学公式。 + + 必须避免: + - 重复用户的查询。 + - 将引用的参考文献翻译成用户查询的语言。 + - 在回复前加上任何标识,如“AI:”。 + """ + + else: + input = f"""回应“{user_query}”。如果“{chat_history_recent}”不为空,请使用其作为聊天上下文。""" + + with st.chat_message("ai", avatar=ui.chat_ai_avatar): + st_callback = StreamHandler(st.empty()) + response = main_chain( + api_key, llm_model, openai_api_base, baidu_llm + ).invoke( + {"input": input}, + {"callbacks": [st_callback]}, + ) + + st.session_state["messages"].append( + { + "role": "ai", + "content": response, + } + ) + ai_message = AIMessage( + content=response, + additional_kwargs={ + "id": st.session_state["username"] + }, + ) + st.session_state["xata_history"].add_message( + ai_message + ) + st.session_state["count_chat_history"] += 1 - if len(st.session_state["messages"]) == 3: - st.session_state["xata_history_refresh"] = True - st.rerun() + if len(st.session_state["messages"]) == 3: + st.session_state["xata_history_refresh"] = True + st.rerun() else: - user_query = st.chat_input( - placeholder=ui.chat_human_placeholder, - disabled=st.session_state["chat_disabled"], - ) + # user_query = st.chat_input( + # placeholder=ui.chat_human_placeholder, + # disabled=st.session_state["chat_disabled"], + # ) + with bottom(): + user_input = chat_input_widget() del st.session_state["xata_history_refresh"] except Exception as e: diff --git a/src/utils.py b/src/utils.py index 3516515..4e11eb3 100644 --- a/src/utils.py +++ b/src/utils.py @@ -39,6 +39,9 @@ from langchain_community.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint from xata.client import XataClient +from aip import AipSpeech +import ffmpeg + import ui_config ui = ui_config.create_ui_from_config() @@ -118,6 +121,44 @@ def password_entered(): return True +def convert_audio_in_memory(input_bytes): + try: + input_stream = ffmpeg.input('pipe:0', format='webm') # 或 format='matroska' + output_stream = ffmpeg.output( + input_stream, + 'pipe:1', + format='wav', + acodec='pcm_s16le', + ar='16000', + ac='1' + ) + + process = ffmpeg.run_async( + output_stream, + pipe_stdin=True, + pipe_stdout=True, + pipe_stderr=True + ) + + stdout, stderr = process.communicate(input=input_bytes) + + if process.returncode != 0: + raise ffmpeg.Error('FFmpeg转换失败', stdout, stderr) + + return stdout + + except ffmpeg.Error as e: + print('转换失败!错误信息:') + print(e.stderr.decode()) + return None + +def voice_to_text(audio_bytes): + audio_bytes = convert_audio_in_memory(audio_bytes) + client = AipSpeech(st.secrets["voice_app_id"], st.secrets["voice_app_key"], st.secrets["voice_app_secret"]) + text = client.asr(audio_bytes, "wav", 16000, {"dev_pid": 1537}) + return text + + def func_calling_chain(api_key, llm_model, openai_api_base): """ Creates and returns a function calling chain for extracting query and filter information from a chat history.