From c7d68978c7944eb156f47e03a6ce45e9e5e6afd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Tue, 4 Jun 2024 19:33:18 -0400 Subject: [PATCH] ENH: Add gaussian process DWI signal representation notebook Add gaussian process DWI signal representation notebook. --- docs/notebooks/dwi_gp_representation.ipynb | 186 +++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 docs/notebooks/dwi_gp_representation.ipynb diff --git a/docs/notebooks/dwi_gp_representation.ipynb b/docs/notebooks/dwi_gp_representation.ipynb new file mode 100644 index 00000000..0f91ecd8 --- /dev/null +++ b/docs/notebooks/dwi_gp_representation.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "Gaussian process notebook", + "id": "486923b289155658" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n", + "\n", + "from eddymotion import model\n", + "from eddymotion.data.dmri import DWI\n", + "from eddymotion.data.splitting import lovo_split\n", + "\n", + "datadir = Path(\"../../test\") # Adapt to your local path or download to a temp location using wget\n", + "\n", + "kernel = DotProduct() + WhiteKernel()\n", + "\n", + "dwi = DWI.from_filename(datadir / \"dwi.h5\")\n", + "\n", + "_dwi_data = dwi.dataobj\n", + "# Use a subset of the data for now to see that something is written to the\n", + "# output\n", + "# bvecs = dwi.gradients[:3, :].T\n", + "bvecs = dwi.gradients[:3, 10:13].T # b0 values have already been masked\n", + "# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\n", + "dwi_data = _dwi_data[60:63, 60:64, 40:45, 10:13]\n", + "\n", + "# ToDo\n", + "# Provide proper values/estimates for these\n", + "a = 1\n", + "h = 1 # should be a NIfTI image\n", + "\n", + "num_iterations = 5\n", + "gp = model.GaussianProcessModel(\n", + " dwi=dwi, a=a, h=h, kernel=kernel, num_iterations=num_iterations\n", + ")\n", + "indices = list(range(bvecs.shape[0]))\n", + "# ToDo\n", + "# This should be done within the GP model class\n", + "# Apply lovo strategy properly\n", + "# Vectorize and parallelize\n", + "result_mean = np.zeros_like(dwi_data)\n", + "result_stddev = np.zeros_like(dwi_data)\n", + "for idx in indices:\n", + " lovo_idx = np.ones(len(indices), dtype=bool)\n", + " lovo_idx[idx] = False\n", + " X = bvecs[lovo_idx]\n", + " for i in range(dwi_data.shape[0]):\n", + " for j in range(dwi_data.shape[1]):\n", + " for k in range(dwi_data.shape[2]):\n", + " # ToDo\n", + " # Use a mask to avoid traversing background data\n", + " y = dwi_data[i, j, k, lovo_idx]\n", + " gp.fit(X, y)\n", + " pred_mean, pred_stddev = gp.predict(\n", + " bvecs[idx, :][np.newaxis]\n", + " ) # Can take multiple values X[:2, :]\n", + " result_mean[i, j, k, idx] = pred_mean.item()\n", + " result_stddev[i, j, k, idx] = pred_stddev.item()" + ], + "id": "da2274009534db61", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the data", + "id": "77e77cd4c73409d3" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from matplotlib import pyplot as plt \n", + "%matplotlib inline\n", + "\n", + "s = dwi_data[1, 1, 2, :]\n", + "s_hat_mean = result_mean[1, 1, 2, :]\n", + "s_hat_stddev = result_stddev[1, 1, 2, :]\n", + "x = np.asarray(indices)\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.plot(x, s_hat_mean, c=\"orange\", label=\"predicted\")\n", + "plt.fill_between(\n", + " x.ravel(),\n", + " s_hat_mean - 1.96 * s_hat_stddev,\n", + " s_hat_mean + 1.96 * s_hat_stddev,\n", + " alpha=0.5,\n", + " color=\"orange\",\n", + " label=r\"95% confidence interval\",\n", + ")\n", + "plt.scatter(x, s, c=\"b\", label=\"ground truth\")\n", + "ax.set_xlabel(\"bvec indices\")\n", + "ax.set_ylabel(\"signal\")\n", + "ax.legend()\n", + "plt.title(\"Gaussian process regression on dataset\")\n", + "\n", + "plt.show()" + ], + "id": "4e51f22890fb045a", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "Plot the DWI signal for a given voxel\n", + "Compute the DWI signal value wrt the b0 (how much larger/smaller is and add that delta to the unit sphere?) for each bvec direction and plot that?" + ], + "id": "694a4c075457425d" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# from mpl_toolkits.mplot3d import Axes3D\n", + "# fig, ax = plt.subplots()\n", + "# ax = fig.add_subplot(111, projection='3d')\n", + "# plt.scatter(xx, yy, zz)" + ], + "id": "bb7d2aef53ac99f0", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the DWI signal brain data\n", + "id": "62d7bc609b65c7cf" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "# plot_dwi(dmri_dataset.dataobj, dmri_dataset.affine, gradient=data_test[1], black_bg=True)", + "id": "edb0e9d255516e38", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the predicted DWI signal", + "id": "1a52e2450fc61dc6" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "# plot_dwi(predicted, dmri_dataset.affine, gradient=data_test[1], black_bg=True);", + "id": "66150cf337b395e0", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}