Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Allow pass column to plcontainer_apply #229

Merged
merged 3 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion greenplumpython/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,18 @@ def apply(
):
return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType
_serialize_to_type_name(return_annotation, db=db, for_return=True)
input_args = self._args
if len(input_args) == 0:
raise Exception("No input data specified, please specify a DataFrame or Columns")
input_clause = (
"*"
if (len(input_args) == 1 and isinstance(input_args[0], DataFrame))
else ",".join([arg._serialize(db=db) for arg in input_args])
)
return DataFrame(
f"""
SELECT * FROM plcontainer_apply(TABLE(
SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS
SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS
{_defined_types[return_annotation.__args__[0]]._serialize(db=db)}
""",
db=db,
Expand Down Expand Up @@ -370,6 +378,7 @@ def _serialize(self, db: Database) -> str:
f" import sys as {sys_lib_name}\n"
f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n"
f" raise ModuleNotFoundError\n"
f" {sys_lib_name}.modules['plpy']=plpy\n"
f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n"
f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n"
f" except ModuleNotFoundError:\n"
Expand Down
47 changes: 39 additions & 8 deletions tests/test_plcontainer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
from dataclasses import dataclass

import pytest

import greenplumpython as gp
from tests import db


def test_simple_func(db: gp.Database):
@dataclass
class Int:
i: int
@dataclass
class Int:
i: int


@dataclass
class Pair:
i: int
j: int


@pytest.fixture
def t(db: gp.Database):
rows = [(i, i) for i in range(10)]
return db.create_dataframe(rows=rows, column_names=["a", "b"])


@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
def add_one(x: list[Int]) -> list[Int]:
return [{"i": arg["i"] + 1} for arg in x]

@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
def add_one(x: list[Int]) -> list[Int]:
return [{"i": arg["i"] + 1} for arg in x]

def test_simple_func(db: gp.Database):
assert (
len(
list(
db.create_dataframe(columns={"i": range(10)}).apply(
lambda _: add_one(), expand=True
lambda t: add_one(t), expand=True
)
)
)
== 10
)


def test_func_no_input(db: gp.Database):

with pytest.raises(Exception) as exc_info: # no input data for func raises Exception
db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True)
assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value)


def test_func_column(db: gp.Database, t: gp.DataFrame):
@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
def add(x: list[Pair]) -> list[Int]:
return [{"i": arg["i"] + arg["j"]} for arg in x]

assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10