Skip to content

Commit

Permalink
Merge pull request #61 from savitakartik/nodes_tab
Browse files Browse the repository at this point in the history
First pass at Nodes tab
  • Loading branch information
jeromekelleher authored Aug 21, 2023
2 parents 953c8e6 + e289257 commit 0eb5a3d
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 316 deletions.
13 changes: 3 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
import sys
import logging
import pathlib
import sys

import numpy as np
import panel as pn
import hvplot.pandas
import holoviews as hv
import pandas as pd

import tskit
import utils

import pathlib
import functools
import model
import pages

Expand All @@ -22,7 +15,6 @@
path = pathlib.Path(sys.argv[1])
tsm = model.TSModel(tskit.load(path), path.name)


pn.extension(sizing_mode="stretch_width")
pn.extension("tabulator")

Expand All @@ -32,6 +24,7 @@
"Edges": pages.edges,
"Edge Explorer": pages.edge_explorer,
"Trees": pages.trees,
"Nodes": pages.nodes,
}


Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Global plot settings
PLOT_WIDTH = 1000
PLOT_HEIGHT = 600
THRESHOLD = 1000 # max number of points to overlay on a plot
THRESHOLD = 1000 # max number of points to overlay on a plot
47 changes: 25 additions & 22 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from functools import cached_property

import tskit
import numpy as np
import numba
import numpy as np
import pandas as pd
import numba
import tskit

