From e07a7cea97bbad8ce8d541a2c1b76dd55ab867c3 Mon Sep 17 00:00:00 2001 From: Guilherme Panza Date: Mon, 28 Oct 2024 10:13:57 -0300 Subject: [PATCH] Refactor: use postponed evaluation of type annotations --- .../fanova_tests/test_tree.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/importance_tests/fanova_tests/test_tree.py b/tests/importance_tests/fanova_tests/test_tree.py index af5871ee0b7..83312590d28 100644 --- a/tests/importance_tests/fanova_tests/test_tree.py +++ b/tests/importance_tests/fanova_tests/test_tree.py @@ -1,7 +1,6 @@ +from __future__ import annotations + import math -from typing import Dict -from typing import List -from typing import Tuple from unittest.mock import Mock import numpy as np @@ -28,7 +27,7 @@ def tree() -> _FanovaTree: @pytest.fixture -def expected_tree_statistics() -> List[Dict[str, List]]: +def expected_tree_statistics() -> list[dict[str, list]]: # Statistics the each node in the tree. return [ {"values": [0.1, 0.2, 0.5], "weights": [0.75, 0.25, 1.0]}, @@ -39,7 +38,7 @@ def expected_tree_statistics() -> List[Dict[str, List]]: ] -def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[str, List]]) -> None: +def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: list[dict[str, list]]) -> None: # The root node at node index `0` holds the values and weights for all nodes in the tree. expected_statistics = expected_tree_statistics[0] expected_values = expected_statistics["values"] @@ -87,9 +86,9 @@ def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[st ) def test_tree_get_marginal_variance( tree: _FanovaTree, - features: List[int], - expected: List[Tuple[List[Size], List[Tuple[NodeIndex, Cardinality]]]], - expected_tree_statistics: List[Dict[str, List]], + features: list[int], + expected: list[tuple[list[Size], list[tuple[NodeIndex, Cardinality]]]], + expected_tree_statistics: list[dict[str, list]], ) -> None: variance = tree.get_marginal_variance(np.array(features)) @@ -145,9 +144,9 @@ def test_tree_get_marginal_variance( ) def test_tree_get_marginalized_statistics( tree: _FanovaTree, - feature_vector: List[float], - expected: List[Tuple[NodeIndex, Cardinality]], - expected_tree_statistics: List[Dict[str, List]], + feature_vector: list[float], + expected: list[tuple[NodeIndex, Cardinality]], + expected_tree_statistics: list[dict[str, list]], ) -> None: value, weight = tree._get_marginalized_statistics(np.array(feature_vector)) @@ -167,7 +166,7 @@ def test_tree_get_marginalized_statistics( def test_tree_statistics( - tree: _FanovaTree, expected_tree_statistics: List[Dict[str, List]] + tree: _FanovaTree, expected_tree_statistics: list[dict[str, list]] ) -> None: statistics = tree._statistics @@ -184,13 +183,13 @@ def test_tree_statistics( @pytest.mark.parametrize("node_index,expected", [(0, [0.5]), (1, [0.25, 0.75]), (2, [0.75, 1.75])]) def test_tree_split_midpoints( - tree: _FanovaTree, node_index: NodeIndex, expected: List[float] + tree: _FanovaTree, node_index: NodeIndex, expected: list[float] ) -> None: np.testing.assert_equal(tree._split_midpoints[node_index], expected) @pytest.mark.parametrize("node_index,expected", [(0, [1.0]), (1, [0.5, 0.5]), (2, [1.5, 0.5])]) -def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: List[float]) -> None: +def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: list[float]) -> None: np.testing.assert_equal(tree._split_sizes[node_index], expected) @@ -205,7 +204,7 @@ def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: Li ], ) def test_tree_subtree_active_features( - tree: _FanovaTree, node_index: NodeIndex, expected: List[bool] + tree: _FanovaTree, node_index: NodeIndex, expected: list[bool] ) -> None: active_features: np.ndarray = tree._subtree_active_features[node_index] == expected assert active_features.all()