Skip to content

Commit

Permalink
Merge pull request #135 from Exabyte-io/feature/SOF-7385
Browse files Browse the repository at this point in the history
feature/SOF 7385
  • Loading branch information
VsevolodX authored Jun 24, 2024
2 parents c1f1c78 + 4b52375 commit a9e1778
Show file tree
Hide file tree
Showing 14 changed files with 529 additions and 72 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ venv/
ENV/
env.bak/
venv.bak/
.python-version

# Spyder project settings
.spyderproject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from mat3ra.utils.mixins import RoundNumericValuesMixin
from pydantic import BaseModel

from ..cell.cell import Cell
from ..utils import ArrayWithIds
from .cell import Cell
from .utils import ArrayWithIds


class Basis(RoundNumericValuesMixin, BaseModel):
Expand All @@ -31,7 +31,7 @@ def from_dict(
elements=ArrayWithIds.from_list_of_dicts(elements),
coordinates=ArrayWithIds.from_list_of_dicts(coordinates),
units=units,
cell=Cell.from_nested_array(cell) if cell else None,
cell=Cell.from_nested_array(cell),
labels=ArrayWithIds.from_list_of_dicts(labels) if labels else ArrayWithIds(values=[]),
constraints=ArrayWithIds.from_list_of_dicts(constraints) if constraints else ArrayWithIds(values=[]),
)
Expand Down
53 changes: 53 additions & 0 deletions src/py/mat3ra/made/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List

import numpy as np
from mat3ra.utils.mixins import RoundNumericValuesMixin
from pydantic import BaseModel, Field


class Cell(RoundNumericValuesMixin, BaseModel):
# TODO: figure out how to use ArrayOf3NumberElementsSchema
vector1: List[float] = Field(default_factory=lambda: [1, 0, 0])
vector2: List[float] = Field(default_factory=lambda: [0, 1, 0])
vector3: List[float] = Field(default_factory=lambda: [0, 0, 1])
__round_precision__ = 6

@classmethod
def from_nested_array(cls, nested_array):
if nested_array is None:
nested_array = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
return cls(vector1=nested_array[0], vector2=nested_array[1], vector3=nested_array[2])

@property
def vectors_as_nested_array(self, skip_rounding=False) -> List[List[float]]:
if skip_rounding:
return [self.vector1, self.vector2, self.vector3]
return self.round_array_or_number([self.vector1, self.vector2, self.vector3])

def to_json(self, skip_rounding=False):
_ = self.round_array_or_number
return [
self.vector1 if skip_rounding else _(self.vector1),
self.vector2 if skip_rounding else _(self.vector2),
self.vector3 if skip_rounding else _(self.vector3),
]

def clone(self):
return self.from_nested_array(self.vectors_as_nested_array)

def clone_and_scale_by_matrix(self, matrix):
new_cell = self.clone()
new_cell.scale_by_matrix(matrix)
return new_cell

def convert_point_to_cartesian(self, point):
np_vector = np.array(self.vectors_as_nested_array)
return np.dot(point, np_vector)

def convert_point_to_fractional(self, point):
np_vector = np.array(self.vectors_as_nested_array)
return np.dot(point, np.linalg.inv(np_vector))

def scale_by_matrix(self, matrix):
np_vector = np.array(self.vectors_as_nested_array)
self.vector1, self.vector2, self.vector3 = np.dot(matrix, np_vector).tolist()
58 changes: 0 additions & 58 deletions src/py/mat3ra/made/cell/cell.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mat3ra.utils.mixins import RoundNumericValuesMixin
from pydantic import BaseModel

from ..cell.cell import Cell
from .cell import Cell

HASH_TOLERANCE = 3

Expand Down
4 changes: 2 additions & 2 deletions src/py/mat3ra/made/material.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from mat3ra.code.entity import HasDescriptionHasMetadataNamedDefaultableInMemoryEntity
from mat3ra.esse.models.material import MaterialSchema

from .basis.basis import Basis
from .lattice.lattice import Lattice
from .basis import Basis
from .lattice import Lattice

