From af4b689691fe4513f702ac8781311dfcea9b2dca Mon Sep 17 00:00:00 2001 From: littlecabiria <71769896+littlecabiria@users.noreply.github.com> Date: Fri, 2 Aug 2024 21:42:13 +1000 Subject: [PATCH] update spatial level metric --- .../ks_statistic_spatial/config.vsh.yaml | 33 +++++ src/metrics/ks_statistic_spatial/script.py | 140 ++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 src/metrics/ks_statistic_spatial/config.vsh.yaml create mode 100644 src/metrics/ks_statistic_spatial/script.py diff --git a/src/metrics/ks_statistic_spatial/config.vsh.yaml b/src/metrics/ks_statistic_spatial/config.vsh.yaml new file mode 100644 index 00000000..996ba8d8 --- /dev/null +++ b/src/metrics/ks_statistic_spatial/config.vsh.yaml @@ -0,0 +1,33 @@ +__merge__: ../../api/comp_metric.yaml + +name: ks_statistic + +info: + metrics: + - name: ks_statistic_frac_zero_genes + label: Fraction of zeros in genes + summary: Ks Statistic of the fraction of zeroes in the genes + description: | + The Kolmogorov-Smirnov statistic comparing the fraction of zeros in the + genes of the real counts versus the fraction of zeros in the genes of + the predicted counts. + # reference: doi? + # documentation_url: https://url.to/the/documentation + # repository_url: https://github.com/organisation/repository + min: -Inf + max: +Inf + maximize: false + +resources: + - type: python_script + path: script.py + +engines: + - type: docker + image: ghcr.io/openproblems-bio/base_images/python:1.1.0 + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime,midmem,midcpu] diff --git a/src/metrics/ks_statistic_spatial/script.py b/src/metrics/ks_statistic_spatial/script.py new file mode 100644 index 00000000..004861c6 --- /dev/null +++ b/src/metrics/ks_statistic_spatial/script.py @@ -0,0 +1,140 @@ +import anndata as ad +import numpy as np +from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import LabelEncoder +import matplotlib.pyplot as plt +import seaborn as sns +import warnings +import squidpy as sq +import pandas as pd +from scipy.stats import ks_2samp + +## VIASH START +par = { + 'input_spatial_dataset': 'resources_test/datasets/MOBNEW/dataset_sp.h5ad', + 'input_singlecell_dataset': 'resources_test/datasets/MOBNEW/dataset_sc.h5ad', + 'input_simulated_dataset': 'resources_test/datasets/MOBNEW/simulated_dataset.h5ad', + 'output': 'output.h5ad' +} +meta = { + 'name': 'my_metric' +} + +## VIASH END + +print('Reading input files', flush=True) +input_spatial_dataset = ad.read_h5ad(par['input_spatial_dataset']) +input_singlecell_dataset = ad.read_h5ad(par['input_singlecell_dataset']) +input_simulated_dataset = ad.read_h5ad(par['input_simulated_dataset']) + +def get_spatial_network(num_sample=None, spatial=None, radius=None, coord_type="grid", n_rings=2, set_diag=False): + spatial_adata = ad.AnnData(np.empty((num_sample, 1), dtype="float32")) + spatial_adata.obsm["spatial"] = spatial + # sq.gr.spatial_neighbors(spatial_adata, n_rings=n_rings, coord_type=coord_type, n_neighs=n_neighs, radius=radius,set_diag =set_diag) + sq.gr.spatial_neighbors(spatial_adata, n_rings=n_rings, coord_type=coord_type, radius=radius, set_diag=set_diag, + delaunay=True) + sn = spatial_adata.obsp["spatial_connectivities"] + + return sn + + +def get_onehot_ct(init_assign=None): + label_encoder = LabelEncoder() + integer_encoded = label_encoder.fit_transform(init_assign) + onehot_encoder = OneHotEncoder(sparse_output=False) + integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) + onehot_ct = onehot_encoder.fit_transform(integer_encoded) + return onehot_ct.astype(np.float32) + + +# @numba.jit("float32[:, ::1](float32[:, ::1], float32[:, ::1])") +def get_nb_freq(nb_count=None, onehot_ct=None): + # nb_freq = onehot_ct.T @ nb_count + nb_freq = np.dot(onehot_ct.T, nb_count) + res = nb_freq / nb_freq.sum(axis=1).reshape(onehot_ct.shape[1], -1) + return res + +def get_trans(adata=None, ct=None): + sn = get_spatial_network(num_sample=adata.obs.shape[0], + spatial=adata.obsm["spatial"], coord_type="generic") + onehot_ct = get_onehot_ct(init_assign=ct) + nb_count = np.array(sn * onehot_ct, dtype=np.float32) + target_trans = get_nb_freq(nb_count=nb_count, onehot_ct=onehot_ct) + return target_trans + + +input_spatial_dataset.obsm["spatial"] = np.array(input_spatial_dataset.obs[['col', 'row']].values.tolist()) +input_spatial_dataset.obs["celltype"] = input_spatial_dataset.obs["spatial_cluster"] +input_spatial_dataset.obs["celltype"] = input_spatial_dataset.obs["celltype"].astype('category') + +sq.gr.spatial_neighbors(input_spatial_dataset, coord_type="generic", set_diag=False, delaunay=True) +# neighborhood enrichment matrix +sq.gr.nhood_enrichment(input_spatial_dataset, cluster_key="celltype") +# centrality scores matrix +sq.gr.centrality_scores(input_spatial_dataset, cluster_key="celltype") + +input_simulated_dataset.obsm["spatial"] = np.array(input_simulated_dataset.obs[['col', 'row']].values.tolist()) +input_simulated_dataset.obs["celltype"] = input_simulated_dataset.obs["spatial_cluster"] +input_simulated_dataset.obs["celltype"] = input_simulated_dataset.obs["celltype"].astype('category') + +sq.gr.spatial_neighbors(input_simulated_dataset, coord_type="generic", set_diag=False, delaunay=True) +# neighborhood enrichment matrix +sq.gr.nhood_enrichment(input_simulated_dataset, cluster_key="celltype") +# centrality scores matrix +sq.gr.centrality_scores(input_simulated_dataset, cluster_key="celltype") + +target_enrich_real = input_spatial_dataset.uns["celltype_nhood_enrichment"]["zscore"] +target_enrich_scale_real = target_enrich_real/np.max(target_enrich_real) +target_enrich_sim = input_simulated_dataset.uns["celltype_nhood_enrichment"]["zscore"] +target_enrich_scale_sim = target_enrich_sim/np.max(target_enrich_sim) + +error_enrich = np.linalg.norm(target_enrich_sim - target_enrich_real) +error_enrich_scale = np.linalg.norm(target_enrich_scale_sim - target_enrich_scale_real) + +target_enrich_real_ds = target_enrich_real.flatten() +target_enrich_sim_ds = target_enrich_sim.flatten() +ks_enrich, p_value = ks_2samp(target_enrich_real_ds, target_enrich_sim_ds) + +# KS central + +real_central_real = np.array(input_spatial_dataset.uns["celltype_centrality_scores"]) +real_central_sim = np.array(input_simulated_dataset.uns["celltype_centrality_scores"]) + +real_central_real_ds = real_central_real.flatten() +real_central_sim_ds = real_central_sim.flatten() +ks_central, p_value = ks_2samp(real_central_real_ds, real_central_sim_ds) + +# transition matrix +real = np.array(input_spatial_dataset.obs['spatial_cluster'].values.tolist()) +sim = np.array(input_simulated_dataset.obs['spatial_cluster'].values.tolist()) + +transition_matrix_real = get_trans(adata=input_spatial_dataset, ct=real) +transition_matrix_sim = get_trans(adata=input_simulated_dataset, ct=sim) + +error = np.linalg.norm(transition_matrix_sim - transition_matrix_real) +transition_matrix_real_ds = transition_matrix_real.flatten() +transition_matrix_sim_ds = transition_matrix_sim.flatten() +ks_stat_error, p_value = ks_2samp(transition_matrix_real_ds, transition_matrix_sim_ds) + +uns_metric_ids = [ + "ks_statistic_transition_matrix", + "ks_statistic_central_score", + "ks_statistic_enrichment" +] + +uns_metric_values = [ + ks_stat_error, + ks_central, + ks_enrich +] + +print("Write output AnnData to file", flush=True) +output = ad.AnnData( + uns={ + 'dataset_id': input_simulated_dataset.uns['dataset_id'], + 'method_id': input_simulated_dataset.uns['method_id'], + 'metric_ids': uns_metric_ids, + 'metric_values': uns_metric_values + } +) +output.write_h5ad(par['output'], compression='gzip') \ No newline at end of file