Skip to content

Commit

Permalink
Add speciation plot
Browse files Browse the repository at this point in the history
  • Loading branch information
orionarcher committed May 8, 2024
1 parent 74b0395 commit aa0ce51
Showing 1 changed file with 240 additions and 0 deletions.
240 changes: 240 additions & 0 deletions solvation_analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,246 @@ def plot_co_occurrence(
return fig


def plot_speciation_bar(speciation: Union[Speciation, Solute]) -> go.Figure:
if isinstance(speciation, Solute):
if not hasattr(speciation, "speciation"):
raise ValueError("Solute speciation analysis class must be instantiated.")
speciation = speciation.speciation
# Create a stacked bar chart
df = speciation.speciation_fraction
df = df.drop("fraction", axis=1)

fig = go.Figure(
layout=dict(
barcornerradius=15,
)
)
# Add a bar for each solvent column
for solvent in df.columns:
fig.add_trace(go.Bar(x=df.index, y=df[solvent], name=solvent))

# Update layout for the stacked bar chart
fig.update_layout(
barmode="stack",
title="Stacked Bar Chart of Solvents per Solute_ix",
xaxis_title="Solute Index",
yaxis_title="Count of Solvents",
)

# Show the figure
fig.show()


def plot_speciation(speciation: Union[Speciation, Solute]) -> go.Figure:
if isinstance(speciation, Solute):
if not hasattr(speciation, "speciation"):
raise ValueError("Solute speciation analysis class must be instantiated.")
speciation = speciation.speciation

# Assuming this pulls the relevant DataFrame
df = speciation.speciation_fraction.head(5)
fraction_data = df["fraction"] # Extract the fraction column
df = df.drop("fraction", axis=1)

# Get unique solvents and assign colors
solvents = df.columns.tolist() # List of solvents
colors = px.colors.qualitative.Plotly # Get a list of Plotly's qualitative colors

# If there are more solvents than colors, cycle through the colors again
if len(solvents) > len(colors):
colors = colors * (
len(solvents) // len(colors) + 1
) # Repeat color list as needed

color_map = dict(zip(solvents, colors)) # Create a color map for solvents

# Prepare data for the plot
x_vals = []
y_vals = []
solvent_names = []
marker_colors = [] # To store color for each marker

# Process each row to create stacks of points
for index, row in df.iterrows():
total_count = 0
for solvent, count in row.items():
for i in range(count):
x_vals.append(index)
y_vals.append(0.5 + i) # Place each solvent count at different y-levels
solvent_names.append(solvent)
marker_colors.append(
color_map[solvent]
) # Use the dynamically assigned color
total_count += 1

# Create the scatter plot
trace1 = go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(size=10, color=marker_colors), # Apply colors to markers
text=solvent_names,
hoverinfo="text",
name="Solvents",
legendgroup="solvents",
showlegend=False,
)

trace2 = go.Scatter(
x=df.index,
y=fraction_data,
mode="lines+markers",
name="Fraction",
yaxis="y2",
line=dict(color="black"),
)

# Create the figure with two traces
fig = go.Figure(data=[trace1, trace2])

# Add traces for each solvent to create a legend
for solvent, color in color_map.items():
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(size=10, color=color),
name=solvent,
legendgroup="solvents",
showlegend=True,
)
)

# Add squares with rounded corners on top of the points using the shapes API
for x, y, color in zip(x_vals, y_vals, marker_colors):
fig.add_shape(
type="rect",
xref="x",
yref="y",
x0=x - 0.15,
y0=y - 0.4,
x1=x + 0.15,
y1=y + 0.4,
line=dict(color=color, width=2),
fillcolor=color,
layer="between",
)

# Update layout
fig.update_layout(
title="Categorical Scatter Plot of Solvents per Solute_ix",
xaxis_title="Solute Index",
yaxis=dict(
title="Stacked Points",
tickmode="array",
tickvals=list(range(1, int(max(y_vals)) + 1)),
range=[0, max(y_vals) + 1], # Scale the top of the y-axis
),
xaxis=dict(tickmode="linear", tick0=0, dtick=1), # Set x-axis ticks to integers
template="plotly_white",
margin=dict(l=20, r=20, t=60, b=20), # Add padding to the edges of the plot
yaxis2=dict(
title="Fraction",
overlaying="y",
side="right",
range=[0, max(fraction_data) * 1.1], # Scale the fraction axis
),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
), # Add legend at the top
)