defaultMaterialConfig = {
"name": "Silicon FCC",
Expand Down
116 changes: 114 additions & 2 deletions src/py/mat3ra/made/tools/analyze.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List
from typing import List, Optional

import numpy as np
from ase import Atoms
from pymatgen.core import IStructure as PymatgenIStructure

from ..material import Material
from .convert import decorator_convert_material_args_kwargs_to_atoms
from .convert import decorator_convert_material_args_kwargs_to_atoms, to_pymatgen

PymatgenIStructure = PymatgenIStructure


@decorator_convert_material_args_kwargs_to_atoms
Expand Down Expand Up @@ -89,3 +92,112 @@ def get_closest_site_id_from_position(material: Material, position: List[float])
position = np.array(position) # type: ignore
distances = np.linalg.norm(coordinates - position, axis=1)
return int(np.argmin(distances))


def get_atom_indices_within_layer_by_atom_index(material: Material, atom_index: int, layer_thickness: float):
"""
Select all atoms within a specified layer thickness of a central atom along a direction.
This direction will be orthogonal to the AB plane.
Layer thickness is converted from angstroms to fractional units based on the lattice vector length.
Args:
material (Material): Material object
atom_index (int): Index of the central atom
layer_thickness (float): Thickness of the layer in angstroms
Returns:
List[int]: List of indices of atoms within the specified layer
"""
coordinates = material.basis.coordinates.to_array_of_values_with_ids()
vectors = material.lattice.vectors
direction_vector = np.array(vectors[2])

# Normalize the direction vector
direction_length = np.linalg.norm(direction_vector)
direction_norm = direction_vector / direction_length
central_atom_position = coordinates[atom_index]
central_atom_projection = np.dot(central_atom_position.value, direction_norm)

layer_thickness_frac = layer_thickness / direction_length

lower_bound = central_atom_projection - layer_thickness_frac / 2
upper_bound = central_atom_projection + layer_thickness_frac / 2

selected_indices = []
for coord in coordinates:
# Project each position onto the direction vector
projection = np.dot(coord.value, direction_norm)
if lower_bound <= projection <= upper_bound:
selected_indices.append(coord.id)
return selected_indices


def get_atom_indices_within_layer_by_atom_position(material: Material, position: List[float], layer_thickness: float):
"""
Select all atoms within a specified layer thickness of a central atom along a direction.
This direction will be orthogonal to the AB plane.
Layer thickness is converted from angstroms to fractional units based on the lattice vector length.
Args:
material (Material): Material object
position (List[float]): Position of the central atom in crystal coordinates
layer_thickness (float): Thickness of the layer in angstroms
Returns:
List[int]: List of indices of atoms within the specified layer
"""
site_id = get_closest_site_id_from_position(material, position)
return get_atom_indices_within_layer_by_atom_index(material, site_id, layer_thickness)


def get_atom_indices_within_layer(
material: Material,
atom_index: Optional[int] = 0,
position: Optional[List[float]] = None,
layer_thickness: float = 1,
):
"""
Select all atoms within a specified layer thickness of the central atom along the c-vector direction.
Args:
material (Material): Material object
atom_index (int): Index of the central atom
position (List[float]): Position of the central atom in crystal coordinates
layer_thickness (float): Thickness of the layer in angstroms
Returns:
List[int]: List of indices of atoms within the specified layer
"""
if position is not None:
return get_atom_indices_within_layer_by_atom_position(material, position, layer_thickness)
if atom_index is not None:
return get_atom_indices_within_layer_by_atom_index(material, atom_index, layer_thickness)


def get_atom_indices_within_radius_pbc(
material: Material, atom_index: Optional[int] = 0, position: Optional[List[float]] = None, radius: float = 1
):
"""
Select all atoms within a specified radius of a central atom considering periodic boundary conditions.
Args:
material (Material): Material object
atom_index (int): Index of the central atom
position (List[float]): Position of the central atom in crystal coordinates
radius (float): Radius of the sphere in angstroms
Returns:
List[int]: List of indices of atoms within the specified
"""

if position is not None:
atom_index = get_closest_site_id_from_position(material, position)

structure = to_pymatgen(material)
immutable_structure = PymatgenIStructure.from_sites(structure.sites)

central_atom = immutable_structure[atom_index]
sites_within_radius = structure.get_sites_in_sphere(central_atom.coords, radius)

selected_indices = [site.index for site in sites_within_radius]
return selected_indices
82 changes: 82 additions & 0 deletions src/py/mat3ra/made/tools/build/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import List

from mat3ra.made.basis import Basis
from mat3ra.made.material import Material

from ..utils import convert_basis_to_crystal, get_distance_between_coordinates


def resolve_close_coordinates_basis(basis: Basis, distance_tolerance: float = 0.01) -> Basis:
"""
Find all atoms that are within distance tolerance and only keep the last one, remove other sites
"""
coordinates = basis.coordinates.to_array_of_values_with_ids()
ids = set(basis.coordinates.ids)
ids_to_remove = set()

for i in range(1, len(coordinates)):
for j in range(i):
if get_distance_between_coordinates(coordinates[i].value, coordinates[j].value) < distance_tolerance:
ids_to_remove.add(coordinates[j].id)

ids_to_keep = list(ids - ids_to_remove)
basis.filter_atoms_by_ids(ids_to_keep)
return basis


def merge_two_bases(basis1: Basis, basis2: Basis, distance_tolerance: float) -> Basis:
basis1 = convert_basis_to_crystal(basis1)
basis2 = convert_basis_to_crystal(basis2)

merged_elements = basis1.elements
merged_coordinates = basis1.coordinates
merged_labels = basis1.labels

for coordinate in basis2.coordinates.values:
merged_coordinates.add_item(coordinate)

for element in basis2.elements.values:
merged_elements.add_item(element)

if basis2.labels:
for label in basis2.labels.values:
merged_labels.add_item(label)

merged_basis = Basis(
elements=merged_elements,
coordinates=merged_coordinates,
units=basis1.units,
cell=basis1.cell,
labels=merged_labels,
)
resolved_basis = resolve_close_coordinates_basis(merged_basis, distance_tolerance)

return resolved_basis


def merge_two_materials(material1: Material, material2: Material, distance_tolerance: float) -> Material:
"""
Merge two materials with the same lattice into a single material,
replacing colliding atoms with the latest material's atoms.
"""

material1 = material1.clone()
material2 = material2.clone()
if material1.lattice != material2.lattice:
raise ValueError("Lattices of the two materials must be the same.")
merged_lattice = material1.lattice
resolved_basis = merge_two_bases(material1.basis, material2.basis, distance_tolerance)

name = "Merged Material"
new_material = Material.create(
{"name": name, "lattice": merged_lattice.to_json(), "basis": resolved_basis.to_json()}
)
return new_material


def merge_materials(materials: List[Material], distance_tolerance: float = 0.01) -> Material:
merged_material = materials[0]
for material in materials[1:]:
merged_material = merge_two_materials(merged_material, material, distance_tolerance)

return merged_material
Loading

0 comments on commit a9e1778

Please sign in to comment.