Skip to content

Commit

Permalink
adding one shot prompting along with multimodal request (#1023)
Browse files Browse the repository at this point in the history
* adding one shot prompting along with multimodal request

* linting fixes

* mypy fixes

---------

Co-authored-by: Soeb Hussain
  • Loading branch information
Soeb-aryn authored Nov 14, 2024
1 parent 766985f commit 9ddcaef
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 87 deletions.
2 changes: 0 additions & 2 deletions lib/sycamore/sycamore/llms/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
TaskIdentifierZeroShotGuidancePrompt,
GraphEntityExtractorPrompt,
GraphRelationshipExtractorPrompt,
ExtractTablePropertiesTablePrompt,
ExtractTablePropertiesPrompt,
)
from sycamore.llms.prompts.default_prompts import _deprecated_prompts
Expand All @@ -26,7 +25,6 @@
"PropertiesZeroShotGuidancePrompt",
"GraphEntityExtractorPrompt",
"GraphRelationshipExtractorPrompt",
"ExtractTablePropertiesTablePrompt",
"ExtractTablePropertiesPrompt",
] + list(_deprecated_prompts.keys())

Expand Down
66 changes: 39 additions & 27 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,36 +119,48 @@ class GraphRelationshipExtractorPrompt(SimplePrompt):

class ExtractTablePropertiesPrompt(SimplePrompt):
user = """
You are given a text string represented as CSV (comma-separated values) where columns are separated by commas representing either a single column,
or a multi-column table each new line is a new row.
Instructions:
1. Parse the table and return a flattened JSON object representing the key-value pairs of properties
defined in the table.
2. Do not return nested objects; keep the dictionary only 1 level deep. The only valid value types
are numbers, strings, and lists.
3. If you find multiple fields defined in a row, feel free to split them into separate properties.
4. Use camelCase for the key names.
5. For fields where the values are in standard measurement units like miles,
nautical miles, knots, or celsius, include the unit in the key name and only set the
numeric value as the value.
- "Wind Speed: 9 knots" should become "windSpeedInKnots": 9
- "Temperature: 3°C" should become "temperatureInC": 3
6. Ensure that key names are enclosed in double quotes.
7. return only the json object between ```
You are given a text string represented as a CSV (comma-separated values) and an image of a table.
Instructions:
Check if the table contains key-value pairs. A key-value pair table is a table where data is structured as key-value pairs. Generally, the first column contains the key and the second column contains the value. However, key-value pairs can also appear in other formats.
If there is a one-to-one mapping between two cells, even if the relationship is not direct, they should be considered key-value pairs.
If the table is a key-value pair table, return its key-value pairs as a JSON object.
If the table is not a key-value pair table, return False.
Parse the table, check the image, and return a flattened JSON object representing the key-value pairs from the table. The extracted key-value pairs should be formatted as a JSON object.
Do not return nested objects; keep the dictionary only one level deep. The only valid value types are numbers, strings, None, and lists.
Use camelCase for the key names.
For fields where the values are in standard measurement units like miles, nautical miles, knots, or Celsius, include the unit in the key name and only set the numeric value as the value:
"Wind Speed: 9 knots" should become "windSpeedInKnots": 9
"Temperature: 3°C" should become "temperatureInC": 3
Ensure that key names are enclosed in double quotes.
Return only the JSON object between ``` if the table is a key-value pair table; otherwise, return False.
example of a key-value pair table:
|---------------------------------|------------------|
| header 1 | header 2 |
|---------------------------------|------------------|
| NEW FIRE ALARM SYSTEMS | $272 TWO HOURS |
| NEW SPRINKLER SYSTEMS | $408 THREE HOURS |
| NEW GASEOUS SUPPRESSION SYSTEMS | $272 TWO HOURS |
|---------------------------------|------------------|
return ```{"NEW FIRE ALARM SYSTEMS": "$272 TWO HOURS", "NEW SPRINKLER SYSTEMS": "$408 THREE HOURS", "NEW GASEOUS SUPPRESSION SYSTEMS": "$272 TWO HOURS"}```
example of a table which is not key-value pair table:
|---------------------------------|------------------|------------------|
| header 1 | header 2 | header 3 |
|---------------------------------|------------------|------------------|
| NEW FIRE ALARM SYSTEMS | $272 TWO HOURS | $2752 ONE HOUR |
| NEW SPRINKLER SYSTEMS | $408 THREE HOURS | $128 FIVE HOURS |
| NEW GASEOUS SUPPRESSION SYSTEMS | $272 TWO HOURS | $652 TEN HOURS |
|---------------------------------|------------------|------------------|
return False
"""


