From c1a3b0aba98cfe3634c0ad0a0e030c0d2e6bb9c9 Mon Sep 17 00:00:00 2001 From: simeonwetzel Date: Tue, 20 Aug 2024 09:46:56 +0200 Subject: [PATCH] Spatial and temporal extent can now be passed to the chatbot API endpoint to adjust search results --- search-app/server/app/server.py | 17 +++-- search-app/server/graph/graph.py | 79 +++++++++++++++----- search-app/server/graph/prompts.py | 7 +- search-app/server/graph/spatial_utilities.py | 29 ++++++- 4 files changed, 101 insertions(+), 31 deletions(-) diff --git a/search-app/server/app/server.py b/search-app/server/app/server.py index 2b663c5..613acc0 100644 --- a/search-app/server/app/server.py +++ b/search-app/server/app/server.py @@ -24,6 +24,8 @@ from pydantic import BaseModel import json import logging +from typing import Optional, Dict, Any + logging.basicConfig() @@ -62,7 +64,7 @@ async def get_current_session(session_id: UUID = Depends(cookie), session_data: "pygeoapi": Indexer(index_name="pygeoapi", score_treshold= 0.4, k = 20), - "geojson_osm_indexer": Indexer(index_name="geojson", # Add indexer for local geojson with OSM features + "geojson": Indexer(index_name="geojson", # Add indexer for local geojson with OSM features score_treshold=-400.0, k = 20, use_hf_model=True, @@ -128,7 +130,7 @@ async def create_session(response: Response): graph = SpatialRetrieverGraph(state=State(messages=[], search_criteria="", - spatial_context="", + spatio_temporal_context={}, search_results=[], ready_to_retrieve=""), thread_id=session_id, @@ -147,12 +149,17 @@ async def create_session(response: Response): class Query(BaseModel): query: str + spatio_temporal_context: Optional[Dict[str, Any]] = None @app.post("/data") async def call_graph(query_data: Query, session_id: UUID = Depends(cookie)): if graph is not None: print(f"-#-#--Running graph---- Using session_id: {str(session_id)}") inputs = {"messages": [HumanMessage(content=query_data.query)]} + + if query_data.spatio_temporal_context: + inputs['spatio_temporal_context'] = query_data.spatio_temporal_context + graph.graph.thread_id = str(session_id) response = await graph.ainvoke(inputs) else: @@ -183,7 +190,7 @@ async def index_geojson_osm(api_key: APIKey = Depends(get_api_key)): # await local_file_connector.add_descriptions_to_features() feature_docs = await geojson_osm_connector._features_to_docs() logging.info(f"Converted {len(feature_docs)} Features or FeatureGroups to documents") - res_local = indexes['geojson_osm_indexer']._index(documents=feature_docs) + res_local = indexes['geojson']._index(documents=feature_docs) return res_local def generate_combined_feature_collection(doc_list: List[Document]): @@ -205,7 +212,7 @@ def generate_combined_feature_collection(doc_list: List[Document]): @app.get("/retrieve_geojson") async def retrieve_geojson(query: str): - features = indexes['geojson_osm_indexer'].retriever.invoke(query) + features = indexes['geojson'].retriever.invoke(query) feature_collection = generate_combined_feature_collection(features) @@ -228,7 +235,7 @@ async def clear_index(index_name: str, api_key: APIKey = Depends(get_api_key)): if index_name == 'geojson': logging.info("Clearing geojson index") - indexes['geojson_osm_indexer']._clear() + indexes['geojson']._clear() else: logging.info(f"Clearing index: {index_name}") indexes[index_name]._clear() diff --git a/search-app/server/graph/graph.py b/search-app/server/graph/graph.py index 32a3585..87d2094 100644 --- a/search-app/server/graph/graph.py +++ b/search-app/server/graph/graph.py @@ -13,7 +13,9 @@ generate_search_tool, spatial_context_extraction_tool ) +from .spatial_utilities import check_within_bbox import logging +import ast logging.basicConfig() logging.getLogger().setLevel(logging.INFO) @@ -28,7 +30,7 @@ def is_valid_json(myjson): class State(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] search_criteria: str - spatial_context: str + spatio_temporal_context: dict search_results: List ready_to_retrieve: str index_name: str @@ -58,7 +60,7 @@ def __init__(self, def setup_graph(self): self.add_node("conversation", self.run_conversation) - self.add_node("extract_spatial_context", self.extract_spatial_context) + self.add_node("extract_spatio_temporal_context", self.extract_spatio_temporal_context) self.add_node("search", self.run_search) self.add_node("final_answer", self.final_answer) self.add_node("save_state", self.save_state) @@ -68,11 +70,11 @@ def setup_graph(self): self.should_continue, { "human": "save_state", - "extract_spatial_context": "extract_spatial_context" + "extract_spatio_temporal_context": "extract_spatio_temporal_context" } ) - self.add_edge("extract_spatial_context", "search") + self.add_edge("extract_spatio_temporal_context", "search") self.add_edge("search", "final_answer") self.add_edge("final_answer", "save_state") self.add_edge("save_state", END) @@ -119,12 +121,24 @@ async def run_conversation(self, state: State): state["ready_to_retrieve"] = parsed_dict.get("ready_to_retrieve", "no") return state - def extract_spatial_context(self, state: State): + def extract_spatio_temporal_context(self, state: State): print("---extracting spatial context of search") - spatial_context = spatial_context_extraction_tool.invoke({"query": str(state['search_criteria'])}) - state['spatial_context'] = spatial_context + spatio_temporal_context = state.get('spatio_temporal_context', None) - logging.info(f"Extracted following spatial context: {spatial_context}") + if spatio_temporal_context: + spatial_extent = spatio_temporal_context.get('extent', []) + temporal_extent = spatio_temporal_context.get('temporal', "") + + if not spatio_temporal_context: + spatial_context_str = spatial_context_extraction_tool.invoke({"query": str(state['search_criteria'])}) + spatial_extent = ast.literal_eval(spatial_context_str).get("extent", []) + + state['spatio_temporal_context'] = spatial_extent + + #Todo: also try to derive temporal extent from inputs + temporal_extent = "" + + logging.info(f"Extracted following spatial context: {spatial_extent} and following temporal extent: {temporal_extent}") return state def run_search(self, state: State): @@ -139,9 +153,9 @@ def run_search(self, state: State): logging.info(f"Starting search in index: {index_name} using this tool: {search_tool.name}") search_results = search_tool.invoke({"query": str(state['search_criteria']), - "search_index": search_index, - "search_type": "similarity", - "k": 3}) + "search_index": search_index, + "search_type": "similarity", + "k": 10}) else: tavily_search = TavilySearchResults() search_results = tavily_search.invoke(state["search_criteria"]) @@ -154,20 +168,45 @@ def run_search(self, state: State): def should_continue(self, state: State) -> str: if state.get("ready_to_retrieve") == "yes": print("---routing to spatial context extractor, then to search") - return "extract_spatial_context" + return "extract_spatio_temporal_context" else: return "human" def final_answer(self, state: State) -> str: - if state["index_name"] == "geojson": - logging.info(f"I found: {state['search_results'][-1]}") - context = state["search_results"][-1] - else: - context = state["search_results"] - query = state["search_criteria"] - answer = final_answer_chain.invoke({"query": query, - "context": context}).strip() + for c in self.collection_info_dict: + if c['collection_name'] == state['index_name']: + search_index_info = c.pop('sample_docs', None) + #Todo: use the temporal constraint (e.g. as filter) + try: + if state["index_name"] == "geojson": + # Check if results match spatial context of query + query_bbox = state['spatio_temporal_context'] + search_results = check_within_bbox(search_results=state["search_results"], + bbox=query_bbox) + + logging.info(f"I found: {len(search_results)} using the query-bbox {query_bbox}") + + doc_contents = "\n\n".join(doc.page_content for doc in search_results) + + context = f"""I searched in this search index: {search_index_info}. + The top-{len(search_results)} search results in the specified spatial extent are: {doc_contents}""" + + else: + search_results = state["search_results"][:10] + doc_contents = "\n\n".join(doc.page_content for doc in search_results) + context = f"""I searched in this search index: {search_index_info}. + The top-{len(search_results)} search results are: {doc_contents}""" + + if len(search_results) < 1: + context = "No search results found using the current search criteria" + state['search_results'] = [] + query = state["search_criteria"] + answer = final_answer_chain.invoke({"query": query, + "context": context}).strip() + except: + answer = "Sorry I was not able to process your input. Can you please try again?" + state["messages"].append(AIMessage(content=answer)) return state diff --git a/search-app/server/graph/prompts.py b/search-app/server/graph/prompts.py index 1c80dd5..ac256e8 100644 --- a/search-app/server/graph/prompts.py +++ b/search-app/server/graph/prompts.py @@ -94,12 +94,11 @@ def generate_conversation_prompt(system_prompt=None): def generate_final_answer_prompt(): final_answer_prompt = PromptTemplate( template=""" - You are an assistant for question-answering tasks related to data search. - The question wil be a query and the context either the found datasets or a summary of the recieved data. - Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. + You describe the results of a data search given a certain query. + The search results are either the found datasets or a summary of the recieved data. Use three sentences maximum and keep the answer concise Question: {query} - Context: {context} + Found data: {context} Answer:""", input_variables=["query", "context"], ) diff --git a/search-app/server/graph/spatial_utilities.py b/search-app/server/graph/spatial_utilities.py index 1b7f742..ba305e1 100644 --- a/search-app/server/graph/spatial_utilities.py +++ b/search-app/server/graph/spatial_utilities.py @@ -86,5 +86,30 @@ def generate_spatial_context_chain(llm): return spatial_context_chain -# response = spatial_context_chain.invoke({"query": "I climate data for Berlin"}) -# print(response) +### Functions to check if results with geojson are within a certain spatial extent +# Use this to check if search results match query bbox. +import json +def is_within_bbox(lon, lat, bbox): + min_lon, max_lat, max_lon, min_lat = bbox + return min_lon <= lon <= max_lon and min_lat <= lat <= max_lat + +def check_within_bbox(search_results, bbox): + if not bbox: + return search_results + + results_within_bbox = [] + + for result in search_results: + feature_str = result.metadata.get('feature', '{}') + feature = json.loads(feature_str) + coordinates = feature.get('coordinates', []) + + # Flatten the coordinates (if needed) and check if any coordinate is within the bbox + for poly in coordinates: + for coord in poly: # Assuming polygon with one ring + lon, lat = coord + if is_within_bbox(lon, lat, bbox): + results_within_bbox.append(result) + break # No need to check other coordinates of this polygon if already within bbox + + return results_within_bbox