From 07f3dcc3905a25fc05acf656dd532fcad135d56c Mon Sep 17 00:00:00 2001 From: aisha-partha <153170327+aisha-partha@users.noreply.github.com> Date: Wed, 22 May 2024 00:03:15 +0530 Subject: [PATCH 1/3] Replaced import List, import Tuple with __future__.annotations in storages/_rdb/models.py --- optuna/storages/_rdb/models.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/optuna/storages/_rdb/models.py b/optuna/storages/_rdb/models.py index c9320b4f889..524f34b7931 100644 --- a/optuna/storages/_rdb/models.py +++ b/optuna/storages/_rdb/models.py @@ -1,9 +1,9 @@ +from __future__ import annotations + import enum import math from typing import Any -from typing import List from typing import Optional -from typing import Tuple from sqlalchemy import asc from sqlalchemy import case @@ -104,7 +104,7 @@ class StudyDirectionModel(BaseModel): ) @classmethod - def where_study_id(cls, study_id: int, session: orm.Session) -> List["StudyDirectionModel"]: + def where_study_id(cls, study_id: int, session: orm.Session) -> list["StudyDirectionModel"]: return session.query(cls).filter(cls.study_id == study_id).all() @@ -136,7 +136,7 @@ def find_by_study_and_key( @classmethod def where_study_id( cls, study_id: int, session: orm.Session - ) -> List["StudyUserAttributeModel"]: + ) -> list["StudyUserAttributeModel"]: return session.query(cls).filter(cls.study_id == study_id).all() @@ -168,7 +168,7 @@ def find_by_study_and_key( @classmethod def where_study_id( cls, study_id: int, session: orm.Session - ) -> List["StudySystemAttributeModel"]: + ) -> list["StudySystemAttributeModel"]: return session.query(cls).filter(cls.study_id == study_id).all() @@ -307,7 +307,7 @@ def find_by_trial_and_key( @classmethod def where_trial_id( cls, trial_id: int, session: orm.Session - ) -> List["TrialUserAttributeModel"]: + ) -> list["TrialUserAttributeModel"]: return session.query(cls).filter(cls.trial_id == trial_id).all() @@ -339,7 +339,7 @@ def find_by_trial_and_key( @classmethod def where_trial_id( cls, trial_id: int, session: orm.Session - ) -> List["TrialSystemAttributeModel"]: + ) -> list["TrialSystemAttributeModel"]: return session.query(cls).filter(cls.trial_id == trial_id).all() @@ -403,7 +403,7 @@ def find_or_raise_by_trial_and_param_name( return param_distribution @classmethod - def where_trial_id(cls, trial_id: int, session: orm.Session) -> List["TrialParamModel"]: + def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialParamModel"]: trial_params = session.query(cls).filter(cls.trial_id == trial_id).all() return trial_params @@ -431,7 +431,7 @@ class TrialValueType(enum.Enum): def value_to_stored_repr( cls, value: float, - ) -> Tuple[Optional[float], TrialValueType]: + ) -> tuple[Optional[float], TrialValueType]: if value == float("inf"): return (None, cls.TrialValueType.INF_POS) elif value == float("-inf"): @@ -466,7 +466,7 @@ def find_by_trial_and_objective( return trial_value @classmethod - def where_trial_id(cls, trial_id: int, session: orm.Session) -> List["TrialValueModel"]: + def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialValueModel"]: trial_values = ( session.query(cls).filter(cls.trial_id == trial_id).order_by(asc(cls.objective)).all() ) @@ -497,7 +497,7 @@ class TrialIntermediateValueType(enum.Enum): def intermediate_value_to_stored_repr( cls, value: float, - ) -> Tuple[Optional[float], TrialIntermediateValueType]: + ) -> tuple[Optional[float], TrialIntermediateValueType]: if math.isnan(value): return (None, cls.TrialIntermediateValueType.NAN) elif value == float("inf"): @@ -541,7 +541,7 @@ def find_by_trial_and_step( @classmethod def where_trial_id( cls, trial_id: int, session: orm.Session - ) -> List["TrialIntermediateValueModel"]: + ) -> list["TrialIntermediateValueModel"]: trial_intermediate_values = session.query(cls).filter(cls.trial_id == trial_id).all() return trial_intermediate_values From c1e136b80a04fbb145a41c34f1f7109b810bb36b Mon Sep 17 00:00:00 2001 From: aisha-partha <153170327+aisha-partha@users.noreply.github.com> Date: Wed, 22 May 2024 00:20:14 +0530 Subject: [PATCH 2/3] Replaced import Optional with __future__.annotations in storages/_rdb/models.py --- optuna/storages/_rdb/models.py | 35 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/optuna/storages/_rdb/models.py b/optuna/storages/_rdb/models.py index 524f34b7931..41b9f1c138f 100644 --- a/optuna/storages/_rdb/models.py +++ b/optuna/storages/_rdb/models.py @@ -3,7 +3,6 @@ import enum import math from typing import Any -from typing import Optional from sqlalchemy import asc from sqlalchemy import case @@ -77,7 +76,7 @@ def find_or_raise_by_id( return study @classmethod - def find_by_name(cls, study_name: str, session: orm.Session) -> Optional["StudyModel"]: + def find_by_name(cls, study_name: str, session: orm.Session) -> "StudyModel" | None: study = session.query(cls).filter(cls.study_name == study_name).one_or_none() return study @@ -123,7 +122,7 @@ class StudyUserAttributeModel(BaseModel): @classmethod def find_by_study_and_key( cls, study: StudyModel, key: str, session: orm.Session - ) -> Optional["StudyUserAttributeModel"]: + ) -> "StudyUserAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.study_id == study.study_id) @@ -155,7 +154,7 @@ class StudySystemAttributeModel(BaseModel): @classmethod def find_by_study_and_key( cls, study: StudyModel, key: str, session: orm.Session - ) -> Optional["StudySystemAttributeModel"]: + ) -> "StudySystemAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.study_id == study.study_id) @@ -261,8 +260,8 @@ def find_or_raise_by_id( def count( cls, session: orm.Session, - study: Optional[StudyModel] = None, - state: Optional[TrialState] = None, + study: StudyModel | None = None, + state: TrialState | None = None, ) -> int: trial_count = session.query(func.count(cls.trial_id)) if study is not None: @@ -294,7 +293,7 @@ class TrialUserAttributeModel(BaseModel): @classmethod def find_by_trial_and_key( cls, trial: TrialModel, key: str, session: orm.Session - ) -> Optional["TrialUserAttributeModel"]: + ) -> "TrialUserAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -326,7 +325,7 @@ class TrialSystemAttributeModel(BaseModel): @classmethod def find_by_trial_and_key( cls, trial: TrialModel, key: str, session: orm.Session - ) -> Optional["TrialSystemAttributeModel"]: + ) -> "TrialSystemAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -381,7 +380,7 @@ def _check_compatibility_with_previous_trial_param_distributions( @classmethod def find_by_trial_and_param_name( cls, trial: TrialModel, param_name: str, session: orm.Session - ) -> Optional["TrialParamModel"]: + ) -> "TrialParamModel" | None: param_distribution = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -431,7 +430,7 @@ class TrialValueType(enum.Enum): def value_to_stored_repr( cls, value: float, - ) -> tuple[Optional[float], TrialValueType]: + ) -> tuple[float | None, TrialValueType]: if value == float("inf"): return (None, cls.TrialValueType.INF_POS) elif value == float("-inf"): @@ -440,7 +439,7 @@ def value_to_stored_repr( return (value, cls.TrialValueType.FINITE) @classmethod - def stored_repr_to_value(cls, value: Optional[float], float_type: TrialValueType) -> float: + def stored_repr_to_value(cls, value: float | None, float_type: TrialValueType) -> float: if float_type == cls.TrialValueType.INF_POS: assert value is None return float("inf") @@ -455,7 +454,7 @@ def stored_repr_to_value(cls, value: Optional[float], float_type: TrialValueType @classmethod def find_by_trial_and_objective( cls, trial: TrialModel, objective: int, session: orm.Session - ) -> Optional["TrialValueModel"]: + ) -> "TrialValueModel" | None: trial_value = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -497,7 +496,7 @@ class TrialIntermediateValueType(enum.Enum): def intermediate_value_to_stored_repr( cls, value: float, - ) -> tuple[Optional[float], TrialIntermediateValueType]: + ) -> tuple[float | None, TrialIntermediateValueType]: if math.isnan(value): return (None, cls.TrialIntermediateValueType.NAN) elif value == float("inf"): @@ -509,7 +508,7 @@ def intermediate_value_to_stored_repr( @classmethod def stored_repr_to_intermediate_value( - cls, value: Optional[float], float_type: TrialIntermediateValueType + cls, value: float | None, float_type: TrialIntermediateValueType ) -> float: if float_type == cls.TrialIntermediateValueType.NAN: assert value is None @@ -528,7 +527,7 @@ def stored_repr_to_intermediate_value( @classmethod def find_by_trial_and_step( cls, trial: TrialModel, step: int, session: orm.Session - ) -> Optional["TrialIntermediateValueModel"]: + ) -> "TrialIntermediateValueModel" | None: trial_intermediate_value = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -559,9 +558,7 @@ class TrialHeartbeatModel(BaseModel): ) @classmethod - def where_trial_id( - cls, trial_id: int, session: orm.Session - ) -> Optional["TrialHeartbeatModel"]: + def where_trial_id(cls, trial_id: int, session: orm.Session) -> "TrialHeartbeatModel" | None: return session.query(cls).filter(cls.trial_id == trial_id).one_or_none() @@ -574,6 +571,6 @@ class VersionInfoModel(BaseModel): library_version = _Column(String(MAX_VERSION_LENGTH)) @classmethod - def find(cls, session: orm.Session) -> Optional["VersionInfoModel"]: + def find(cls, session: orm.Session) -> "VersionInfoModel" | None: version_info = session.query(cls).one_or_none() return version_info From 9b073d63bc8f367097d636d257ad94c12796fc3b Mon Sep 17 00:00:00 2001 From: aisha-partha <153170327+aisha-partha@users.noreply.github.com> Date: Wed, 22 May 2024 18:57:12 +0530 Subject: [PATCH 3/3] Source code formatting change in optuna/storages/_rdb/models.py --- optuna/storages/_rdb/models.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/optuna/storages/_rdb/models.py b/optuna/storages/_rdb/models.py index 41b9f1c138f..ef435e9ba7b 100644 --- a/optuna/storages/_rdb/models.py +++ b/optuna/storages/_rdb/models.py @@ -258,10 +258,7 @@ def find_or_raise_by_id( @classmethod def count( - cls, - session: orm.Session, - study: StudyModel | None = None, - state: TrialState | None = None, + cls, session: orm.Session, study: StudyModel | None = None, state: TrialState | None = None ) -> int: trial_count = session.query(func.count(cls.trial_id)) if study is not None: @@ -427,10 +424,7 @@ class TrialValueType(enum.Enum): ) @classmethod - def value_to_stored_repr( - cls, - value: float, - ) -> tuple[float | None, TrialValueType]: + def value_to_stored_repr(cls, value: float) -> tuple[float | None, TrialValueType]: if value == float("inf"): return (None, cls.TrialValueType.INF_POS) elif value == float("-inf"): @@ -494,8 +488,7 @@ class TrialIntermediateValueType(enum.Enum): @classmethod def intermediate_value_to_stored_repr( - cls, - value: float, + cls, value: float ) -> tuple[float | None, TrialIntermediateValueType]: if math.isnan(value): return (None, cls.TrialIntermediateValueType.NAN)