Skip to content

Commit

Permalink
Merge pull request #3 from nuniz/notebooks
Browse files Browse the repository at this point in the history
Notebooks & Bugfix (all excitatory inputs)
  • Loading branch information
nuniz authored Jul 13, 2024
2 parents b204096 + 1c787c9 commit 9adc2f9
Show file tree
Hide file tree
Showing 13 changed files with 687 additions and 152 deletions.
117 changes: 90 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,60 +1,121 @@
# CD-Network

CD-Network is a Python library designed for the analytical derivation of the stochastic output of coincidence detector (CD) cells.
These cells receive inputs modeled as non-homogeneous Poisson processes (NHPP) with both excitatory and inhibitory components.
CD-Network is a Python library designed for the analytical derivation of the stochastic output of coincidence detector (
CD) cells.
These cells receive inputs modeled as non-homogeneous Poisson processes (NHPP) with both excitatory and inhibitory
components.

![cd_scheme](cd.png)

## Features

### Dynamic Connections (CD Network)

Define how cells are interconnected within the network and how external inputs affect cell
responses.

```python
import numpy as np
from cd_network.network import NeuralNetwork

if __name__ == '__main__':
# Load the neural network configuration from a JSON file
config_path = r'config.json' # Path to the configuration file
network = NeuralNetwork(config_path)

# Define external inputs for the network
external_inputs = {
'external1': np.random.randn(1000),
'external2': np.random.randn(1000),
'external3': np.random.randn(1000)
}

# Run the network with the provided external inputs
outputs = network.run_network(external_inputs)

# Print the outputs of the network
print(outputs)

```

### Configuration File

The CD network simulation uses a JSON configuration file. Below is a breakdown of the configuration structure:

fs: Sampling frequency in Hz. This value is used across all cells for time-based calculations.

cells: An array of objects where each object represents a neural cell and its specific parameters:
type: Specifies the type of the cell (e.g., ei, simple_ee, cd).
id: A unique identifier for the cell.
params: Parameters specific to the cell type, such as delta_s for the time window in seconds and n_spikes for the minimum number of spikes required.

connections: An array defining the connections between cells or from external inputs to cells:
source: Identifier for the source of the input. This can be an external source or another cell.
target: Identifier for the cell receiving the input.
input_type: Specifies whether the input is excitatory or inhibitory.

[Example Configuration File](examplenotebooks/config.yaml)

### CD Cells

#### `ei(excitatory_input, inhibitory_inputs, delta_s, fs)`
Computes the output of an excitatory-inhibitory (EI) neuron model.
The model outputs spikes based on the excitatory inputs, except when inhibited by any preceding spikes within a specified time window from the inhibitory inputs.

Computes the output of an excitatory-inhibitory (EI) neuron model.
The model outputs spikes based on the excitatory inputs, except when inhibited by any preceding spikes within a
specified time window from the inhibitory inputs.

- **Parameters:**
- `excitatory_input (np.ndarray)`: 1D array of spike times or binary spikes from the excitatory neuron.
- `inhibitory_inputs (np.ndarray)`: 1D or 2D array of spike times or binary spikes from one or more inhibitory neurons.
- `delta_s (float)`: Coincidence integration duration in seconds, defining the time window for inhibition.
- `fs (float)`: Sampling frequency in Hz.
- `excitatory_input (np.ndarray)`: 1D array of spike times or binary spikes from the excitatory neuron.
- `inhibitory_inputs (np.ndarray)`: 1D or 2D array of spike times or binary spikes from one or more inhibitory
neurons.
- `delta_s (float)`: Coincidence integration duration in seconds, defining the time window for inhibition.
- `fs (float)`: Sampling frequency in Hz.

- **Returns:**
- `np.ndarray`: Output spike times or binary spike array after applying the excitatory-inhibitory interaction.
- `np.ndarray`: Output spike times or binary spike array after applying the excitatory-inhibitory interaction.

#### `simple_ee(inputs, delta_s, fs)`
Simplifies the model of excitatory-excitatory (EE) interaction where an output spike is generated whenever both inputs spike within a specified time interval.

Simplifies the model of excitatory-excitatory (EE) interaction where an output spike is generated whenever both inputs
spike within a specified time interval.

- **Parameters:**
- `inputs (np.ndarray)`: 2D array of excitatory input spikes.
- `delta_s (float)`: Coincidence integration duration in seconds.
- `fs (float)`: Sampling frequency in Hz.
- `inputs (np.ndarray)`: 2D array of excitatory input spikes.
- `delta_s (float)`: Coincidence integration duration in seconds.
- `fs (float)`: Sampling frequency in Hz.

