-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4c71207
commit af4b689
Showing
2 changed files
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
__merge__: ../../api/comp_metric.yaml | ||
|
||
name: ks_statistic | ||
|
||
info: | ||
metrics: | ||
- name: ks_statistic_frac_zero_genes | ||
label: Fraction of zeros in genes | ||
summary: Ks Statistic of the fraction of zeroes in the genes | ||
description: | | ||
The Kolmogorov-Smirnov statistic comparing the fraction of zeros in the | ||
genes of the real counts versus the fraction of zeros in the genes of | ||
the predicted counts. | ||
# reference: doi? | ||
# documentation_url: https://url.to/the/documentation | ||
# repository_url: https://github.com/organisation/repository | ||
min: -Inf | ||
max: +Inf | ||
maximize: false | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
|
||
engines: | ||
- type: docker | ||
image: ghcr.io/openproblems-bio/base_images/python:1.1.0 | ||
|
||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime,midmem,midcpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import anndata as ad | ||
import numpy as np | ||
from sklearn.preprocessing import OneHotEncoder | ||
from sklearn.preprocessing import LabelEncoder | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
import warnings | ||
import squidpy as sq | ||
import pandas as pd | ||
from scipy.stats import ks_2samp | ||
|
||
## VIASH START | ||
par = { | ||
'input_spatial_dataset': 'resources_test/datasets/MOBNEW/dataset_sp.h5ad', | ||
'input_singlecell_dataset': 'resources_test/datasets/MOBNEW/dataset_sc.h5ad', | ||
'input_simulated_dataset': 'resources_test/datasets/MOBNEW/simulated_dataset.h5ad', | ||
'output': 'output.h5ad' | ||
} | ||
meta = { | ||
'name': 'my_metric' | ||
} | ||
|
||
## VIASH END | ||
|
||
print('Reading input files', flush=True) | ||
input_spatial_dataset = ad.read_h5ad(par['input_spatial_dataset']) | ||
input_singlecell_dataset = ad.read_h5ad(par['input_singlecell_dataset']) | ||
input_simulated_dataset = ad.read_h5ad(par['input_simulated_dataset']) | ||
|
||
def get_spatial_network(num_sample=None, spatial=None, radius=None, coord_type="grid", n_rings=2, set_diag=False): | ||
spatial_adata = ad.AnnData(np.empty((num_sample, 1), dtype="float32")) | ||
spatial_adata.obsm["spatial"] = spatial | ||
# sq.gr.spatial_neighbors(spatial_adata, n_rings=n_rings, coord_type=coord_type, n_neighs=n_neighs, radius=radius,set_diag =set_diag) | ||
sq.gr.spatial_neighbors(spatial_adata, n_rings=n_rings, coord_type=coord_type, radius=radius, set_diag=set_diag, | ||
delaunay=True) | ||
sn = spatial_adata.obsp["spatial_connectivities"] | ||
|
||
return sn | ||
|
||
|
||
def get_onehot_ct(init_assign=None): | ||
label_encoder = LabelEncoder() | ||
integer_encoded = label_encoder.fit_transform(init_assign) | ||
onehot_encoder = OneHotEncoder(sparse_output=False) | ||
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) | ||
onehot_ct = onehot_encoder.fit_transform(integer_encoded) | ||
return onehot_ct.astype(np.float32) | ||
|
||
|
||
# @numba.jit("float32[:, ::1](float32[:, ::1], float32[:, ::1])") | ||
def get_nb_freq(nb_count=None, onehot_ct=None): | ||
# nb_freq = onehot_ct.T @ nb_count | ||
nb_freq = np.dot(onehot_ct.T, nb_count) | ||
res = nb_freq / nb_freq.sum(axis=1).reshape(onehot_ct.shape[1], -1) | ||
return res | ||
|
||
def get_trans(adata=None, ct=None): | ||
sn = get_spatial_network(num_sample=adata.obs.shape[0], | ||
spatial=adata.obsm["spatial"], coord_type="generic") | ||
onehot_ct = get_onehot_ct(init_assign=ct) | ||
nb_count = np.array(sn * onehot_ct, dtype=np.float32) | ||
target_trans = get_nb_freq(nb_count=nb_count, onehot_ct=onehot_ct) | ||
return target_trans | ||
|
||
|
||
input_spatial_dataset.obsm["spatial"] = np.array(input_spatial_dataset.obs[['col', 'row']].values.tolist()) | ||
input_spatial_dataset.obs["celltype"] = input_spatial_dataset.obs["spatial_cluster"] | ||
input_spatial_dataset.obs["celltype"] = input_spatial_dataset.obs["celltype"].astype('category') | ||
|
||
sq.gr.spatial_neighbors(input_spatial_dataset, coord_type="generic", set_diag=False, delaunay=True) | ||
# neighborhood enrichment matrix | ||
sq.gr.nhood_enrichment(input_spatial_dataset, cluster_key="celltype") | ||
# centrality scores matrix | ||
sq.gr.centrality_scores(input_spatial_dataset, cluster_key="celltype") | ||
|
||
input_simulated_dataset.obsm["spatial"] = np.array(input_simulated_dataset.obs[['col', 'row']].values.tolist()) | ||
input_simulated_dataset.obs["celltype"] = input_simulated_dataset.obs["spatial_cluster"] | ||
input_simulated_dataset.obs["celltype"] = input_simulated_dataset.obs["celltype"].astype('category') | ||
|
||
sq.gr.spatial_neighbors(input_simulated_dataset, coord_type="generic", set_diag=False, delaunay=True) | ||
# neighborhood enrichment matrix | ||
sq.gr.nhood_enrichment(input_simulated_dataset, cluster_key="celltype") | ||
# centrality scores matrix | ||
sq.gr.centrality_scores(input_simulated_dataset, cluster_key="celltype") | ||
|
||
target_enrich_real = input_spatial_dataset.uns["celltype_nhood_enrichment"]["zscore"] | ||
target_enrich_scale_real = target_enrich_real/np.max(target_enrich_real) | ||
target_enrich_sim = input_simulated_dataset.uns["celltype_nhood_enrichment"]["zscore"] | ||
target_enrich_scale_sim = target_enrich_sim/np.max(target_enrich_sim) | ||
|
||
error_enrich = np.linalg.norm(target_enrich_sim - target_enrich_real) | ||
error_enrich_scale = np.linalg.norm(target_enrich_scale_sim - target_enrich_scale_real) | ||
|
||
target_enrich_real_ds = target_enrich_real.flatten() | ||
target_enrich_sim_ds = target_enrich_sim.flatten() | ||
ks_enrich, p_value = ks_2samp(target_enrich_real_ds, target_enrich_sim_ds) | ||
|
||
# KS central | ||
|
||
real_central_real = np.array(input_spatial_dataset.uns["celltype_centrality_scores"]) | ||
real_central_sim = np.array(input_simulated_dataset.uns["celltype_centrality_scores"]) | ||
|
||
real_central_real_ds = real_central_real.flatten() | ||
real_central_sim_ds = real_central_sim.flatten() | ||
ks_central, p_value = ks_2samp(real_central_real_ds, real_central_sim_ds) | ||
|
||
# transition matrix | ||
real = np.array(input_spatial_dataset.obs['spatial_cluster'].values.tolist()) | ||
sim = np.array(input_simulated_dataset.obs['spatial_cluster'].values.tolist()) | ||
|
||
transition_matrix_real = get_trans(adata=input_spatial_dataset, ct=real) | ||
transition_matrix_sim = get_trans(adata=input_simulated_dataset, ct=sim) | ||
|
||
error = np.linalg.norm(transition_matrix_sim - transition_matrix_real) | ||
transition_matrix_real_ds = transition_matrix_real.flatten() | ||
transition_matrix_sim_ds = transition_matrix_sim.flatten() | ||
ks_stat_error, p_value = ks_2samp(transition_matrix_real_ds, transition_matrix_sim_ds) | ||
|
||
uns_metric_ids = [ | ||
"ks_statistic_transition_matrix", | ||
"ks_statistic_central_score", | ||
"ks_statistic_enrichment" | ||
] | ||
|
||
uns_metric_values = [ | ||
ks_stat_error, | ||
ks_central, | ||
ks_enrich | ||
] | ||
|
||
print("Write output AnnData to file", flush=True) | ||
output = ad.AnnData( | ||
uns={ | ||
'dataset_id': input_simulated_dataset.uns['dataset_id'], | ||
'method_id': input_simulated_dataset.uns['method_id'], | ||
'metric_ids': uns_metric_ids, | ||
'metric_values': uns_metric_values | ||
} | ||
) | ||
output.write_h5ad(par['output'], compression='gzip') |