Skip to content

Commit

Permalink
improving ExtractTableProperties and standardizer transforms (#773)
Browse files Browse the repository at this point in the history
* ntsb transform changes

* updating ntsb notebook

* removing old file

* lint fixes

* filename fixes

* linting fix

* linting fix

* moving prompts to a common file

* lint fix

---------

Co-authored-by: Soeb <[email protected]>
  • Loading branch information
Soeb-aryn authored Sep 9, 2024
1 parent 7fae584 commit 1c7b31d
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 304 deletions.
4 changes: 4 additions & 0 deletions lib/sycamore/sycamore/llms/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
SchemaZeroShotGuidancePrompt,
PropertiesZeroShotGuidancePrompt,
TaskIdentifierZeroShotGuidancePrompt,
ExtractTablePropertiesTablePrompt,
ExtractTablePropertiesPrompt,
)
from sycamore.llms.prompts.default_prompts import _deprecated_prompts

Expand All @@ -20,6 +22,8 @@
"TextSummarizerGuidancePrompt",
"SchemaZeroShotGuidancePrompt",
"PropertiesZeroShotGuidancePrompt",
"ExtractTablePropertiesTablePrompt",
"ExtractTablePropertiesPrompt",
] + list(_deprecated_prompts.keys())

__all__ = prompts
Expand Down
33 changes: 33 additions & 0 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,39 @@ class TaskIdentifierZeroShotGuidancePrompt(SimplePrompt):
"""


class ExtractTablePropertiesPrompt(SimplePrompt):
user = """
You are given a text string where columns are separated by comma representing either a single column,
or multi-column table each new line is a new row.
Instructions:
1. Parse the table and return a flattened JSON object representing the key-value pairs of properties
defined in the table.
2. Do not return nested objects, keep the dictionary only 1 level deep. The only valid value types
are numbers, strings, and lists.
3. If you find multiple fields defined in a row, feel free to split them into separate properties.
4. Use camelCase for the key names
5. For fields where the values are in standard measurement units like miles,
nautical miles, knots, celsius
6. return only the json object between ```
- include the unit in the key name and only set the numeric value as the value.
- e.g. "Wind Speed: 9 knots" should become windSpeedInKnots: 9,
"Temperature: 3°C" should become temperatureInC: 3
"""


class ExtractTablePropertiesTablePrompt(SimplePrompt):
user = """
You are given a text string where columns are separated by comma representing either a single column,
or multi-column table each new line is a new row.
Instructions:
1. Parse the table and make decision if key, value pair information can be extracted from it.
2. if the table contains multiple cell value corresponding to one key, the key, value pair for such table
cant be extracted.
3. return True if table cant be parsed as key value pair.
4. return only True or False nothing should be added in the response.
"""


class EntityExtractorMessagesPrompt(SimplePrompt):
def __init__(self, question: str, field: str, format: Optional[str], discrete: bool = False):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,18 @@ def setUp(self):
)

def test_datetime(self):
date_standardizer = DateTimeStandardizer()

output = StandardizeProperty(
None, standardizer=date_standardizer, path=[["properties", "entity", "dateTime"]]
None, standardizer=DateTimeStandardizer, path=[["properties", "entity", "dateTime"]]
).run(self.input)

assert "properties" in output.keys()
assert "entity" in output.properties.keys()
assert output.properties.get("entity")["dateTime"] == "March 17, 2023, 14:25 "
assert output.properties.get("entity")["day"] == date(2023, 3, 17)

def test_location(self):
loc_standardizer = LocationStandardizer()
output = StandardizeProperty(
None, standardizer=loc_standardizer, path=[["properties", "entity", "location"]]
None, standardizer=LocationStandardizer, path=[["properties", "entity", "location"]]
).run(self.input)

assert "properties" in output.keys()
Expand Down
2 changes: 1 addition & 1 deletion lib/sycamore/sycamore/transforms/assign_doc_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class AssignDocProperties(SingleThreadUser, NonGPUUser, Map):
"""
The AssignDocProperties transform is used to copy properties from first element pf a specific type
The AssignDocProperties transform is used to copy properties from first element of a specific type
to the parent document. This allows for the consolidation of key attributes at the document level.
Args:
Expand Down
42 changes: 12 additions & 30 deletions lib/sycamore/sycamore/transforms/extract_table_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from sycamore.transforms.llm_query import LLMTextQueryAgent
from sycamore.llms import LLM
from sycamore.llms.prompts import ExtractTablePropertiesPrompt, ExtractTablePropertiesTablePrompt


class ExtractTableProperties(SingleThreadUser, NonGPUUser, Map):
Expand Down Expand Up @@ -52,42 +53,23 @@ def extract_parent_json(input_string: str) -> str:

@staticmethod
@timetrace("ExtrKeyVal")
def extract_table_properties(parent: Document, property_name: str, llm: LLM) -> Document:
def extract_table_properties(
parent: Document, property_name: str, llm: LLM, prompt_find_table: str = "", prompt_LLM: str = ""
) -> Document:
"""
This Method is used to extract key value pair from table using LLM and
populate it as property of that element.
"""
prompt = """
You are given a text string where columns are separated by comma representing either a single column,
or multi-column table each new line is a new row.
Instructions:
1. Parse the table and make decision if key, value pair information can be extracted from it.
2. if the table contains multiple cell value corresponding to one key, the key, value pair for such table
cant be extracted.
3. return True if table cant be parsed as key value pair.
4. return only True or False nothing should be added in the response.
"""
query_agent = LLMTextQueryAgent(prompt=prompt, llm=llm, output_property="keyValueTable", element_type="table")
if prompt_find_table == "":
prompt_find_table = ExtractTablePropertiesTablePrompt().user
query_agent = LLMTextQueryAgent(
prompt=prompt_find_table, llm=llm, output_property="keyValueTable", element_type="table"
)
doc = query_agent.execute_query(parent)

prompt = """
You are given a text string where columns are separated by comma representing either a single column,
or multi-column table each new line is a new row.
Instructions:
1. Parse the table and return a flattened JSON object representing the key-value pairs of properties
defined in the table.
2. Do not return nested objects, keep the dictionary only 1 level deep. The only valid value types
are numbers, strings, and lists.
3. If you find multiple fields defined in a row, feel free to split them into separate properties.
4. Use camelCase for the key names
5. For fields where the values are in standard measurement units like miles,
nautical miles, knots, celsius
6. return only the json object between ```
- include the unit in the key name and only set the numeric value as the value.
- e.g. "Wind Speed: 9 knots" should become windSpeedInKnots: 9,
"Temperature: 3°C" should become temperatureInC: 3
"""
query_agent = LLMTextQueryAgent(prompt=prompt, llm=llm, output_property=property_name, element_type="table")
if prompt_LLM == "":
prompt_LLM = ExtractTablePropertiesPrompt().user
query_agent = LLMTextQueryAgent(prompt=prompt_LLM, llm=llm, output_property=property_name, element_type="table")
doc = query_agent.execute_query(parent)

for ele in doc.elements:
Expand Down
55 changes: 38 additions & 17 deletions lib/sycamore/sycamore/transforms/standardizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def fixer(self, text: str) -> Union[str, Tuple[str, date]]:
"""
pass

