Skip to content

Commit

Permalink
Merge pull request #12 from nuniz/8-cli-script-for-running-the-network
Browse files Browse the repository at this point in the history
8 cli script for running the network
  • Loading branch information
nuniz authored Jul 15, 2024
2 parents 80e6d44 + a18ed3c commit d8c5f1a
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 63 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@
/dist/*
/.idea/*
/local_history.patch
/notebooks/.ipynb_checkpoints/*
/notebooks/.ipynb_checkpoints/*

*.txt
PKG-INFO
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# CD-Network

CD-Network is a Python library designed for the analytical derivation of the stochastic output of coincidence detector (
CD) cells.
CD-Network is a Python library designed for the analytical derivation of the stochastic output of coincidence detection
(CD) cells.
These cells receive inputs modeled as non-homogeneous Poisson processes (NHPP) with both excitatory and inhibitory
components.

Expand Down Expand Up @@ -57,7 +57,7 @@ The CD network simulation uses a JSON configuration file. Below is a breakdown o
target: Identifier for the cell receiving the input.
input_type: Specifies whether the input is excitatory or inhibitory.

[Example Configuration File](example_notebooks/config.yaml)
[Example Configuration File](example_notebooks/config.json)

### CD Cells

Expand Down
163 changes: 108 additions & 55 deletions cd_network/network.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import json
from collections import defaultdict, deque
from typing import Any, Dict, Union

import numpy as np

from .cells import cd, ee, ei, simple_ee


class Neuron:
def __init__(self, cell_type, cell_id, params, fs):
def __init__(self, cell_type: str, cell_id: str, params: Dict[str, Any], fs: float):
self.cell_type = cell_type
self.cell_id = cell_id
self.params = params
self.fs = fs

def compute_output(self, inputs):
excitatory_inputs = inputs.get("excitatory", None)
inhibitory_inputs = inputs.get("inhibitory", None)
def __call__(self, inputs: Dict[str, np.ndarray], *args, **kwargs) -> np.ndarray:
excitatory_inputs = inputs.get("excitatory")
inhibitory_inputs = inputs.get("inhibitory")

if self.cell_type == "ei":
return ei(
Expand Down Expand Up @@ -42,14 +44,30 @@ def compute_output(self, inputs):


class CDNetwork:
def __init__(self, config_path):
def __init__(self, config: Union[Dict[str, Any], str]):
"""Initialize the network with a configuration dictionary or a path to a configuration JSON file."""
self.cells = {}
self.connections = []
self.load_config(config_path)
self.load_config(config)

def load_config(self, config: Union[Dict[str, Any], str]):
"""Load and parse the configuration from a dictionary or a JSON file."""
if isinstance(config, str):
try:
with open(config, "r") as file:
config = json.load(file)
except FileNotFoundError:
raise FileNotFoundError(
f"The configuration file {config} was not found."
)
except json.JSONDecodeError:
raise ValueError("Invalid JSON format in the configuration file.")

if not isinstance(config, dict):
raise ValueError(
"Configuration must be a dictionary or a path to a JSON file."
)

def load_config(self, config_path):
with open(config_path, "r") as f:
config = json.load(f)
fs = config["fs"]
for cell_config in config["cells"]:
cell = Neuron(
Expand All @@ -61,15 +79,45 @@ def load_config(self, config_path):
self.cells[cell_config["id"]] = cell
self.connections = config["connections"]

def __call__(self, external_inputs, *args, **kwargs):
def plot_network_connections(self):
"""
Plot the network connections using NetworkX and Matplotlib.
"""
import matplotlib.pyplot as plt
import networkx as nx

G = nx.DiGraph()
for cell_id in self.cells:
cell_type = self.cells[cell_id].cell_type
color_map = {'excitatory': 'green', 'inhibitory': 'red'}
node_color = color_map.get(cell_type, 'skyblue')
G.add_node(cell_id, label=f"({cell_type})", color=node_color)
for conn in self.connections:
G.add_edge(conn["source"], conn["target"], label=conn["input_type"])

pos = nx.spring_layout(G) # Consider changing the layout for large networks
colors = ([G.nodes[node].get('color', 'white') for node in G.nodes()])
nx.draw_networkx_nodes(G, pos, node_size=5000, node_color=colors, alpha=1) # edgecolors='black'

node_labels = {node: f"{node} \n {G.nodes[node].get('label', '')}" for node in G.nodes()}
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=12)

nx.draw_networkx_edges(G, pos, arrowstyle="-|>", arrowsize=20, edge_color="gray")
edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red")

plt.title("CDNetwork Connections")
plt.axis("off")
plt.show()

def __call__(self, external_inputs: Dict[str, np.ndarray], *args, **kwargs):
cell_outputs = {}
# Initialize storage for each cell's inputs
cell_inputs = {
cell_id: {"excitatory": [], "inhibitory": []}
for cell_id in self.cells.keys()
}

# Populate initial external inputs
# Process external inputs
ext_data_shape = None
for ext_key, ext_data in external_inputs.items():
if ext_data_shape is None:
Expand All @@ -84,49 +132,54 @@ def __call__(self, external_inputs, *args, **kwargs):
if conn["source"] == ext_key:
cell_inputs[conn["target"]][conn["input_type"]].append(ext_data)

# Process each cell once all inputs are ready
cells_to_process = list(self.cells.keys())
while cells_to_process:
processed_cells = []
for cell_id in cells_to_process:
# Check if all inputs are available
inputs_ready = True
for conn in self.connections:
if conn["target"] == cell_id and not conn["source"].startswith(
"external"
):
if cell_outputs.get(conn["source"]) is None:
inputs_ready = False
break
if inputs_ready:
# Gather inputs from sources
for conn in self.connections:
if conn["target"] == cell_id and not conn["source"].startswith(
"external"
):
cell_inputs[cell_id][conn["input_type"]].append(
cell_outputs[conn["source"]]
)
# Compute outputs
excitatory_input = (
np.vstack(cell_inputs[cell_id]["excitatory"])
if cell_inputs[cell_id]["excitatory"]
else None
)
inhibitory_input = (
np.vstack(cell_inputs[cell_id]["inhibitory"])
if cell_inputs[cell_id]["inhibitory"]
else None
)
cell = self.cells[cell_id]
output = cell.compute_output(
{"excitatory": excitatory_input, "inhibitory": inhibitory_input}
# Build adjacency list and compute in-degree
graph = defaultdict(list)
in_degree = {cell_id: 0 for cell_id in self.cells}
for conn in self.connections:
if conn["source"] in self.cells:
graph[conn["source"]].append((conn["target"], conn["input_type"]))
in_degree[conn["target"]] += 1

# Initialize queue with cells that have no incoming edges
process_queue = deque(
[cell_id for cell_id, degree in in_degree.items() if degree == 0]
)

while process_queue:
cell_id = process_queue.popleft()
# Process each cell's inputs
for conn in self.connections:
if conn["target"] == cell_id and conn["source"] in cell_outputs:
cell_inputs[cell_id][conn["input_type"]].append(
cell_outputs[conn["source"]]
)
cell_outputs[cell_id] = output
processed_cells.append(cell_id)
# Update the list of cells to process by removing those already processed
cells_to_process = [
cell for cell in cells_to_process if cell not in processed_cells
]

# Compute outputs for current cell
excitatory_input = (
np.vstack(cell_inputs[cell_id]["excitatory"])
if cell_inputs[cell_id]["excitatory"]
else None
)
inhibitory_input = (
np.vstack(cell_inputs[cell_id]["inhibitory"])
if cell_inputs[cell_id]["inhibitory"]
else None
)
output = self.cells[cell_id](
{"excitatory": excitatory_input, "inhibitory": inhibitory_input}
)
cell_outputs[cell_id] = output

# Decrement in-degrees of successors and enqueue if in-degree becomes zero
for target, input_type in graph[cell_id]:
in_degree[target] -= 1
if in_degree[target] == 0:
process_queue.append(target)

# Check for unresolved dependencies (indicates a cycle or misconfiguration)
if any(degree > 0 for degree in in_degree.values()):
raise RuntimeError(
"A deadlock was detected in the network due to unresolved dependencies or cycles."
)

return cell_outputs
49 changes: 49 additions & 0 deletions cd_network/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
import pickle

from cd_network.network import CDNetwork


def load_input_file(input_file):
# Load input file (pickle) as dictionary
with open(input_file, "rb") as f:
inputs = pickle.load(f)
if not isinstance(inputs, dict):
raise TypeError(f"The input file {input_file} should contain a dictionary.")
return inputs


def parse_arguments():
parser = argparse.ArgumentParser(
description="Run CDNetwork with external inputs from pickle file"
)
parser.add_argument(
"config", type=str, help="Path to the network configuration JSON file"
)
parser.add_argument(
"input_file",
type=str,
help="Path to the pickle file containing external inputs (dictionary)",
)
parser.add_argument(
"output_path", type=str, help="Path to save the outputs as a pickle file"
)
return parser.parse_args()


def main():
args = parse_arguments()

network = CDNetwork(args.config)

external_inputs = load_input_file(args.input_file)
outputs = network(external_inputs)

with open(args.output_path, "wb") as f:
pickle.dump(outputs, f)

print(f"Outputs saved to {args.output_path}")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
numpy
scipy
matplotlib
networkx
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
"Topic :: Scientific/Engineering :: Bio-Informatics",
],
python_requires=">=3.6",
install_requires=["numpy<2", "scipy"],
entry_points={"console_scripts": ["cd_network = cd_network.run:main"]},
install_requires=["numpy<2", "scipy", "matplotlib", "networkx"],
extras_require={
"dev": ["pytest", "check-manifest", "pre-commit", "matplotlib"],
"dev": ["pytest", "check-manifest", "pre-commit"],
"test": ["pytest", "coverage"],
},
)
10 changes: 8 additions & 2 deletions tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
class TestNeuralNetwork(unittest.TestCase):
def setUp(self):
"""Set up the Neural Network with a configuration path."""
self.config_path = "tests/config.json"
self.network = CDNetwork(self.config_path)
try:
self.config_path = "tests/config.json"
self.network = CDNetwork(self.config_path)
except Exception as e:
print(f"Load config locally, {e}")
self.config_path = "config.json"
self.network = CDNetwork(self.config_path)

self.external_inputs = {
"external1": np.random.randn(1000), # Example external excitatory inputs
"external2": np.random.randn(1000), # Example external inhibitory inputs
Expand Down

0 comments on commit d8c5f1a

Please sign in to comment.