class ExtractTablePropertiesTablePrompt(SimplePrompt):
user = """
You are given a table represented as CSV (comma-separated values),
Instructions:
1. Parse the table to determine if key-value pair information can be extracted from it.
2. A key cell may correspond to multiple value cells.
3. Return True if the table can be parsed as key-value pairs.
4. Return only True or False; nothing should be added into the response.
"""


class EntityExtractorMessagesPrompt(SimplePrompt):
def __init__(self, question: str, field: str, format: Optional[str], discrete: bool = False):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,42 +1,63 @@
from sycamore.data import Document
from sycamore.data import Document, Element, Table, TableCell
from sycamore.llms import OpenAI
from sycamore.transforms.extract_table_properties import ExtractTableProperties
from sycamore.data.table import Table, TableCell
from PIL import Image
from io import BytesIO


class TestExtractTableProperties:
def table(self, str1, str2) -> Table:
return Table(
[
TableCell(content="head1", rows=[0], cols=[0], is_header=True),
TableCell(content="head2", rows=[0], cols=[1], is_header=True),
TableCell(content=str1, rows=[1], cols=[0], is_header=False),
TableCell(content=str2, rows=[1], cols=[1], is_header=False),
]
def test_extract_key_value_pair(self, mocker):
table_cells = [
TableCell(content="head1", rows=[0], cols=[0], is_header=True),
TableCell(content="head2", rows=[0], cols=[1], is_header=True),
TableCell(content="key1", rows=[1], cols=[0], is_header=False),
TableCell(content="val1", rows=[1], cols=[1], is_header=False),
]
table = Table(table_cells)
table_bbox = (0.1, 0.2, 0.8, 0.9)
table_element = Element(
type="table",
bbox=table_bbox,
properties={"page_number": 1, "title": {"rows": None, "columns": None}},
table=table,
tokens=None,
)

def test_extract_key_value_pair(self, mocker):
self.doc = Document(
{
"doc_id": "doc_id",
"type": "pdf",
"text_representation": "text_representation",
"bbox": (1, 2.3, 3.4, 4.5),
"elements": [
{
"type": "table",
"bbox": (1, 2, 3, 4.0),
"properties": {"title": {"rows": None, "columns": None}},
"table": self.table("key1", "val1"),
"tokens": None,
},
],
"properties": {"int": 0, "float": 3.14, "list": [1, 2, 3, 4], "tuple": (1, "tuple")},
}
doc_id="doc_id",
type="pdf",
text_representation="text_representation",
elements=[table_element],
properties={"int": 0, "float": 3.14, "list": [1, 2, 3, 4], "tuple": (1, "tuple")},
binary_representation=b"<dummy>",
)
# print(self.doc)

mock_split_and_convert_to_image = mocker.patch(
"sycamore.transforms.extract_table_properties.split_and_convert_to_image"
)
image = Image.new("RGB", (100, 100), color="white")
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="PNG")
img_data = img_byte_arr.getvalue()

mock_image = Document()
mock_image.binary_representation = img_data
mock_image.properties = {
"size": (100, 100),
"mode": "RGB",
}
mock_image.binary_representation = img_data

mock_split_and_convert_to_image.return_value = [mock_image]