@abstractmethod
def standardize(self, doc: Document, key_path: List[str]) -> Document:
"""
Applies the fixer method to a specific field in the document as defined by the key_path.
Abstract method applies the fixer method to a specific field in the document as defined by the key_path.
Args:
doc (Document): The document to be standardized.
Expand All @@ -45,18 +46,7 @@ def standardize(self, doc: Document, key_path: List[str]) -> Document:
Raises:
KeyError: If any of the keys in key_path are not found in the document.
"""
current = doc
for key in key_path[:-1]:
if current.get(key, None):
current = current[key]
else:
raise KeyError(f"Key {key} not found in the dictionary among {current.keys()}")
target_key = key_path[-1]
if current.get(target_key, None):
current[target_key] = self.fixer(current[target_key])
else:
raise KeyError(f"Key {target_key} not found in the dictionary among {current.keys()}")
return doc
pass


class LocationStandardizer(Standardizer):
Expand Down Expand Up @@ -118,7 +108,8 @@ class LocationStandardizer(Standardizer):
"WY": "Wyoming",
}

def fixer(self, text: str) -> str:
@staticmethod
def fixer(text: str) -> str:
"""
Replaces any US state abbreviations in the text with their full state names.
Expand All @@ -135,13 +126,42 @@ def replacer(match):

return re.sub(r"\b[A-Z]{2}\b", replacer, text)

@staticmethod
def standardize(doc: Document, key_path: List[str]) -> Document:
"""
Applies the fixer method to a specific field in the document as defined by the key_path.
Args:
doc (Document): The document to be standardized.
key_path (List[str]): The path to the field within the document that should be standardized.
Returns:
Document: The document with the standardized field.
Raises:
KeyError: If any of the keys in key_path are not found in the document.
"""
current = doc
for key in key_path[:-1]:
if current.get(key, None):
current = current[key]
else:
raise KeyError(f"Key {key} not found in the dictionary among {current.keys()}")
target_key = key_path[-1]
if current.get(target_key, None):
current[target_key] = LocationStandardizer.fixer(current[target_key])
else:
raise KeyError(f"Key {target_key} not found in the dictionary among {current.keys()}")
return doc


class DateTimeStandardizer(Standardizer):
"""
A standardizer for transforming date and time strings into a consistent format.
"""

def fixer(self, raw_dateTime: str) -> Tuple[str, date]:
@staticmethod
def fixer(raw_dateTime: str) -> Tuple[str, date]:
"""
Converts a date-time string by replacing periods with colons and parsing it into a date object.
Expand Down Expand Up @@ -175,7 +195,8 @@ def fixer(self, raw_dateTime: str) -> Tuple[str, date]:
# Handle any other exceptions
raise RuntimeError(f"Unexpected error occurred while processing: {raw_dateTime}") from e

def standardize(self, doc: Document, key_path: List[str]) -> Document:
@staticmethod
def standardize(doc: Document, key_path: List[str]) -> Document:
"""
Applies the fixer method to a specific date-time field in the document as defined by the key_path,
and adds an additional "day" field with the extracted date.
Expand All @@ -199,7 +220,7 @@ def standardize(self, doc: Document, key_path: List[str]) -> Document:
raise KeyError(f"Key {key} not found in the dictionary among {current.keys()}")
target_key = key_path[-1]
if target_key in current.keys():
current[target_key], current["day"] = self.fixer(current[target_key])
current[target_key], current["day"] = DateTimeStandardizer.fixer(current[target_key])
else:
raise KeyError(f"Key {target_key} not found in the dictionary among {current.keys()}")
return doc
Expand Down
Loading

0 comments on commit 1c7b31d

Please sign in to comment.