From 722e0fffa112d93746d8b187fef18b1a79de6804 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Tue, 26 Dec 2023 14:39:28 +0100 Subject: [PATCH 1/3] Allow pass column to plcontainer_apply --- greenplumpython/func.py | 12 +++++++++++- tests/test_plcontainer.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index c92082b7..7dcc490e 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -115,12 +115,21 @@ def apply( isinstance(self._function, NormalFunction) and self._function._language_handler == "plcontainer" ): + input_args = self._args return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType _serialize_to_type_name(return_annotation, db=db, for_return=True) + input_clause = ( + "*" + if ( + len(input_args) == 0 + or (len(input_args) == 1 and isinstance(input_args[0], DataFrame)) + ) + else ",".join([arg._serialize(db=db) for arg in input_args]) + ) return DataFrame( f""" SELECT * FROM plcontainer_apply(TABLE( - SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS + SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS {_defined_types[return_annotation.__args__[0]]._serialize(db=db)} """, db=db, @@ -370,6 +379,7 @@ def _serialize(self, db: Database) -> str: f" import sys as {sys_lib_name}\n" f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n" f" raise ModuleNotFoundError\n" + f" {sys_lib_name}.modules['plpy']=plpy\n" f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n" f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n" f" except ModuleNotFoundError:\n" diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py index 6b16f051..d703283a 100644 --- a/tests/test_plcontainer.py +++ b/tests/test_plcontainer.py @@ -1,14 +1,29 @@ from dataclasses import dataclass +import pytest + import greenplumpython as gp from tests import db -def test_simple_func(db: gp.Database): - @dataclass - class Int: - i: int +@dataclass +class Int: + i: int + + +@dataclass +class Pair: + i: int + j: int + + +@pytest.fixture +def t(db: gp.Database): + rows = [(i, i) for i in range(10)] + return db.create_dataframe(rows=rows, column_names=["a", "b"]) + +def test_simple_func(db: gp.Database): @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") def add_one(x: list[Int]) -> list[Int]: return [{"i": arg["i"] + 1} for arg in x] @@ -23,3 +38,11 @@ def add_one(x: list[Int]) -> list[Int]: ) == 10 ) + + +def test_func_column(db: gp.Database, t: gp.DataFrame): + @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") + def add(x: list[Pair]) -> list[Int]: + return [{"i": arg["i"] + arg["j"]} for arg in x] + + assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10 From 4452769630c003dd5d7934824dff3a642d37d2d4 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Tue, 26 Dec 2023 14:42:45 +0100 Subject: [PATCH 2/3] Retrigger From c395bd29ba6f0f2a8b57fe15cdddd58213e9df4e Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Wed, 27 Dec 2023 09:57:23 +0100 Subject: [PATCH 3/3] Do not allow no input data for plcontainer_apply --- greenplumpython/func.py | 9 ++++----- tests/test_plcontainer.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index 7dcc490e..dff81877 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -115,15 +115,14 @@ def apply( isinstance(self._function, NormalFunction) and self._function._language_handler == "plcontainer" ): - input_args = self._args return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType _serialize_to_type_name(return_annotation, db=db, for_return=True) + input_args = self._args + if len(input_args) == 0: + raise Exception("No input data specified, please specify a DataFrame or Columns") input_clause = ( "*" - if ( - len(input_args) == 0 - or (len(input_args) == 1 and isinstance(input_args[0], DataFrame)) - ) + if (len(input_args) == 1 and isinstance(input_args[0], DataFrame)) else ",".join([arg._serialize(db=db) for arg in input_args]) ) return DataFrame( diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py index d703283a..1a1b39e4 100644 --- a/tests/test_plcontainer.py +++ b/tests/test_plcontainer.py @@ -23,16 +23,17 @@ def t(db: gp.Database): return db.create_dataframe(rows=rows, column_names=["a", "b"]) -def test_simple_func(db: gp.Database): - @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") - def add_one(x: list[Int]) -> list[Int]: - return [{"i": arg["i"] + 1} for arg in x] +@gp.create_function(language_handler="plcontainer", runtime="plc_python_example") +def add_one(x: list[Int]) -> list[Int]: + return [{"i": arg["i"] + 1} for arg in x] + +def test_simple_func(db: gp.Database): assert ( len( list( db.create_dataframe(columns={"i": range(10)}).apply( - lambda _: add_one(), expand=True + lambda t: add_one(t), expand=True ) ) ) @@ -40,6 +41,13 @@ def add_one(x: list[Int]) -> list[Int]: ) +def test_func_no_input(db: gp.Database): + + with pytest.raises(Exception) as exc_info: # no input data for func raises Exception + db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True) + assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value) + + def test_func_column(db: gp.Database, t: gp.DataFrame): @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") def add(x: list[Pair]) -> list[Int]: