Skip to content

Commit

Permalink
changed EPS in cells.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nuniz committed Jul 13, 2024
1 parent b8429e5 commit 383d23d
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions cd_network/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from .coincidence_integral import cached_coincidence_integral, coincidence_integral

EPS = 1e-15


def ei(
excitatory_input: np.ndarray,
inhibitory_inputs: np.ndarray,
delta_s: float,
fs: float,
excitatory_input: np.ndarray,
inhibitory_inputs: np.ndarray,
delta_s: float,
fs: float,
) -> np.ndarray:
"""
The general EI cell spikes whenever the excitatory input spikes and in the preceding ∆ seconds none of the
Expand All @@ -34,7 +36,7 @@ def ei(
inhibitory_inputs = inhibitory_inputs[np.newaxis, ...]

assert (
len(excitatory_input) == inhibitory_inputs.shape[-1]
len(excitatory_input) == inhibitory_inputs.shape[-1]
), "Length of excitatory input must match the size of inhibitory inputs along the last axis."

output = excitatory_input * np.prod(
Expand Down Expand Up @@ -63,12 +65,12 @@ def _all_spikes_ee(inputs: np.ndarray, delta_s: float, fs: float) -> np.ndarray:
n_inputs, samples = inputs.shape
output = np.zeros(samples)
for i in range(n_inputs):
output += inputs[i] * coincidence_prod / (coincidence_integral_outputs[i] + np.finfo(np.float64))
output += inputs[i] * coincidence_prod / (coincidence_integral_outputs[i] + EPS)
return output


def _exactly_n_spikes_ee(
inputs: np.ndarray, n_spikes: int, delta_s: float, fs: float
inputs: np.ndarray, n_spikes: int, delta_s: float, fs: float
) -> np.ndarray:
"""
An all-spikes EE cell generates a spike whenever exactly n_spikes of its inputs spikes during an interval ∆.
Expand All @@ -86,7 +88,7 @@ def _exactly_n_spikes_ee(

n_inputs, samples = inputs.shape
assert (
n_inputs <= n_inputs
n_inputs <= n_inputs
), "n_spikes should be less than or equal to the number of inputs."

output = np.zeros(samples)
Expand Down Expand Up @@ -152,7 +154,7 @@ def ee(inputs, n_spikes: int, delta_s: float, fs: float) -> np.ndarray:

n_inputs, samples = inputs.shape
assert (
n_inputs <= n_inputs
n_inputs <= n_inputs
), "n_spikes should be less than or equal to the number of inputs."

output = np.zeros(samples)
Expand All @@ -163,11 +165,11 @@ def ee(inputs, n_spikes: int, delta_s: float, fs: float) -> np.ndarray:


def cd(
excitatory_inputs: np.ndarray,
inhibitory_inputs: np.ndarray,
n_spikes: int,
delta_s: float,
fs: float,
excitatory_inputs: np.ndarray,
inhibitory_inputs: np.ndarray,
n_spikes: int,
delta_s: float,
fs: float,
) -> np.ndarray:
"""
A general CD cell is defined as one with n_excitatory_inputs excitatory inputs and n_inhibitory_inputs inhibitory
Expand Down Expand Up @@ -205,7 +207,7 @@ def cd(
n_inhibitory_inputs, inhibitory_samples = inhibitory_inputs.shape

assert (
inhibitory_samples == excitatory_samples
inhibitory_samples == excitatory_samples
), "Number of samples in inhibitory and excitatory inputs must match."

output = np.zeros(excitatory_samples)
Expand Down

0 comments on commit 383d23d

Please sign in to comment.