# Show the figure
return fig # Return the figure instead of showing it, to allow for more flexibilityurn the figure instead of showing it, to allow for more flexibility

# def plot_rdfs(solute, x_axis, clasp=None):


# rdf_data = solute.rdf_data
# df = pd.DataFrame()
# dataframes = {}
# for atom_solute_name in rdf_data:
# for solvent in rdf_data[atom_solute_name]:
# temp = pd.DataFrame(data=rdf_data[atom_solute_name][solvent]).transpose()
# temp.columns = ["bins", "rdf"]
# temp["solvent"] = solvent
# temp["atom solute"] = atom_solute_name
# dataframes[(atom_solute_name, solvent)] = temp[["bins", "rdf"]]
# df = pd.concat([df, temp])
#
# atom_solutes, solvents = zip(*dataframes.keys());
#
# atom_solutes = set(atom_solutes)
# solvents = set(solvents)
#
# # TODO: check subset of df
# fig = go.Figure()
# # 2x3 grid --> 2x1 grid, each graph has three traces
# if clasp == "x" and x_axis == "atom solute":
# fig = make_subplots(rows=len(atom_solutes), cols=1, x_title="Atom Solute")
# r = 1
# for atom_solute in atom_solutes:
# temp = df.loc[df["atom solute"] == atom_solute]
# fig.add_trace(go.Scatter(x=temp["bins"], y=temp["rdf"]), row=r, col=1)
# fig.update_xaxes(title_text=atom_solute, row=r, col=1)
# r += 1
# # 3x2 grid --> 3x1 grid, each graph has two traces
# elif clasp == "x" and x_axis == "solvent":
# fig = make_subplots(rows=1, cols=len(atom_solutes), x_title="Solvent")
# c = 1
# for solvent in solvents:
# temp = df.loc[df["solvent"] == solvent]
# fig.add_trace(go.Scatter(x=temp["bins"], y=temp["rdf"]), row=1, col=c)
# fig.update_xaxes(title_text=solvent, row=1, col=c)
# c += 1
# # 2x3 grid --> 1x3 grid, each graph has two traces
# elif clasp == "y" and x_axis == "atom solute":
# fig = make_subplots(rows=len(solvents), cols=len(atom_solutes), x_title="Atom Solute")
# c = 1
# for atom_solute in atom_solutes:
# temp = df.loc[df["atom solute"] == atom_solute]
# fig.add_trace(go.Scatter(x=temp["bins"], y=temp["rdf"]), row=1, col=c)
# fig.update_xaxes(title_text=atom_solute, row=1, col=c)
# c += 1
# # 3x2 grid --> 1x2 grid, each graph has three traces
# elif clasp == "y" and x_axis == "solvent":
# print("need to work on this")
# elif x_axis == "atom solute":
# fig = make_subplots(rows=len(solvents), cols=len(atom_solutes), shared_xaxes=True, shared_yaxes=True, x_title="Atom Solute", y_title="Solvent")
#
# r = 1
# for solvent in solvents:
# c = 1
# for atom_solute in atom_solutes:
# data = dataframes[(atom_solute, solvent)]
# fig.add_trace(go.Scatter(x=data["bins"], y=data["rdf"]), row=r, col=c)
# fig.update_xaxes(title_text=atom_solute, row=r, col=c)
# fig.update_yaxes(title_text=solvent, row=r, col=c)
# c += 1
# r += 1
# elif x_axis == "solvent":
# fig = make_subplots(rows=len(atom_solutes), cols=len(solvents), shared_xaxes=True, shared_yaxes=True, x_title="Solvent", y_title="Atom Solute")
#
# r = 1
# for atom_solute in atom_solutes:
# c = 1
# for solvent in solvents:
# data = dataframes[(atom_solute, solvent)]
# fig.add_trace(go.Scatter(x=data["bins"], y=data["rdf"]), row=r, col=c)
# fig.update_xaxes(title_text=solvent, row=r, col=c)
# fig.update_yaxes(title_text=atom_solute, row=r, col=c)
# c += 1
# r += 1
#
# fig.update_layout(showlegend=False, template="simple_white")
# return fig


def compare_solvent_dicts(
property_dict: dict[str, dict[str, float]],
rename_solvent_dict: dict[str, str],
Expand Down

0 comments on commit aa0ce51

Please sign in to comment.