- **Returns:**
- `np.ndarray`: Output spike times or binary spike array after applying the EE interaction.
- `np.ndarray`: Output spike times or binary spike array after applying the EE interaction.

#### `ee(inputs, n_spikes, delta_s, fs)`
A general excitatory-excitatory (EE) cell model that generates a spike whenever at least a minimum number of its inputs spike simultaneously within a specific time interval.

A general excitatory-excitatory (EE) cell model that generates a spike whenever at least a minimum number of its inputs
spike simultaneously within a specific time interval.

- **Parameters:**
- `inputs (np.ndarray)`: 2D array of excitatory input spikes.
- `n_spikes (int)`: Minimum number of inputs that must spike simultaneously.
- `delta_s (float)`: Coincidence integration duration in seconds.
- `fs (float)`: Sampling frequency in Hz.
- `inputs (np.ndarray)`: 2D array of excitatory input spikes.
- `n_spikes (int)`: Minimum number of inputs that must spike simultaneously.
- `delta_s (float)`: Coincidence integration duration in seconds.
- `fs (float)`: Sampling frequency in Hz.

- **Returns:**
- `np.ndarray`: Output spike times or binary spike array based on the input conditions.
- `np.ndarray`: Output spike times or binary spike array based on the input conditions.

#### `cd(excitatory_inputs, inhibitory_inputs, n_spikes, delta_s, fs)`
Models the output of a coincidence detector (CD) cell which generates spikes based on the relative timing and number of excitatory and inhibitory inputs within a defined interval.

Models the output of a coincidence detector (CD) cell which generates spikes based on the relative timing and number of
excitatory and inhibitory inputs within a defined interval.

- **Parameters:**
- `excitatory_inputs (np.ndarray)`: 2D array of excitatory input spikes.
- `inhibitory_inputs (np.ndarray)`: 2D array of inhibitory input spikes.
- `n_spikes (int)`: Minimum excess of excitatory spikes over inhibitory spikes required to generate an output spike.
- `delta_s (float)`: Interval length in seconds.
- `fs (float)`: Sampling frequency in Hz.
- `excitatory_inputs (np.ndarray)`: 2D array of excitatory input spikes.
- `inhibitory_inputs (np.ndarray)`: 2D array of inhibitory input spikes.
- `n_spikes (int)`: Minimum excess of excitatory spikes over inhibitory spikes required to generate an output spike.
- `delta_s (float)`: Interval length in seconds.
- `fs (float)`: Sampling frequency in Hz.

- **Returns:**
- `np.ndarray`: Output spike array after applying the CD interaction based on the relative timing and number of inputs.
- `np.ndarray`: Output spike array after applying the CD interaction based on the relative timing and number of
inputs.

## Installation

