-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #135 from Exabyte-io/feature/SOF-7385
feature/SOF 7385
- Loading branch information
Showing
14 changed files
with
529 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,6 +135,7 @@ venv/ | |
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
.python-version | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
from mat3ra.utils.mixins import RoundNumericValuesMixin | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
class Cell(RoundNumericValuesMixin, BaseModel): | ||
# TODO: figure out how to use ArrayOf3NumberElementsSchema | ||
vector1: List[float] = Field(default_factory=lambda: [1, 0, 0]) | ||
vector2: List[float] = Field(default_factory=lambda: [0, 1, 0]) | ||
vector3: List[float] = Field(default_factory=lambda: [0, 0, 1]) | ||
__round_precision__ = 6 | ||
|
||
@classmethod | ||
def from_nested_array(cls, nested_array): | ||
if nested_array is None: | ||
nested_array = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] | ||
return cls(vector1=nested_array[0], vector2=nested_array[1], vector3=nested_array[2]) | ||
|
||
@property | ||
def vectors_as_nested_array(self, skip_rounding=False) -> List[List[float]]: | ||
if skip_rounding: | ||
return [self.vector1, self.vector2, self.vector3] | ||
return self.round_array_or_number([self.vector1, self.vector2, self.vector3]) | ||
|
||
def to_json(self, skip_rounding=False): | ||
_ = self.round_array_or_number | ||
return [ | ||
self.vector1 if skip_rounding else _(self.vector1), | ||
self.vector2 if skip_rounding else _(self.vector2), | ||
self.vector3 if skip_rounding else _(self.vector3), | ||
] | ||
|
||
def clone(self): | ||
return self.from_nested_array(self.vectors_as_nested_array) | ||
|
||
def clone_and_scale_by_matrix(self, matrix): | ||
new_cell = self.clone() | ||
new_cell.scale_by_matrix(matrix) | ||
return new_cell | ||
|
||
def convert_point_to_cartesian(self, point): | ||
np_vector = np.array(self.vectors_as_nested_array) | ||
return np.dot(point, np_vector) | ||
|
||
def convert_point_to_fractional(self, point): | ||
np_vector = np.array(self.vectors_as_nested_array) | ||
return np.dot(point, np.linalg.inv(np_vector)) | ||
|
||
def scale_by_matrix(self, matrix): | ||
np_vector = np.array(self.vectors_as_nested_array) | ||
self.vector1, self.vector2, self.vector3 = np.dot(matrix, np_vector).tolist() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import List | ||
|
||
from mat3ra.made.basis import Basis | ||
from mat3ra.made.material import Material | ||
|
||
from ..utils import convert_basis_to_crystal, get_distance_between_coordinates | ||
|
||
|
||
def resolve_close_coordinates_basis(basis: Basis, distance_tolerance: float = 0.01) -> Basis: | ||
""" | ||
Find all atoms that are within distance tolerance and only keep the last one, remove other sites | ||
""" | ||
coordinates = basis.coordinates.to_array_of_values_with_ids() | ||
ids = set(basis.coordinates.ids) | ||
ids_to_remove = set() | ||
|
||
for i in range(1, len(coordinates)): | ||
for j in range(i): | ||
if get_distance_between_coordinates(coordinates[i].value, coordinates[j].value) < distance_tolerance: | ||
ids_to_remove.add(coordinates[j].id) | ||
|
||
ids_to_keep = list(ids - ids_to_remove) | ||
basis.filter_atoms_by_ids(ids_to_keep) | ||
return basis | ||
|
||
|
||
def merge_two_bases(basis1: Basis, basis2: Basis, distance_tolerance: float) -> Basis: | ||
basis1 = convert_basis_to_crystal(basis1) | ||
basis2 = convert_basis_to_crystal(basis2) | ||
|
||
merged_elements = basis1.elements | ||
merged_coordinates = basis1.coordinates | ||
merged_labels = basis1.labels | ||
|
||
for coordinate in basis2.coordinates.values: | ||
merged_coordinates.add_item(coordinate) | ||
|
||
for element in basis2.elements.values: | ||
merged_elements.add_item(element) | ||
|
||
if basis2.labels: | ||
for label in basis2.labels.values: | ||
merged_labels.add_item(label) | ||
|
||
merged_basis = Basis( | ||
elements=merged_elements, | ||
coordinates=merged_coordinates, | ||
units=basis1.units, | ||
cell=basis1.cell, | ||
labels=merged_labels, | ||
) | ||
resolved_basis = resolve_close_coordinates_basis(merged_basis, distance_tolerance) | ||
|
||
return resolved_basis | ||
|
||
|
||
def merge_two_materials(material1: Material, material2: Material, distance_tolerance: float) -> Material: | ||
""" | ||
Merge two materials with the same lattice into a single material, | ||
replacing colliding atoms with the latest material's atoms. | ||
""" | ||
|
||
material1 = material1.clone() | ||
material2 = material2.clone() | ||
if material1.lattice != material2.lattice: | ||
raise ValueError("Lattices of the two materials must be the same.") | ||
merged_lattice = material1.lattice | ||
resolved_basis = merge_two_bases(material1.basis, material2.basis, distance_tolerance) | ||
|
||
name = "Merged Material" | ||
new_material = Material.create( | ||
{"name": name, "lattice": merged_lattice.to_json(), "basis": resolved_basis.to_json()} | ||
) | ||
return new_material | ||
|
||
|
||
def merge_materials(materials: List[Material], distance_tolerance: float = 0.01) -> Material: | ||
merged_material = materials[0] | ||
for material in materials[1:]: | ||
merged_material = merge_two_materials(merged_material, material, distance_tolerance) | ||
|
||
return merged_material |
Oops, something went wrong.