Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cnivera committed Oct 31, 2024
1 parent 8dac1b6 commit 4787e0c
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 62 deletions.
43 changes: 22 additions & 21 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
name: sf_env
channels:
- snowflake
- snowflake
dependencies:
- python=3.10.*
- pandas=2.2.2
- tqdm=4.66.5
- streamlit=1.35.0
- loguru=0.5.3
- protobuf=3.20.3
- pydantic=2.8.2
- pyyaml=6.0.1
- ruamel.yaml=0.17.21
- pyarrow=14.0.2
- sqlglot=25.10.0
- numpy=1.26.4
- python-dotenv=0.21.0
- urllib3=2.2.2
- requests=2.32.3
- types-pyyaml=6.0.12.12
- types-protobuf=4.25.0.20240417
- snowflake-snowpark-python=1.18.0
- streamlit-extras=0.4.0
- cattrs=23.1.2
- python=3.10.*
- pandas=2.2.2
- tqdm=4.66.5
- joblib=1.4.2
- streamlit=1.35.0
- loguru=0.5.3
- protobuf=3.20.3
- pydantic=2.8.2
- pyyaml=6.0.1
- ruamel.yaml=0.17.21
- pyarrow=14.0.2
- sqlglot=25.10.0
- numpy=1.26.4
- python-dotenv=0.21.0
- urllib3=2.2.2
- requests=2.32.3
- types-pyyaml=6.0.12.12
- types-protobuf=4.25.0.20240417
- snowflake-snowpark-python=1.18.0
- streamlit-extras=0.4.0
- cattrs=23.1.2
4 changes: 3 additions & 1 deletion journeys/iteration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from streamlit import config
from streamlit import config

# Set minCachedMessageSize to 500 MB to disable forward message cache:
# st.set_config would trigger an error, only the set_config from config module works
config.set_option("global.minCachedMessageSize", 500 * 1e6)
Expand Down Expand Up @@ -663,6 +664,7 @@ def show() -> None:
if "last_saved_yaml" not in st.session_state:
st.session_state["last_saved_yaml"] = yaml

st.write(st.session_state["logs"])
left, right = st.columns(2)
yaml_container = left.container(height=760)
chat_container = right.container(height=760)
Expand Down
52 changes: 32 additions & 20 deletions semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import concurrent.futures
import multiprocessing
import threading

import streamlit as st
import os
import time
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Tuple

from joblib import Parallel, delayed
from loguru import logger
from snowflake.connector import SnowflakeConnection
from snowflake.connector import SnowflakeConnection, connect

from semantic_model_generator.data_processing import data_types, proto_utils
from semantic_model_generator.protos import semantic_model_pb2
Expand Down Expand Up @@ -164,7 +168,12 @@ def _raw_table_to_semantic_context_table(

def process_table(
table: str, conn: SnowflakeConnection, n_sample_values: int
) -> semantic_model_pb2.Table:
) -> Tuple[semantic_model_pb2.Table, str]:
start_time = time.time()
start_time_formatted = datetime.fromtimestamp(start_time).strftime(
"%Y-%m-%d %H:%M:%S"
)

