Skip to content

Commit

Permalink
Merge pull request optuna#5731 from guisp03/fix/future-annotations-te…
Browse files Browse the repository at this point in the history
…st_tree.py

Use `__future__.annotations` in `tests/importance_tests/fanova_tests/test_tree.py`
  • Loading branch information
not522 authored Oct 29, 2024
2 parents c014b9d + e07a7ce commit ef79c31
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions tests/importance_tests/fanova_tests/test_tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import math
from typing import Dict
from typing import List
from typing import Tuple
from unittest.mock import Mock

import numpy as np
Expand All @@ -28,7 +27,7 @@ def tree() -> _FanovaTree:


@pytest.fixture
def expected_tree_statistics() -> List[Dict[str, List]]:
def expected_tree_statistics() -> list[dict[str, list]]:
# Statistics the each node in the tree.
return [
{"values": [0.1, 0.2, 0.5], "weights": [0.75, 0.25, 1.0]},
Expand All @@ -39,7 +38,7 @@ def expected_tree_statistics() -> List[Dict[str, List]]:
]


def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[str, List]]) -> None:
def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: list[dict[str, list]]) -> None:
# The root node at node index `0` holds the values and weights for all nodes in the tree.
expected_statistics = expected_tree_statistics[0]
expected_values = expected_statistics["values"]
Expand Down Expand Up @@ -87,9 +86,9 @@ def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[st
)
def test_tree_get_marginal_variance(
tree: _FanovaTree,
features: List[int],
expected: List[Tuple[List[Size], List[Tuple[NodeIndex, Cardinality]]]],
expected_tree_statistics: List[Dict[str, List]],
features: list[int],
expected: list[tuple[list[Size], list[tuple[NodeIndex, Cardinality]]]],
expected_tree_statistics: list[dict[str, list]],
) -> None:
variance = tree.get_marginal_variance(np.array(features))

Expand Down Expand Up @@ -145,9 +144,9 @@ def test_tree_get_marginal_variance(
)
def test_tree_get_marginalized_statistics(
tree: _FanovaTree,
feature_vector: List[float],
expected: List[Tuple[NodeIndex, Cardinality]],
expected_tree_statistics: List[Dict[str, List]],
feature_vector: list[float],
expected: list[tuple[NodeIndex, Cardinality]],
expected_tree_statistics: list[dict[str, list]],
) -> None:
value, weight = tree._get_marginalized_statistics(np.array(feature_vector))

Expand All @@ -167,7 +166,7 @@ def test_tree_get_marginalized_statistics(


def test_tree_statistics(
tree: _FanovaTree, expected_tree_statistics: List[Dict[str, List]]
tree: _FanovaTree, expected_tree_statistics: list[dict[str, list]]
) -> None:
statistics = tree._statistics

Expand All @@ -184,13 +183,13 @@ def test_tree_statistics(

@pytest.mark.parametrize("node_index,expected", [(0, [0.5]), (1, [0.25, 0.75]), (2, [0.75, 1.75])])
def test_tree_split_midpoints(
tree: _FanovaTree, node_index: NodeIndex, expected: List[float]
tree: _FanovaTree, node_index: NodeIndex, expected: list[float]
) -> None:
np.testing.assert_equal(tree._split_midpoints[node_index], expected)


@pytest.mark.parametrize("node_index,expected", [(0, [1.0]), (1, [0.5, 0.5]), (2, [1.5, 0.5])])
def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: List[float]) -> None:
def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: list[float]) -> None:
np.testing.assert_equal(tree._split_sizes[node_index], expected)


Expand All @@ -205,7 +204,7 @@ def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: Li
],
)
def test_tree_subtree_active_features(
tree: _FanovaTree, node_index: NodeIndex, expected: List[bool]
tree: _FanovaTree, node_index: NodeIndex, expected: list[bool]
) -> None:
active_features: np.ndarray = tree._subtree_active_features[node_index] == expected
assert active_features.all()
Expand Down

0 comments on commit ef79c31

Please sign in to comment.