mock_frombytes = mocker.patch("PIL.Image.frombytes")
mock_frombytes.return_value = image

llm = mocker.Mock(spec=OpenAI)
_ = mocker.patch.object(llm, "generate", side_effect=["True", '{"key1":"val1"}'])
doc1 = ExtractTableProperties(None, parameters=["llm_response", llm]).run(self.doc)
print(doc1)
llm.generate.return_value = '{"key1":"val1"}'
llm.format_image.return_value = {"type": "image", "data": "dummy"}

property_name = "llm_response"
doc1 = ExtractTableProperties(None, parameters=[property_name, llm]).run(self.doc)

assert (doc1.elements[0].properties.get("llm_response")) == {"key1": "val1"}
58 changes: 31 additions & 27 deletions lib/sycamore/sycamore/transforms/extract_table_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import timetrace
from sycamore.transforms.llm_query import LLMTextQueryAgent
from sycamore.llms import LLM
from sycamore.llms.prompts import ExtractTablePropertiesPrompt, ExtractTablePropertiesTablePrompt
from sycamore.llms.prompts import ExtractTablePropertiesPrompt
from PIL import Image
from sycamore.functions.document import split_and_convert_to_image


class ExtractTableProperties(SingleThreadUser, NonGPUUser, Map):
Expand Down Expand Up @@ -65,30 +66,33 @@ def extract_table_properties(
This method is used to extract key/value pairs from tables, using the LLM,
and populate them as a property of that element.
"""
prompt_find_table = prompt_find_table or ExtractTablePropertiesTablePrompt().user
query_agent = LLMTextQueryAgent(
prompt=prompt_find_table, llm=llm, output_property="keyValueTable", element_type="table"
)
query_agent.execute_query(parent)
doc1 = split_and_convert_to_image(parent)
img_list = []
for img in doc1:
# print(img['properties'])
size = tuple(img.properties["size"])
mode = img.properties["mode"]
image = Image.frombytes(mode=mode, size=size, data=img.binary_representation)
img_list.append((image, size, mode))

prompt_llm = prompt_LLM or ExtractTablePropertiesPrompt().user
query_agent = LLMTextQueryAgent(prompt=prompt_llm, llm=llm, output_property=property_name, element_type="table")
query_agent.execute_query(parent)

for ele in parent.elements:
if ele.type == "table" and property_name in ele.properties.keys():
if ele.properties.get("keyValueTable", False) != "True":
del ele.properties[property_name]
continue
jsonstring_llm = ele.properties.get(property_name)
assert isinstance(
jsonstring_llm, str
), f"Expected string, got {type(jsonstring_llm).__name__}: {jsonstring_llm}"
json_string = ExtractTableProperties.extract_parent_json(jsonstring_llm)
assert isinstance(json_string, str)
keyValue = json.loads(json_string)
if isinstance(keyValue, dict):
ele.properties[property_name] = keyValue
else:
raise ValueError(f"Extracted JSON string is not a dictionary: {keyValue}")
for idx, ele in enumerate(parent.elements):
if ele is not None and ele.type == "table" and ele.bbox is not None:
image, size, mode = img_list[ele.properties["page_number"] - 1] # output of APS is one indexed
bbox = ele.bbox.coordinates
img = image.crop((bbox[0] * size[0], bbox[1] * size[1], bbox[2] * size[0], bbox[3] * size[1]))
content = [
{
"type": "text",
"text": prompt_LLM if prompt_LLM is not None else ExtractTablePropertiesPrompt.user,
},
llm.format_image(img),
]
messages = [
{"role": "user", "content": content},
]
prompt_kwargs = {"messages": messages}
raw_answer = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={})
parsed_json = ExtractTableProperties.extract_parent_json(raw_answer)
if parsed_json:
ele.properties[property_name] = json.loads(parsed_json)
return parent

0 comments on commit 9ddcaef

Please sign in to comment.