diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 6dd57ad8..402f5678 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -1126,20 +1126,7 @@ def from_table(cls, table_name: str, db: Database, schema: Optional[str] = None) """ qualified_name = f'"{schema}"."{table_name}"' if schema is not None else f'"{table_name}"' - columns_query = f""" - SELECT attname AS column_name, atttypid::regtype AS data_type - FROM pg_attribute - WHERE attrelid = '{qualified_name}'::regclass and attnum > 0; - """ - columns_inf_result = list(db._execute(columns_query, has_results=True)) # type: ignore reportUnknownVariableType - assert columns_inf_result, f"Table {qualified_name} does not exists" - columns_list: dict[str, str] = {d["column_name"]: d["data_type"] for d in columns_inf_result} # type: ignore reportUnknownVariableType - return cls( - f"TABLE {qualified_name}", - db=db, - qualified_table_name=qualified_name, - columns=columns_list, - ) # type: ignore reportUnknownVariableType + return cls(f"TABLE {qualified_name}", db=db, qualified_table_name=qualified_name) @classmethod def from_rows( @@ -1277,3 +1264,26 @@ def from_files(cls, files: list[str], parser: "NormalFunction", db: Database) -> raise NotImplementedError( "Please import greenplumpython.experimental.file to load the implementation." ) + + def describe(self) -> dict[str, str]: + """ + Returns a dictionary summarising the column information of the dataframe, + conditional on the table existing in the database. + + Returns: + Dictionary containing the column names and types. + + """ + assert self._qualified_table_name is not None, f"Dataframe is not saved in database." + columns_query = f""" + SELECT attname AS column_name, atttypid::regtype AS data_type + FROM pg_attribute + WHERE attrelid = '{self._qualified_table_name}'::regclass and attnum > 0; + """ + assert self._db is not None + columns_inf_result = list(self._db._execute(columns_query, has_results=True)) # type: ignore reportUnknownVariableType + assert columns_inf_result, f"Table {self._qualified_table_name} does not exists." + columns_list: dict[str, str] = { + d["column_name"]: d["data_type"] for d in columns_inf_result # type: ignore reportUnknownVariableType + } # type: ignore reportUnknownVariableType + return columns_list diff --git a/tests/test_dataframe.py b/tests/test_dataframe.py index c7d8e8e3..d427f638 100644 --- a/tests/test_dataframe.py +++ b/tests/test_dataframe.py @@ -506,3 +506,20 @@ def test_const_non_ascii(db: gp.Database): df = db.create_dataframe(columns={"Ø": ["Ø"]}) for row in df[["Ø"]]: assert row["Ø"] == "Ø" + + +def test_table_describe(db: gp.Database): + df = db.create_dataframe(table_name="pg_class") + result = df.describe() + assert len(result) == 33 + df_not_exist = db.create_dataframe(table_name="not_exist_table") + with pytest.raises(Exception) as exc_info: + df_not_exist.describe() + assert 'relation "not_exist_table" does not exist' in str(exc_info.value) + + +def test_dataframe_describe(db: gp.Database): + df = db.create_dataframe(table_name="pg_class")[["relname", "relnamespace"]] + with pytest.raises(Exception) as exc_info: + df.describe() + assert "Dataframe is not saved in database" in str(exc_info.value)