diff --git a/matcalc/elasticity.py b/matcalc/elasticity.py index 1840635..cd07ff2 100644 --- a/matcalc/elasticity.py +++ b/matcalc/elasticity.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np from pymatgen.analysis.elasticity import DeformedStructureSet, ElasticTensor, Strain @@ -12,6 +12,8 @@ from .relaxation import RelaxCalc if TYPE_CHECKING: + from collections.abc import Sequence + from ase.calculators.calculator import Calculator from pymatgen.core import Structure @@ -22,8 +24,8 @@ class ElasticityCalc(PropCalc): def __init__( self, calculator: Calculator, - norm_strains: tuple[float, ...] | float = (-0.01, -0.005, 0.005, 0.01), - shear_strains: tuple[float, ...] | float = (-0.06, -0.03, 0.03, 0.06), + norm_strains: Sequence[float] | float = (-0.01, -0.005, 0.005, 0.01), + shear_strains: Sequence[float] | float = (-0.06, -0.03, 0.03, 0.06), fmax: float = 0.1, relax_structure: bool = True, use_equilibrium: bool = True, @@ -58,7 +60,7 @@ def __init__( else: self.use_equilibrium = True - def calc(self, structure: Structure) -> dict[str, float | ElasticTensor | Structure]: + def calc(self, structure: Structure) -> dict[str, Any]: """ Calculates elastic properties of Pymatgen structure with units determined by the calculator, (often the stress_weight). diff --git a/tests/test_elasticity.py b/tests/test_elasticity.py index 6a7d9dc..087d87b 100644 --- a/tests/test_elasticity.py +++ b/tests/test_elasticity.py @@ -1,15 +1,21 @@ """Tests for ElasticCalc class""" from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pytest from matcalc.elasticity import ElasticityCalc +if TYPE_CHECKING: + from matgl.ext.ase import M3GNetCalculator + from pymatgen.core import Structure + -def test_elastic_calc(Li2O, M3GNetCalc): +def test_elastic_calc(Li2O: Structure, M3GNetCalc: M3GNetCalculator) -> None: """Tests for ElasticCalc class""" - e_calc = ElasticityCalc( + elast_calc = ElasticityCalc( M3GNetCalc, fmax=0.1, norm_strains=list(np.linspace(-0.004, 0.004, num=4)), @@ -18,7 +24,7 @@ def test_elastic_calc(Li2O, M3GNetCalc): ) # Test Li2O with equilibrium structure - results = e_calc.calc(Li2O) + results = elast_calc.calc(Li2O) assert results["elastic_tensor"].shape == (3, 3, 3, 3) assert results["elastic_tensor"][0][1][1][0] == pytest.approx(0.5014895636122672, rel=1e-3) assert results["bulk_modulus_vrh"] == pytest.approx(0.6737897607182401, rel=1e-3) @@ -28,7 +34,7 @@ def test_elastic_calc(Li2O, M3GNetCalc): assert results["structure"].lattice.a == pytest.approx(3.2885851104196875, rel=1e-4) # Test Li2O without the equilibrium structure - e_calc = ElasticityCalc( + elast_calc = ElasticityCalc( M3GNetCalc, fmax=0.1, norm_strains=list(np.linspace(-0.004, 0.004, num=4)), @@ -36,11 +42,11 @@ def test_elastic_calc(Li2O, M3GNetCalc): use_equilibrium=False, ) - results = e_calc.calc(Li2O) + results = elast_calc.calc(Li2O) assert results["residuals_sum"] == pytest.approx(2.9257237571340992e-08, rel=1e-2) # Test Li2O with float - e_calc = ElasticityCalc( + elast_calc = ElasticityCalc( M3GNetCalc, fmax=0.1, norm_strains=0.004, @@ -48,12 +54,12 @@ def test_elastic_calc(Li2O, M3GNetCalc): use_equilibrium=True, ) - results = e_calc.calc(Li2O) + results = elast_calc.calc(Li2O) assert results["residuals_sum"] == 0.0 assert results["bulk_modulus_vrh"] == pytest.approx(0.6631894154825593, rel=1e-3) -def test_elastic_calc_invalid_states(Li2O, M3GNetCalc): +def test_elastic_calc_invalid_states(M3GNetCalc: M3GNetCalculator): with pytest.raises(ValueError, match="shear_strains must be nonempty"): ElasticityCalc(M3GNetCalc, shear_strains=[]) with pytest.raises(ValueError, match="norm_strains must be nonempty"): diff --git a/tests/test_eos.py b/tests/test_eos.py index 4d49f07..f872cc4 100644 --- a/tests/test_eos.py +++ b/tests/test_eos.py @@ -1,29 +1,39 @@ """Tests for PhononCalc class""" from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from matcalc.eos import EOSCalc +if TYPE_CHECKING: + from matgl.ext.ase import M3GNetCalculator + from pymatgen.core import Structure + -def test_eos_calc(Li2O, LiFePO4, M3GNetCalc): +def test_eos_calc( + Li2O: Structure, + LiFePO4: Structure, + M3GNetCalc: M3GNetCalculator, +) -> None: """Tests for EOSCalc class""" # Note that the fmax is probably too high. This is for testing purposes only. - pcalc = EOSCalc(M3GNetCalc, fmax=0.1) - results = pcalc.calc(Li2O) + eos_calc = EOSCalc(M3GNetCalc, fmax=0.1) + result = eos_calc.calc(Li2O) - assert {*results} == {"eos", "r2_score_bm", "bulk_modulus_bm"} - assert results["bulk_modulus_bm"] == pytest.approx(65.57980045603279, rel=1e-2) - assert {*results["eos"]} == {"volumes", "energies"} - assert results["eos"]["volumes"] == pytest.approx( + assert {*result} == {"eos", "r2_score_bm", "bulk_modulus_bm"} + assert result["bulk_modulus_bm"] == pytest.approx(65.57980045603279, rel=1e-2) + assert {*result["eos"]} == {"volumes", "energies"} + assert result["eos"]["volumes"] == pytest.approx( [18.38, 19.63, 20.94, 22.3, 23.73, 25.21, 26.75, 28.36, 30.02, 31.76, 33.55], rel=1e-3, ) - assert results["eos"]["energies"] == pytest.approx( + assert result["eos"]["energies"] == pytest.approx( [-13.52, -13.77, -13.94, -14.08, -14.15, -14.18, -14.16, -14.11, -14.03, -13.94, -13.83], rel=1e-3, ) - pcalc = EOSCalc(M3GNetCalc, relax_structure=False) - results = list(pcalc.calc_many([Li2O, LiFePO4])) + eos_calc = EOSCalc(M3GNetCalc, relax_structure=False) + results = list(eos_calc.calc_many([Li2O, LiFePO4])) assert len(results) == 2 assert results[1]["bulk_modulus_bm"] == pytest.approx(54.5953851822073, rel=1e-2) diff --git a/tests/test_neb.py b/tests/test_neb.py index 3dc153b..3a400a2 100644 --- a/tests/test_neb.py +++ b/tests/test_neb.py @@ -1,11 +1,19 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from matcalc.neb import NEBCalc +if TYPE_CHECKING: + from pathlib import Path + + from matgl.ext.ase import M3GNetCalculator + from pymatgen.core import Structure + -def test_neb_calc(LiFePO4, M3GNetCalc, tmp_path): +def test_neb_calc(LiFePO4: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path) -> None: """Tests for NEBCalc class""" image_start = LiFePO4.copy() image_start.remove_sites([2]) diff --git a/tests/test_phonon.py b/tests/test_phonon.py index 3457b11..b1dcae1 100644 --- a/tests/test_phonon.py +++ b/tests/test_phonon.py @@ -1,22 +1,30 @@ """Tests for PhononCalc class""" from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from matcalc.phonon import PhononCalc +if TYPE_CHECKING: + from matgl.ext.ase import M3GNetCalculator + from pymatgen.core import Structure + -def test_phonon_calc(Li2O, M3GNetCalc): +def test_phonon_calc(Li2O: Structure, M3GNetCalc: M3GNetCalculator) -> None: """Tests for PhononCalc class""" # Note that the fmax is probably too high. This is for testing purposes only. - pcalc = PhononCalc(M3GNetCalc, supercell_matrix=((2, 0, 0), (0, 2, 0), (0, 0, 2)), fmax=0.1, t_step=50, t_max=1000) - results = pcalc.calc(Li2O) + phonon_calc = PhononCalc( + M3GNetCalc, supercell_matrix=((2, 0, 0), (0, 2, 0), (0, 0, 2)), fmax=0.1, t_step=50, t_max=1000 + ) + result = phonon_calc.calc(Li2O) # Test values at 100 K - ind = results["thermal_properties"]["temperatures"].tolist().index(300) - assert results["thermal_properties"]["heat_capacity"][ind] == pytest.approx(58.42898370395005, rel=1e-2) - assert results["thermal_properties"]["entropy"][ind] == pytest.approx(49.3774618162247, rel=1e-2) - assert results["thermal_properties"]["free_energy"][ind] == pytest.approx(13.245478097108784, rel=1e-2) + ind = result["thermal_properties"]["temperatures"].tolist().index(300) + assert result["thermal_properties"]["heat_capacity"][ind] == pytest.approx(58.42898370395005, rel=1e-2) + assert result["thermal_properties"]["entropy"][ind] == pytest.approx(49.3774618162247, rel=1e-2) + assert result["thermal_properties"]["free_energy"][ind] == pytest.approx(13.245478097108784, rel=1e-2) - results = list(pcalc.calc_many([Li2O, Li2O])) + results = list(phonon_calc.calc_many([Li2O, Li2O])) assert len(results) == 2 diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index a0783d6..87459d6 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -1,22 +1,30 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from matcalc.relaxation import RelaxCalc +if TYPE_CHECKING: + from pathlib import Path + + from matgl.ext.ase import M3GNetCalculator + from pymatgen.core import Structure + -def test_relax_calc(Li2O, M3GNetCalc, tmp_path): - pcalc = RelaxCalc(M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE") - results = pcalc.calc(Li2O) - assert results["a"] == pytest.approx(3.291071792359756, rel=0.002) - assert results["b"] == pytest.approx(3.291071899625086, rel=0.002) - assert results["c"] == pytest.approx(3.291072056855788, rel=0.002) - assert results["alpha"] == pytest.approx(60, abs=1) - assert results["beta"] == pytest.approx(60, abs=1) - assert results["gamma"] == pytest.approx(60, abs=1) - assert results["volume"] == pytest.approx(results["a"] * results["b"] * results["c"] / 2**0.5, abs=0.1) +def test_relax_calc(Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path) -> None: + relax_calc = RelaxCalc(M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE") + result = relax_calc.calc(Li2O) + assert result["a"] == pytest.approx(3.291071792359756, rel=0.002) + assert result["b"] == pytest.approx(3.291071899625086, rel=0.002) + assert result["c"] == pytest.approx(3.291072056855788, rel=0.002) + assert result["alpha"] == pytest.approx(60, abs=1) + assert result["beta"] == pytest.approx(60, abs=1) + assert result["gamma"] == pytest.approx(60, abs=1) + assert result["volume"] == pytest.approx(result["a"] * result["b"] * result["c"] / 2**0.5, abs=0.1) - results = list(pcalc.calc_many([Li2O] * 2)) + results = list(relax_calc.calc_many([Li2O] * 2)) assert len(results) == 2 assert results[-1]["a"] == pytest.approx(3.291071792359756, rel=0.002) diff --git a/tests/test_util.py b/tests/test_util.py index 97fb324..8fc8791 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,7 +6,7 @@ from matcalc.util import UNIVERSAL_CALCULATORS, get_universal_calculator -def test_get_universal_calculator(): +def test_get_universal_calculator() -> None: for name in UNIVERSAL_CALCULATORS: calc = get_universal_calculator(name) assert isinstance(calc, Calculator)