From 418a90c53c179c0669226c3fa6f233a250cc5eef Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Mon, 16 Oct 2023 08:22:53 +0200 Subject: [PATCH 01/11] [code-fix] Make positional args to kwargs in suggest_int --- optuna/multi_objective/trial.py | 7 ++++++- optuna/trial/_base.py | 7 ++++++- optuna/trial/_fixed.py | 7 ++++++- optuna/trial/_frozen.py | 7 ++++++- optuna/trial/_trial.py | 7 ++++++- 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/optuna/multi_objective/trial.py b/optuna/multi_objective/trial.py index 3fc37cead61..928eab56056 100644 --- a/optuna/multi_objective/trial.py +++ b/optuna/multi_objective/trial.py @@ -8,12 +8,14 @@ from typing import Union from optuna import multi_objective +from optuna._convert_positional_args import convert_positional_args from optuna._deprecated import deprecated_class from optuna.distributions import BaseDistribution from optuna.study._study_direction import StudyDirection from optuna.trial import FrozenTrial from optuna.trial import Trial from optuna.trial import TrialState +from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS CategoricalChoiceType = Union[None, bool, int, float, str] @@ -87,7 +89,10 @@ def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) return self._trial.suggest_discrete_uniform(name, low, high, q) - def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: + @convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS) + def suggest_int( + self, name: str, low: int, high: int, *, step: int = 1, log: bool = False + ) -> int: """Suggest a value for the integer parameter. Please refer to the documentation of :func:`optuna.trial.Trial.suggest_int` diff --git a/optuna/trial/_base.py b/optuna/trial/_base.py index 84fb22e6ccb..96e2dec8e6a 100644 --- a/optuna/trial/_base.py +++ b/optuna/trial/_base.py @@ -11,6 +11,9 @@ from optuna.distributions import CategoricalChoiceType +_SUGGEST_INT_POSITIONAL_ARGS = ["self", "name", "low", "high", "step", "log"] + + class BaseTrial(abc.ABC): """Base class for trials. @@ -45,7 +48,9 @@ def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) raise NotImplementedError @abc.abstractmethod - def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: + def suggest_int( + self, name: str, low: int, high: int, *, step: int = 1, log: bool = False + ) -> int: raise NotImplementedError @overload diff --git a/optuna/trial/_fixed.py b/optuna/trial/_fixed.py index aae0539934b..5db23b0d5ea 100644 --- a/optuna/trial/_fixed.py +++ b/optuna/trial/_fixed.py @@ -7,12 +7,14 @@ import warnings from optuna import distributions +from optuna._convert_positional_args import convert_positional_args from optuna._deprecated import deprecated_func from optuna.distributions import BaseDistribution from optuna.distributions import CategoricalChoiceType from optuna.distributions import CategoricalDistribution from optuna.distributions import FloatDistribution from optuna.distributions import IntDistribution +from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS from optuna.trial._base import BaseTrial @@ -89,7 +91,10 @@ def suggest_loguniform(self, name: str, low: float, high: float) -> float: def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: return self.suggest_float(name, low, high, step=q) - def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: + @convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS) + def suggest_int( + self, name: str, low: int, high: int, *, step: int = 1, log: bool = False + ) -> int: return int(self._suggest(name, IntDistribution(low, high, log=log, step=step))) @overload diff --git a/optuna/trial/_frozen.py b/optuna/trial/_frozen.py index 5b946e73538..309b91accfe 100644 --- a/optuna/trial/_frozen.py +++ b/optuna/trial/_frozen.py @@ -11,6 +11,7 @@ from optuna import distributions from optuna import logging +from optuna._convert_positional_args import convert_positional_args from optuna._deprecated import deprecated_func from optuna._typing import JSONSerializable from optuna.distributions import _convert_old_distribution_to_new_distribution @@ -19,6 +20,7 @@ from optuna.distributions import CategoricalDistribution from optuna.distributions import FloatDistribution from optuna.distributions import IntDistribution +from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS from optuna.trial._base import BaseTrial from optuna.trial._state import TrialState @@ -225,7 +227,10 @@ def suggest_loguniform(self, name: str, low: float, high: float) -> float: def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: return self.suggest_float(name, low, high, step=q) - def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: + @convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS) + def suggest_int( + self, name: str, low: int, high: int, *, step: int = 1, log: bool = False + ) -> int: return int(self._suggest(name, IntDistribution(low, high, log=log, step=step))) @overload diff --git a/optuna/trial/_trial.py b/optuna/trial/_trial.py index 8d2bdacb9e6..fad8e8f267a 100644 --- a/optuna/trial/_trial.py +++ b/optuna/trial/_trial.py @@ -14,6 +14,7 @@ from optuna import distributions from optuna import logging from optuna import pruners +from optuna._convert_positional_args import convert_positional_args from optuna._deprecated import deprecated_func from optuna.distributions import BaseDistribution from optuna.distributions import CategoricalChoiceType @@ -21,6 +22,7 @@ from optuna.distributions import FloatDistribution from optuna.distributions import IntDistribution from optuna.trial import FrozenTrial +from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS from optuna.trial._base import BaseTrial @@ -235,7 +237,10 @@ def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) return self.suggest_float(name, low, high, step=q) - def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: + @convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS) + def suggest_int( + self, name: str, low: int, high: int, *, step: int = 1, log: bool = False + ) -> int: """Suggest a value for the integer parameter. The value is sampled from the integers in :math:`[\\mathsf{low}, \\mathsf{high}]`. From dbe38d6d22b1374db81ff86a4aee14c6f0641772 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Mon, 16 Oct 2023 08:37:39 +0200 Subject: [PATCH 02/11] [fix] Resolve mypy errors --- tests/trial_tests/test_trial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trial_tests/test_trial.py b/tests/trial_tests/test_trial.py index 9a9dfa1f7fe..6069afbd6e5 100644 --- a/tests/trial_tests/test_trial.py +++ b/tests/trial_tests/test_trial.py @@ -128,11 +128,11 @@ def test_check_distribution_suggest_discrete_uniform(storage_mode: str) -> None: assert len([r for r in record if r.category != FutureWarning]) == 1 with pytest.raises(ValueError): - trial.suggest_int("x", 10, 20, 2) + trial.suggest_int("x", 10, 20, step=2) trial = Trial(study, study._storage.create_new_trial(study._study_id)) with pytest.raises(ValueError): - trial.suggest_int("x", 10, 20, 2) + trial.suggest_int("x", 10, 20, step=2) @pytest.mark.parametrize("storage_mode", STORAGE_MODES) From 0f15a0f3f626b14b6242d1fe960f256756452413 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Mon, 16 Oct 2023 08:46:04 +0200 Subject: [PATCH 03/11] [test] Add a test for the old-style suggest_int --- tests/trial_tests/test_trial.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/trial_tests/test_trial.py b/tests/trial_tests/test_trial.py index 6069afbd6e5..f065abbb71a 100644 --- a/tests/trial_tests/test_trial.py +++ b/tests/trial_tests/test_trial.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import math from typing import Any @@ -704,3 +706,14 @@ def test_lazy_trial_system_attrs(storage_mode: str) -> None: system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage) assert set(system_attrs.items()) == {("int", 0), ("str", "A")} assert set(system_attrs.items()) == {("int", 0), ("str", "A")} + + +@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]]) +def test_suggest_int_positional_args(positional_args_names: list[str]): + # If log is specified as positional, step must also be provided as positional. + study = optuna.create_study() + trial = study.ask() + kwargs = dict(step=1, log=False) + args = [kwargs[name] for name in positional_args_names] + # No error should not be raised even if the coding style is old. + trial.suggest_int("x", -1, 1, *args) From 29b63dab417a81cffa13f69660ccdd0536576de2 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Mon, 16 Oct 2023 08:55:54 +0200 Subject: [PATCH 04/11] [test] Add tests for the old-style suggest_int of fixed and frozen trials --- tests/trial_tests/test_fixed.py | 15 +++++++++++++++ tests/trial_tests/test_frozen.py | 26 ++++++++++++++++++++++++++ tests/trial_tests/test_trial.py | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/trial_tests/test_fixed.py b/tests/trial_tests/test_fixed.py index 6d0a46cca30..0c0c0c02243 100644 --- a/tests/trial_tests/test_fixed.py +++ b/tests/trial_tests/test_fixed.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import pytest + from optuna.trial import FixedTrial @@ -10,6 +14,17 @@ def test_params() -> None: assert trial.params == params +@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]]) +def test_suggest_int_positional_args(positional_args_names: list[str]) -> None: + # If log is specified as positional, step must also be provided as positional. + params = {"x": 1} + trial = FixedTrial(params) + kwargs = dict(step=1, log=False) + args = [kwargs[name] for name in positional_args_names] + # No error should not be raised even if the coding style is old. + trial.suggest_int("x", -1, 1, *args) + + def test_number() -> None: params = {"x": 1} trial = FixedTrial(params, 2) diff --git a/tests/trial_tests/test_frozen.py b/tests/trial_tests/test_frozen.py index 925c151e87c..7d479c10768 100644 --- a/tests/trial_tests/test_frozen.py +++ b/tests/trial_tests/test_frozen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import datetime from typing import Any @@ -11,6 +13,7 @@ from optuna import create_study from optuna.distributions import BaseDistribution from optuna.distributions import FloatDistribution +from optuna.distributions import IntDistribution from optuna.testing.storages import STORAGE_MODES from optuna.testing.storages import StorageSupplier import optuna.trial @@ -385,3 +388,26 @@ def test_create_trial_distribution_conversion_noop() -> None: # Check fixed_distributions doesn't change. assert trial.distributions == fixed_distributions + + +@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]]) +def test_suggest_int_positional_args(positional_args_names: list[str]) -> None: + # If log is specified as positional, step must also be provided as positional. + trial = FrozenTrial( + number=0, + trial_id=0, + state=TrialState.COMPLETE, + value=0.0, + values=None, + datetime_start=datetime.datetime.now(), + datetime_complete=datetime.datetime.now(), + params={"x": 1}, + distributions={"x": IntDistribution(-1, 1)}, + user_attrs={}, + system_attrs={}, + intermediate_values={}, + ) + kwargs = dict(step=1, log=False) + args = [kwargs[name] for name in positional_args_names] + # No error should not be raised even if the coding style is old. + trial.suggest_int("x", -1, 1, *args) diff --git a/tests/trial_tests/test_trial.py b/tests/trial_tests/test_trial.py index f065abbb71a..88de52e1a2b 100644 --- a/tests/trial_tests/test_trial.py +++ b/tests/trial_tests/test_trial.py @@ -709,7 +709,7 @@ def test_lazy_trial_system_attrs(storage_mode: str) -> None: @pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]]) -def test_suggest_int_positional_args(positional_args_names: list[str]): +def test_suggest_int_positional_args(positional_args_names: list[str]) -> None: # If log is specified as positional, step must also be provided as positional. study = optuna.create_study() trial = study.ask() From 18c4b3a331d0f01a224a91cdbaa1df788f4a11d2 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Mon, 16 Oct 2023 09:18:54 +0200 Subject: [PATCH 05/11] [test] Add a test for the old-style suggest_int of MO trial --- tests/multi_objective_tests/test_trial.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/multi_objective_tests/test_trial.py b/tests/multi_objective_tests/test_trial.py index f5ed15bddb0..17d2f90135e 100644 --- a/tests/multi_objective_tests/test_trial.py +++ b/tests/multi_objective_tests/test_trial.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from typing import List from typing import Tuple @@ -205,3 +207,13 @@ def create_trial( else: # If `t1` isn't COMPLETE, it doesn't dominate others. assert not t1._dominates(t0, directions) + + +@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]]) +def test_suggest_int_positional_args(positional_args_names: list[str]) -> None: + # If log is specified as positional, step must also be provided as positional. + study = optuna.multi_objective.create_study(["maximize"]) + kwargs = dict(step=1, log=False) + args = [kwargs[name] for name in positional_args_names] + # No error should not be raised even if the coding style is old. + study.optimize(lambda trial: [trial.suggest_int("x", -1, 1, *args)], n_trials=1) From 92a4c6a1a03b6c7857f74a1362b06c599bee3560 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 25 Oct 2023 08:34:38 +0200 Subject: [PATCH 06/11] Add the future warning for inappropriate positional args usages --- optuna/_convert_positional_args.py | 10 ++++++++++ tests/test_convert_positional_args.py | 10 +++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/optuna/_convert_positional_args.py b/optuna/_convert_positional_args.py index caf9d80bcfd..535d49bc1b3 100644 --- a/optuna/_convert_positional_args.py +++ b/optuna/_convert_positional_args.py @@ -50,6 +50,7 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": f" arguments but {len(args)} were given." ) + sig = signature(func).parameters for val, arg_name in zip(args, previous_positional_arg_names): # When specifying a positional argument that is not located at the end of args as # a keyword argument, raise TypeError as follows by imitating the Python standard @@ -58,6 +59,15 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": raise TypeError( f"{func.__name__}() got multiple values for argument '{arg_name}'." ) + + if sig[arg_name].kind == sig[arg_name].KEYWORD_ONLY: + warnings.warn( + f"{func.__name__}() takes '{arg_name}' as a keyword argument" + " but it was given as a positional argument.", + FutureWarning, + stacklevel=warning_stacklevel, + ) + kwargs[arg_name] = val return func(**kwargs) diff --git a/tests/test_convert_positional_args.py b/tests/test_convert_positional_args.py index 3a104f10d63..7c13b7f6bc5 100644 --- a/tests/test_convert_positional_args.py +++ b/tests/test_convert_positional_args.py @@ -33,11 +33,19 @@ def test_convert_positional_args_future_warning() -> None: decorated_func(1, b=2, c=3) # type: ignore decorated_func(a=1, b=2, c=3) # No warning. - assert len(record) == 2 + assert len(record) == 5 + count_give_all = 0 + count_give_kwargs = 0 for warn in record.list: + msg = warn.message.args[0] + count_give_all += ("give all" in msg) + count_give_kwargs += ("as a keyword argument" in msg) assert isinstance(warn.message, FutureWarning) assert _sample_func.__name__ in str(warn.message) + assert count_give_all == 2 + assert count_give_kwargs == 3 + def test_convert_positional_args_mypy_type_inference() -> None: previous_positional_arg_names: List[str] = [] From f02c7fe4bb32d19ce3d93f0b139870f921c57199 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 25 Oct 2023 11:07:09 +0200 Subject: [PATCH 07/11] Apply black formatting --- tests/test_convert_positional_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_convert_positional_args.py b/tests/test_convert_positional_args.py index 7c13b7f6bc5..b32ef6e26fa 100644 --- a/tests/test_convert_positional_args.py +++ b/tests/test_convert_positional_args.py @@ -38,8 +38,8 @@ def test_convert_positional_args_future_warning() -> None: count_give_kwargs = 0 for warn in record.list: msg = warn.message.args[0] - count_give_all += ("give all" in msg) - count_give_kwargs += ("as a keyword argument" in msg) + count_give_all += "give all" in msg + count_give_kwargs += "as a keyword argument" in msg assert isinstance(warn.message, FutureWarning) assert _sample_func.__name__ in str(warn.message) From 070cc6146e524da389cc86497dc27d5c4072f490 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Fri, 27 Oct 2023 08:13:33 +0200 Subject: [PATCH 08/11] Remove unnecessary warnings based on the comment by not522 --- optuna/_convert_positional_args.py | 55 +++++++++++++++++---------- tests/test_convert_positional_args.py | 12 +----- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/optuna/_convert_positional_args.py b/optuna/_convert_positional_args.py index 535d49bc1b3..d33015b7885 100644 --- a/optuna/_convert_positional_args.py +++ b/optuna/_convert_positional_args.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from functools import wraps from inspect import signature +from inspect import Parameter from typing import Any from typing import TYPE_CHECKING from typing import TypeVar @@ -17,6 +18,23 @@ _T = TypeVar("_T") +def _get_positional_arg_names(func: "Callable[_P, _T]") -> list[str]: + params = signature(func).parameters + positional_arg_names = [ + name for name, p in params.items() + if p.default == Parameter.empty and p.kind == p.POSITIONAL_OR_KEYWORD + ] + return positional_arg_names + + +def _infer_given_args(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]: + inferred_args = { + arg_name: val + for val, arg_name in zip(args, previous_positional_arg_names) + } + return inferred_args + + def convert_positional_args( *, previous_positional_arg_names: Sequence[str], @@ -37,9 +55,13 @@ def converter_decorator(func: "Callable[_P, _T]") -> "Callable[_P, _T]": @wraps(func) def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": - if len(args) >= 1: + positional_arg_names = _get_positional_arg_names(func) + inferred_args = _infer_given_args(previous_positional_arg_names, *args) + if len(inferred_args) > len(positional_arg_names): + kwargs_expected = set(inferred_args) - set(positional_arg_names) warnings.warn( - f"{func.__name__}(): Please give all values as keyword arguments." + f"{func.__name__}() got {kwargs_expected} as positional arguments " + "but they were expected to be given as keyword arguments." " See https://github.com/optuna/optuna/issues/3324 for details.", FutureWarning, stacklevel=warning_stacklevel, @@ -50,25 +72,16 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": f" arguments but {len(args)} were given." ) - sig = signature(func).parameters - for val, arg_name in zip(args, previous_positional_arg_names): - # When specifying a positional argument that is not located at the end of args as - # a keyword argument, raise TypeError as follows by imitating the Python standard - # behavior. - if arg_name in kwargs: - raise TypeError( - f"{func.__name__}() got multiple values for argument '{arg_name}'." - ) - - if sig[arg_name].kind == sig[arg_name].KEYWORD_ONLY: - warnings.warn( - f"{func.__name__}() takes '{arg_name}' as a keyword argument" - " but it was given as a positional argument.", - FutureWarning, - stacklevel=warning_stacklevel, - ) - - kwargs[arg_name] = val + duplicated_arg_names = set(kwargs).intersection(inferred_args) + if len(duplicated_arg_names): + # When specifying positional arguments that are not located at the end of args as + # keyword arguments, raise TypeError as follows by imitating the Python standard + # behavior + raise TypeError( + f"{func.__name__}() got multiple values for arguments {duplicated_arg_names}." + ) + + kwargs.update(inferred_args) return func(**kwargs) diff --git a/tests/test_convert_positional_args.py b/tests/test_convert_positional_args.py index b32ef6e26fa..ab29a7094bf 100644 --- a/tests/test_convert_positional_args.py +++ b/tests/test_convert_positional_args.py @@ -33,19 +33,11 @@ def test_convert_positional_args_future_warning() -> None: decorated_func(1, b=2, c=3) # type: ignore decorated_func(a=1, b=2, c=3) # No warning. - assert len(record) == 5 - count_give_all = 0 - count_give_kwargs = 0 + assert len(record) == 2 for warn in record.list: - msg = warn.message.args[0] - count_give_all += "give all" in msg - count_give_kwargs += "as a keyword argument" in msg assert isinstance(warn.message, FutureWarning) assert _sample_func.__name__ in str(warn.message) - assert count_give_all == 2 - assert count_give_kwargs == 3 - def test_convert_positional_args_mypy_type_inference() -> None: previous_positional_arg_names: List[str] = [] @@ -113,4 +105,4 @@ def test_convert_positional_args_invalid_positional_args() -> None: with pytest.raises(TypeError) as record: decorated_func(1, 3, b=2) # type: ignore - assert str(record.value) == "_sample_func() got multiple values for argument 'b'." + assert str(record.value) == "_sample_func() got multiple values for arguments {'b'}." From 7009e0ffe6997eade395f36fe4c1564eb52bbca2 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Fri, 27 Oct 2023 08:39:25 +0200 Subject: [PATCH 09/11] Add a test using a simple method --- optuna/_convert_positional_args.py | 10 ++++------ tests/test_convert_positional_args.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/optuna/_convert_positional_args.py b/optuna/_convert_positional_args.py index d33015b7885..03cec74c5d7 100644 --- a/optuna/_convert_positional_args.py +++ b/optuna/_convert_positional_args.py @@ -3,8 +3,8 @@ from collections.abc import Callable from collections.abc import Sequence from functools import wraps -from inspect import signature from inspect import Parameter +from inspect import signature from typing import Any from typing import TYPE_CHECKING from typing import TypeVar @@ -21,17 +21,15 @@ def _get_positional_arg_names(func: "Callable[_P, _T]") -> list[str]: params = signature(func).parameters positional_arg_names = [ - name for name, p in params.items() + name + for name, p in params.items() if p.default == Parameter.empty and p.kind == p.POSITIONAL_OR_KEYWORD ] return positional_arg_names def _infer_given_args(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]: - inferred_args = { - arg_name: val - for val, arg_name in zip(args, previous_positional_arg_names) - } + inferred_args = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)} return inferred_args diff --git a/tests/test_convert_positional_args.py b/tests/test_convert_positional_args.py index ab29a7094bf..442cce22b0c 100644 --- a/tests/test_convert_positional_args.py +++ b/tests/test_convert_positional_args.py @@ -10,6 +10,12 @@ def _sample_func(*, a: int, b: int, c: int) -> int: return a + b + c +class _SimpleClass: + @convert_positional_args(previous_positional_arg_names=["self", "a", "b"]) + def simple_method(self, a: int, *, b: int, c: int = 1) -> None: + pass + + def test_convert_positional_args_decorator() -> None: previous_positional_arg_names: List[str] = [] decorator_converter = convert_positional_args( @@ -20,6 +26,19 @@ def test_convert_positional_args_decorator() -> None: assert decorated_func.__name__ == _sample_func.__name__ +def test_convert_positional_args_future_warning_for_methods() -> None: + simple_class = _SimpleClass() + with pytest.warns(FutureWarning) as record: + simple_class.simple_method(1, 2, c=3) # type: ignore + simple_class.simple_method(1, b=2, c=3) # No warning. + simple_class.simple_method(a=1, b=2, c=3) # No warning. + + assert len(record) == 1 + for warn in record.list: + assert isinstance(warn.message, FutureWarning) + assert "simple_method" in str(warn.message) + + def test_convert_positional_args_future_warning() -> None: previous_positional_arg_names: List[str] = ["a", "b"] decorator_converter = convert_positional_args( From e763c2c7f3bff3bb3c4a47cac736bfbf001a535a Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Tue, 31 Oct 2023 06:10:25 +0100 Subject: [PATCH 10/11] Remove ref to 3324 as it is not related anymore --- optuna/_convert_positional_args.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optuna/_convert_positional_args.py b/optuna/_convert_positional_args.py index 03cec74c5d7..c38e3aef9d6 100644 --- a/optuna/_convert_positional_args.py +++ b/optuna/_convert_positional_args.py @@ -59,8 +59,7 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": kwargs_expected = set(inferred_args) - set(positional_arg_names) warnings.warn( f"{func.__name__}() got {kwargs_expected} as positional arguments " - "but they were expected to be given as keyword arguments." - " See https://github.com/optuna/optuna/issues/3324 for details.", + "but they were expected to be given as keyword arguments.", FutureWarning, stacklevel=warning_stacklevel, ) From 6a44734458831d757c4b1aa1d8718044a5911fa6 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Tue, 31 Oct 2023 06:15:36 +0100 Subject: [PATCH 11/11] Change variable names for readability --- optuna/_convert_positional_args.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/optuna/_convert_positional_args.py b/optuna/_convert_positional_args.py index c38e3aef9d6..d496a288526 100644 --- a/optuna/_convert_positional_args.py +++ b/optuna/_convert_positional_args.py @@ -28,9 +28,9 @@ def _get_positional_arg_names(func: "Callable[_P, _T]") -> list[str]: return positional_arg_names -def _infer_given_args(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]: - inferred_args = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)} - return inferred_args +def _infer_kwargs(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]: + inferred_kwargs = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)} + return inferred_kwargs def convert_positional_args( @@ -54,11 +54,11 @@ def converter_decorator(func: "Callable[_P, _T]") -> "Callable[_P, _T]": @wraps(func) def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": positional_arg_names = _get_positional_arg_names(func) - inferred_args = _infer_given_args(previous_positional_arg_names, *args) - if len(inferred_args) > len(positional_arg_names): - kwargs_expected = set(inferred_args) - set(positional_arg_names) + inferred_kwargs = _infer_kwargs(previous_positional_arg_names, *args) + if len(inferred_kwargs) > len(positional_arg_names): + expected_kwds = set(inferred_kwargs) - set(positional_arg_names) warnings.warn( - f"{func.__name__}() got {kwargs_expected} as positional arguments " + f"{func.__name__}() got {expected_kwds} as positional arguments " "but they were expected to be given as keyword arguments.", FutureWarning, stacklevel=warning_stacklevel, @@ -69,16 +69,16 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": f" arguments but {len(args)} were given." ) - duplicated_arg_names = set(kwargs).intersection(inferred_args) - if len(duplicated_arg_names): + duplicated_kwds = set(kwargs).intersection(inferred_kwargs) + if len(duplicated_kwds): # When specifying positional arguments that are not located at the end of args as # keyword arguments, raise TypeError as follows by imitating the Python standard # behavior raise TypeError( - f"{func.__name__}() got multiple values for arguments {duplicated_arg_names}." + f"{func.__name__}() got multiple values for arguments {duplicated_kwds}." ) - kwargs.update(inferred_args) + kwargs.update(inferred_kwargs) return func(**kwargs)