Expand All @@ -73,7 +134,9 @@ pip install .
```

## Contribution

run pre-commit to check all files in the repo.

```bash
pre-commit run --all-files
```
Expand Down
Binary file added cd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions cd_network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@

from .cells import cd, ee, ei, simple_ee
from .version import __version__

# import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
48 changes: 31 additions & 17 deletions cd_network/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
EPS = 1e-15


def ei(excitatory_input: np.ndarray, inhibitory_inputs: np.ndarray, delta_s: float, fs: float) -> np.ndarray:
def ei(
excitatory_input: np.ndarray,
inhibitory_inputs: np.ndarray,
delta_s: float,
fs: float,
) -> np.ndarray:
"""
Computes the output of an excitatory-inhibitory (EI) cell. The EI cell generates a spike
based on the excitatory input, provided there are no spikes in the inhibitory inputs
Expand All @@ -28,15 +33,21 @@ def ei(excitatory_input: np.ndarray, inhibitory_inputs: np.ndarray, delta_s: flo
"""
assert excitatory_input.ndim in [1, 2], "Excitatory input must be a 1D or 2D array."
if excitatory_input.ndim == 2:
assert excitatory_input.shape[0] == 1, "If 2D, excitatory input must have a single row."
assert (
excitatory_input.shape[0] == 1
), "If 2D, excitatory input must have a single row."
excitatory_input = excitatory_input[0]

assert inhibitory_inputs.ndim in [1, 2], "Inhibitory inputs must be a 1D or 2D array."
assert inhibitory_inputs.ndim in [
1,
2,
], "Inhibitory inputs must be a 1D or 2D array."
if inhibitory_inputs.ndim == 1:
inhibitory_inputs = inhibitory_inputs[np.newaxis, ...]

assert len(excitatory_input) == inhibitory_inputs.shape[-1], \
"Length of excitatory input must match the size of inhibitory inputs along the last axis."
assert (
len(excitatory_input) == inhibitory_inputs.shape[-1]
), "Length of excitatory input must match the size of inhibitory inputs along the last axis."

output = excitatory_input * np.prod(
1 - coincidence_integral(inhibitory_inputs, delta_s, fs), axis=0
Expand Down Expand Up @@ -69,7 +80,7 @@ def _all_spikes_ee(inputs: np.ndarray, delta_s: float, fs: float) -> np.ndarray:


def _exactly_n_spikes_ee(
inputs: np.ndarray, n_spikes: int, delta_s: float, fs: float
inputs: np.ndarray, n_spikes: int, delta_s: float, fs: float
) -> np.ndarray:
"""
An all-spikes EE cell generates a spike whenever exactly n_spikes of its inputs spikes during an interval ∆.
Expand All @@ -87,7 +98,7 @@ def _exactly_n_spikes_ee(

n_inputs, samples = inputs.shape
assert (
n_inputs <= n_inputs
n_inputs <= n_inputs
), "n_spikes should be less than or equal to the number of inputs."

output = np.zeros(samples)
Expand All @@ -96,16 +107,19 @@ def _exactly_n_spikes_ee(
for comb in binomial_combinations:
indices_spike = np.array(comb)
indices_not_spike = np.array(list(set(range(n_inputs)) - set(indices_spike)))
if len(indices_not_spike) > 0 & len(indices_spike) > 0:
ei_output = ei(
if len(indices_not_spike) == 0:
output += _all_spikes_ee(
inputs=inputs[indices_spike], delta_s=delta_s, fs=fs
)
else:
output += ei(
excitatory_input=_all_spikes_ee(
inputs=inputs[indices_spike], delta_s=delta_s, fs=fs
),
inhibitory_inputs=inputs[indices_not_spike],
delta_s=delta_s,
fs=fs,
)
output += ei_output

return output

Expand Down Expand Up @@ -153,7 +167,7 @@ def ee(inputs, n_spikes: int, delta_s: float, fs: float) -> np.ndarray:

n_inputs, samples = inputs.shape
assert (
n_inputs <= n_inputs
n_inputs <= n_inputs
), "n_spikes should be less than or equal to the number of inputs."

output = np.zeros(samples)
Expand All @@ -164,11 +178,11 @@ def ee(inputs, n_spikes: int, delta_s: float, fs: float) -> np.ndarray:


def cd(
excitatory_inputs: np.ndarray,
inhibitory_inputs: np.ndarray,
n_spikes: int,
delta_s: float,
fs: float,
excitatory_inputs: np.ndarray,
inhibitory_inputs: np.ndarray,
n_spikes: int,
delta_s: float,
fs: float,
) -> np.ndarray:
"""
A general CD cell is defined as one with n_excitatory_inputs excitatory inputs and n_inhibitory_inputs inhibitory
Expand Down Expand Up @@ -206,7 +220,7 @@ def cd(
n_inhibitory_inputs, inhibitory_samples = inhibitory_inputs.shape

assert (
inhibitory_samples == excitatory_samples
inhibitory_samples == excitatory_samples
), "Number of samples in inhibitory and excitatory inputs must match."

output = np.zeros(excitatory_samples)
Expand Down
6 changes: 3 additions & 3 deletions cd_network/coincidence_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_trapezoid_kernel(samples_integral: int) -> np.ndarray:


def apply_filter(
x: np.ndarray, kernel: np.ndarray, dt: float, filter_func: Callable
x: np.ndarray, kernel: np.ndarray, dt: float, filter_func: Callable
) -> np.ndarray:
"""
Apply a filtering function to an input signal using a specified kernel.
Expand All @@ -41,7 +41,7 @@ def apply_filter(


def coincidence_integral(
x: np.ndarray, integration_duration: float, fs: float, method: str = "filtfilt"
x: np.ndarray, integration_duration: float, fs: float, method: str = "filtfilt"
) -> np.ndarray:
"""
Computes the coincidence integral of the input signal.
Expand All @@ -66,7 +66,7 @@ def coincidence_integral(
if method in filter_methods:
return filter_methods[method](x)

raise ValueError(f'method {method} is not supported.')
raise ValueError(f"method {method} is not supported.")


@lru_cache(maxsize=None)
Expand Down
Loading

0 comments on commit 9adc2f9

Please sign in to comment.