diff --git a/.gitignore b/.gitignore index 74f07008..5c16cbe9 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.python-version # Spyder project settings .spyderproject diff --git a/src/py/mat3ra/made/basis/basis.py b/src/py/mat3ra/made/basis.py similarity index 96% rename from src/py/mat3ra/made/basis/basis.py rename to src/py/mat3ra/made/basis.py index b3508f27..288ef0eb 100644 --- a/src/py/mat3ra/made/basis/basis.py +++ b/src/py/mat3ra/made/basis.py @@ -5,8 +5,8 @@ from mat3ra.utils.mixins import RoundNumericValuesMixin from pydantic import BaseModel -from ..cell.cell import Cell -from ..utils import ArrayWithIds +from .cell import Cell +from .utils import ArrayWithIds class Basis(RoundNumericValuesMixin, BaseModel): @@ -31,7 +31,7 @@ def from_dict( elements=ArrayWithIds.from_list_of_dicts(elements), coordinates=ArrayWithIds.from_list_of_dicts(coordinates), units=units, - cell=Cell.from_nested_array(cell) if cell else None, + cell=Cell.from_nested_array(cell), labels=ArrayWithIds.from_list_of_dicts(labels) if labels else ArrayWithIds(values=[]), constraints=ArrayWithIds.from_list_of_dicts(constraints) if constraints else ArrayWithIds(values=[]), ) diff --git a/src/py/mat3ra/made/cell.py b/src/py/mat3ra/made/cell.py new file mode 100644 index 00000000..ade91d3b --- /dev/null +++ b/src/py/mat3ra/made/cell.py @@ -0,0 +1,53 @@ +from typing import List + +import numpy as np +from mat3ra.utils.mixins import RoundNumericValuesMixin +from pydantic import BaseModel, Field + + +class Cell(RoundNumericValuesMixin, BaseModel): + # TODO: figure out how to use ArrayOf3NumberElementsSchema + vector1: List[float] = Field(default_factory=lambda: [1, 0, 0]) + vector2: List[float] = Field(default_factory=lambda: [0, 1, 0]) + vector3: List[float] = Field(default_factory=lambda: [0, 0, 1]) + __round_precision__ = 6 + + @classmethod + def from_nested_array(cls, nested_array): + if nested_array is None: + nested_array = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + return cls(vector1=nested_array[0], vector2=nested_array[1], vector3=nested_array[2]) + + @property + def vectors_as_nested_array(self, skip_rounding=False) -> List[List[float]]: + if skip_rounding: + return [self.vector1, self.vector2, self.vector3] + return self.round_array_or_number([self.vector1, self.vector2, self.vector3]) + + def to_json(self, skip_rounding=False): + _ = self.round_array_or_number + return [ + self.vector1 if skip_rounding else _(self.vector1), + self.vector2 if skip_rounding else _(self.vector2), + self.vector3 if skip_rounding else _(self.vector3), + ] + + def clone(self): + return self.from_nested_array(self.vectors_as_nested_array) + + def clone_and_scale_by_matrix(self, matrix): + new_cell = self.clone() + new_cell.scale_by_matrix(matrix) + return new_cell + + def convert_point_to_cartesian(self, point): + np_vector = np.array(self.vectors_as_nested_array) + return np.dot(point, np_vector) + + def convert_point_to_fractional(self, point): + np_vector = np.array(self.vectors_as_nested_array) + return np.dot(point, np.linalg.inv(np_vector)) + + def scale_by_matrix(self, matrix): + np_vector = np.array(self.vectors_as_nested_array) + self.vector1, self.vector2, self.vector3 = np.dot(matrix, np_vector).tolist() diff --git a/src/py/mat3ra/made/cell/cell.py b/src/py/mat3ra/made/cell/cell.py deleted file mode 100644 index e999ee3a..00000000 --- a/src/py/mat3ra/made/cell/cell.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List - -import numpy as np -from mat3ra.esse.models.core.primitive.array_of_3_numbers import ArrayOf3NumberElementsSchema -from mat3ra.utils.mixins import RoundNumericValuesMixin -from pydantic import BaseModel - - -class Cell(RoundNumericValuesMixin, BaseModel): - # TODO: figure out how to use - vector1: ArrayOf3NumberElementsSchema = [1, 0, 0] - vector2: ArrayOf3NumberElementsSchema = [0, 1, 0] - vector3: ArrayOf3NumberElementsSchema = [0, 0, 1] - __round_precision__ = 1e-6 - - @classmethod - def from_nested_array(cls, nested_array): - if not nested_array: - nested_array = [cls.vector1, cls.vector2, cls.vector3] - return cls(vector1=nested_array[0], vector2=nested_array[1], vector3=nested_array[2]) - - def __init__(self, vector1=[1, 0, 0], vector2=[0, 1, 0], vector3=[0, 0, 1]): - super().__init__(**{"vector1": vector1, "vector2": vector2, "vector3": vector3}) - - @property - def vectors_as_array(self, skip_rounding=False) -> List[ArrayOf3NumberElementsSchema]: - if skip_rounding: - return [self.vector1, self.vector2, self.vector3] - return self.round_array_or_number([self.vector1, self.vector2, self.vector3]) - - def to_json(self, skip_rounding=False): - _ = self.round_array_or_number - if skip_rounding: - return { - "vector1": _(self.vector1) if skip_rounding else self.vector1, - "vector2": _(self.vector2) if skip_rounding else self.vector2, - "vector3": _(self.vector3) if skip_rounding else self.vector3, - } - - def clone(self): - return self.from_nested_array(self.vectors_as_array) - - def clone_and_scale_by_matrix(self, matrix): - new_cell = self.clone() - new_cell.scale_by_matrix(matrix) - return new_cell - - def convert_point_to_cartesian(self, point): - np_vector = np.array(self.vectors_as_array) - return np.dot(point, np_vector) - - def convert_point_to_fractional(self, point): - np_vector = np.array(self.vectors_as_array) - return np.dot(point, np.linalg.inv(np_vector)) - - def scale_by_matrix(self, matrix): - np_vector = np.array(self.vectors_as_array) - self.vector1, self.vector2, self.vector3 = np.dot(matrix, np_vector).tolist() diff --git a/src/py/mat3ra/made/lattice/lattice.py b/src/py/mat3ra/made/lattice.py similarity index 98% rename from src/py/mat3ra/made/lattice/lattice.py rename to src/py/mat3ra/made/lattice.py index 0bf5710c..dcaa89fc 100644 --- a/src/py/mat3ra/made/lattice/lattice.py +++ b/src/py/mat3ra/made/lattice.py @@ -5,7 +5,7 @@ from mat3ra.utils.mixins import RoundNumericValuesMixin from pydantic import BaseModel -from ..cell.cell import Cell +from .cell import Cell HASH_TOLERANCE = 3 diff --git a/src/py/mat3ra/made/material.py b/src/py/mat3ra/made/material.py index 616246c5..c102767d 100644 --- a/src/py/mat3ra/made/material.py +++ b/src/py/mat3ra/made/material.py @@ -4,8 +4,8 @@ from mat3ra.code.entity import HasDescriptionHasMetadataNamedDefaultableInMemoryEntity from mat3ra.esse.models.material import MaterialSchema -from .basis.basis import Basis -from .lattice.lattice import Lattice +from .basis import Basis +from .lattice import Lattice defaultMaterialConfig = { "name": "Silicon FCC", diff --git a/src/py/mat3ra/made/tools/analyze.py b/src/py/mat3ra/made/tools/analyze.py index 6f02c9c8..fcf1dbb3 100644 --- a/src/py/mat3ra/made/tools/analyze.py +++ b/src/py/mat3ra/made/tools/analyze.py @@ -1,10 +1,13 @@ -from typing import List +from typing import List, Optional import numpy as np from ase import Atoms +from pymatgen.core import IStructure as PymatgenIStructure from ..material import Material -from .convert import decorator_convert_material_args_kwargs_to_atoms +from .convert import decorator_convert_material_args_kwargs_to_atoms, to_pymatgen + +PymatgenIStructure = PymatgenIStructure @decorator_convert_material_args_kwargs_to_atoms @@ -89,3 +92,112 @@ def get_closest_site_id_from_position(material: Material, position: List[float]) position = np.array(position) # type: ignore distances = np.linalg.norm(coordinates - position, axis=1) return int(np.argmin(distances)) + + +def get_atom_indices_within_layer_by_atom_index(material: Material, atom_index: int, layer_thickness: float): + """ + Select all atoms within a specified layer thickness of a central atom along a direction. + This direction will be orthogonal to the AB plane. + Layer thickness is converted from angstroms to fractional units based on the lattice vector length. + + Args: + material (Material): Material object + atom_index (int): Index of the central atom + layer_thickness (float): Thickness of the layer in angstroms + + Returns: + List[int]: List of indices of atoms within the specified layer + """ + coordinates = material.basis.coordinates.to_array_of_values_with_ids() + vectors = material.lattice.vectors + direction_vector = np.array(vectors[2]) + + # Normalize the direction vector + direction_length = np.linalg.norm(direction_vector) + direction_norm = direction_vector / direction_length + central_atom_position = coordinates[atom_index] + central_atom_projection = np.dot(central_atom_position.value, direction_norm) + + layer_thickness_frac = layer_thickness / direction_length + + lower_bound = central_atom_projection - layer_thickness_frac / 2 + upper_bound = central_atom_projection + layer_thickness_frac / 2 + + selected_indices = [] + for coord in coordinates: + # Project each position onto the direction vector + projection = np.dot(coord.value, direction_norm) + if lower_bound <= projection <= upper_bound: + selected_indices.append(coord.id) + return selected_indices + + +def get_atom_indices_within_layer_by_atom_position(material: Material, position: List[float], layer_thickness: float): + """ + Select all atoms within a specified layer thickness of a central atom along a direction. + This direction will be orthogonal to the AB plane. + Layer thickness is converted from angstroms to fractional units based on the lattice vector length. + + Args: + material (Material): Material object + position (List[float]): Position of the central atom in crystal coordinates + layer_thickness (float): Thickness of the layer in angstroms + + Returns: + List[int]: List of indices of atoms within the specified layer + """ + site_id = get_closest_site_id_from_position(material, position) + return get_atom_indices_within_layer_by_atom_index(material, site_id, layer_thickness) + + +def get_atom_indices_within_layer( + material: Material, + atom_index: Optional[int] = 0, + position: Optional[List[float]] = None, + layer_thickness: float = 1, +): + """ + Select all atoms within a specified layer thickness of the central atom along the c-vector direction. + + Args: + material (Material): Material object + atom_index (int): Index of the central atom + position (List[float]): Position of the central atom in crystal coordinates + layer_thickness (float): Thickness of the layer in angstroms + + Returns: + List[int]: List of indices of atoms within the specified layer + """ + if position is not None: + return get_atom_indices_within_layer_by_atom_position(material, position, layer_thickness) + if atom_index is not None: + return get_atom_indices_within_layer_by_atom_index(material, atom_index, layer_thickness) + + +def get_atom_indices_within_radius_pbc( + material: Material, atom_index: Optional[int] = 0, position: Optional[List[float]] = None, radius: float = 1 +): + """ + Select all atoms within a specified radius of a central atom considering periodic boundary conditions. + + Args: + material (Material): Material object + atom_index (int): Index of the central atom + position (List[float]): Position of the central atom in crystal coordinates + radius (float): Radius of the sphere in angstroms + + Returns: + List[int]: List of indices of atoms within the specified + """ + + if position is not None: + atom_index = get_closest_site_id_from_position(material, position) + + structure = to_pymatgen(material) + immutable_structure = PymatgenIStructure.from_sites(structure.sites) + + central_atom = immutable_structure[atom_index] + sites_within_radius = structure.get_sites_in_sphere(central_atom.coords, radius) + + selected_indices = [site.index for site in sites_within_radius] + return selected_indices diff --git a/src/py/mat3ra/made/tools/build/utils.py b/src/py/mat3ra/made/tools/build/utils.py new file mode 100644 index 00000000..2b9ab71d --- /dev/null +++ b/src/py/mat3ra/made/tools/build/utils.py @@ -0,0 +1,82 @@ +from typing import List + +from mat3ra.made.basis import Basis +from mat3ra.made.material import Material + +from ..utils import convert_basis_to_crystal, get_distance_between_coordinates + + +def resolve_close_coordinates_basis(basis: Basis, distance_tolerance: float = 0.01) -> Basis: + """ + Find all atoms that are within distance tolerance and only keep the last one, remove other sites + """ + coordinates = basis.coordinates.to_array_of_values_with_ids() + ids = set(basis.coordinates.ids) + ids_to_remove = set() + + for i in range(1, len(coordinates)): + for j in range(i): + if get_distance_between_coordinates(coordinates[i].value, coordinates[j].value) < distance_tolerance: + ids_to_remove.add(coordinates[j].id) + + ids_to_keep = list(ids - ids_to_remove) + basis.filter_atoms_by_ids(ids_to_keep) + return basis + + +def merge_two_bases(basis1: Basis, basis2: Basis, distance_tolerance: float) -> Basis: + basis1 = convert_basis_to_crystal(basis1) + basis2 = convert_basis_to_crystal(basis2) + + merged_elements = basis1.elements + merged_coordinates = basis1.coordinates + merged_labels = basis1.labels + + for coordinate in basis2.coordinates.values: + merged_coordinates.add_item(coordinate) + + for element in basis2.elements.values: + merged_elements.add_item(element) + + if basis2.labels: + for label in basis2.labels.values: + merged_labels.add_item(label) + + merged_basis = Basis( + elements=merged_elements, + coordinates=merged_coordinates, + units=basis1.units, + cell=basis1.cell, + labels=merged_labels, + ) + resolved_basis = resolve_close_coordinates_basis(merged_basis, distance_tolerance) + + return resolved_basis + + +def merge_two_materials(material1: Material, material2: Material, distance_tolerance: float) -> Material: + """ + Merge two materials with the same lattice into a single material, + replacing colliding atoms with the latest material's atoms. + """ + + material1 = material1.clone() + material2 = material2.clone() + if material1.lattice != material2.lattice: + raise ValueError("Lattices of the two materials must be the same.") + merged_lattice = material1.lattice + resolved_basis = merge_two_bases(material1.basis, material2.basis, distance_tolerance) + + name = "Merged Material" + new_material = Material.create( + {"name": name, "lattice": merged_lattice.to_json(), "basis": resolved_basis.to_json()} + ) + return new_material + + +def merge_materials(materials: List[Material], distance_tolerance: float = 0.01) -> Material: + merged_material = materials[0] + for material in materials[1:]: + merged_material = merge_two_materials(merged_material, material, distance_tolerance) + + return merged_material diff --git a/src/py/mat3ra/made/tools/modify.py b/src/py/mat3ra/made/tools/modify.py index fb0c7119..d1a1be77 100644 --- a/src/py/mat3ra/made/tools/modify.py +++ b/src/py/mat3ra/made/tools/modify.py @@ -1,9 +1,10 @@ -from typing import Union +from typing import List, Union from mat3ra.made.material import Material from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer from pymatgen.core.structure import Structure +from .analyze import get_atom_indices_within_layer_by_atom_index, get_atom_indices_within_radius_pbc from .convert import decorator_convert_material_args_kwargs_to_structure from .utils import translate_to_bottom_pymatgen_structure @@ -58,3 +59,68 @@ def wrap_to_unit_cell(structure: Structure): """ structure.make_supercell((1, 1, 1), to_unit_cell=True) return structure + + +def filter_material_by_ids(material: Material, ids: List[int], invert: bool = False) -> Material: + """ + Filter out only atoms corresponding to the ids. + + Args: + material (Material): The material object to filter. + ids (List[int]): The ids to filter by. + invert (bool): Whether to invert the selection. + + Returns: + Material: The filtered material object. + """ + new_material = material.clone() + new_basis = new_material.basis + if invert is True: + ids = list(set(new_basis.elements.ids) - set(ids)) + new_basis.filter_atoms_by_ids(ids) + new_material.basis = new_basis + return new_material + + +def filter_by_layers( + material: Material, central_atom_id: int, layer_thickness: float, invert: bool = False +) -> Material: + """ + Filter out atoms within a specified layer thickness of a central atom along c-vector direction. + + Args: + material (Material): The material object to filter. + central_atom_id (int): Index of the central atom. + layer_thickness (float): Thickness of the layer in angstroms. + invert (bool): Whether to invert the selection. + + Returns: + Material: The filtered material object. + """ + ids = get_atom_indices_within_layer_by_atom_index( + material, + central_atom_id, + layer_thickness, + ) + return filter_material_by_ids(material, ids, invert=invert) + + +def filter_by_sphere(material: Material, central_atom_id: int, radius: float, invert: bool = False) -> Material: + """ + Filter out atoms within a specified radius of a central atom considering periodic boundary conditions. + + Args: + material (Material): The material object to filter. + central_atom_id (int): Index of the central atom. + radius (float): Radius of the sphere in angstroms. + invert (bool): Whether to invert the selection. + + Returns: + Material: The filtered material object. + """ + ids = get_atom_indices_within_radius_pbc( + material=material, + atom_index=central_atom_id, + radius=radius, + ) + return filter_material_by_ids(material, ids, invert=invert) diff --git a/src/py/mat3ra/made/tools/utils.py b/src/py/mat3ra/made/tools/utils.py index 71470605..78fa5e0c 100644 --- a/src/py/mat3ra/made/tools/utils.py +++ b/src/py/mat3ra/made/tools/utils.py @@ -1,6 +1,8 @@ from functools import wraps -from typing import Callable +from typing import Callable, List +import numpy as np +from mat3ra.made.basis import Basis from mat3ra.utils.matrix import convert_2x2_to_3x3 from pymatgen.core.structure import Structure @@ -34,3 +36,62 @@ def wrapper(*args, **kwargs): return func(*new_args, **kwargs) return wrapper + + +def convert_basis_to_cartesian(basis: Basis) -> Basis: + """ + Convert the basis to the Cartesian coordinates. + Args: + basis (Dict): The basis to convert. + + Returns: + Dict: The basis in Cartesian coordinates. + """ + if basis.units == "cartesian": + return basis + unit_cell = np.array(basis.cell) + basis.coordinates = np.multiply(basis.coordinates, unit_cell) + basis.units = "cartesian" + return basis + + +def convert_basis_to_crystal(basis: Basis) -> Basis: + """ + Convert the basis to the crystal coordinates. + Args: + basis (Dict): The basis to convert. + + Returns: + Dict: The basis in crystal coordinates. + """ + if basis.units == "crystal": + return basis + unit_cell = np.array(basis.cell) + basis.coordinates.values = np.multiply(basis.coordinates.values, np.linalg.inv(unit_cell)) + basis.units = "crystal" + return basis + + +def get_distance_between_coordinates(coordinate1: List[float], coordinate2: List[float]) -> float: + """ + Get the distance between two coordinates. + Args: + coordinate1 (List[float]): The first coordinate. + coordinate2 (List[float]): The second coordinate. + + Returns: + float: The distance between the two coordinates. + """ + return float(np.linalg.norm(np.array(coordinate1) - np.array(coordinate2))) + + +def get_norm(vector: List[float]) -> float: + """ + Get the norm of a vector. + Args: + vector (List[float]): The vector. + + Returns: + float: The norm of the vector. + """ + return float(np.linalg.norm(vector)) diff --git a/tests/py/unit/test_material.py b/tests/py/unit/test_material.py index 35d9d853..d343eb32 100644 --- a/tests/py/unit/test_material.py +++ b/tests/py/unit/test_material.py @@ -1,5 +1,5 @@ -from mat3ra.made.basis.basis import Basis -from mat3ra.made.lattice.lattice import Lattice +from mat3ra.made.basis import Basis +from mat3ra.made.lattice import Lattice from mat3ra.made.material import Material from mat3ra.utils import assertion as assertion_utils diff --git a/tests/py/unit/test_tools_build.py b/tests/py/unit/test_tools_build.py new file mode 100644 index 00000000..a3d53b09 --- /dev/null +++ b/tests/py/unit/test_tools_build.py @@ -0,0 +1,54 @@ +from ase.build import bulk +from mat3ra.made.material import Material +from mat3ra.made.tools.build.utils import merge_materials +from mat3ra.made.tools.convert import from_ase +from mat3ra.made.tools.modify import filter_by_layers +from mat3ra.utils import assertion as assertion_utils + +ase_ni = bulk("Ni", "fcc", a=3.52, cubic=True) +material = Material(from_ase(ase_ni)) +section = filter_by_layers(material, 0, 1.0) +cavity = filter_by_layers(material, 0, 1.0, invert=True) + +# Change 0th element +section.basis.elements.values[0] = "Ge" + +# Add element to cavity for collision test +cavity.basis.elements.add_item("S", id=4) +coordinate_value = section.basis.coordinates.values[1] +cavity.basis.coordinates.add_item(coordinate_value, id=4) + +expected_merged_material_basis = { + "elements": [{"id": 0, "value": "Ge"}, {"id": 1, "value": "Ni"}, {"id": 2, "value": "Ni"}, {"id": 4, "value": "S"}], + "coordinates": [ + {"id": 0, "value": [0.0, 0.0, 0.0]}, + {"id": 1, "value": [0.0, 0.5, 0.5]}, + {"id": 2, "value": [0.5, 0.0, 0.5]}, + {"id": 4, "value": [0.5, 0.5, 0.0]}, + ], + "labels": [], +} + + +expected_merged_material_reverse_basis = { + "elements": [ + {"id": 1, "value": "Ni"}, + {"id": 2, "value": "Ni"}, + {"id": 0, "value": "Ge"}, + {"id": 3, "value": "Ni"}, + ], + "coordinates": [ + {"id": 1, "value": [0.0, 0.5, 0.5]}, + {"id": 2, "value": [0.5, 0.0, 0.5]}, + {"id": 0, "value": [0.0, 0.0, 0.0]}, + {"id": 3, "value": [0.5, 0.5, 0.0]}, + ], + "labels": [], +} + + +def test_merge_materials(): + merged_material = merge_materials([section, cavity]) + merged_material_reverse = merge_materials([cavity, section]) + assertion_utils.assert_deep_almost_equal(merged_material.basis, expected_merged_material_basis) + assertion_utils.assert_deep_almost_equal(merged_material_reverse.basis, expected_merged_material_reverse_basis) diff --git a/tests/py/unit/test_tools_convert.py b/tests/py/unit/test_tools_convert.py index 5bf3ea8a..ddee5cc9 100644 --- a/tests/py/unit/test_tools_convert.py +++ b/tests/py/unit/test_tools_convert.py @@ -1,7 +1,7 @@ import numpy as np from ase import Atoms from ase.build import bulk -from mat3ra.made.basis.basis import Basis +from mat3ra.made.basis import Basis from mat3ra.made.material import Material from mat3ra.made.tools.convert import from_ase, from_poscar, from_pymatgen, to_ase, to_poscar, to_pymatgen from mat3ra.utils import assertion as assertion_utils diff --git a/tests/py/unit/test_tools_modify.py b/tests/py/unit/test_tools_modify.py index 9e0e1f51..8813a71b 100644 --- a/tests/py/unit/test_tools_modify.py +++ b/tests/py/unit/test_tools_modify.py @@ -1,7 +1,77 @@ from ase.build import bulk from mat3ra.made.material import Material from mat3ra.made.tools.convert import from_ase -from mat3ra.made.tools.modify import filter_by_label +from mat3ra.made.tools.modify import filter_by_label, filter_by_layers, filter_by_sphere +from mat3ra.utils import assertion as assertion_utils + +from .fixtures import SI_CONVENTIONAL_CELL + +COMMON_PART = { + "units": "crystal", + "cell": [[5.468763846, 0.0, 0.0], [-0.0, 5.468763846, 0.0], [0.0, 0.0, 5.468763846]], + "labels": [], +} + +expected_basis_layers_section = { + "elements": [ + {"id": 0, "value": "Si"}, + {"id": 3, "value": "Si"}, + {"id": 5, "value": "Si"}, + {"id": 6, "value": "Si"}, + ], + "coordinates": [ + {"id": 0, "value": [0.5, 0.0, 0.0]}, + {"id": 3, "value": [0.25, 0.75, 0.25]}, + {"id": 5, "value": [0.75, 0.25, 0.25]}, + {"id": 6, "value": [0.0, 0.5, 0.0]}, + ], + **COMMON_PART, +} + +expected_basis_layers_cavity = { + "elements": [ + {"id": 1, "value": "Si"}, + {"id": 2, "value": "Si"}, + {"id": 4, "value": "Si"}, + {"id": 7, "value": "Si"}, + ], + "coordinates": [ + {"id": 1, "value": [0.25, 0.25, 0.75]}, + {"id": 2, "value": [0.5, 0.5, 0.5]}, + {"id": 4, "value": [0.0, 0.0, 0.5]}, + {"id": 7, "value": [0.75, 0.75, 0.75]}, + ], + **COMMON_PART, +} + + +expected_basis_sphere_cluster = { + "elements": [{"id": 0, "value": "Si"}], + "coordinates": [{"id": 0, "value": [0.5, 0.0, 0.0]}], + **COMMON_PART, +} + +expected_basis_sphere_cavity = { + "elements": [ + {"id": 1, "value": "Si"}, + {"id": 2, "value": "Si"}, + {"id": 3, "value": "Si"}, + {"id": 4, "value": "Si"}, + {"id": 5, "value": "Si"}, + {"id": 6, "value": "Si"}, + {"id": 7, "value": "Si"}, + ], + "coordinates": [ + {"id": 1, "value": [0.25, 0.25, 0.75]}, + {"id": 2, "value": [0.5, 0.5, 0.5]}, + {"id": 3, "value": [0.25, 0.75, 0.25]}, + {"id": 4, "value": [0.0, 0.0, 0.5]}, + {"id": 5, "value": [0.75, 0.25, 0.25]}, + {"id": 6, "value": [0.0, 0.5, 0.0]}, + {"id": 7, "value": [0.75, 0.75, 0.75]}, + ], + **COMMON_PART, +} def test_filter_by_label(): @@ -15,3 +85,19 @@ def test_filter_by_label(): # Ids of filtered elements will be missing, comparing the resulting values assert [el for el in film_material.basis.elements.values] == [el for el in film_extracted.basis.elements.values] + + +def test_filter_by_layers(): + material = Material(SI_CONVENTIONAL_CELL) + section = filter_by_layers(material, 0, 3.0) + cavity = filter_by_layers(material, 0, 3.0, invert=True) + assertion_utils.assert_deep_almost_equal(expected_basis_layers_section, section.basis.to_json()) + assertion_utils.assert_deep_almost_equal(expected_basis_layers_cavity, cavity.basis.to_json()) + + +def test_filter_by_sphere(): + material = Material(SI_CONVENTIONAL_CELL) + cluster = filter_by_sphere(material, 0, 2.0) + cavity = filter_by_sphere(material, 0, 2.0, invert=True) + assertion_utils.assert_deep_almost_equal(expected_basis_sphere_cluster, cluster.basis.to_json()) + assertion_utils.assert_deep_almost_equal(expected_basis_sphere_cavity, cavity.basis.to_json())