diff --git a/.gitignore b/.gitignore index c3b177d..249fa4b 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,7 @@ /dist/* /.idea/* /local_history.patch -/notebooks/.ipynb_checkpoints/* \ No newline at end of file +/notebooks/.ipynb_checkpoints/* + +*.txt +PKG-INFO \ No newline at end of file diff --git a/README.md b/README.md index cc0eec7..4ecc4b3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # CD-Network -CD-Network is a Python library designed for the analytical derivation of the stochastic output of coincidence detector ( -CD) cells. +CD-Network is a Python library designed for the analytical derivation of the stochastic output of coincidence detection +(CD) cells. These cells receive inputs modeled as non-homogeneous Poisson processes (NHPP) with both excitatory and inhibitory components. @@ -57,7 +57,7 @@ The CD network simulation uses a JSON configuration file. Below is a breakdown o target: Identifier for the cell receiving the input. input_type: Specifies whether the input is excitatory or inhibitory. -[Example Configuration File](example_notebooks/config.yaml) +[Example Configuration File](example_notebooks/config.json) ### CD Cells diff --git a/cd_network/network.py b/cd_network/network.py index 853fdb2..4f44e83 100644 --- a/cd_network/network.py +++ b/cd_network/network.py @@ -1,4 +1,6 @@ import json +from collections import defaultdict, deque +from typing import Any, Dict, Union import numpy as np @@ -6,15 +8,15 @@ class Neuron: - def __init__(self, cell_type, cell_id, params, fs): + def __init__(self, cell_type: str, cell_id: str, params: Dict[str, Any], fs: float): self.cell_type = cell_type self.cell_id = cell_id self.params = params self.fs = fs - def compute_output(self, inputs): - excitatory_inputs = inputs.get("excitatory", None) - inhibitory_inputs = inputs.get("inhibitory", None) + def __call__(self, inputs: Dict[str, np.ndarray], *args, **kwargs) -> np.ndarray: + excitatory_inputs = inputs.get("excitatory") + inhibitory_inputs = inputs.get("inhibitory") if self.cell_type == "ei": return ei( @@ -42,14 +44,30 @@ def compute_output(self, inputs): class CDNetwork: - def __init__(self, config_path): + def __init__(self, config: Union[Dict[str, Any], str]): + """Initialize the network with a configuration dictionary or a path to a configuration JSON file.""" self.cells = {} self.connections = [] - self.load_config(config_path) + self.load_config(config) + + def load_config(self, config: Union[Dict[str, Any], str]): + """Load and parse the configuration from a dictionary or a JSON file.""" + if isinstance(config, str): + try: + with open(config, "r") as file: + config = json.load(file) + except FileNotFoundError: + raise FileNotFoundError( + f"The configuration file {config} was not found." + ) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format in the configuration file.") + + if not isinstance(config, dict): + raise ValueError( + "Configuration must be a dictionary or a path to a JSON file." + ) - def load_config(self, config_path): - with open(config_path, "r") as f: - config = json.load(f) fs = config["fs"] for cell_config in config["cells"]: cell = Neuron( @@ -61,15 +79,45 @@ def load_config(self, config_path): self.cells[cell_config["id"]] = cell self.connections = config["connections"] - def __call__(self, external_inputs, *args, **kwargs): + def plot_network_connections(self): + """ + Plot the network connections using NetworkX and Matplotlib. + """ + import matplotlib.pyplot as plt + import networkx as nx + + G = nx.DiGraph() + for cell_id in self.cells: + cell_type = self.cells[cell_id].cell_type + color_map = {'excitatory': 'green', 'inhibitory': 'red'} + node_color = color_map.get(cell_type, 'skyblue') + G.add_node(cell_id, label=f"({cell_type})", color=node_color) + for conn in self.connections: + G.add_edge(conn["source"], conn["target"], label=conn["input_type"]) + + pos = nx.spring_layout(G) # Consider changing the layout for large networks + colors = ([G.nodes[node].get('color', 'white') for node in G.nodes()]) + nx.draw_networkx_nodes(G, pos, node_size=5000, node_color=colors, alpha=1) # edgecolors='black' + + node_labels = {node: f"{node} \n {G.nodes[node].get('label', '')}" for node in G.nodes()} + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=12) + + nx.draw_networkx_edges(G, pos, arrowstyle="-|>", arrowsize=20, edge_color="gray") + edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)} + nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red") + + plt.title("CDNetwork Connections") + plt.axis("off") + plt.show() + + def __call__(self, external_inputs: Dict[str, np.ndarray], *args, **kwargs): cell_outputs = {} - # Initialize storage for each cell's inputs cell_inputs = { cell_id: {"excitatory": [], "inhibitory": []} for cell_id in self.cells.keys() } - # Populate initial external inputs + # Process external inputs ext_data_shape = None for ext_key, ext_data in external_inputs.items(): if ext_data_shape is None: @@ -84,49 +132,54 @@ def __call__(self, external_inputs, *args, **kwargs): if conn["source"] == ext_key: cell_inputs[conn["target"]][conn["input_type"]].append(ext_data) - # Process each cell once all inputs are ready - cells_to_process = list(self.cells.keys()) - while cells_to_process: - processed_cells = [] - for cell_id in cells_to_process: - # Check if all inputs are available - inputs_ready = True - for conn in self.connections: - if conn["target"] == cell_id and not conn["source"].startswith( - "external" - ): - if cell_outputs.get(conn["source"]) is None: - inputs_ready = False - break - if inputs_ready: - # Gather inputs from sources - for conn in self.connections: - if conn["target"] == cell_id and not conn["source"].startswith( - "external" - ): - cell_inputs[cell_id][conn["input_type"]].append( - cell_outputs[conn["source"]] - ) - # Compute outputs - excitatory_input = ( - np.vstack(cell_inputs[cell_id]["excitatory"]) - if cell_inputs[cell_id]["excitatory"] - else None - ) - inhibitory_input = ( - np.vstack(cell_inputs[cell_id]["inhibitory"]) - if cell_inputs[cell_id]["inhibitory"] - else None - ) - cell = self.cells[cell_id] - output = cell.compute_output( - {"excitatory": excitatory_input, "inhibitory": inhibitory_input} + # Build adjacency list and compute in-degree + graph = defaultdict(list) + in_degree = {cell_id: 0 for cell_id in self.cells} + for conn in self.connections: + if conn["source"] in self.cells: + graph[conn["source"]].append((conn["target"], conn["input_type"])) + in_degree[conn["target"]] += 1 + + # Initialize queue with cells that have no incoming edges + process_queue = deque( + [cell_id for cell_id, degree in in_degree.items() if degree == 0] + ) + + while process_queue: + cell_id = process_queue.popleft() + # Process each cell's inputs + for conn in self.connections: + if conn["target"] == cell_id and conn["source"] in cell_outputs: + cell_inputs[cell_id][conn["input_type"]].append( + cell_outputs[conn["source"]] ) - cell_outputs[cell_id] = output - processed_cells.append(cell_id) - # Update the list of cells to process by removing those already processed - cells_to_process = [ - cell for cell in cells_to_process if cell not in processed_cells - ] + + # Compute outputs for current cell + excitatory_input = ( + np.vstack(cell_inputs[cell_id]["excitatory"]) + if cell_inputs[cell_id]["excitatory"] + else None + ) + inhibitory_input = ( + np.vstack(cell_inputs[cell_id]["inhibitory"]) + if cell_inputs[cell_id]["inhibitory"] + else None + ) + output = self.cells[cell_id]( + {"excitatory": excitatory_input, "inhibitory": inhibitory_input} + ) + cell_outputs[cell_id] = output + + # Decrement in-degrees of successors and enqueue if in-degree becomes zero + for target, input_type in graph[cell_id]: + in_degree[target] -= 1 + if in_degree[target] == 0: + process_queue.append(target) + + # Check for unresolved dependencies (indicates a cycle or misconfiguration) + if any(degree > 0 for degree in in_degree.values()): + raise RuntimeError( + "A deadlock was detected in the network due to unresolved dependencies or cycles." + ) return cell_outputs diff --git a/cd_network/run.py b/cd_network/run.py new file mode 100644 index 0000000..1d0d4da --- /dev/null +++ b/cd_network/run.py @@ -0,0 +1,49 @@ +import argparse +import pickle + +from cd_network.network import CDNetwork + + +def load_input_file(input_file): + # Load input file (pickle) as dictionary + with open(input_file, "rb") as f: + inputs = pickle.load(f) + if not isinstance(inputs, dict): + raise TypeError(f"The input file {input_file} should contain a dictionary.") + return inputs + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run CDNetwork with external inputs from pickle file" + ) + parser.add_argument( + "config", type=str, help="Path to the network configuration JSON file" + ) + parser.add_argument( + "input_file", + type=str, + help="Path to the pickle file containing external inputs (dictionary)", + ) + parser.add_argument( + "output_path", type=str, help="Path to save the outputs as a pickle file" + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + network = CDNetwork(args.config) + + external_inputs = load_input_file(args.input_file) + outputs = network(external_inputs) + + with open(args.output_path, "wb") as f: + pickle.dump(outputs, f) + + print(f"Outputs saved to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 6bad103..43b27e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ numpy scipy +matplotlib +networkx diff --git a/setup.py b/setup.py index 537a9a7..8dbd75a 100644 --- a/setup.py +++ b/setup.py @@ -30,9 +30,10 @@ "Topic :: Scientific/Engineering :: Bio-Informatics", ], python_requires=">=3.6", - install_requires=["numpy<2", "scipy"], + entry_points={"console_scripts": ["cd_network = cd_network.run:main"]}, + install_requires=["numpy<2", "scipy", "matplotlib", "networkx"], extras_require={ - "dev": ["pytest", "check-manifest", "pre-commit", "matplotlib"], + "dev": ["pytest", "check-manifest", "pre-commit"], "test": ["pytest", "coverage"], }, ) diff --git a/tests/test_network.py b/tests/test_network.py index 63bdf57..1d6f3d7 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -8,8 +8,14 @@ class TestNeuralNetwork(unittest.TestCase): def setUp(self): """Set up the Neural Network with a configuration path.""" - self.config_path = "tests/config.json" - self.network = CDNetwork(self.config_path) + try: + self.config_path = "tests/config.json" + self.network = CDNetwork(self.config_path) + except Exception as e: + print(f"Load config locally, {e}") + self.config_path = "config.json" + self.network = CDNetwork(self.config_path) + self.external_inputs = { "external1": np.random.randn(1000), # Example external excitatory inputs "external2": np.random.randn(1000), # Example external inhibitory inputs