Skip to content

Commit

Permalink
Issue/ml model (#101)
Browse files Browse the repository at this point in the history
* override site model name if site.ml_model is not assigned

* update

* fix
  • Loading branch information
peterdudfield authored Nov 26, 2024
1 parent b79c344 commit c1f831d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"cryptography >= 42.0.7",
"fastapi >= 0.105.0",
"pvsite-datamodel >= 1.0.41",
"pvsite-datamodel >= 1.0.45",
"pyjwt >= 2.8.0",
"pyproj >= 3.3.0",
"pytz >= 2023.3",
Expand Down
11 changes: 11 additions & 0 deletions src/india_api/internal/inputs/indiadb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_pv_generation_by_sites,
get_user_by_email,
get_sites_from_user,
get_site_by_uuid,
)
from pvsite_datamodel.write.generation import insert_generation_values
from pvsite_datamodel.sqlmodels import SiteAssetType, ForecastValueSQL
Expand Down Expand Up @@ -103,6 +104,10 @@ def get_predicted_power_production_for_location(

site = sites[0]

if site.ml_model is not None:
ml_model_name = site.ml_model.name
log.info(f"Using ml model {ml_model_name}")

# read actual generations
values = get_latest_forecast_values_by_site(
session,
Expand Down Expand Up @@ -286,6 +291,12 @@ def get_site_forecast(self, site_uuid: str, email: str) -> list[internal.Predict
with self._get_session() as session:
check_user_has_access_to_site(session=session, email=email, site_uuid=site_uuid)

# get site and the get the ml model name
site = get_site_by_uuid(session=session, site_uuid=site_uuid)
if site.ml_model is not None:
ml_model_name = site.ml_model.name
log.info(f"Using ml model {ml_model_name}")

if isinstance(site_uuid, str):
site_uuid = UUID(site_uuid)

Expand Down
1 change: 1 addition & 0 deletions src/india_api/internal/inputs/indiadb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def make_fake_forecast_values(db_session, sites, model_name):
horizon_minutes=horizon,
)
forecast_value.ml_model = ml_model
site.ml_model = ml_model

forecast_values.append(forecast_value)

Expand Down

0 comments on commit c1f831d

Please sign in to comment.