fqn_table = create_fqn_table(table)
valid_schemas_tables_columns_df = get_valid_schemas_tables_columns_df(
conn=conn,
Expand All @@ -178,18 +187,26 @@ def process_table(
valid_schemas_tables_columns_df["TABLE_NAME"] == fqn_table.table
]

raw_table = get_table_representation(
raw_table, logs = get_table_representation(
conn=conn,
schema_name=fqn_table.database + "." + fqn_table.schema_name,
table_name=fqn_table.table,
table_index=0,
ndv_per_column=n_sample_values,
columns_df=valid_columns_df_this_table,
)
return _raw_table_to_semantic_context_table(
database=fqn_table.database,
schema=fqn_table.schema_name,
raw_table=raw_table,

end_time = time.time()
end_time_formatted = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")

return (
_raw_table_to_semantic_context_table(
database=fqn_table.database,
schema=fqn_table.schema_name,
raw_table=raw_table,
),
f"Process ID: {threading.current_thread().name}, Finish processing table: {table}, StartTime: {start_time_formatted}, EndTime: {end_time_formatted}"
+ "\n".join(logs),
)


Expand All @@ -201,30 +218,25 @@ def raw_schema_to_semantic_context(
allow_joins: Optional[bool] = False,
) -> semantic_model_pb2.SemanticModel:
start_time = time.time()
table_objects = []
st.session_state["logs"] = []

# Create a Table object representation for each provided table name.
# This is done concurrently because `process_table` is I/O bound, executing potentially long-running
# queries to fetch column metadata and sample values.
with concurrent.futures.ThreadPoolExecutor() as executor:
table_futures = [
executor.submit(process_table, table, conn, n_sample_values)
for table in base_tables
]
concurrent.futures.wait(table_futures)
for future in table_futures:
table_object = future.result()
table_objects.append(table_object)
table_objects = Parallel(n_jobs=-1, backend="threading")(
delayed(process_table)(table, conn, n_sample_values) for table in base_tables
)

placeholder_relationships = _get_placeholder_joins() if allow_joins else None
context = semantic_model_pb2.SemanticModel(
name=semantic_model_name,
tables=table_objects,
tables=[obj[0] for obj in table_objects],
relationships=placeholder_relationships,
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Time taken to generate semantic model: {elapsed_time} seconds.")
st.session_state["logs"] = [obj[1] for obj in table_objects]
return context


Expand Down
53 changes: 33 additions & 20 deletions semantic_model_generator/snowflake_utils/snowflake_connector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import concurrent.futures
import multiprocessing
import threading
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, TypeVar
from typing import Any, Dict, Generator, List, Optional, TypeVar, Tuple

import pandas as pd
from joblib import Parallel, delayed
from loguru import logger
from snowflake.connector import DictCursor
from snowflake.connector.connection import SnowflakeConnection
Expand Down Expand Up @@ -146,37 +151,45 @@ def get_table_representation(
table_index: int,
ndv_per_column: int,
columns_df: pd.DataFrame,
) -> Table:
table_comment = _get_table_comment(conn, schema_name, table_name, columns_df)
) -> Tuple[Table, list[str]]:
table_comment = (
"test2 " # _get_table_comment(conn, schema_name, table_name, columns_df)
)

def _get_col(col_index: int, column_row: pd.Series) -> Column:
return _get_column_representation(
def _get_col(col_index: int, column_row: pd.Series) -> Tuple[Column, str]:
start_time = time.time()
start_time_formatted = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(start_time)
)
repy = _get_column_representation(
conn=conn,
schema_name=schema_name,
table_name=table_name,
column_row=column_row,
column_index=col_index,
ndv=ndv_per_column,
)
end_time = time.time()
end_time_formatted = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(end_time)
)
return (
repy,
f"Process ID: {threading.current_thread().name}, Finish processing column: {table_name}.{column_row[_COLUMN_NAME_COL]}, StartTime: {start_time_formatted}, EndTime: {end_time_formatted}",
)

with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_col_index = {
executor.submit(_get_col, col_index, column_row): col_index
for col_index, (_, column_row) in enumerate(columns_df.iterrows())
}
index_and_column = []
for future in concurrent.futures.as_completed(future_to_col_index):
col_index = future_to_col_index[future]
column = future.result()
index_and_column.append((col_index, column))
columns = [c for _, c in sorted(index_and_column, key=lambda x: x[0])]
# Run _get_table_comment and _get_col in parallel
columns = Parallel(n_jobs=-1, backend="threading")(
delayed(_get_col)(col_index, column_row)
for col_index, (_, column_row) in enumerate(columns_df.iterrows())
)

return Table(
id_=table_index,
name=table_name,
comment=table_comment,
columns=columns,
)
columns=[col[0] for col in columns],
), [col[1] for col in columns]


def _get_column_representation(
Expand All @@ -190,6 +203,8 @@ def _get_column_representation(
column_name = column_row[_COLUMN_NAME_COL]
column_datatype = column_row[_DATATYPE_COL]
column_values = None
column_comment = "test" # _get_column_comment(conn, column_row, column_values)

if ndv > 0:
# Pull sample values.
try:
Expand All @@ -215,8 +230,6 @@ def _get_column_representation(
except Exception as e:
logger.error(f"unable to get values: {e}")

column_comment = _get_column_comment(conn, column_row, column_values)

column = Column(
id_=column_index,
column_name=column_name,
Expand Down

0 comments on commit 4787e0c

Please sign in to comment.