Skip to content

Commit

Permalink
Spatial and temporal extent can now be passed to the chatbot API endp…
Browse files Browse the repository at this point in the history
…oint to adjust search results
  • Loading branch information
simeonwetzel committed Aug 20, 2024
1 parent df7c010 commit c1a3b0a
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 31 deletions.
17 changes: 12 additions & 5 deletions search-app/server/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pydantic import BaseModel
import json
import logging
from typing import Optional, Dict, Any



logging.basicConfig()
Expand Down Expand Up @@ -62,7 +64,7 @@ async def get_current_session(session_id: UUID = Depends(cookie), session_data:
"pygeoapi": Indexer(index_name="pygeoapi",
score_treshold= 0.4,
k = 20),
"geojson_osm_indexer": Indexer(index_name="geojson", # Add indexer for local geojson with OSM features
"geojson": Indexer(index_name="geojson", # Add indexer for local geojson with OSM features
score_treshold=-400.0,
k = 20,
use_hf_model=True,
Expand Down Expand Up @@ -128,7 +130,7 @@ async def create_session(response: Response):

graph = SpatialRetrieverGraph(state=State(messages=[],
search_criteria="",
spatial_context="",
spatio_temporal_context={},
search_results=[],
ready_to_retrieve=""),
thread_id=session_id,
Expand All @@ -147,12 +149,17 @@ async def create_session(response: Response):

class Query(BaseModel):
query: str
spatio_temporal_context: Optional[Dict[str, Any]] = None

@app.post("/data")
async def call_graph(query_data: Query, session_id: UUID = Depends(cookie)):
if graph is not None:
print(f"-#-#--Running graph---- Using session_id: {str(session_id)}")
inputs = {"messages": [HumanMessage(content=query_data.query)]}

if query_data.spatio_temporal_context:
inputs['spatio_temporal_context'] = query_data.spatio_temporal_context

graph.graph.thread_id = str(session_id)
response = await graph.ainvoke(inputs)
else:
Expand Down Expand Up @@ -183,7 +190,7 @@ async def index_geojson_osm(api_key: APIKey = Depends(get_api_key)):
# await local_file_connector.add_descriptions_to_features()
feature_docs = await geojson_osm_connector._features_to_docs()
logging.info(f"Converted {len(feature_docs)} Features or FeatureGroups to documents")
res_local = indexes['geojson_osm_indexer']._index(documents=feature_docs)
res_local = indexes['geojson']._index(documents=feature_docs)
return res_local

def generate_combined_feature_collection(doc_list: List[Document]):
Expand All @@ -205,7 +212,7 @@ def generate_combined_feature_collection(doc_list: List[Document]):

@app.get("/retrieve_geojson")
async def retrieve_geojson(query: str):
features = indexes['geojson_osm_indexer'].retriever.invoke(query)
features = indexes['geojson'].retriever.invoke(query)

feature_collection = generate_combined_feature_collection(features)

Expand All @@ -228,7 +235,7 @@ async def clear_index(index_name: str, api_key: APIKey = Depends(get_api_key)):

if index_name == 'geojson':
logging.info("Clearing geojson index")
indexes['geojson_osm_indexer']._clear()
indexes['geojson']._clear()
else:
logging.info(f"Clearing index: {index_name}")
indexes[index_name]._clear()
Expand Down
79 changes: 59 additions & 20 deletions search-app/server/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
generate_search_tool,
spatial_context_extraction_tool
)
from .spatial_utilities import check_within_bbox
import logging
import ast

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
Expand All @@ -28,7 +30,7 @@ def is_valid_json(myjson):
class State(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
search_criteria: str
spatial_context: str
spatio_temporal_context: dict
search_results: List
ready_to_retrieve: str
index_name: str
Expand Down Expand Up @@ -58,7 +60,7 @@ def __init__(self,

def setup_graph(self):
self.add_node("conversation", self.run_conversation)
self.add_node("extract_spatial_context", self.extract_spatial_context)
self.add_node("extract_spatio_temporal_context", self.extract_spatio_temporal_context)
self.add_node("search", self.run_search)
self.add_node("final_answer", self.final_answer)
self.add_node("save_state", self.save_state)
Expand All @@ -68,11 +70,11 @@ def setup_graph(self):
self.should_continue,
{
"human": "save_state",
"extract_spatial_context": "extract_spatial_context"
"extract_spatio_temporal_context": "extract_spatio_temporal_context"
}
)

self.add_edge("extract_spatial_context", "search")
self.add_edge("extract_spatio_temporal_context", "search")
self.add_edge("search", "final_answer")
self.add_edge("final_answer", "save_state")
self.add_edge("save_state", END)
Expand Down Expand Up @@ -119,12 +121,24 @@ async def run_conversation(self, state: State):
state["ready_to_retrieve"] = parsed_dict.get("ready_to_retrieve", "no")
return state

def extract_spatial_context(self, state: State):
def extract_spatio_temporal_context(self, state: State):
print("---extracting spatial context of search")
spatial_context = spatial_context_extraction_tool.invoke({"query": str(state['search_criteria'])})
state['spatial_context'] = spatial_context
spatio_temporal_context = state.get('spatio_temporal_context', None)

logging.info(f"Extracted following spatial context: {spatial_context}")
if spatio_temporal_context:
spatial_extent = spatio_temporal_context.get('extent', [])
temporal_extent = spatio_temporal_context.get('temporal', "")

if not spatio_temporal_context:
spatial_context_str = spatial_context_extraction_tool.invoke({"query": str(state['search_criteria'])})
spatial_extent = ast.literal_eval(spatial_context_str).get("extent", [])

state['spatio_temporal_context'] = spatial_extent

#Todo: also try to derive temporal extent from inputs
temporal_extent = ""

logging.info(f"Extracted following spatial context: {spatial_extent} and following temporal extent: {temporal_extent}")
return state

def run_search(self, state: State):
Expand All @@ -139,9 +153,9 @@ def run_search(self, state: State):
logging.info(f"Starting search in index: {index_name} using this tool: {search_tool.name}")

search_results = search_tool.invoke({"query": str(state['search_criteria']),
"search_index": search_index,
"search_type": "similarity",
"k": 3})
"search_index": search_index,
"search_type": "similarity",
"k": 10})
else:
tavily_search = TavilySearchResults()
search_results = tavily_search.invoke(state["search_criteria"])
Expand All @@ -154,20 +168,45 @@ def run_search(self, state: State):
def should_continue(self, state: State) -> str:
if state.get("ready_to_retrieve") == "yes":
print("---routing to spatial context extractor, then to search")
return "extract_spatial_context"
return "extract_spatio_temporal_context"
else:
return "human"

def final_answer(self, state: State) -> str:
if state["index_name"] == "geojson":
logging.info(f"I found: {state['search_results'][-1]}")
context = state["search_results"][-1]
else:
context = state["search_results"]
query = state["search_criteria"]
answer = final_answer_chain.invoke({"query": query,
"context": context}).strip()
for c in self.collection_info_dict:
if c['collection_name'] == state['index_name']:
search_index_info = c.pop('sample_docs', None)

#Todo: use the temporal constraint (e.g. as filter)
try:
if state["index_name"] == "geojson":
# Check if results match spatial context of query
query_bbox = state['spatio_temporal_context']
search_results = check_within_bbox(search_results=state["search_results"],
bbox=query_bbox)

logging.info(f"I found: {len(search_results)} using the query-bbox {query_bbox}")

doc_contents = "\n\n".join(doc.page_content for doc in search_results)

context = f"""I searched in this search index: {search_index_info}.
The top-{len(search_results)} search results in the specified spatial extent are: {doc_contents}"""

else:
search_results = state["search_results"][:10]
doc_contents = "\n\n".join(doc.page_content for doc in search_results)
context = f"""I searched in this search index: {search_index_info}.
The top-{len(search_results)} search results are: {doc_contents}"""

if len(search_results) < 1:
context = "No search results found using the current search criteria"
state['search_results'] = []
query = state["search_criteria"]
answer = final_answer_chain.invoke({"query": query,
"context": context}).strip()
except:
answer = "Sorry I was not able to process your input. Can you please try again?"

state["messages"].append(AIMessage(content=answer))
return state

Expand Down
7 changes: 3 additions & 4 deletions search-app/server/graph/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@ def generate_conversation_prompt(system_prompt=None):
def generate_final_answer_prompt():
final_answer_prompt = PromptTemplate(
template="""
You are an assistant for question-answering tasks related to data search.
The question wil be a query and the context either the found datasets or a summary of the recieved data.
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
You describe the results of a data search given a certain query.
The search results are either the found datasets or a summary of the recieved data.
Use three sentences maximum and keep the answer concise
Question: {query}
Context: {context}
Found data: {context}
Answer:""",
input_variables=["query", "context"],
)
Expand Down
29 changes: 27 additions & 2 deletions search-app/server/graph/spatial_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,30 @@ def generate_spatial_context_chain(llm):
return spatial_context_chain


# response = spatial_context_chain.invoke({"query": "I climate data for Berlin"})
# print(response)
### Functions to check if results with geojson are within a certain spatial extent
# Use this to check if search results match query bbox.
import json
def is_within_bbox(lon, lat, bbox):
min_lon, max_lat, max_lon, min_lat = bbox
return min_lon <= lon <= max_lon and min_lat <= lat <= max_lat

def check_within_bbox(search_results, bbox):
if not bbox:
return search_results

results_within_bbox = []

for result in search_results:
feature_str = result.metadata.get('feature', '{}')
feature = json.loads(feature_str)
coordinates = feature.get('coordinates', [])

# Flatten the coordinates (if needed) and check if any coordinate is within the bbox
for poly in coordinates:
for coord in poly: # Assuming polygon with one ring
lon, lat = coord
if is_within_bbox(lon, lat, bbox):
results_within_bbox.append(result)
break # No need to check other coordinates of this polygon if already within bbox

return results_within_bbox

0 comments on commit c1a3b0a

Please sign in to comment.