Skip to content

Commit

Permalink
Merge pull request #31 from 52North/feature/generic-prompts
Browse files Browse the repository at this point in the history
Feature/generic prompts
  • Loading branch information
simeonwetzel authored Aug 19, 2024
2 parents 6b1ab1c + dfc1a73 commit df7c010
Show file tree
Hide file tree
Showing 11 changed files with 649 additions and 132 deletions.
100 changes: 51 additions & 49 deletions search-app/server/app/server.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
from fastapi import FastAPI, HTTPException
import asyncio
from fastapi import FastAPI, HTTPException
import asyncio
from fastapi.responses import RedirectResponse
from langserve import add_routes
from graph.graph import SpatialRetrieverGraph, State
from langchain_core.runnables import chain
from graph.routers import CollectionRouter
from config.config import Config
from indexing.indexer import Indexer
from connectors.pygeoapi_retriever import PyGeoAPI
from connectors.geojson_osm import GeoJSON
from langchain_core.runnables.graph import MermaidDrawMethod
from langchain.schema import Document
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.messages import HumanMessage
from fastapi.middleware.cors import CORSMiddleware
from .utils import SessionData, cookie, verifier, backend
from .utils import (SessionData, cookie, verifier, backend,
calculate_bounding_box, summarize_feature_collection_properties,
load_conversational_prompts)

from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver
from fastapi import HTTPException, FastAPI, Depends, Response, Security
from fastapi.security.api_key import APIKeyHeader, APIKey
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver
from uuid import UUID, uuid4
from typing import List, Optional
from typing import List
import geojson
from pydantic import BaseModel
import json

import logging


logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

Expand All @@ -36,7 +35,9 @@
"http://localhost:5173", # Frontend app origin
]

# Init memory:
memory = AsyncSqliteSaver.from_conn_string(":memory:")

