-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1013 from parea-ai/PAI-1405-auto-trace-cohere
feat(cohere): auto trace cohere
- Loading branch information
Showing
13 changed files
with
2,669 additions
and
1,512 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
|
||
import cohere | ||
from dotenv import load_dotenv | ||
|
||
from parea import Parea | ||
|
||
load_dotenv() | ||
|
||
p = Parea(api_key=os.getenv("PAREA_API_KEY")) | ||
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) | ||
p.wrap_cohere_client(co) | ||
|
||
response = co.chat( | ||
model="command-r-plus", | ||
preamble="You are a helpful assistant talking in JSON.", | ||
message="Generate a JSON describing a person, with the fields 'name' and 'age'", | ||
response_format={"type": "json_object"}, | ||
) | ||
print(response) | ||
print("\n\n") | ||
|
||
response = co.chat(message="Who discovered gravity?") | ||
print(response) | ||
print("\n\n") | ||
# | ||
docs = [ | ||
"Carson City is the capital city of the American state of Nevada.", | ||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", | ||
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", | ||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", | ||
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", | ||
] | ||
response = co.rerank( | ||
model="rerank-english-v3.0", | ||
query="What is the capital of the United States?", | ||
documents=docs, | ||
top_n=3, | ||
) | ||
print(response) | ||
print("\n\n") | ||
|
||
|
||
response = co.chat( | ||
model="command-r-plus", | ||
message="Where do the tallest penguins live?", | ||
documents=[ | ||
{"title": "Tall penguins", "snippet": "Emperor penguins are the tallest."}, | ||
{"title": "Penguin habitats", "snippet": "Emperor penguins only live in Antarctica."}, | ||
{"title": "What are animals?", "snippet": "Animals are different from plants."}, | ||
], | ||
) | ||
print(response) | ||
print("\n\n") | ||
|
||
response = co.chat(model="command-r-plus", message="Who is more popular: Nsync or Backstreet Boys?", search_queries_only=True) | ||
print(response) | ||
print("\n\n") | ||
|
||
response = co.chat(model="command-r-plus", message="Who is more popular: Nsync or Backstreet Boys?", connectors=[{"id": "web-search"}]) | ||
print(response) | ||
print("\n\n") | ||
|
||
for event in co.chat_stream(message="Who discovered gravity?"): | ||
print(event) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
|
||
import cohere | ||
from dotenv import load_dotenv | ||
|
||
from parea import Parea | ||
from parea.utils.universal_encoder import json_dumps | ||
|
||
load_dotenv() | ||
|
||
p = Parea(api_key=os.getenv("PAREA_API_KEY")) | ||
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) | ||
p.wrap_cohere_client(co) | ||
|
||
|
||
def web_search(query: str) -> list[dict]: | ||
# your code for performing a web search goes here | ||
return [{"url": "https://en.wikipedia.org/wiki/Ontario", "text": "The capital of Ontario is Toronto, ..."}] | ||
|
||
|
||
web_search_tool = { | ||
"name": "web_search", | ||
"description": "performs a web search with the specified query", | ||
"parameter_definitions": {"query": {"description": "the query to look up", "type": "str", "required": True}}, | ||
} | ||
|
||
message = "Who is the mayor of the capital of Ontario?" | ||
model = "command-r-plus" | ||
|
||
# STEP 2: Check what tools the model wants to use and how | ||
|
||
res = co.chat(model=model, message=message, force_single_step=False, tools=[web_search_tool]) | ||
|
||
# as long as the model sends back tool_calls, | ||
# keep invoking tools and sending the results back to the model | ||
while res.tool_calls: | ||
print(res.text) # This will be an observation and a plan with next steps | ||
tool_results = [] | ||
for call in res.tool_calls: | ||
# use the `web_search` tool with the search query the model sent back | ||
web_search_results = {"call": call, "outputs": web_search(call.parameters["query"])} | ||
tool_results.append(web_search_results) | ||
|
||
# call chat again with tool results | ||
res = co.chat(model="command-r-plus", chat_history=res.chat_history, message="", force_single_step=False, tools=[web_search_tool], tool_results=tool_results) | ||
|
||
print(res.text) # "The mayor of Toronto, the capital of Ontario is Olivia Chow" | ||
|
||
|
||
# tool descriptions that the model has access to | ||
tools = [ | ||
{ | ||
"name": "query_daily_sales_report", | ||
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.", | ||
"parameter_definitions": {"day": {"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", "type": "str", "required": True}}, | ||
}, | ||
{ | ||
"name": "query_product_catalog", | ||
"description": "Connects to a a product catalog with information about all the products being sold, including categories, prices, and stock levels.", | ||
"parameter_definitions": {"category": {"description": "Retrieves product information data for all products in this category.", "type": "str", "required": True}}, | ||
}, | ||
] | ||
|
||
# preamble containing instructions about the task and the desired style for the output. | ||
preamble = """ | ||
## Task & Context | ||
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging. | ||
## Style Guide | ||
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling. | ||
""" | ||
|
||
# user request | ||
message = "Can you provide a sales summary for 29th September 2023, and also give me some details about the products in the 'Electronics' category, for example their prices and stock levels?" | ||
|
||
response = co.chat(message=message, force_single_step=True, tools=tools, preamble=preamble, model="command-r") | ||
print("The model recommends doing the following tool calls:") | ||
print("\n".join(str(tool_call) for tool_call in response.tool_calls)) | ||
|
||
tool_results = [] | ||
# Iterate over the tool calls generated by the model | ||
for tool_call in response.tool_calls: | ||
# here is where you would call the tool recommended by the model, using the parameters recommended by the model | ||
output = {"output": f"functions_map[{tool_call.name}]({tool_call.parameters})"} | ||
# store the output in a list | ||
outputs = [output] | ||
# store your tool results in this format | ||
tool_results.append({"call": tool_call, "outputs": outputs}) | ||
|
||
|
||
print("Tool results that will be fed back to the model in step 4:") | ||
print(json_dumps(tool_results, indent=4)) | ||
|
||
response = co.chat(message=message, tools=tools, tool_results=tool_results, preamble=preamble, model="command-r", temperature=0.3, force_single_step=True) | ||
|
||
|
||
print("Final answer:") | ||
print(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import List, Optional | ||
|
||
import os | ||
from datetime import datetime | ||
|
||
import cohere | ||
from dotenv import load_dotenv | ||
|
||
from parea import Parea, trace, trace_insert | ||
|
||
load_dotenv() | ||
|
||
p = Parea(api_key=os.getenv("PAREA_API_KEY")) | ||
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) | ||
p.wrap_cohere_client(co) | ||
|
||
|
||
def call_llm(message: str, chat_history: Optional[List[dict]] = None, system_message: str = "", model: str = "command-r-plus") -> str: | ||
return co.chat( | ||
model=model, | ||
preamble=system_message, | ||
chat_history=chat_history or [], | ||
message=message, | ||
).text | ||
|
||
|
||
@trace | ||
def argumentor(query: str, additional_description: str = "") -> str: | ||
return call_llm( | ||
system_message=f"""You are a debater making an argument on a topic. {additional_description}. | ||
The current time is {datetime.now().strftime("%Y-%m-%d")}""", | ||
message=f"The discussion topic is {query}", | ||
) | ||
|
||
|
||
@trace | ||
def critic(argument: str) -> str: | ||
return call_llm( | ||
system_message="""You are a critic. | ||
What unresolved questions or criticism do you have after reading the following argument? | ||
Provide a concise summary of your feedback.""", | ||
message=argument, | ||
) | ||
|
||
|
||
@trace | ||
def refiner(query: str, additional_description: str, argument: str, criticism: str) -> str: | ||
return call_llm( | ||
system_message=f"""You are a debater making an argument on a topic. {additional_description}. | ||
The current time is {datetime.now().strftime("%Y-%m-%d")}""", | ||
chat_history=[{"role": "USER", "message": f"""The discussion topic is {query}"""}, {"role": "CHATBOT", "message": argument}, {"role": "USER", "message": criticism}], | ||
message="Please generate a new argument that incorporates the feedback from the user.", | ||
) | ||
|
||
|
||
@trace | ||
def argument_chain(query: str, additional_description: str = "") -> str: | ||
trace_insert({"session_id": "cus_1234", "end_user_identifier": "user_1234"}) | ||
argument = argumentor(query, additional_description) | ||
criticism = critic(argument) | ||
refined_argument = refiner(query, additional_description, argument, criticism) | ||
return refined_argument | ||
|
||
|
||
@trace(session_id="cus_1234", end_user_identifier="user_1234") | ||
def json_call() -> str: | ||
completion = co.chat( | ||
model="command-r-plus", | ||
preamble="You are a helpful assistant talking in JSON.", | ||
message="What are you?", | ||
response_format={"type": "json_object"}, | ||
) | ||
return completion.text | ||
|
||
|
||
if __name__ == "__main__": | ||
result = argument_chain( | ||
"Whether sparkling wine is good for you.", | ||
additional_description="Provide a concise, few sentence argument on why sparkling wine is good for you.", | ||
) | ||
print(result) | ||
print(json_call()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.