spec = [
("num_edges", numba.int64),
Expand Down Expand Up @@ -44,7 +43,7 @@ def __init__(
self.in_range = np.zeros(2, dtype=np.int64)
self.out_range = np.zeros(2, dtype=np.int64)

def next(self):
def next(self): # noqa
left = self.interval[1]
j = self.in_range[1]
k = self.out_range[1]
Expand Down Expand Up @@ -221,7 +220,7 @@ def mutations_df(self):
unknown = tskit.is_unknown_time(mutations_time)
mutations_time[unknown] = self.ts.nodes_time[mutations_node[unknown]]

node_flag = ts.nodes_flags[mutations_node]
# node_flag = ts.nodes_flags[mutations_node]
position = ts.sites_position[ts.mutations_site]

tables = self.ts.tables
Expand Down Expand Up @@ -341,18 +340,22 @@ def nodes_df(self):
child_left, child_right = self.child_bounds(
ts.num_nodes, ts.edges_left, ts.edges_right, ts.edges_child
)
is_sample = np.zeros(ts.num_nodes)
is_sample[ts.samples()] = 1
df = pd.DataFrame(
{
"time": ts.nodes_time,
"num_mutations": self.nodes_num_mutations,
"ancestors_span": child_right - child_left,
"is_sample": is_sample,
}
)
return df.astype(
{
"time": "float64",
"num_mutations": "int",
"ancestors_span": "float64",
"is_sample": "bool",
}
)

Expand Down Expand Up @@ -437,6 +440,23 @@ def make_sliding_windows(self, iterable, size, overlap=0):
end += step
yield iterable[start:]

def calc_mean_node_arity(self):
span_sums = np.bincount(
self.ts.edges_parent,
weights=self.ts.edges_right - self.ts.edges_left,
minlength=self.ts.num_nodes,
)
node_spans = self.ts.sample_count_stat(
[self.ts.samples()],
lambda x: (x > 0),
1,
polarised=True,
span_normalise=False,
strict=False,
mode="node",
)[:, 0]
return span_sums / node_spans

def calc_site_tree_index(self):
return (
np.searchsorted(
Expand All @@ -459,20 +479,3 @@ def calc_mutations_per_tree(self):
mutations_per_tree = np.zeros(self.ts.num_trees, dtype=np.int64)
mutations_per_tree[unique_values] = counts
return mutations_per_tree

def calc_mean_node_arity(self):
span_sums = np.bincount(
self.ts.edges_parent,
weights=self.ts.edges_right - self.ts.edges_left,
minlength=self.ts.num_nodes,
)
node_spans = self.ts.sample_count_stat(
[self.ts.samples()],
lambda x: (x > 0),
1,
polarised=True,
span_normalise=False,
strict=False,
mode="node",
)[:, 0]
return span_sums / node_spans
6 changes: 3 additions & 3 deletions pages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import importlib
import os

# List all files in the current directory
for module_file in os.listdir(os.path.dirname(__file__)):
# Check if it's a python file and not this __init__ file
if module_file.endswith('.py') and module_file != '__init__.py':
if module_file.endswith(".py") and module_file != "__init__.py":
module_name = module_file[:-3] # remove the .py extension
module = importlib.import_module('.' + module_name, package=__name__)
module = importlib.import_module("." + module_name, package=__name__)

# Add the page function to the current module's namespace
globals()[module_name] = module.page
6 changes: 4 additions & 2 deletions pages/edge_explorer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import panel as pn
import holoviews as hv
import bokeh.models as bkm
import holoviews as hv
import panel as pn

import config


def page(tsm):
hv.extension("bokeh")
edges_df = tsm.edges_df
Expand Down
7 changes: 5 additions & 2 deletions pages/edges.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import bokeh.models as bkm
import holoviews as hv
import holoviews.operation.datashader as hd
import panel as pn

import config
import bokeh.models as bkm
from plot_helpers import filter_points, hover_points
from plot_helpers import filter_points
from plot_helpers import hover_points


def page(tsm):
hv.extension("bokeh")
Expand Down
25 changes: 18 additions & 7 deletions pages/mutations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import panel as pn
import holoviews as hv
import config
import holoviews.operation.datashader as hd
import hvplot.pandas # noqa
import numpy as np
from plot_helpers import filter_points, hover_points
import panel as pn

import config
from plot_helpers import filter_points
from plot_helpers import hover_points


def make_hist_on_axis(dimension, points, num_bins=30):
### Make histogram function for a specified axis of a scatter plot
"""
Make histogram function for a specified axis of a scatter plot
"""

def compute_hist(x_range, y_range):
filtered_points = filter_points(points, x_range, y_range)
hist = hv.operation.histogram(
Expand All @@ -17,9 +24,10 @@ def compute_hist(x_range, y_range):
return compute_hist



def make_hist(data, title, bins_range, log_y=True, plot_width=800):
### Make histogram from given count data
"""
Make histogram from given count data
"""
count, bins = np.histogram(data, bins=bins_range)
ylabel = "log(Count)" if log_y else "Count"
np.seterr(divide="ignore")
Expand All @@ -34,7 +42,9 @@ def make_hist(data, title, bins_range, log_y=True, plot_width=800):


def make_hist_panel(tsm, log_y):
### Make row of histograms for holoviews panel
"""
Make row of histograms for holoviews panel
"""
overall_site_hist = make_hist(
tsm.sites_num_mutations,
"Mutations per site",
Expand All @@ -51,6 +61,7 @@ def make_hist_panel(tsm, log_y):
)
return pn.Row(overall_site_hist, overall_node_hist)


def page(tsm):
hv.extension("bokeh")
plot_width = 1000
Expand Down
63 changes: 63 additions & 0 deletions pages/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import functools

import holoviews as hv
import holoviews.operation.datashader as hd
import hvplot.pandas # noqa
import numpy as np
import panel as pn

import config
from plot_helpers import filter_points
from plot_helpers import hover_points
from plot_helpers import make_hist_matplotlib


def page(tsm):
hv.extension("matplotlib")
df_nodes = tsm.nodes_df
df_internal_nodes = df_nodes[
(df_nodes.is_sample == 0) & (df_nodes.ancestors_span != -np.inf)
]
bins = min(50, int(np.sqrt(len(df_internal_nodes))))

ancestor_spans_hist_func = functools.partial(
make_hist_matplotlib,
df_internal_nodes.ancestors_span,
"Ancestor spans per node",
num_bins=bins,
log_y=True,
)

log_y_checkbox = pn.widgets.Checkbox(name="log y-axis of histogram", value=True)

ancestor_spans_hist_panel = pn.bind(
ancestor_spans_hist_func,
log_y=log_y_checkbox,
)

hist_panel = pn.Column(
ancestor_spans_hist_panel,
)

hv.extension("bokeh")
points = df_nodes.hvplot.scatter(
x="ancestors_span",
y="time",
hover_cols=["ancestors_span", "time"],
).opts(width=config.PLOT_WIDTH, height=config.PLOT_HEIGHT)

range_stream = hv.streams.RangeXY(source=points)
streams = [range_stream]
filtered = points.apply(filter_points, streams=streams)
hover = filtered.apply(hover_points, threshold=config.THRESHOLD)
shaded = hd.datashade(filtered, width=400, height=400, streams=streams)

main = (shaded * hover).opts(
hv.opts.Points(tools=["hover"], alpha=0.1, hover_alpha=0.2, size=10)
)

plot_options = pn.Column(
pn.pane.Markdown("# Plot Options"),
log_y_checkbox,
)
return pn.Column(main, hist_panel, plot_options)
3 changes: 2 additions & 1 deletion pages/overview.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import panel as pn


def page(tsm):
return pn.pane.HTML(tsm.ts)
return pn.pane.HTML(tsm.ts)
Loading

0 comments on commit 0eb5a3d

Please sign in to comment.