Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup dataframes #1360

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
78a967a
Update internals of AnalysisResultTable
nkanazawa1989 Jan 10, 2024
7a83179
Update internals of ScatterTable
nkanazawa1989 Jan 18, 2024
3bc9525
Removed unused mixin
nkanazawa1989 Jan 18, 2024
68013a8
Fix index mismatch issue after JSON serialization
nkanazawa1989 Jan 31, 2024
5032e86
Add more tests
nkanazawa1989 Jan 31, 2024
7736f19
Bug fixes
nkanazawa1989 Jan 31, 2024
8bbaa15
Merge branch 'main' of github.com:Qiskit/qiskit-experiments into clea…
nkanazawa1989 Jan 31, 2024
fc9273e
Unpin pandas 2.2
nkanazawa1989 Jan 31, 2024
0cae116
Update old pattern
nkanazawa1989 Jan 31, 2024
2fb28dc
Fix cross-reference
nkanazawa1989 Feb 1, 2024
ac972fd
Update curve analysis tutorial
nkanazawa1989 Feb 2, 2024
01471bb
Add shortcut methods
nkanazawa1989 Feb 2, 2024
8dc6c4f
Bugfix autosave
nkanazawa1989 Feb 2, 2024
144127a
Raise user warning when numbers contain multiple series
nkanazawa1989 Feb 2, 2024
a81f97c
Merge branch 'main' into cleanup/more_composition
nkanazawa1989 Feb 2, 2024
7c0662c
Bugfix: Missing circuit metadata in composite analysis
nkanazawa1989 Feb 2, 2024
92cfc92
Replace class_id with data_uid
nkanazawa1989 Feb 5, 2024
346d23a
Add documentation for filtering triplet
nkanazawa1989 Feb 5, 2024
ee03161
Apply review comments
nkanazawa1989 Feb 5, 2024
ee5b34d
Wording suggestions
nkanazawa1989 Feb 6, 2024
38abdff
Remove DEFAULT_
nkanazawa1989 Feb 6, 2024
9e27f16
Reorganize the doc
nkanazawa1989 Feb 6, 2024
b870be3
Remove _data
nkanazawa1989 Feb 6, 2024
cc905c6
Remove key from add_data
nkanazawa1989 Feb 6, 2024
0dc4eb2
Remove type cast depending on the entry number
nkanazawa1989 Feb 6, 2024
f8c1efe
Minor docs formatting
nkanazawa1989 Feb 6, 2024
ee92f1d
Add more tests for result table
nkanazawa1989 Feb 6, 2024
03aac67
Performance optimization
nkanazawa1989 Feb 6, 2024
ac5bdd8
name, data_uid -> series_name, series_id
nkanazawa1989 Feb 6, 2024
58671eb
Add more tests for construction
nkanazawa1989 Feb 6, 2024
7ff2c6a
Update Ramsey analysis
nkanazawa1989 Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qiskit_experiments/curve_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
.. autosummary::
:toctree: ../stubs/

