Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Exabyte-io/utils
Browse files Browse the repository at this point in the history
  • Loading branch information
timurbazhirov committed Jan 18, 2025
2 parents 1f5b71c + d85535c commit 5d61c51
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ extra = [
]
jupyterlite = [
"pyyaml",
"matplotlib>=3.4.1",
"plotly>=5.18",
"ipython>=8.0",
]
dev = [
"pre-commit",
Expand Down
214 changes: 214 additions & 0 deletions src/py/mat3ra/utils/jupyterlite/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from IPython.display import display


def scatter_plot_2d(
x_values: List[float],
y_values: List[float],
hover_texts: List[str],
settings: Dict[str, Any],
trace_names: Optional[List[str]] = None,
) -> go.Figure:
"""
Create a generic 2D scatter plot.
Args:
x_values: List of x-coordinates
y_values: List of y-coordinates
hover_texts: List of hover texts for each point
settings: Plot settings including scales, height, and titles
trace_names: Optional list of names for each trace
"""
data = []
for i in range(len(x_values)):
trace = go.Scatter(
x=[x_values[i]],
y=[y_values[i]],
text=[hover_texts[i]],
mode="markers",
hoverinfo="text",
name=trace_names[i] if trace_names else f"Point {i}",
)
data.append(trace)

layout = go.Layout(
xaxis=dict(title=settings.get("x_title", "X"), type=settings.get("x_scale", "linear")),
yaxis=dict(title=settings.get("y_title", "Y"), type=settings.get("y_scale", "linear")),
hovermode="closest",
height=settings.get("height", 600),
title=settings.get("title", ""),
legend_title_text=settings.get("legend_title", ""),
)

return go.Figure(data=data, layout=layout)


def create_realtime_plot(
title: str = "Real-time Progress", x_label: str = "Step", y_label: str = "Value"
) -> go.FigureWidget:
"""
Create a real-time updating plot.
"""
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter"}]])
scatter = go.Scatter(x=[], y=[], mode="lines+markers", name="Progress")
fig.add_trace(scatter)
fig.update_layout(title_text=title, xaxis_title=x_label, yaxis_title=y_label)
widget = go.FigureWidget(fig)
display(widget) # Automatically display the widget
return widget


def create_update_callback(
dynamic_object: Any,
value_getter: Union[Callable, Any],
figure: go.FigureWidget,
steps: List[int],
values: List[float],
step_attr: str = "nsteps",
print_format: str = "Step: {}, Value: {:.4f}",
) -> Callable:
"""
Create a general update callback for real-time plotting.
Args:
dynamic_object: Object containing step information
value_getter: Either a callable function or an object with a getter method
figure: Plotly figure widget to update
steps: List to store step values
values: List to store measured values
step_attr: Attribute name for step count in dynamic_object
print_format: Format string for progress printing
"""

def update():
step = getattr(dynamic_object, step_attr)
# Handle both callable and object with getter method
value = value_getter() if callable(value_getter) else value_getter.get_total_energy()

steps.append(step)
values.append(value)

print(print_format.format(step, value))
with figure.batch_update():
figure.data[0].x = steps
figure.data[0].y = values

return update


def plot_distribution_function(
bin_centers: np.ndarray,
distribution: np.ndarray,
xlabel: str = "Distance",
ylabel: str = "g(r)",
title: str = "Distribution Function",
figsize: Tuple[int, int] = (8, 5),
) -> None:
"""
Plot a generic distribution function.
Args:
bin_centers: The bin centers.
distribution: The distribution values.
xlabel: The x-axis label.
ylabel: The y-axis label.
title: The title of the plot.
figsize: The size of the figure.
"""
plt.figure(figsize=figsize)
plt.plot(bin_centers, distribution, label=title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend()
plt.grid()
plt.show()


def plot_3d_surface(
x_matrix: np.ndarray,
y_matrix: np.ndarray,
z_matrix: np.ndarray,
optimal_point: Optional[Tuple[float, float]] = None,
title: str = "Surface Plot",
labels: Optional[Dict[str, str]] = None,
) -> None:
"""
Create a 3D surface plot with optional optimal point.
Args:
x_matrix: The x-axis matrix.
y_matrix: The y-axis matrix.
z_matrix: The z-axis matrix.
optimal_point: The optimal point to highlight.
title: The title of the plot.
labels: The labels for the axes.
"""
if labels is None:
labels = {"x": "X", "y": "Y", "z": "Z"}

fig = go.Figure(data=[go.Surface(x=x_matrix, y=y_matrix, z=z_matrix, colorscale="Viridis")])

if optimal_point is not None:
x_opt, y_opt = optimal_point
z_opt = np.min(z_matrix)
fig.add_trace(
go.Scatter3d(
x=[x_opt], y=[y_opt], z=[z_opt], mode="markers", marker=dict(size=8, color="red"), name="Optimal Point"
)
)

fig.update_layout(
title=title,
scene=dict(xaxis_title=labels["x"], yaxis_title=labels["y"], zaxis_title=labels["z"]),
width=800,
height=800,
)
fig.show()


def plot_2d_heatmap(
x_values: np.ndarray,
y_values: np.ndarray,
z_matrix: np.ndarray,
optimal_point: Optional[Tuple[float, float]] = None,
title: str = "Heatmap",
labels: Optional[Dict[str, str]] = None,
) -> None:
"""
Create a 2D heatmap with optional optimal point.
Args:
x_values: The x-axis values.
y_values: The y-axis values.
z_matrix: The z-axis matrix.
optimal_point: The optimal point to highlight.
title: The title of the plot.
labels: The labels for the axes.
"""
if labels is None:
labels = {"x": "X", "y": "Y", "z": "Z"}

fig = go.Figure(
data=go.Heatmap(x=x_values, y=y_values, z=z_matrix, colorscale="Viridis", colorbar=dict(title=labels["z"]))
)

if optimal_point is not None:
x_opt, y_opt = optimal_point
fig.add_trace(
go.Scatter(
x=[x_opt],
y=[y_opt],
mode="markers",
marker=dict(size=12, color="red", symbol="x"),
name="Optimal Point",
)
)

fig.update_layout(title=title, xaxis_title=labels["x"], yaxis_title=labels["y"], width=800, height=600)
fig.show()

0 comments on commit 5d61c51

Please sign in to comment.