### Get session info via cookie
async def get_current_session(session_id: UUID = Depends(cookie), session_data: SessionData = Depends(verifier)):
return session_data
Expand All @@ -54,40 +55,38 @@ async def get_current_session(session_id: UUID = Depends(cookie), session_data:
# graph = SpatialRetrieverGraph(State(messages=[], search_criteria="", search_results=[], ready_to_retrieve="")).compile()
graph = None
session_id = None
"""
# Generate a visualization of the current dialog-module workflow
graph_visualization = graph.get_graph().draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
with open("./graph/current_workflow.png", "wb") as f:
f.write(graph_visualization)
"""


# Create a dictionary of indexes
indexes = {
"pygeoapi": Indexer(index_name="pygeoapi",
score_treshold= 0.4,
k = 20),
}

# Add indexer for local geojson with OSM features
geojson_osm_indexer = Indexer(index_name="geojson",
"geojson_osm_indexer": Indexer(index_name="geojson", # Add indexer for local geojson with OSM features
score_treshold=-400.0,
k = 20,
use_hf_model=True,
embedding_model="Alibaba-NLP/gte-large-en-v1.5"
)

}

# Add connection to local file including building features
# Replace the value for tag_name argument if you have other data
geojson_osm_connector = GeoJSON(tag_name="building")

"""
# We can also use a osm/geojson that comes from a web resource
local_file_connector = GeoJSON(file_dir="https://webais.demo.52north.org/pygeoapi/collections/dresden_buildings/items",
tag_name="building")
"""

# Adding conversational routes. We do this here to avoid time-expensive llm calls during inference:
collection_router = CollectionRouter()

# Check if already custom prompts generated and if yes: check if these match the existing search indexes
conversational_prompts = load_conversational_prompts(collection_router=collection_router)



app = FastAPI()

app.add_middleware(
Expand Down Expand Up @@ -127,8 +126,17 @@ async def create_session(response: Response):

global graph

graph = SpatialRetrieverGraph(state=State(messages=[], search_criteria="", search_results=[], ready_to_retrieve=""),
thread_id=session_id, memory=memory).compile()
graph = SpatialRetrieverGraph(state=State(messages=[],
search_criteria="",
spatial_context="",
search_results=[],
ready_to_retrieve=""),
thread_id=session_id,
memory=memory,
search_indexes=indexes,
collection_router=collection_router,
conversational_prompts=conversational_prompts
).compile()

data = SessionData(session_id=session_id)

Expand All @@ -137,23 +145,6 @@ async def create_session(response: Response):

return {"message": f"created session for {session}"}

"""
@chain
async def call_graph(query: str, session_id: UUID = Depends(cookie), session_data: SessionData = Depends(verifier)):
if graph is not None:
print(f"-#-#--Running graph---- Using session_id: {str(session_id)}")
print(f"session_data: {session_data}")
inputs = {"messages": [HumanMessage(content=query)]}
graph.graph.thread_id = "test"
response = await graph.ainvoke(inputs)
else:
raise HTTPException(status_code=400, detail="No session created")
return response
"""
@app.get("/test_api_key")
async def test_api_key(api_key: APIKey = Depends(get_api_key)):
return f"Entered API KEY: {api_key}"

class Query(BaseModel):
query: str

Expand Down Expand Up @@ -192,7 +183,7 @@ async def index_geojson_osm(api_key: APIKey = Depends(get_api_key)):
# await local_file_connector.add_descriptions_to_features()
feature_docs = await geojson_osm_connector._features_to_docs()
logging.info(f"Converted {len(feature_docs)} Features or FeatureGroups to documents")
res_local = geojson_osm_indexer._index(documents=feature_docs)
res_local = indexes['geojson_osm_indexer']._index(documents=feature_docs)
return res_local

def generate_combined_feature_collection(doc_list: List[Document]):
Expand All @@ -208,15 +199,26 @@ def generate_combined_feature_collection(doc_list: List[Document]):
features.extend(feature_list)

combined_feature_collection = geojson.FeatureCollection(features)
geojson_str = geojson.dumps(combined_feature_collection, sort_keys=True, indent=2)
# geojson_str = geojson.dumps(combined_feature_collection, sort_keys=True, indent=2)

return geojson_str
return combined_feature_collection

@app.get("/retrieve_geojson")
async def retrieve_geojson(query: str):
features = geojson_osm_indexer.retriever.invoke(query)
features = indexes['geojson_osm_indexer'].retriever.invoke(query)

feature_collection = generate_combined_feature_collection(features)

spatial_extent = calculate_bounding_box(feature_collection)
properties = summarize_feature_collection_properties(feature_collection)

summary = f"""Summary of found features:
{properties}
Spatial Extent of all features: {spatial_extent}
"""

return generate_combined_feature_collection(features)
return feature_collection, summary


@app.get("/clear_index")
Expand All @@ -226,7 +228,7 @@ async def clear_index(index_name: str, api_key: APIKey = Depends(get_api_key)):

if index_name == 'geojson':
logging.info("Clearing geojson index")
geojson_osm_indexer._clear()
indexes['geojson_osm_indexer']._clear()
else:
logging.info(f"Clearing index: {index_name}")
indexes[index_name]._clear()
Expand Down Expand Up @@ -256,4 +258,4 @@ async def remove_doc_from_index(index_name: str, _id: str, api_key: APIKey = Dep

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", reload=False, port=8000)
106 changes: 106 additions & 0 deletions search-app/server/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
from fastapi_sessions.backends.implementations import InMemoryBackend
from fastapi_sessions.session_verifier import SessionVerifier
from fastapi_sessions.frontends.implementations import SessionCookie, CookieParameters
import os
import importlib
import json
import sys
from pathlib import Path
import logging


logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)


class SessionData(BaseModel):
session_id: str
Expand Down Expand Up @@ -52,3 +63,98 @@ def verify_session(self, model: SessionData) -> bool:
backend=backend,
auth_http_exception=HTTPException(status_code=403, detail="Invalid session"),
)

### Geojson utilities
def calculate_bounding_box(geojson):
min_lng, min_lat = float('inf'), float('inf')
max_lng, max_lat = float('-inf'), float('-inf')

def extract_coordinates(geometry):
if geometry['type'] == 'Point':
return [geometry['coordinates']]
elif geometry['type'] in ['MultiPoint', 'LineString']:
return geometry['coordinates']
elif geometry['type'] in ['MultiLineString', 'Polygon']:
return [coord for line in geometry['coordinates'] for coord in line]
elif geometry['type'] == 'MultiPolygon':
return [coord for poly in geometry['coordinates'] for line in poly for coord in line]
else:
return []

for feature in geojson['features']:
coords = extract_coordinates(feature['geometry'])
for coord in coords:
lng, lat = coord
min_lng = min(min_lng, lng)
min_lat = min(min_lat, lat)
max_lng = max(max_lng, lng)
max_lat = max(max_lat, lat)

return [min_lng, min_lat, max_lng, max_lat]


def summarize_feature_collection_properties(feature_collection):

data = list(map(lambda f: f['properties'], feature_collection['features']))

summary = {}

for item in data:
item_type = item.get('type', '')
description = item.get('description', '')

if item_type not in summary:
summary[item_type] = {'count': 0, 'descriptions': []}

summary[item_type]['count'] += 1

if description and description not in summary[item_type]['descriptions']:
summary[item_type]['descriptions'].append(description)

summary_text = ""
for item_type, details in summary.items():
summary_text += f"Type: {item_type} (Count: {details['count']})\nDescriptions:\n"
for desc in details['descriptions']:
summary_text += f"- {desc}\n"
summary_text += "\n"

return summary_text.strip()


### Custom prompt utilities

def save_conversational_prompts(file_name, conversational_prompts):
with open(file_name, 'w') as f:
json.dump(conversational_prompts, f, indent=4) # Pretty print with indentation


def read_dict_from_module(module_path):
module_name = Path(module_path).stem
if os.path.exists(f"{module_path}"):
try:
from graph.custom_prompts.custom_prompts import prompts
return prompts
except ImportError:
return None
else:
logging.info(f"Module '{module_name}.py' does not exist.")
return None

def write_dict_to_file(dictionary, filename):
with open(filename, 'w') as file:
file.write(f"prompts = {repr(dictionary)}\n")

def load_conversational_prompts(collection_router):
loaded_dict = read_dict_from_module('./graph/custom_prompts/custom_prompts.py')
collection_names = [c['collection_name'] for c in collection_router.coll_dicts]

if loaded_dict and set(loaded_dict.keys()) == set(collection_names):
logging.info("Custom prompts already generated for current collections. Reading it from file...")
conversational_prompts = loaded_dict
else:
conversational_prompts = collection_router.generate_conversation_prompts()
write_dict_to_file(conversational_prompts,'./graph/custom_prompts/custom_prompts.py')

return conversational_prompts


Empty file.
13 changes: 9 additions & 4 deletions search-app/server/connectors/geojson_osm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,19 @@ class GeoJSON():
_features_to_docs() -> List[Document]:
Converts features into a list of Document objects for further use.
"""
def __init__(self, file_dir: str = None, tag_name: str = "building"):
def __init__(self, file_dir: str = None, tag_name: str = None):
if file_dir and is_url(file_dir):
"""We assume the online resource to be a collection published via a PyGeoAPI instance"""
logging.info("Getting features from online resource")
params = {"f": "json", "limit": 10000}
gj = self._fetch_features_from_online_resource(file_dir, params)
print(f"Retrieved {len(gj)} features")

self.features = self._filter_meaningful_features(gj, tag_name)

self.tag_name = tag_name
if self.tag_name:
self.features = self._filter_meaningful_features(gj, self.tag_name)
else:
self.features = gj
else:
if not file_dir:
file_dir = config.local_geojson_files
Expand Down Expand Up @@ -180,7 +184,8 @@ def _get_feature_description(self, feature):
return "\n".join(description_parts)

async def _features_to_docs(self) -> List[Document]:
await self.add_descriptions_to_features()
if self.tag_name:
await self.add_descriptions_to_features()

# Part 1: Create documents for features with names
features_with_names = list(filter(lambda feature: feature if feature["properties"].get("name", "") else None, self.features))
Expand Down
Loading

0 comments on commit df7c010

Please sign in to comment.