ScatterTable
SeriesDef
CurveData
CurveFitResult
Expand Down
72 changes: 38 additions & 34 deletions qiskit_experiments/curve_analysis/composite_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,34 +230,35 @@ def _create_figures(
A list of figures.
"""
for analysis in self.analyses():
sub_data = curve_data[curve_data.group == analysis.name]
for name, data in list(sub_data.groupby("name")):
full_name = f"{name}_{analysis.name}"
group_data = curve_data.filter(analysis=analysis.name)
model_names = analysis.model_names()
for uid, sub_data in group_data.iter_by_class():
full_name = f"{model_names[uid]}_{analysis.name}"
# Plot raw data scatters
if analysis.options.plot_raw_data:
raw_data = data[data.category == "raw"]
raw_data = sub_data.filter(category="raw")
self.plotter.set_series_data(
series_name=full_name,
x=raw_data.xval.to_numpy(),
y=raw_data.yval.to_numpy(),
x=raw_data.x,
y=raw_data.y,
)
# Plot formatted data scatters
formatted_data = data[data.category == analysis.options.fit_category]
formatted_data = sub_data.filter(category=analysis.options.fit_category)
self.plotter.set_series_data(
series_name=full_name,
x_formatted=formatted_data.xval.to_numpy(),
y_formatted=formatted_data.yval.to_numpy(),
y_formatted_err=formatted_data.yerr.to_numpy(),
x_formatted=formatted_data.x,
y_formatted=formatted_data.y,
y_formatted_err=formatted_data.y_err,
)
# Plot fit lines
line_data = data[data.category == "fitted"]
line_data = sub_data.filter(category="fitted")
if len(line_data) == 0:
continue
fit_stdev = line_data.yerr.to_numpy()
fit_stdev = line_data.y_err
self.plotter.set_series_data(
series_name=full_name,
x_interp=line_data.xval.to_numpy(),
y_interp=line_data.yval.to_numpy(),
x_interp=line_data.x,
y_interp=line_data.y,
y_interp_err=fit_stdev if np.isfinite(fit_stdev).all() else None,
)

Expand Down Expand Up @@ -354,7 +355,7 @@ def _run_analysis(
metadata["group"] = analysis.name

table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
formatted_subset = table[table.category == analysis.options.fit_category]
formatted_subset = table.filter(category=analysis.options.fit_category)
fit_data = analysis._run_curve_fit(formatted_subset)
fit_dataset[analysis.name] = fit_data

Expand All @@ -376,32 +377,35 @@ def _run_analysis(

if fit_data.success:
# Add fit data to curve data table
fit_curves = []
columns = list(table.columns)
model_names = analysis.model_names()
for i, sub_data in list(formatted_subset.groupby("class_id")):
xval = sub_data.xval.to_numpy()
for i, sub_data in formatted_subset.iter_by_class():
xval = sub_data.x
if len(xval) == 0:
# If data is empty, skip drawing this model.
# This is the case when fit model exist but no data to fit is provided.
continue
# Compute X, Y values with fit parameters.
xval_fit = np.linspace(np.min(xval), np.max(xval), num=100)
yval_fit = eval_with_uncertainties(
x=xval_fit,
xval_arr_fit = np.linspace(np.min(xval), np.max(xval), num=100, dtype=float)
uval_arr_fit = eval_with_uncertainties(
x=xval_arr_fit,
model=analysis.models[i],
params=fit_data.ufloat_params,
)
model_fit = np.full((100, len(columns)), np.nan, dtype=object)
fit_curves.append(model_fit)
model_fit[:, columns.index("xval")] = xval_fit
model_fit[:, columns.index("yval")] = unp.nominal_values(yval_fit)
yval_arr_fit = unp.nominal_values(uval_arr_fit)
if fit_data.covar is not None:
model_fit[:, columns.index("yerr")] = unp.std_devs(yval_fit)
model_fit[:, columns.index("name")] = model_names[i]
model_fit[:, columns.index("class_id")] = i
model_fit[:, columns.index("category")] = "fitted"
table = table.append_list_values(other=np.vstack(fit_curves))
yerr_arr_fit = unp.std_devs(uval_arr_fit)
else:
yerr_arr_fit = np.zeros_like(xval_arr_fit)
for xval, yval, yerr in zip(xval_arr_fit, yval_arr_fit, yerr_arr_fit):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still surprises me that it is better to iterate over numpy arrays point by point and add them to them to lists to add to a new dataframe rather than just adding the numpy arrays to a new dataframe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling of the empty column is expensive because it requires careful handling of missing values. Without doing this shots column may be accidentally typecasted to float because numpy doesn't support nullable integer. This means we first need to create a 2D object-dtype ndarray and populate values, then convert it into dataframe. Since current _lazy_add_rows buffer assumes row-wise data list, arrays needs to be converted into this form internally.

table.add_row(
name=model_names[i],
class_id=i,
category="fitted",
x=xval,
y=yval,
y_err=yerr,
analysis=analysis.name,
)
analysis_results.extend(
analysis._create_analysis_results(
fit_data=fit_data,
Expand All @@ -416,11 +420,11 @@ def _run_analysis(
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
)

# Add extra column to identify the fit model
table["group"] = analysis.name
curve_data_set.append(table)

combined_curve_data = pd.concat(curve_data_set)
combined_curve_data = ScatterTable.from_dataframe(
pd.concat([d.dataframe for d in curve_data_set])
)
total_quality = self._evaluate_quality(fit_dataset)

# After the quality is determined, plot can become a boolean flag for whether
Expand Down
Loading
Loading