diff --git a/.travis.yml b/.travis.yml index 0236087f1..e470e8e28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,68 +4,28 @@ matrix: include: # OSX - os: osx - name: "MacOS sierra" - osx_image: xcode9.2 - env: - - PYTHON_VERSION=3.7 - - NEURON_VERSION=7.7 - - - os: osx - name: "MacOS el capitan" - osx_image: xcode8 - env: - - PYTHON_VERSION=3.7 - - NEURON_VERSION=7.7 - - - os: osx - name: "MacOS mojave" - osx_image: xcode11.3 - env: - - PYTHON_VERSION=3.7 - - NEURON_VERSION=7.7 + name: "MacOS Catalina" + osx_image: xcode12 - - os: osx - name: "MacOS high sierra" - osx_image: xcode10.1 + # WSL + - os: windows + name: "WSL" env: - - PYTHON_VERSION=3.7 - - NEURON_VERSION=7.7 + - WSL_INSTALL=1 + - USE_CONDA=0 # Windows - os: windows name: "Windows" env: - - PYTHON_VERSION=3.7 - - NEURON_VERSION=7.7 + - WSL_INSTALL=0 # Linux - - os: linux - dist: xenial - name: "Ubuntu xenial" - env: - - NEURON_VERSION=7.7 - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - xvfb - - os: linux dist: bionic - name: "Ubuntu bionic" + name: "Ubuntu Bionic" env: - - NEURON_VERSION=7.7 - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - xvfb - - - os: linux - dist: disco - name: "Ubuntu disco" - env: - - NEURON_VERSION=7.7 + - USE_CONDA=1 apt: sources: - ubuntu-toolchain-r-test @@ -75,8 +35,6 @@ matrix: - os: linux dist: focal name: "Ubuntu focal" - env: - - NEURON_VERSION=7.7 apt: sources: - ubuntu-toolchain-r-test @@ -84,103 +42,67 @@ matrix: - xvfb before_install: - - set -e # error on any command failure - - | # function exports - export TRAVIS_TESTING=1 - - if [[ "${TRAVIS_OS_NAME}" == "windows" ]]; then - # for start_vcxsrv_print and stop_vcxsrv - source "scripts/docker_functions.sh" - set_globals - fi - - # source utility functions - export LOGFILE="hnn_travis.log" - source scripts/utils.sh - export -f cleanup - - | - if [ "${TRAVIS_OS_NAME}" == "osx" ]; then # install osx prerequisites - echo "Installing macOS prerequisites" - - scripts/setup-travis-mac.sh - export PATH=${HOME}/Miniconda3/bin:$PATH - export PATH=$PATH:/Applications/NEURON-${NEURON_VERSION}/nrn/x86_64/bin - export PYTHONPATH=/Applications/NEURON-${NEURON_VERSION}/nrn/lib/python:$PYTHONPATH - export PYTHON=python3 - - source activate hnn && echo "activated conda HNN environment" - fi - - | # windows - if [ "${TRAVIS_OS_NAME}" == "windows" ]; then - echo "Installing windows prerequisites" - - scripts/setup-travis-windows.sh + # Step 0: set common environment variables - # add miniconda python to the path - export PATH=$PATH:$HOME/Miniconda3/Scripts - export PATH=$HOME/Miniconda3/envs/hnn/:$PATH - export PATH=$HOME/Miniconda3/envs/hnn/Scripts:$PATH - export PATH=$HOME/Miniconda3/envs/hnn/Library/bin:$PATH - - # for using X server - export PATH="$PATH:/c/Program\ Files/VcXsrv" - - # for MESA dll's - export PATH=$PATH:/c/tools/msys64/mingw64/bin + set -e + export TRAVIS_TESTING=1 + export DISPLAY=:0 + export LOGFILE="hnn_travis.log" + export PATH="$PATH:$HOME/.local/bin" + if [[ "${WSL_INSTALL}" -eq 1 ]]; then # for sharing with WSL environment - export WSLENV=TRAVIS_TESTING/u - - # set other variables for neuron and HNN - export PATH=$PATH:/c/nrn/bin - export DISPLAY="localhost:0" - export NEURONHOME=/c/nrn - export PYTHON=python + export OMPI_MCA_btl_vader_single_copy_mechanism=none + export WSLENV=TRAVIS_TESTING/u:DISPLAY/u:OMPI_MCA_btl_vader_single_copy_mechanism/u:WSL_INSTALL/u fi - - | # Linux - if [ "${TRAVIS_OS_NAME}" == "linux" ]; then - echo "Installing Linux prerequisites" - - export DISPLAY=:0 - export PATH=/usr/bin:/usr/local/bin:$PATH - - echo "Starting fake Xserver" - Xvfb $DISPLAY -listen tcp -screen 0 1024x768x24 > /dev/null & - - echo "Starting Ubuntu install script" - installer/ubuntu/hnn-ubuntu.sh + - | + # Step 1: install prerequisites - NLOPT_LIB=$(ls -d $HOME/.local/lib/python*/site-packages) - echo $NLOPT_LIB - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NLOPT_LIB - export PYTHON=python3 + if [[ "${WSL_INSTALL}" -eq 1 ]]; then + powershell.exe -ExecutionPolicy Bypass -File ./scripts/setup-travis-wsl.ps1 + # scripts/setup-travis-wsl.sh + else + echo "Installing ${TRAVIS_OS_NAME} prerequisites" + scripts/setup-travis-${TRAVIS_OS_NAME}.sh - # test X server - xset -display $DISPLAY -q > /dev/null; + source "$HOME/Miniconda3/etc/profile.d/conda.sh" + conda activate hnn fi install: - - | # for mac build HNN .mod files - if [[ "${TRAVIS_OS_NAME}" == "osx" ]]; then - make -j2 + - | + # Step 2: install hnn Python module and modules for testing + + if [[ "${WSL_INSTALL}" -ne 1 ]]; then + pip install flake8 pytest pytest-cov coverage coveralls mne pytest-qt + python setup.py install + else + wsl -- pip install flake8 pytest pytest-cov coverage coveralls mne \ + pytest-qt + wsl -- python3 setup.py install --user fi - - | # testing packages - pip install flake8 pytest pytest-cov coverage coveralls mne script: - - | # Check that the GUI starts on host OS - echo "Testing GUI on host OS..." - $PYTHON hnn.py - - | # Run py.test that includes running a simulation and verifying results - echo "Running Python tests on host OS..." - py.test --cov=. tests/ - - | # Test WSL-based version on windows (needs VcXsrv) - if [[ "${TRAVIS_OS_NAME}" == "windows" ]]; then - find_command_suggested_path "vcxsrv" "/c/Program Files/VcXsrv" && \ - start_vcxsrv_print || script_fail - wsl -- bash -e //home/hnn_user/hnn/scripts/run-travis-wsl.sh - stop_vcxsrv || script_fail + - | + # Step 3: run CI tests with py.test + + if [[ "${WSL_INSTALL}" -eq 1 ]]; then + wsl -- //home/hnn_user/hnn/scripts/run-pytest.sh + else + if [[ "${TRAVIS_OS_NAME}" == "osx" ]]; then + # NEURON will fail to import if DISPLAY is set + unset DISPLAY + elif [[ "${TRAVIS_OS_NAME}" == "windows" ]]; then + # Python will search path to find neuron dll's + export PATH=$PATH:/c/nrn/bin + + # run tests first as a user with a space (TODO) + # runas //user:"test user" //wait "bash" "scripts/run-pytest.sh" < "$HOME/test_user_creds" + # echo "Finished test with 'test user'" + fi + + scripts/run-pytest.sh fi after_success: diff --git a/DataViewGUI.py b/DataViewGUI.py deleted file mode 100644 index b01c16f15..000000000 --- a/DataViewGUI.py +++ /dev/null @@ -1,139 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox, QInputDialog -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from gutils import getmplDPI -import matplotlib.pyplot as plt -from conf import dconf - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: plt.rcParams['font.size'] = dconf['fontsize'] = 10 - -# GUI for viewing data from individual/all trials -class DataViewGUI (QMainWindow): - def __init__ (self, CanvasType, paramf, ntrial,title): - super().__init__() - self.fontsize = dconf['fontsize'] - self.linewidth = plt.rcParams['lines.linewidth'] = 1 - self.markersize = plt.rcParams['lines.markersize'] = 5 - self.CanvasType = CanvasType - self.paramf = paramf - self.ntrial = ntrial - self.title = title - self.initUI() - - def initMenu (self): - exitAction = QAction(QIcon.fromTheme('exit'), 'Exit', self) - exitAction.setShortcut('Ctrl+Q') - exitAction.setStatusTip('Exit ' + self.title + '.') - exitAction.triggered.connect(qApp.quit) - - menubar = self.menuBar() - self.fileMenu = menubar.addMenu('&File') - menubar.setNativeMenuBar(False) - self.fileMenu.addAction(exitAction) - - viewMenu = menubar.addMenu('&View') - changeFontSizeAction = QAction('Change Font Size',self) - changeFontSizeAction.setStatusTip('Change Font Size.') - changeFontSizeAction.triggered.connect(self.changeFontSize) - viewMenu.addAction(changeFontSizeAction) - changeLineWidthAction = QAction('Change Line Width',self) - changeLineWidthAction.setStatusTip('Change Line Width.') - changeLineWidthAction.triggered.connect(self.changeLineWidth) - viewMenu.addAction(changeLineWidthAction) - changeMarkerSizeAction = QAction('Change Marker Size',self) - changeMarkerSizeAction.setStatusTip('Change Marker Size.') - changeMarkerSizeAction.triggered.connect(self.changeMarkerSize) - viewMenu.addAction(changeMarkerSizeAction) - - def changeFontSize (self): - i, okPressed = QInputDialog.getInt(self, "Set Font Size","Font Size:", plt.rcParams['font.size'], 1, 100, 1) - if okPressed: - self.fontsize = plt.rcParams['font.size'] = dconf['fontsize'] = i - self.initCanvas() - self.m.plot() - - def changeLineWidth (self): - i, okPressed = QInputDialog.getInt(self, "Set Line Width","Line Width:", plt.rcParams['lines.linewidth'], 1, 20, 1) - if okPressed: - self.linewidth = plt.rcParams['lines.linewidth'] = i - self.initCanvas() - self.m.plot() - - def changeMarkerSize (self): - i, okPressed = QInputDialog.getInt(self, "Set Marker Size","Font Size:", self.markersize, 1, 100, 1) - if okPressed: - self.markersize = plt.rcParams['lines.markersize'] = i - self.initCanvas() - self.m.plot() - - def printStat (self,s): - print(s) - self.statusBar().showMessage(s) - - def initCanvas (self): - try: # to avoid memory leaks remove any pre-existing widgets before adding new ones - self.grid.removeWidget(self.m) - self.grid.removeWidget(self.toolbar) - self.m.setParent(None) - self.toolbar.setParent(None) - self.m = self.toolbar = None - except: - pass - self.m = self.CanvasType(self.paramf, self.index, parent = self, width=12, height=10, dpi=getmplDPI()) - # this is the Navigation widget - # it takes the Canvas widget and a parent - self.toolbar = NavigationToolbar(self.m, self) - self.grid.addWidget(self.toolbar, 0, 0, 1, 4); - self.grid.addWidget(self.m, 1, 0, 1, 4); - - def updateCB (self): - self.cb.clear() - if self.ntrial > 1: - self.cb.addItem('Show All Trials') - for i in range(self.ntrial): - self.cb.addItem('Show Trial ' + str(i+1)) - else: - self.cb.addItem('All Trials') - self.cb.activated[int].connect(self.onActivated) - - def initUI (self): - self.initMenu() - self.statusBar() - self.setGeometry(300, 300, 1300, 1100) - self.setWindowTitle(self.title + ' - ' + self.paramf) - self.grid = grid = QGridLayout() - self.index = 0 - self.initCanvas() - self.cb = QComboBox(self) - self.grid.addWidget(self.cb,2,0,1,4) - - self.updateCB() - - # need a separate widget to put grid on - widget = QWidget(self) - widget.setLayout(grid) - self.setCentralWidget(widget); - - try: self.setWindowIcon(QIcon(os.path.join('res','icon.png'))) - except: pass - - self.show() - - def onActivated(self, idx): - if idx != self.index: - self.index = idx - if self.index == 0: - self.statusBar().showMessage('Loading data from all trials.') - else: - self.statusBar().showMessage('Loading data from trial ' + str(self.index) + '.') - self.m.index = self.index - self.initCanvas() - self.m.plot() - self.statusBar().showMessage('') diff --git a/L2_basket.py b/L2_basket.py deleted file mode 100644 index 801192d55..000000000 --- a/L2_basket.py +++ /dev/null @@ -1,215 +0,0 @@ -# L2_basket.py - establish class def for layer 2 basket cells -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed dependence on it.izip) -# last rev: (SL: toward python3) - -from neuron import h as nrn -from cell import BasketSingle - -# Units for e: mV -# Units for gbar: S/cm^2 unless otherwise noted - -# Layer 2 basket cell class -class L2Basket(BasketSingle): - def __init__(self, gid = -1, pos = -1): - # BasketSingle.__init__(self, pos, L, diam, Ra, cm) - # Note: Basket cell properties set in BasketSingle()) - BasketSingle.__init__(self, gid, pos, 'L2Basket') - self.celltype = 'L2_basket' - - self.__synapse_create() - self.__biophysics() - - # creation of synapses - def __synapse_create(self): - # creates synapses onto this cell - self.soma_ampa = self.syn_ampa_create(self.soma(0.5)) - self.soma_gabaa = self.syn_gabaa_create(self.soma(0.5)) - self.soma_nmda = self.syn_nmda_create(self.soma(0.5)) - - def __biophysics(self): - self.soma.insert('hh2') - - # insert IClamps in all situations - def create_all_IClamp(self, p): - # list of sections for this celltype - sect_list_IClamp = [ - 'soma', - ] - - # some parameters - t_delay = p['Itonic_t0_L2Basket'] - - # T = -1 means use nrn.tstop - if p['Itonic_T_L2Basket'] == -1: - t_dur = nrn.tstop - t_delay - - else: - t_dur = p['Itonic_T_L2Basket'] - t_delay - - # t_dur must be nonnegative, I imagine - if t_dur < 0.: - t_dur = 0. - - # properties of the IClamp - props_IClamp = { - 'loc': 0.5, - 'delay': t_delay, - 'dur': t_dur, - 'amp': p['Itonic_A_L2Basket'] - } - - # iterate through list of sect_list_IClamp to create a persistent IClamp object - # the insert_IClamp procedure is in Cell() and checks on names - # so names must be actual section names, or else it will fail silently - # self.list_IClamp as a variable is guaranteed in Cell() - self.list_IClamp = [self.insert_IClamp(sect_name, props_IClamp) for sect_name in sect_list_IClamp] - - # par connect between all presynaptic cells - # no connections from L5Pyr or L5Basket to L2Baskets - def parconnect(self, gid, gid_dict, pos_dict, p): - # FROM L2 pyramidals TO this cell - for gid_src, pos in zip(gid_dict['L2_pyramidal'], pos_dict['L2_pyramidal']): - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Pyr_L2Basket'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L2_pyramidal' - } - - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.soma_ampa)) - - # FROM other L2Basket cells - for gid_src, pos in zip(gid_dict['L2_basket'], pos_dict['L2_basket']): - # no autapses - # if gid_src != gid: - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Basket_L2Basket'], - 'A_delay': 1., - 'lamtha': 20., - 'threshold': p['threshold'], - 'type_src' : 'L2_basket' - } - - self.ncfrom_L2Basket.append(self.parconnect_from_src(gid_src, nc_dict, self.soma_gabaa)) - - # this function might make more sense as a method of net? - # par: receive from external inputs - def parreceive(self, gid, gid_dict, pos_dict, p_ext): - # for some gid relating to the input feed: - for gid_src, p_src, pos in zip(gid_dict['extinput'], p_ext, pos_dict['extinput']): - # check if AMPA params are defined in the p_src - if 'L2Basket_ampa' in p_src.keys(): - # create an nc_dict - nc_dict_ampa = { - 'pos_src': pos, - 'A_weight': p_src['L2Basket_ampa'][0], - 'A_delay': p_src['L2Basket_ampa'][1], - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src' : 'ext' - } - - # AMPA synapse - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.soma_ampa)) - - # Check if NMDA params are defined in p_src - if 'L2Basket_nmda' in p_src.keys(): - nc_dict_nmda = { - 'pos_src': pos, - 'A_weight': p_src['L2Basket_nmda'][0], - 'A_delay': p_src['L2Basket_nmda'][1], - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src' : 'ext' - } - - # NMDA synapse - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.soma_nmda)) - - # one parreceive function to handle all types of external parreceives - # types must be defined explicitly here - def parreceive_ext(self, type, gid, gid_dict, pos_dict, p_ext): - if type.startswith(('evprox', 'evdist')): - if self.celltype in p_ext.keys(): - gid_ev = gid + gid_dict[type][0] - - nc_dict_ampa = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 is ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 is delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - nc_dict_nmda = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][1], # index 1 is nmda weight - 'A_delay': p_ext[self.celltype][2], # index 2 is delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - # connections depend on location of input - why only for L2 basket and not L5 basket? - if p_ext['loc'] is 'proximal': - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.soma_ampa)) - # NEW: note that default/original is 0 nmda weight for the soma (for prox evoked) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.soma_nmda)) - - elif p_ext['loc'] is 'distal': - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.soma_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.soma_nmda)) - - elif type == 'extgauss': - # gid is this cell's gid - # gid_dict is the whole dictionary, including the gids of the extgauss - # pos_list is also the pos of the extgauss (net origin) - # p_ext_gauss are the params (strength, etc.) - # I recognize this is ugly (hack) - if self.celltype in p_ext.keys(): - # since gid ids are unique, then these will all be shifted. - # if order of extgauss random feeds ever matters (likely) - # then will have to preserve order - # of creation based on gid ids of the cells - # this is a dumb place to put this information - gid_extgauss = gid + gid_dict['extgauss'][0] - - # gid works here because there are as many pos items in pos_dict['extgauss'] as there are cells - nc_dict = { - 'pos_src': pos_dict['extgauss'][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 is ampa weight - 'A_delay': p_ext[self.celltype][1], # index 2 is delay - 'lamtha': p_ext['lamtha'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss, nc_dict, self.soma_ampa)) - - elif type == 'extpois': - if self.celltype in p_ext.keys(): - gid_extpois = gid + gid_dict['extpois'][0] - - nc_dict = { - 'pos_src': pos_dict['extpois'][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 is ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 is delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.soma_ampa)) - - if p_ext[self.celltype][1] > 0.0: - nc_dict['A_weight'] = p_ext[self.celltype][1] # index 1 for nmda weight - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.soma_nmda)) - - else: - print("Warning, type def not specified in L2Basket") diff --git a/L2_pyramidal.py b/L2_pyramidal.py deleted file mode 100644 index ce5d95d42..000000000 --- a/L2_pyramidal.py +++ /dev/null @@ -1,618 +0,0 @@ -# L2_pyramidal.py - est class def for layer 2 pyramidal cells -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed dep on it.izip) -# last rev: (SL: toward python3, moved cells) - -import sys -import os -import numpy as np - -from neuron import h -from cell import Pyr -import paramrw -import params_default as p_default - -# Units for e: mV -# Units for gbar: S/cm^2 unless otherwise noted - -# Layer 2 pyramidal cell class -class L2Pyr(Pyr): - - def __init__(self, gid = -1, pos = -1, p={}): - # Get default L2Pyr params and update them with any corresponding params in p - p_all_default = p_default.get_L2Pyr_params_default() - self.p_all = paramrw.compare_dictionaries(p_all_default, p) - - # Get somatic, dendritic, and synapse properties - p_soma = self.__get_soma_props(pos) - p_dend = self.__get_dend_props() - p_syn = self.__get_syn_props() - # p_dend_props, dend_names = self.__get_dend_props() - - # usage: Pyr.__init__(self, soma_props) - Pyr.__init__(self, gid, p_soma) - self.celltype = 'L2_pyramidal' - - # geometry - # creates dict of dends: self.dends - self.create_dends(p_dend) - self.topol() # sets the connectivity between sections - self.geom(p_dend) # sets geom properties; adjusted after translation from hoc (2009 model) - - # biophysics - self.__biophys_soma() - self.__biophys_dends() - - # dipole_insert() comes from Cell() - self.yscale = self.get_sectnames() - self.dipole_insert(self.yscale) - - # create synapses - self.__synapse_create(p_syn) - # self.__synapse_create() - - # run record_current_soma(), defined in Cell() - self.record_current_soma() - - # insert IClamps in all situations - # temporarily an external function taking the p dict - def create_all_IClamp(self, p): - # list of sections for this celltype - sect_list_IClamp = [ - 'soma', - ] - - # some parameters - t_delay = p['Itonic_t0_L2Pyr_soma'] - - # T = -1 means use h.tstop - if p['Itonic_T_L2Pyr_soma'] == -1: - # t_delay = 50. - t_dur = h.tstop - t_delay - - else: - t_dur = p['Itonic_T_L2Pyr_soma'] - t_delay - - # t_dur must be nonnegative, I imagine - if t_dur < 0.: - t_dur = 0. - - # properties of the IClamp - props_IClamp = { - 'loc': 0.5, - 'delay': t_delay, - 'dur': t_dur, - 'amp': p['Itonic_A_L2Pyr_soma'] - } - - # iterate through list of sect_list_IClamp to create a persistent IClamp object - # the insert_IClamp procedure is in Cell() and checks on names - # so names must be actual section names, or else it will fail silently - self.list_IClamp = [self.insert_IClamp(sect_name, props_IClamp) for sect_name in sect_list_IClamp] - - # Returns hardcoded somatic properties - def __get_soma_props(self, pos): - return { - 'pos': pos, - 'L': self.p_all['L2Pyr_soma_L'], - 'diam': self.p_all['L2Pyr_soma_diam'], - 'cm': self.p_all['L2Pyr_soma_cm'], - 'Ra': self.p_all['L2Pyr_soma_Ra'], - 'name': 'L2Pyr', - } - - # Returns hardcoded dendritic properties - def __get_dend_props(self): - return { - 'apical_trunk': { - 'L': self.p_all['L2Pyr_apicaltrunk_L'] , - 'diam': self.p_all['L2Pyr_apicaltrunk_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - 'apical_1': { - 'L': self.p_all['L2Pyr_apical1_L'], - 'diam': self.p_all['L2Pyr_apical1_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - 'apical_tuft': { - 'L': self.p_all['L2Pyr_apicaltuft_L'], - 'diam': self.p_all['L2Pyr_apicaltuft_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - 'apical_oblique': { - 'L': self.p_all['L2Pyr_apicaloblique_L'], - 'diam': self.p_all['L2Pyr_apicaloblique_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - 'basal_1': { - 'L': self.p_all['L2Pyr_basal1_L'], - 'diam': self.p_all['L2Pyr_basal1_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - 'basal_2': { - 'L': self.p_all['L2Pyr_basal2_L'], - 'diam': self.p_all['L2Pyr_basal2_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - 'basal_3': { - 'L': self.p_all['L2Pyr_basal3_L'], - 'diam': self.p_all['L2Pyr_basal3_diam'], - 'cm': self.p_all['L2Pyr_dend_cm'], - 'Ra': self.p_all['L2Pyr_dend_Ra'], - }, - } - - # This order matters! - # dend_order = ['apical_trunk', 'apical_1', 'apical_tuft', 'apical_oblique', - # 'basal_1', 'basal_2', 'basal_3'] - - # return dend_props, dend_order - - def __get_syn_props(self): - return { - 'ampa': { - 'e': self.p_all['L2Pyr_ampa_e'], - 'tau1': self.p_all['L2Pyr_ampa_tau1'], - 'tau2': self.p_all['L2Pyr_ampa_tau2'], - }, - 'nmda': { - 'e': self.p_all['L2Pyr_nmda_e'], - 'tau1': self.p_all['L2Pyr_nmda_tau1'], - 'tau2': self.p_all['L2Pyr_nmda_tau2'], - }, - 'gabaa': { - 'e': self.p_all['L2Pyr_gabaa_e'], - 'tau1': self.p_all['L2Pyr_gabaa_tau1'], - 'tau2': self.p_all['L2Pyr_gabaa_tau2'], - }, - 'gabab': { - 'e': self.p_all['L2Pyr_gabab_e'], - 'tau1': self.p_all['L2Pyr_gabab_tau1'], - 'tau2': self.p_all['L2Pyr_gabab_tau2'], - } - } - - def geom (self, p_dend): - soma = self.soma; dend = self.list_dend; - # increased by 70% for human - soma.L = 22.1 - dend[0].L = 59.5 - dend[1].L = 340 - dend[2].L = 306 - dend[3].L = 238 - dend[4].L = 85 - dend[5].L = 255 - dend[6].L = 255 - soma.diam = 23.4 - dend[0].diam = 4.25 - dend[1].diam = 3.91 - dend[2].diam = 4.08 - dend[3].diam = 3.4 - dend[4].diam = 4.25 - dend[5].diam = 2.72 - dend[6].diam = 2.72 - self.set_dend_props(p_dend) # resets length,diam,etc. based on param specification - - # Connects sections of THIS cell together - def topol (self): - """ original topol - connect dend(0), soma(1) - for i = 1, 2 connect dend[i](0), dend(1) - connect dend[3](0), dend[2](1) - connect dend[4](0), soma(0) //was soma(1), 0 is correct! - for i = 5, 6 connect dend[i](0), dend[4](1) - - """ - - # child.connect(parent, parent_end, {child_start=0}) - # Distal (Apical) - self.dends['apical_trunk'].connect(self.soma, 1, 0) - self.dends['apical_1'].connect(self.dends['apical_trunk'], 1, 0) - self.dends['apical_tuft'].connect(self.dends['apical_1'], 1, 0) - - # apical_oblique comes off distal end of apical_trunk - self.dends['apical_oblique'].connect(self.dends['apical_trunk'], 1, 0) - - # Proximal (basal) - self.dends['basal_1'].connect(self.soma, 0, 0) - self.dends['basal_2'].connect(self.dends['basal_1'], 1, 0) - self.dends['basal_3'].connect(self.dends['basal_1'], 1, 0) - - self.basic_shape() # translated from original hoc (2009 model) - - def basic_shape (self): - # THESE AND LENGHTHS MUST CHANGE TOGETHER!!! - pt3dclear=h.pt3dclear; pt3dadd=h.pt3dadd; soma = self.soma; dend = self.list_dend - pt3dclear(sec=soma); pt3dadd(-50, 765, 0, 1,sec=soma); pt3dadd(-50, 778, 0, 1,sec=soma) - pt3dclear(sec=dend[0]); pt3dadd(-50, 778, 0, 1,sec=dend[0]); pt3dadd(-50, 813, 0, 1,sec=dend[0]) - pt3dclear(sec=dend[1]); pt3dadd(-50, 813, 0, 1,sec=dend[1]); pt3dadd(-250, 813, 0, 1,sec=dend[1]) - pt3dclear(sec=dend[2]); pt3dadd(-50, 813, 0, 1,sec=dend[2]); pt3dadd(-50, 993, 0, 1,sec=dend[2]) - pt3dclear(sec=dend[3]); pt3dadd(-50, 993, 0, 1,sec=dend[3]); pt3dadd(-50, 1133, 0, 1,sec=dend[3]) - pt3dclear(sec=dend[4]); pt3dadd(-50, 765, 0, 1,sec=dend[4]); pt3dadd(-50, 715, 0, 1,sec=dend[4]) - pt3dclear(sec=dend[5]); pt3dadd(-50, 715, 0, 1,sec=dend[5]); pt3dadd(-156, 609, 0, 1,sec=dend[5]) - pt3dclear(sec=dend[6]); pt3dadd(-50, 715, 0, 1,sec=dend[6]); pt3dadd(56, 609, 0, 1,sec=dend[6]) - - # Adds biophysics to soma - def __biophys_soma (self): - # set soma biophysics specified in Pyr - # self.pyr_biophys_soma() - - # Insert 'hh2' mechanism - self.soma.insert('hh2') - self.soma.gkbar_hh2 = self.p_all['L2Pyr_soma_gkbar_hh2'] - self.soma.gl_hh2 = self.p_all['L2Pyr_soma_gl_hh2'] - self.soma.el_hh2 = self.p_all['L2Pyr_soma_el_hh2'] - self.soma.gnabar_hh2 = self.p_all['L2Pyr_soma_gnabar_hh2'] - - # Insert 'km' mechanism - # Units: pS/um^2 - self.soma.insert('km') - self.soma.gbar_km = self.p_all['L2Pyr_soma_gbar_km'] - - # Defining biophysics for dendrites - def __biophys_dends (self): - # set dend biophysics - # iterate over keys in self.dends and set biophysics for each dend - for key in self.dends: - # neuron syntax is used to set values for mechanisms - # sec.gbar_mech = x sets value of gbar for mech to x for all segs - # in a section. This method is significantly faster than using - # a for loop to iterate over all segments to set mech values - - # Insert 'hh' mechanism - self.dends[key].insert('hh2') - self.dends[key].gkbar_hh2 = self.p_all['L2Pyr_dend_gkbar_hh2'] - self.dends[key].gl_hh2 = self.p_all['L2Pyr_dend_gl_hh2'] - self.dends[key].gnabar_hh2 = self.p_all['L2Pyr_dend_gnabar_hh2'] - self.dends[key].el_hh2 = self.p_all['L2Pyr_dend_el_hh2'] - - # Insert 'km' mechanism - # Units: pS/um^2 - self.dends[key].insert('km') - self.dends[key].gbar_km = self.p_all['L2Pyr_dend_gbar_km'] - - def __synapse_create (self, p_syn): - # creates synapses onto this cell - # Somatic synapses - self.synapses = { - 'soma_gabaa': self.syn_create(self.soma(0.5), p_syn['gabaa']), - 'soma_gabab': self.syn_create(self.soma(0.5), p_syn['gabab']), - } - - # Dendritic synapses - self.apicaloblique_ampa = self.syn_create(self.dends['apical_oblique'](0.5), p_syn['ampa']) - self.apicaloblique_nmda = self.syn_create(self.dends['apical_oblique'](0.5), p_syn['nmda']) - - self.basal2_ampa = self.syn_create(self.dends['basal_2'](0.5), p_syn['ampa']) - self.basal2_nmda = self.syn_create(self.dends['basal_2'](0.5), p_syn['nmda']) - - self.basal3_ampa = self.syn_create(self.dends['basal_3'](0.5), p_syn['ampa']) - self.basal3_nmda = self.syn_create(self.dends['basal_3'](0.5), p_syn['nmda']) - - self.apicaltuft_ampa = self.syn_create(self.dends['apical_tuft'](0.5), p_syn['ampa']) - self.apicaltuft_nmda = self.syn_create(self.dends['apical_tuft'](0.5), p_syn['nmda']) - - # self.synapses = { - # 'soma_gabaa': self.syn_gabaa_create(self.soma(0.5)), - # 'soma_gabab': self.syn_gabab_create(self.soma(0.5)), - # } - - # Dendritic synapses - # self.apicaloblique_ampa = self.syn_ampa_create(self.dends['apical_oblique'](0.5), p_syn['ampa']) - # self.apicaloblique_nmda = self.syn_create(self.dends['apical_oblique'](0.5), p_syn['nmda']) - - # self.basal2_ampa = self.syn_ampa_create(self.dends['basal_2'](0.5)) - # self.basal2_nmda = self.syn_nmda_create(self.dends['basal_2'](0.5)) - - # self.basal3_ampa = self.syn_ampa_create(self.dends['basal_3'](0.5)) - # self.basal3_nmda = self.syn_nmda_create(self.dends['basal_3'](0.5)) - - # self.apicaltuft_ampa = self.syn_ampa_create(self.dends['apical_tuft'](0.5)) - # self.apicaltuft_nmda = self.syn_nmda_create(self.dends['apical_tuft'](0.5)) - - # collect receptor-type-based connections here - def parconnect (self, gid, gid_dict, pos_dict, p): - # init dict of dicts - # nc_dict for ampa and nmda may be the same for this cell type - nc_dict = { - 'ampa': None, - 'nmda': None, - } - - # Connections FROM all other L2 Pyramidal cells to this one - for gid_src, pos in zip(gid_dict['L2_pyramidal'], pos_dict['L2_pyramidal']): - # don't be redundant, this is only possible for LIKE cells, but it might not hurt to check - if gid_src != gid: - nc_dict['ampa'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Pyr_L2Pyr_ampa'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L2_pyramidal' - } - - # parconnect_from_src(gid_presyn, nc_dict, postsyn) - # ampa connections - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict['ampa'], self.apicaloblique_ampa)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict['ampa'], self.basal2_ampa)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict['ampa'], self.basal3_ampa)) - - nc_dict['nmda'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Pyr_L2Pyr_nmda'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L2_pyramidal' - } - - # parconnect_from_src(gid_presyn, nc_dict, postsyn) - # nmda connections - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict['nmda'], self.apicaloblique_nmda)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict['nmda'], self.basal2_nmda)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict['nmda'], self.basal3_nmda)) - - # connections FROM L2 basket cells TO this L2Pyr cell - for gid_src, pos in zip(gid_dict['L2_basket'], pos_dict['L2_basket']): - nc_dict['gabaa'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Basket_L2Pyr_gabaa'], - 'A_delay': 1., - 'lamtha': 50., - 'threshold': p['threshold'], - 'type_src' : 'L2_basket' - } - - nc_dict['gabab'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Basket_L2Pyr_gabab'], - 'A_delay': 1., - 'lamtha': 50., - 'threshold': p['threshold'], - 'type_src' : 'L2_basket' - } - - self.ncfrom_L2Basket.append(self.parconnect_from_src(gid_src, nc_dict['gabaa'], self.synapses['soma_gabaa'])) - self.ncfrom_L2Basket.append(self.parconnect_from_src(gid_src, nc_dict['gabab'], self.synapses['soma_gabab'])) - - # connections FROM L5 basket cells TO this L2Pyr cell - # for gid_src in gid_dict['L5_basket']: - # nc_dict = { - # 'pos_src': pos_list[gid_src], - # 'A_weight': 2.5e-2, - # 'A_delay': 1., - # 'lamtha': 70. - # } - - # self.ncfrom_L5Basket.append(self.parconnect_from_src(gid_src, nc_dict, self.synapes['soma_gabaa'])) - # self.ncfrom_L5Basket.append(self.parconnect_from_src(gid_src, nc_dict, self.synapes['soma_gabab'])) - - # may be reorganizable - def parreceive (self, gid, gid_dict, pos_dict, p_ext): - for gid_src, p_src, pos in zip(gid_dict['extinput'], p_ext, pos_dict['extinput']): - # Check if AMPA params defined in p_src - if 'L2Pyr_ampa' in p_src.keys(): - nc_dict_ampa = { - 'pos_src': pos, - 'A_weight': p_src['L2Pyr_ampa'][0], - 'A_delay': p_src['L2Pyr_ampa'][1], - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src': 'ext' - } - - # Proximal feed AMPA synapses - if p_src['loc'] is 'proximal': - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.basal2_ampa)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.basal3_ampa)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.apicaloblique_ampa)) - # Distal feed AMPA synapses - elif p_src['loc'] is 'distal': - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.apicaltuft_ampa)) - - # Check is NMDA params defined in p_src - if 'L2Pyr_nmda' in p_src.keys(): - nc_dict_nmda = { - 'pos_src': pos, - 'A_weight': p_src['L2Pyr_nmda'][0], - 'A_delay': p_src['L2Pyr_nmda'][1], - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src': 'ext' - } - - # Proximal feed NMDA synapses - if p_src['loc'] is 'proximal': - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.basal2_nmda)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.basal3_nmda)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.apicaloblique_nmda)) - # Distal feed NMDA synapses - elif p_src['loc'] is 'distal': - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.apicaltuft_nmda)) - - # one parreceive function to handle all types of external parreceives - # types must be defined explicitly here - # this function handles evoked, gaussian, and poisson inputs - def parreceive_ext (self, type, gid, gid_dict, pos_dict, p_ext): - if type.startswith(('evprox', 'evdist')): - if self.celltype in p_ext.keys(): - gid_ev = gid + gid_dict[type][0] - - # separate dictionaries for ampa and nmda evoked inputs - nc_dict_ampa = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 for ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 for delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src': type - } - - nc_dict_nmda = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][1], # index 1 for nmda weight - 'A_delay': p_ext[self.celltype][2], # index 2 for delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src': type - } - - if p_ext['loc'] is 'proximal': - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.basal2_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.basal3_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.apicaloblique_ampa)) - - # NEW: note that default/original is 0 nmda weight for these proximal dends - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.basal2_nmda)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.basal3_nmda)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.apicaloblique_nmda)) - - elif p_ext['loc'] is 'distal': - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.apicaltuft_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.apicaltuft_nmda)) - - elif type == 'extgauss': - # gid is this cell's gid - # gid_dict is the whole dictionary, including the gids of the extgauss - # pos_list is also the pos of the extgauss (net origin) - # p_ext_gauss are the params (strength, etc.) - - # gid shift is based on L2_pyramidal cells NOT L5 - # I recognize this is ugly (hack) - # gid_shift = gid_dict['extgauss'][0] - gid_dict['L2_pyramidal'][0] - if 'L2_pyramidal' in p_ext.keys(): - gid_extgauss = gid + gid_dict['extgauss'][0] - - nc_dict = { - 'pos_src': pos_dict['extgauss'][gid], - 'A_weight': p_ext['L2_pyramidal'][0], # index 0 for ampa weight (nmda not yet used in Gauss) - 'A_delay': p_ext['L2_pyramidal'][2], # index 2 for delay - 'lamtha': p_ext['lamtha'], - 'threshold': p_ext['threshold'], - 'type_src': type - } - - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss,nc_dict,self.basal2_ampa)) - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss,nc_dict,self.basal3_ampa)) - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss,nc_dict,self.apicaloblique_ampa)) - - elif type == 'extpois': - if self.celltype in p_ext.keys(): - gid_extpois = gid + gid_dict['extpois'][0] - - nc_dict = { - 'pos_src': pos_dict['extpois'][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 for ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 for delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src': type - } - - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois,nc_dict,self.basal2_ampa)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois,nc_dict,self.basal3_ampa)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois,nc_dict,self.apicaloblique_ampa)) - - if p_ext[self.celltype][1] > 0.0: - nc_dict['A_weight'] = p_ext[self.celltype][1] # index 1 for nmda weight - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois,nc_dict,self.basal2_nmda)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois,nc_dict,self.basal3_nmda)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois,nc_dict,self.apicaloblique_nmda)) - - else: - print("Warning, ext type def does not exist in L2Pyr") - - # Define 3D shape and position of cell. By default neuron uses xy plane for - # height and xz plane for depth. This is opposite for model as a whole, but - # convention is followed in this function for ease use of gui. - def __set_3Dshape (self): - # set 3d shape of soma by calling shape_soma from class Cell - # print("Warning: You are setiing 3d shape geom. You better be doing") - # print("gui analysis and not numerical analysis!!") - self.shape_soma() - - # soma proximal coords - x_prox = 0 - y_prox = 0 - - # soma distal coords - x_distal = 0 - y_distal = self.soma.L - - # dend 0-2 are major axis, dend 3 is branch - # deal with distal first along major cable axis - # the way this is assigning variables is ugly/lazy right now - for i in range(0, 3): - h.pt3dclear(sec=self.list_dend[i]) - - # x_distal and y_distal are the starting points for each segment - # these are updated at the end of the loop - h.pt3dadd(0, y_distal, 0, self.dend_diam[i], sec=self.list_dend[i]) - - # update x_distal and y_distal after setting them - # x_distal += dend_dx[i] - y_distal += self.dend_L[i] - - # add next point - h.pt3dadd(0, y_distal, 0, self.dend_diam[i], sec=self.list_dend[i]) - - # now deal with dend 3 - # dend 3 will ALWAYS be positioned at the end of dend[0] - h.pt3dclear(sec=self.list_dend[3]) - - # activate this section with 'sec =' notation - # self.list_dend[0].push() - x_start = h.x3d(1, sec = self.list_dend[0]) - y_start = h.y3d(1, sec = self.list_dend[0]) - # h.pop_section() - - h.pt3dadd(x_start, y_start, 0, self.dend_diam[3], sec=self.list_dend[3]) - # self.dend_L[3] is subtracted because lengths always positive, - # and this goes to negative x - h.pt3dadd(x_start-self.dend_L[3], y_start, 0, self.dend_diam[3], sec=self.list_dend[3]) - - # now deal with proximal dends - for i in range(4, 7): - h.pt3dclear(sec=self.list_dend[i]) - - # deal with dend 4, ugly. sorry. - h.pt3dadd(x_prox, y_prox, 0, self.dend_diam[i], sec=self.list_dend[4]) - y_prox += -self.dend_L[4] - - h.pt3dadd(x_prox, y_prox, 0, self.dend_diam[4], sec=self.list_dend[4]) - - # x_prox, y_prox are now the starting points for BOTH last 2 sections - - # dend 5 - # Calculate x-coordinate for end of dend - dend5_x = -self.dend_L[5] * np.sqrt(2) / 2. - h.pt3dadd(x_prox, y_prox, 0, self.dend_diam[5], sec=self.list_dend[5]) - h.pt3dadd(dend5_x, y_prox-self.dend_L[5] * np.sqrt(2) / 2., - 0, self.dend_diam[5], sec=self.list_dend[5]) - - # dend 6 - # Calculate x-coordinate for end of dend - dend6_x = self.dend_L[6] * np.sqrt(2) / 2. - h.pt3dadd(x_prox, y_prox, 0, self.dend_diam[6], sec=self.list_dend[6]) - h.pt3dadd(dend6_x, y_prox-self.dend_L[6] * np.sqrt(2) / 2., - 0, self.dend_diam[6], sec=self.list_dend[6]) - - # set 3D position - # z grid position used as y coordinate in h.pt3dchange() to satisfy - # gui convention that y is height and z is depth. In h.pt3dchange() - # x and z components are scaled by 100 for visualization clarity - self.soma.push() - for i in range(0, int(h.n3d())): - h.pt3dchange(i, self.pos[0]*100 + h.x3d(i), self.pos[2] + - h.y3d(i), self.pos[1] * 100 + h.z3d(i), - h.diam3d(i)) - - h.pop_section() diff --git a/L5_basket.py b/L5_basket.py deleted file mode 100644 index b69fe1a70..000000000 --- a/L5_basket.py +++ /dev/null @@ -1,213 +0,0 @@ -# L5_basket.py - establish class def for layer 5 basket cells -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed izip dep) -# last rev: (SL: toward python3) - -from neuron import h as nrn -from cell import BasketSingle - -# Units for e: mV -# Units for gbar: S/cm^2 unless otherwise noted - -# Layer 5 basket cell class -class L5Basket(BasketSingle): - def __init__(self, gid = -1, pos = -1): - # Note: Cell properties are set in BasketSingle() - BasketSingle.__init__(self, gid, pos, 'L5Basket') - self.celltype = 'L5_basket' - - self.__synapse_create() - self.__biophysics() - - # creates synapses - def __synapse_create(self): - # creates synapses onto this cell - self.soma_ampa = self.syn_ampa_create(self.soma(0.5)) - self.soma_nmda = self.syn_nmda_create(self.soma(0.5)) - self.soma_gabaa = self.syn_gabaa_create(self.soma(0.5)) - - # insert IClamps in all situations - def create_all_IClamp(self, p): - """ temporarily an external function taking the p dict - """ - # list of sections for this celltype - sect_list_IClamp = [ - 'soma', - ] - - # some parameters - t_delay = p['Itonic_t0_L5Basket'] - - # T = -1 means use nrn.tstop - if p['Itonic_T_L5Basket'] == -1: - t_dur = nrn.tstop - t_delay - - else: - t_dur = p['Itonic_T_L5Basket'] - t_delay - - # t_dur must be nonnegative, I imagine - if t_dur < 0.: - t_dur = 0. - - # properties of the IClamp - props_IClamp = { - 'loc': 0.5, - 'delay': t_delay, - 'dur': t_dur, - 'amp': p['Itonic_A_L5Basket'] - } - - # iterate through list of sect_list_IClamp to create a persistent IClamp object - # the insert_IClamp procedure is in Cell() and checks on names - # so names must be actual section names, or else it will fail silently - self.list_IClamp = [self.insert_IClamp(sect_name, props_IClamp) for sect_name in sect_list_IClamp] - - # defines biophysics - def __biophysics(self): - self.soma.insert('hh2') - - # connections FROM other cells TO this cell - # there are no connections from the L2Basket cells. congrats! - def parconnect(self, gid, gid_dict, pos_dict, p): - # FROM other L5Basket cells TO this cell - for gid_src, pos in zip(gid_dict['L5_basket'], pos_dict['L5_basket']): - if gid_src != gid: - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L5Basket_L5Basket'], - 'A_delay': 1., - 'lamtha': 20., - 'threshold': p['threshold'], - 'type_src' : 'L5_basket' - } - - self.ncfrom_L5Basket.append(self.parconnect_from_src(gid_src, nc_dict, self.soma_gabaa)) - - # FROM other L5Pyr cells TO this cell - for gid_src, pos in zip(gid_dict['L5_pyramidal'], pos_dict['L5_pyramidal']): - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L5Pyr_L5Basket'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L5_pyramidal' - } - - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.soma_ampa)) - - # FROM other L2Pyr cells TO this cell - for gid_src, pos in zip(gid_dict['L2_pyramidal'], pos_dict['L2_pyramidal']): - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Pyr_L5Basket'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L2_pyramidal' - } - - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.soma_ampa)) - - # parallel receive function parreceive() - def parreceive(self, gid, gid_dict, pos_dict, p_ext): - for gid_src, p_src, pos in zip(gid_dict['extinput'], p_ext, pos_dict['extinput']): - # Check if AMPA params are define in p_src - if 'L5Basket_ampa' in p_src.keys(): - nc_dict_ampa = { - 'pos_src': pos, - 'A_weight': p_src['L5Basket_ampa'][0], - 'A_delay': p_src['L5Basket_ampa'][1], # right index?? - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src' : 'ext' - } - - # AMPA synapse - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.soma_ampa)) - - # Check if nmda params are define in p_src - if 'L5Basket_nmda' in p_src.keys(): - nc_dict_nmda = { - 'pos_src': pos, - 'A_weight': p_src['L5Basket_nmda'][0], - 'A_delay': p_src['L5Basket_nmda'][1], # right index?? - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src' : 'ext' - } - - # NMDA synapse - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.soma_nmda)) - - # one parreceive function to handle all types of external parreceives - # types must be defined explicitly here - def parreceive_ext(self, type, gid, gid_dict, pos_dict, p_ext): - if type.startswith(('evprox', 'evdist')): # shouldn't this just check for evprox? - if self.celltype in p_ext.keys(): - gid_ev = gid + gid_dict[type][0] - - nc_dict_ampa = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 is ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 is delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - nc_dict_nmda = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][1], # index 1 is nmda weight - 'A_delay': p_ext[self.celltype][2], # index 2 is delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.soma_ampa)) - - # NEW: note that default/original is 0 nmda weight for the soma (both prox and distal evoked) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.soma_nmda)) - - elif type == 'extgauss': - # gid is this cell's gid - # gid_dict is the whole dictionary, including the gids of the extgauss - # pos_dict is also the pos of the extgauss (net origin) - # p_ext_gauss are the params (strength, etc.) - if 'L5_basket' in p_ext.keys(): - gid_extgauss = gid + gid_dict['extgauss'][0] - - nc_dict = { - 'pos_src': pos_dict['extgauss'][gid], - 'A_weight': p_ext['L5_basket'][0], # index 0 is ampa weight - 'A_delay': p_ext['L5_basket'][2], # index 2 is delay - 'lamtha': p_ext['lamtha'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss, nc_dict, self.soma_ampa)) - - elif type == 'extpois': - if self.celltype in p_ext.keys(): - gid_extpois = gid + gid_dict['extpois'][0] - - nc_dict = { - 'pos_src': pos_dict['extpois'][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 is ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 is delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.soma_ampa)) - - if p_ext[self.celltype][1] > 0.0: - nc_dict['A_weight'] = p_ext[self.celltype][1] # index 1 for nmda weight - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.soma_nmda)) - - else: - print("Warning, type def not specified in L2Basket") diff --git a/L5_pyramidal.py b/L5_pyramidal.py deleted file mode 100644 index 2f338e61f..000000000 --- a/L5_pyramidal.py +++ /dev/null @@ -1,720 +0,0 @@ -# L5_pyramidal.py - establish class def for layer 5 pyramidal cells -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed it.izip dep) -# last rev: (SL: toward python3, moved cells) - -import sys -import numpy as np - -from neuron import h -from cell import Pyr -import paramrw -import params_default as p_default - -# Units for e: mV -# Units for gbar: S/cm^2 unless otherwise noted -# units for taur: ms - -class L5Pyr(Pyr): - - def basic_shape (self): - # THESE AND LENGHTHS MUST CHANGE TOGETHER!!! - pt3dclear=h.pt3dclear; pt3dadd=h.pt3dadd; dend = self.list_dend - pt3dclear(sec=self.soma); pt3dadd(0, 0, 0, 1, sec=self.soma); pt3dadd(0, 23, 0, 1, sec=self.soma) - pt3dclear(sec=dend[0]); pt3dadd(0, 23, 0, 1,sec=dend[0]); pt3dadd(0, 83, 0, 1,sec=dend[0]) - pt3dclear(sec=dend[1]); pt3dadd(0, 83, 0, 1,sec=dend[1]); pt3dadd(-150, 83, 0, 1,sec=dend[1]) - pt3dclear(sec=dend[2]); pt3dadd(0, 83, 0, 1,sec=dend[2]); pt3dadd(0, 483, 0, 1,sec=dend[2]) - pt3dclear(sec=dend[3]); pt3dadd(0, 483, 0, 1,sec=dend[3]); pt3dadd(0, 883, 0, 1,sec=dend[3]) - pt3dclear(sec=dend[4]); pt3dadd(0, 883, 0, 1,sec=dend[4]); pt3dadd(0, 1133, 0, 1,sec=dend[4]) - pt3dclear(sec=dend[5]); pt3dadd(0, 0, 0, 1,sec=dend[5]); pt3dadd(0, -50, 0, 1,sec=dend[5]) - pt3dclear(sec=dend[6]); pt3dadd(0, -50, 0, 1,sec=dend[6]); pt3dadd(-106, -156, 0, 1,sec=dend[6]) - pt3dclear(sec=dend[7]); pt3dadd(0, -50, 0, 1,sec=dend[7]); pt3dadd(106, -156, 0, 1,sec=dend[7]) - - def geom (self, p_dend): - soma = self.soma; dend = self.list_dend; - # soma.L = 13 # BUSH 1999 spike amp smaller - soma.L=39 # Bush 1993 - dend[0].L = 102 - dend[1].L = 255 - dend[2].L = 680 # default 400 - dend[3].L = 680 # default 400 - dend[4].L = 425 - dend[5].L = 85 - dend[6].L = 255 # default 150 - dend[7].L = 255 # default 150 - # soma.diam = 18.95 # Bush 1999 - soma.diam = 28.9 # Bush 1993 - dend[0].diam = 10.2 - dend[1].diam = 5.1 - dend[2].diam = 7.48 # default 4.4 - dend[3].diam = 4.93 # default 2.9 - dend[4].diam = 3.4 - dend[5].diam = 6.8 - dend[6].diam = 8.5 - dend[7].diam = 8.5 - self.set_dend_props(p_dend) # resets length,diam,etc. based on param specification - - def __init__(self, gid = -1, pos = -1, p={}): - # Get default L5Pyr params and update them with corresponding params in p - p_all_default = p_default.get_L5Pyr_params_default() - self.p_all = paramrw.compare_dictionaries(p_all_default, p) - - # Get somatic, dendirtic, and synapse properties - p_soma = self.__get_soma_props(pos) - p_dend = self.__get_dend_props() - p_syn = self.__get_syn_props() - - Pyr.__init__(self, gid, p_soma) - self.celltype = 'L5_pyramidal' - - # Geometry - # dend Cm and dend Ra set using soma Cm and soma Ra - self.create_dends(p_dend) # just creates the sections - self.topol() # sets the connectivity between sections - self.geom(p_dend) # sets geom properties; adjusted after translation from hoc (2009 model) - - # biophysics - self.__biophys_soma() - self.__biophys_dends() - - # Dictionary of length scales to calculate dipole without 3d shape. Comes from Pyr(). - # dipole_insert() comes from Cell() - self.yscale = self.get_sectnames() - self.dipole_insert(self.yscale) - - # create synapses - self.__synapse_create(p_syn) - - # insert iclamp - self.list_IClamp = [] - - # run record current soma, defined in Cell() - self.record_current_soma() - - # insert IClamps in all situations - # temporarily an external function taking the p dict - def create_all_IClamp(self, p): - # list of sections for this celltype - sect_list_IClamp = ['soma',] - - # some parameters - t_delay = p['Itonic_t0_L5Pyr_soma'] - - # T = -1 means use h.tstop - if p['Itonic_T_L5Pyr_soma'] == -1: - # t_delay = 50. - t_dur = h.tstop - t_delay - else: - t_dur = p['Itonic_T_L5Pyr_soma'] - t_delay - - # t_dur must be nonnegative, I imagine - if t_dur < 0.: - t_dur = 0. - - # properties of the IClamp - props_IClamp = { - 'loc': 0.5, - 'delay': t_delay, - 'dur': t_dur, - 'amp': p['Itonic_A_L5Pyr_soma'] - } - - # iterate through list of sect_list_IClamp to create a persistent IClamp object - # the insert_IClamp procedure is in Cell() and checks on names - # so names must be actual section names, or else it will fail silently - self.list_IClamp = [self.insert_IClamp(sect_name, props_IClamp) for sect_name in sect_list_IClamp] - - # Sets somatic properties. Returns dictionary. - def __get_soma_props(self, pos): - return { - 'pos': pos, - 'L': self.p_all['L5Pyr_soma_L'], - 'diam': self.p_all['L5Pyr_soma_diam'], - 'cm': self.p_all['L5Pyr_soma_cm'], - 'Ra': self.p_all['L5Pyr_soma_Ra'], - 'name': 'L5Pyr', - } - - # Returns dictionary of dendritic properties and list of dendrite names - def __get_dend_props(self): - # def __set_dend_props(self): - # Hard coded dend properties - # dend_props = { - return { - 'apical_trunk': { - 'L': self.p_all['L5Pyr_apicaltrunk_L'] , - 'diam': self.p_all['L5Pyr_apicaltrunk_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'apical_1': { - 'L': self.p_all['L5Pyr_apical1_L'], - 'diam': self.p_all['L5Pyr_apical1_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'apical_2': { - 'L': self.p_all['L5Pyr_apical2_L'], - 'diam': self.p_all['L5Pyr_apical2_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'apical_tuft': { - 'L': self.p_all['L5Pyr_apicaltuft_L'], - 'diam': self.p_all['L5Pyr_apicaltuft_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'apical_oblique': { - 'L': self.p_all['L5Pyr_apicaloblique_L'], - 'diam': self.p_all['L5Pyr_apicaloblique_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'basal_1': { - 'L': self.p_all['L5Pyr_basal1_L'], - 'diam': self.p_all['L5Pyr_basal1_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'basal_2': { - 'L': self.p_all['L5Pyr_basal2_L'], - 'diam': self.p_all['L5Pyr_basal2_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - 'basal_3': { - 'L': self.p_all['L5Pyr_basal3_L'], - 'diam': self.p_all['L5Pyr_basal3_diam'], - 'cm': self.p_all['L5Pyr_dend_cm'], - 'Ra': self.p_all['L5Pyr_dend_Ra'], - }, - } - - # These MUST match order the above keys in exact order! - # dend_names = [ - # 'apical_trunk', 'apical_1', 'apical_2', - # 'apical_tuft', 'apical_oblique', 'basal_1', - # 'basal_2', 'basal_3' - # ] - - # return dend_props, dend_names - - # self.dend_L = [102, 680, 680, 425, 255, 85, 255, 255] - # self.dend_diam = [10.2, 7.48, 4.93, 3.4, 5.1, 6.8, 8.5, 8.5] - - # # check lengths for congruity - # if len(self.dend_L) == len(self.dend_diam): - # # Zip above lists together - # self.dend_props = zip(self.dend_names, self.dend_L, self.dend_diam) - # else: - # print "self.dend_L and self.dend_diam are not the same length" - # print "please fix in L5_pyramidal.py" - # sys.exit() - - def __get_syn_props(self): - return { - 'ampa': { - 'e': self.p_all['L5Pyr_ampa_e'], - 'tau1': self.p_all['L5Pyr_ampa_tau1'], - 'tau2': self.p_all['L5Pyr_ampa_tau2'], - }, - 'nmda': { - 'e': self.p_all['L5Pyr_nmda_e'], - 'tau1': self.p_all['L5Pyr_nmda_tau1'], - 'tau2': self.p_all['L5Pyr_nmda_tau2'], - }, - 'gabaa': { - 'e': self.p_all['L5Pyr_gabaa_e'], - 'tau1': self.p_all['L5Pyr_gabaa_tau1'], - 'tau2': self.p_all['L5Pyr_gabaa_tau2'], - }, - 'gabab': { - 'e': self.p_all['L5Pyr_gabab_e'], - 'tau1': self.p_all['L5Pyr_gabab_tau1'], - 'tau2': self.p_all['L5Pyr_gabab_tau2'], - } - } - - # connects sections of this cell together - def topol (self): - - """ original topol - connect dend(0), soma(1) // dend[0] is apical trunk - for i = 1, 2 connect dend[i](0), dend(1) // dend[1] is oblique, dend[2] is apic1 - for i = 3, 4 connect dend[i](0), dend[i-1](1) // dend[3],dend[4] are apic2,apic tuft - connect dend[5](0), soma(0) //was soma(1)this is correct! - for i = 6, 7 connect dend[i](0), dend[5](1) - """ - - # child.connect(parent, parent_end, {child_start=0}) - # Distal (apical) - self.dends['apical_trunk'].connect(self.soma, 1, 0) - self.dends['apical_1'].connect(self.dends['apical_trunk'], 1, 0) - self.dends['apical_2'].connect(self.dends['apical_1'], 1, 0) - self.dends['apical_tuft'].connect(self.dends['apical_2'], 1, 0) - - # apical_oblique comes off distal end of apical_trunk - self.dends['apical_oblique'].connect(self.dends['apical_trunk'], 1, 0) - - # Proximal (basal) - self.dends['basal_1'].connect(self.soma, 0, 0) - self.dends['basal_2'].connect(self.dends['basal_1'], 1, 0) - self.dends['basal_3'].connect(self.dends['basal_1'], 1, 0) - - self.basic_shape() # translated from original hoc (2009 model) - - # # Distal - # self.list_dend[0].connect(self.soma, 1, 0) - # self.list_dend[1].connect(self.list_dend[0], 1, 0) - - # self.list_dend[2].connect(self.list_dend[1], 1, 0) - # self.list_dend[3].connect(self.list_dend[2], 1, 0) - - # # dend[4] comes off of dend[0](1) - # self.list_dend[4].connect(self.list_dend[0], 1, 0) - - # # Proximal - # self.list_dend[5].connect(self.soma, 0, 0) - # self.list_dend[6].connect(self.list_dend[5], 1, 0) - # self.list_dend[7].connect(self.list_dend[5], 1, 0) - - # adds biophysics to soma - def __biophys_soma(self): - # set soma biophysics specified in Pyr - # self.pyr_biophys_soma() - - # Insert 'hh2' mechanism - self.soma.insert('hh2') - self.soma.gkbar_hh2 = self.p_all['L5Pyr_soma_gkbar_hh2'] - self.soma.gnabar_hh2 = self.p_all['L5Pyr_soma_gnabar_hh2'] - self.soma.gl_hh2 = self.p_all['L5Pyr_soma_gl_hh2'] - self.soma.el_hh2 = self.p_all['L5Pyr_soma_el_hh2'] - - # insert 'ca' mechanism - # Units: pS/um^2 - self.soma.insert('ca') - self.soma.gbar_ca = self.p_all['L5Pyr_soma_gbar_ca'] - - # insert 'cad' mechanism - # units of tau are ms - self.soma.insert('cad') - self.soma.taur_cad = self.p_all['L5Pyr_soma_taur_cad'] - - # insert 'kca' mechanism - # units are S/cm^2? - self.soma.insert('kca') - self.soma.gbar_kca = self.p_all['L5Pyr_soma_gbar_kca'] - - # Insert 'km' mechanism - # Units: pS/um^2 - self.soma.insert('km') - self.soma.gbar_km = self.p_all['L5Pyr_soma_gbar_km'] - - # insert 'cat' mechanism - self.soma.insert('cat') - self.soma.gbar_cat = self.p_all['L5Pyr_soma_gbar_cat'] - - # insert 'ar' mechanism - self.soma.insert('ar') - self.soma.gbar_ar = self.p_all['L5Pyr_soma_gbar_ar'] - - def __biophys_dends(self): - # set dend biophysics specified in Pyr() - # self.pyr_biophys_dends() - - # set dend biophysics not specified in Pyr() - for key in self.dends: - # Insert 'hh2' mechanism - self.dends[key].insert('hh2') - self.dends[key].gkbar_hh2 = self.p_all['L5Pyr_dend_gkbar_hh2'] - self.dends[key].gl_hh2 = self.p_all['L5Pyr_dend_gl_hh2'] - self.dends[key].gnabar_hh2 = self.p_all['L5Pyr_dend_gnabar_hh2'] - self.dends[key].el_hh2 = self.p_all['L5Pyr_dend_el_hh2'] - - # Insert 'ca' mechanims - # Units: pS/um^2 - self.dends[key].insert('ca') - self.dends[key].gbar_ca = self.p_all['L5Pyr_dend_gbar_ca'] - - # Insert 'cad' mechanism - self.dends[key].insert('cad') - self.dends[key].taur_cad = self.p_all['L5Pyr_dend_taur_cad'] - - # Insert 'kca' mechanism - self.dends[key].insert('kca') - self.dends[key].gbar_kca = self.p_all['L5Pyr_dend_gbar_kca'] - - # Insert 'km' mechansim - # Units: pS/um^2 - self.dends[key].insert('km') - self.dends[key].gbar_km = self.p_all['L5Pyr_dend_gbar_km'] - - # insert 'cat' mechanism - self.dends[key].insert('cat') - self.dends[key].gbar_cat = self.p_all['L5Pyr_dend_gbar_cat'] - - # insert 'ar' mechanism - self.dends[key].insert('ar') - - # set gbar_ar - # Value depends on distance from the soma. Soma is set as - # origin by passing self.soma as a sec argument to h.distance() - # Then iterate over segment nodes of dendritic sections - # and set gbar_ar depending on h.distance(seg.x), which returns - # distance from the soma to this point on the CURRENTLY ACCESSED - # SECTION!!! - h.distance(sec=self.soma) - - for key in self.dends: - self.dends[key].push() - for seg in self.dends[key]: - seg.gbar_ar = 1e-6 * np.exp(3e-3 * h.distance(seg.x)) - - h.pop_section() - - def __synapse_create(self, p_syn): - # creates synapses onto this cell - # Somatic synapses - self.synapses = { - 'soma_gabaa': self.syn_create(self.soma(0.5), p_syn['gabaa']), - 'soma_gabab': self.syn_create(self.soma(0.5), p_syn['gabab']), - } - - # Dendritic synapses - self.apicaltuft_gabaa = self.syn_create(self.dends['apical_tuft'](0.5), p_syn['gabaa']) - #self.apicaltuft_gabaa = self.syn_create(self.dends['apical_tuft'](0.5), p_syn['gabab'])#RL version - - self.apicaltuft_ampa = self.syn_create(self.dends['apical_tuft'](0.5), p_syn['ampa']) - self.apicaltuft_nmda = self.syn_create(self.dends['apical_tuft'](0.5), p_syn['nmda']) - - self.apicaloblique_ampa = self.syn_create(self.dends['apical_oblique'](0.5), p_syn['ampa']) - self.apicaloblique_nmda = self.syn_create(self.dends['apical_oblique'](0.5), p_syn['nmda']) - - self.basal2_ampa = self.syn_create(self.dends['basal_2'](0.5), p_syn['ampa']) - self.basal2_nmda = self.syn_create(self.dends['basal_2'](0.5), p_syn['nmda']) - - self.basal3_ampa = self.syn_create(self.dends['basal_3'](0.5), p_syn['ampa']) - self.basal3_nmda = self.syn_create(self.dends['basal_3'](0.5), p_syn['nmda']) - - # parallel connection function FROM all cell types TO here - def parconnect(self, gid, gid_dict, pos_dict, p): - # init dict of dicts - # nc_dict for ampa and nmda may be the same for this cell type - nc_dict = { - 'ampa': None, - 'nmda': None, - } - - # connections FROM L5Pyr TO here - for gid_src, pos in zip(gid_dict['L5_pyramidal'], pos_dict['L5_pyramidal']): - # no autapses - if gid_src != gid: - nc_dict['ampa'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L5Pyr_L5Pyr_ampa'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L5_pyramidal' - } - - # ampa connections - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict['ampa'], self.apicaloblique_ampa)) - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict['ampa'], self.basal2_ampa)) - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict['ampa'], self.basal3_ampa)) - - nc_dict['nmda'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L5Pyr_L5Pyr_nmda'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L5_pyramidal' - } - - # nmda connections - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict['nmda'], self.apicaloblique_nmda)) - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict['nmda'], self.basal2_nmda)) - self.ncfrom_L5Pyr.append(self.parconnect_from_src(gid_src, nc_dict['nmda'], self.basal3_nmda)) - - # connections FROM L5Basket TO here - for gid_src, pos in zip(gid_dict['L5_basket'], pos_dict['L5_basket']): - nc_dict['gabaa'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L5Basket_L5Pyr_gabaa'], - 'A_delay': 1., - 'lamtha': 70., - 'threshold': p['threshold'], - 'type_src' : 'L5_basket' - } - - nc_dict['gabab'] = { - 'pos_src': pos, - 'A_weight': p['gbar_L5Basket_L5Pyr_gabab'], - 'A_delay': 1., - 'lamtha': 70., - 'threshold': p['threshold'], - 'type_src' : 'L5_basket' - } - - # soma synapses are defined in Pyr() - self.ncfrom_L5Basket.append(self.parconnect_from_src(gid_src, nc_dict['gabaa'], self.synapses['soma_gabaa'])) - self.ncfrom_L5Basket.append(self.parconnect_from_src(gid_src, nc_dict['gabab'], self.synapses['soma_gabab'])) - - # connections FROM L2Pyr TO here - for gid_src, pos in zip(gid_dict['L2_pyramidal'], pos_dict['L2_pyramidal']): - # this delay is longer than most - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Pyr_L5Pyr'], - 'A_delay': 1., - 'lamtha': 3., - 'threshold': p['threshold'], - 'type_src' : 'L2_pyramidal' - } - - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.basal2_ampa)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.basal3_ampa)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.apicaltuft_ampa)) - self.ncfrom_L2Pyr.append(self.parconnect_from_src(gid_src, nc_dict, self.apicaloblique_ampa)) - - # connections FROM L2Basket TO here - for gid_src, pos in zip(gid_dict['L2_basket'], pos_dict['L2_basket']): - nc_dict = { - 'pos_src': pos, - 'A_weight': p['gbar_L2Basket_L5Pyr'], - 'A_delay': 1., - 'lamtha': 50., - 'threshold': p['threshold'], - 'type_src' : 'L2_basket' - } - - self.ncfrom_L2Basket.append(self.parconnect_from_src(gid_src, nc_dict, self.apicaltuft_gabaa)) - - # receive from external inputs - def parreceive(self, gid, gid_dict, pos_dict, p_ext): - for gid_src, p_src, pos in zip(gid_dict['extinput'], p_ext, pos_dict['extinput']): - # Check if AMPA params defined in p_src - if 'L5Pyr_ampa' in p_src.keys(): - nc_dict_ampa = { - 'pos_src': pos, - 'A_weight': p_src['L5Pyr_ampa'][0], - 'A_delay': p_src['L5Pyr_ampa'][1], - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src' : 'ext' - } - - # Proximal feed AMPA synapses - if p_src['loc'] is 'proximal': - # basal2_ampa, basal3_ampa, apicaloblique_ampa - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.basal2_ampa)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.basal3_ampa)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.apicaloblique_ampa)) - # Distal feed AMPA synsapes - elif p_src['loc'] is 'distal': - # apical tuft - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_ampa, self.apicaltuft_ampa)) - - # Check if NMDA params defined in p_src - if 'L5Pyr_nmda' in p_src.keys(): - nc_dict_nmda = { - 'pos_src': pos, - 'A_weight': p_src['L5Pyr_nmda'][0], - 'A_delay': p_src['L5Pyr_nmda'][1], - 'lamtha': p_src['lamtha'], - 'threshold': p_src['threshold'], - 'type_src' : 'ext' - } - - # Proximal feed NMDA synapses - if p_src['loc'] is 'proximal': - # basal2_nmda, basal3_nmda, apicaloblique_nmda - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.basal2_nmda)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.basal3_nmda)) - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.apicaloblique_nmda)) - # Distal feed NMDA synsapes - elif p_src['loc'] is 'distal': - # apical tuft - self.ncfrom_extinput.append(self.parconnect_from_src(gid_src, nc_dict_nmda, self.apicaltuft_nmda)) - - # one parreceive function to handle all types of external parreceives - # types must be defined explicitly here - def parreceive_ext(self, type, gid, gid_dict, pos_dict, p_ext): - if type.startswith(('evprox', 'evdist')): - if self.celltype in p_ext.keys(): - gid_ev = gid + gid_dict[type][0] - - nc_dict_ampa = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 for ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 for delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - nc_dict_nmda = { - 'pos_src': pos_dict[type][gid], - 'A_weight': p_ext[self.celltype][1], # index 1 for nmda weight - 'A_delay': p_ext[self.celltype][2], # index 2 for delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - #print('L5pyr:',type,'w:',nc_dict['A_weight']) - - if p_ext['loc'] is 'proximal': - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.basal2_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.basal3_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.apicaloblique_ampa)) - - # NEW: note that default/original is 0 nmda weight for these proximal dends - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.basal2_nmda)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.basal3_nmda)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.apicaloblique_nmda)) - - elif p_ext['loc'] is 'distal': - # apical tuft - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_ampa, self.apicaltuft_ampa)) - self.ncfrom_ev.append(self.parconnect_from_src(gid_ev, nc_dict_nmda, self.apicaltuft_nmda)) - - elif type == 'extgauss': - # gid is this cell's gid - # gid_dict is the whole dictionary, including the gids of the extgauss - # pos_dict is also the pos of the extgauss (net origin) - # p_ext_gauss are the params (strength, etc.) - # doesn't matter if this doesn't do anything - - # gid shift is based on L2_pyramidal cells NOT L5 - # I recognize this is ugly (hack) - # gid_shift = gid_dict['extgauss'][0] - gid_dict['L2_pyramidal'][0] - if 'L5_pyramidal' in p_ext.keys(): - gid_extgauss = gid + gid_dict['extgauss'][0] - - nc_dict = { - 'pos_src': pos_dict['extgauss'][gid], - 'A_weight': p_ext['L5_pyramidal'][0], # index 0 for ampa weight - 'A_delay': p_ext['L5_pyramidal'][2], # index 2 for delay - 'lamtha': p_ext['lamtha'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss, nc_dict, self.basal2_ampa)) - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss, nc_dict, self.basal3_ampa)) - self.ncfrom_extgauss.append(self.parconnect_from_src(gid_extgauss, nc_dict, self.apicaloblique_ampa)) - - elif type == 'extpois': - if self.celltype in p_ext.keys(): - gid_extpois = gid + gid_dict['extpois'][0] - - nc_dict = { - 'pos_src': pos_dict['extpois'][gid], - 'A_weight': p_ext[self.celltype][0], # index 0 for ampa weight - 'A_delay': p_ext[self.celltype][2], # index 2 for delay - 'lamtha': p_ext['lamtha_space'], - 'threshold': p_ext['threshold'], - 'type_src' : type - } - - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.basal2_ampa)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.basal3_ampa)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.apicaloblique_ampa)) - - if p_ext[self.celltype][1] > 0.0: - nc_dict['A_weight'] = p_ext[self.celltype][1] # index 1 for nmda weight - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.basal2_nmda)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.basal3_nmda)) - self.ncfrom_extpois.append(self.parconnect_from_src(gid_extpois, nc_dict, self.apicaloblique_nmda)) - - # Define 3D shape and position of cell. By default neuron uses xy plane for - # height and xz plane for depth. This is opposite for model as a whole, but - # convention is followed in this function for ease use of gui. - def __set_3Dshape(self): - # set 3D shape of soma by calling shape_soma from class Cell - # print "WARNING: You are setting 3d shape geom. You better be doing" - # print "gui analysis and not numerical analysis!!" - self.shape_soma() - - # soma proximal coords - x_prox = 0 - y_prox = 0 - - # soma distal coords - x_distal = 0 - y_distal = self.soma.L - - # dend 0-3 are major axis, dend 4 is branch - # deal with distal first along major cable axis - # the way this is assigning variables is ugly/lazy right now - for i in range(0, 4): - h.pt3dclear(sec=self.list_dend[i]) - - # x_distal and y_distal are the starting points for each segment - # these are updated at the end of the loop - sec=self.list_dend[i] - h.pt3dadd(0, y_distal, 0, sec.diam, sec=sec) - - # update x_distal and y_distal after setting them - # x_distal += dend_dx[i] - y_distal += sec.L - - # add next point - h.pt3dadd(0, y_distal, 0, sec.diam, sec=sec) - - # now deal with dend 4 - # dend 4 will ALWAYS be positioned at the end of dend[0] - h.pt3dclear(sec=self.list_dend[4]) - - # activate this section with 'sec=self.list_dend[i]' notation - x_start = h.x3d(1, sec=self.list_dend[0]) - y_start = h.y3d(1, sec=self.list_dend[0]) - - sec=self.list_dend[4] - h.pt3dadd(x_start, y_start, 0, sec.diam, sec=sec) - # self.dend_L[4] is subtracted because lengths always positive, - # and this goes to negative x - h.pt3dadd(x_start-sec.L, y_start, 0, sec.diam, sec=sec) - - # now deal with proximal dends - for i in range(5, 8): - h.pt3dclear(sec=self.list_dend[i]) - - # deal with dend 5, ugly. sorry. - sec=self.list_dend[5] - h.pt3dadd(x_prox, y_prox, 0, sec.diam, sec=sec) - y_prox += -sec.L - - h.pt3dadd(x_prox, y_prox, 0, sec.diam,sec=sec) - - # x_prox, y_prox are now the starting points for BOTH of last 2 sections - # dend 6 - # Calculate x-coordinate for end of dend - sec=self.list_dend[6] - dend6_x = -sec.L * np.sqrt(2) / 2. - h.pt3dadd(x_prox, y_prox, 0, sec.diam, sec=sec) - h.pt3dadd(dend6_x, y_prox-sec.L * np.sqrt(2) / 2., - 0, sec.diam, sec=sec) - - # dend 7 - # Calculate x-coordinate for end of dend - sec=self.list_dend[7] - dend7_x = sec.L * np.sqrt(2) / 2. - h.pt3dadd(x_prox, y_prox, 0, sec.diam, sec=sec) - h.pt3dadd(dend7_x, y_prox-sec.L * np.sqrt(2) / 2., - 0, sec.diam, sec=sec) - - # set 3D position - # z grid position used as y coordinate in h.pt3dchange() to satisfy - # gui convention that y is height and z is depth. In h.pt3dchange() - # x and z components are scaled by 100 for visualization clarity - self.soma.push() - for i in range(0, int(h.n3d())): - h.pt3dchange(i, self.pos[0]*100 + h.x3d(i), -self.pos[2] + h.y3d(i), - self.pos[1] * 100 + h.z3d(i), h.diam3d(i)) - - h.pop_section() diff --git a/Makefile b/Makefile deleted file mode 100644 index 5023c771c..000000000 --- a/Makefile +++ /dev/null @@ -1,16 +0,0 @@ -# Makefile for model - compiles mod files for use by NEURON -# first rev: (SL: created) - -# macros -UNAME := $(shell uname) - -vpath %.mod mod/ - -# make rules -x86_64/special : mod - nrnivmodl $< - -# clean -.PHONY: clean -clean : - rm -f x86_64/* diff --git a/PT_example.py b/PT_example.py deleted file mode 100644 index 58f8b63e6..000000000 --- a/PT_example.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python -# PT_example.py - Plot Template example -# -# v 1.9.4 -# rev 2016-02-02 (SL: created) -# last major: () - -import numpy as np -import matplotlib as mpl -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -import axes_create as ac - -# spec plus dipole -class FigExample(ac.FigBase): - def __init__(self): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(8, 6)) - - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # the right margin is a hack and NOT guaranteed! - # it's making space for the stupid colorbar that creates a new grid to replace gs1 - # when called, and it doesn't update the params of gs1 - self.gs = { - 'dpl': gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.85, top=0.95, left=0.1, right=0.82), - 'spec': gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.30, top=0.80, left=0.1, right=1.), - 'pgram': gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.05, top=0.25, left=0.1, right=0.82), - } - - self.ax = { - 'dipole': self.f.add_subplot(self.gs['dpl'][:, :]), - 'spec': self.f.add_subplot(self.gs['spec'][:, :]), - 'pgram': self.f.add_subplot(self.gs['pgram'][:, :]), - } - -if __name__ == '__main__': - fig = FigExample() - fig.ax['dipole'].plot(np.random.rand(1000)) - fig.savepng('testing.png') - fig.close() diff --git a/README.md b/README.md index 84368ff5c..197dafc81 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,14 @@ potentials (ERPs) and low frequency rhythms (alpha/beta/gamma). Please follow the links on our [installation page](installer) to find instructions for your operating system. +## Quickstart + +Just do: + + $ python hnn.py + +to start the HNN graphical user interface + ## Command-line usage HNN is not designed to be invoked from the command line, but we have started diff --git a/__init__.py b/__init__.py deleted file mode 100644 index 3189e6389..000000000 --- a/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '0.1.3' diff --git a/ac_manu_gamma.py b/ac_manu_gamma.py deleted file mode 100644 index aea803804..000000000 --- a/ac_manu_gamma.py +++ /dev/null @@ -1,1166 +0,0 @@ -# ac_manu_gamma.py - axes for gamma manuscript paper figs -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed izip dep) -# last major: (MS: commented out mpl.use('agg') to prevent conflict ...) - -import matplotlib as mpl -import axes_create as ac -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -import numpy as np - -class FigSimpleSpec(ac.FigBase): - def __init__(self): - self.f = plt.figure(figsize=(8, 6)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # the right margin is a hack and NOT guaranteed! - # it's making space for the stupid colorbar that creates a new grid to replace gs1 - # when called, and it doesn't update the params of gs1 - self.gspec = { - 'dpl': gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.85, top=0.95, left=0.1, right=0.82), - 'spec': gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.10, top=0.80, left=0.1, right=1.), - } - - self.ax = {} - self.ax['dipole'] = self.f.add_subplot(self.gspec['dpl'][:, :]) - self.ax['spec'] = self.f.add_subplot(self.gspec['spec'][:, :]) - -class FigLaminarComparison(ac.FigBase): - def __init__(self, runtype='debug'): - # ac.FigBase.__init__() - self.f = plt.figure(figsize=(9, 7)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(8, 50), - 'middle': gridspec.GridSpec(8, 50), - 'right': gridspec.GridSpec(8, 50), - 'bottom_left': gridspec.GridSpec(1, 50), - 'bottom_middle': gridspec.GridSpec(1, 50), - 'bottom_right': gridspec.GridSpec(1, 50), - } - - # reposition the gridspecs - l = np.arange(0.12, 0.80, 0.27) - r = l + 0.24 - - # update the gridspecs - # um, left is going right ... - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.25, top=0.94, left=l[2], right=r[2]) - self.gspec['middle'].update(wspace=0, hspace=0.30, bottom=0.25, top=0.94, left=l[0], right=r[0]) - self.gspec['right'].update(wspace=0, hspace=0.30, bottom=0.25, top=0.94, left=l[1], right=r[1]) - - # bottom are going to mirror the top, despite the names - self.gspec['bottom_left'].update(wspace=0, hspace=0.0, bottom=0.1, top=0.22, left=l[2], right=r[2]) - self.gspec['bottom_middle'].update(wspace=0, hspace=0.0, bottom=0.1, top=0.22, left=l[0], right=r[0]) - self.gspec['bottom_right'].update(wspace=0, hspace=0.0, bottom=0.1, top=0.22, left=l[1], right=r[1]) - - # create axes and handles - self.ax = { - 'dpl_L': self.f.add_subplot(self.gspec['left'][3:5, :40]), - 'dpl_M': self.f.add_subplot(self.gspec['middle'][3:5, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['right'][3:5, :40]), - - 'spk_M': self.f.add_subplot(self.gspec['middle'][:2, :40]), - 'spk_R': self.f.add_subplot(self.gspec['right'][:2, :40]), - - 'current_M': self.f.add_subplot(self.gspec['middle'][2:3, :40]), - 'current_R': self.f.add_subplot(self.gspec['right'][2:3, :40]), - - 'spec_L': None, - 'spec_M': None, - 'spec_R': self.f.add_subplot(self.gspec['right'][5:7, :]), - - 'pgram_L': self.f.add_subplot(self.gspec['bottom_left'][:, :40]), - 'pgram_M': self.f.add_subplot(self.gspec['bottom_middle'][:, :40]), - 'pgram_R': self.f.add_subplot(self.gspec['bottom_right'][:, :40]), - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][5:7, :]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][5:7, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][5:7, :]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][5:7, :]) - - # remove xtick labels - list_ax_noxtick = [ax_handle for ax_handle in self.ax.keys() if ax_handle.startswith(('dpl', 'current', 'spk'))] - - # function defined in FigBase() - self.remove_tick_labels(list_ax_noxtick, 'x') - - # remove ytick labels - self.ax['spk_M'].set_yticklabels('') - self.ax['spk_R'].set_yticklabels('') - list_ax_noytick = [] - - # write list of no y tick axes - # if runtype == 'pub': - # list_ax_noytick.extend([ax_h for ax_h in self.ax.keys() if ax_h.startswith('spk')]) - # list_ax_noytick.extend(['spec_R', 'spec_L']) - - # function defined in FigBase() - self.remove_tick_labels(list_ax_noytick, 'y') - self.create_ax_bounds_dict() - self.create_y_centers_dict() - self.__add_labels_subfig(l) - self.__change_formatting() - - def __change_formatting(self): - list_axes = ['pgram_L', 'pgram_M', 'pgram_R'] - self.set_notation_scientific(list_axes, 2) - - # add text labels - def __add_labels_subfig(self, l): - # top labels - self.f.text(self.ax_bounds['spk_M'][0], self.ax_bounds['spk_M'][-1] + 0.005, 'A.') - self.f.text(self.ax_bounds['spk_R'][0], self.ax_bounds['spk_R'][-1] + 0.005, 'B.') - self.f.text(self.ax_bounds['dpl_L'][0], self.ax_bounds['dpl_L'][-1] + 0.005, 'C.') - - # left labels - labels_left = { - 'va': 'center', - 'ma': 'center', - 'rotation': 90, - } - self.f.text(0.025, self.y_centers['spec_M'], 'Frequency (Hz)', **labels_left) - self.f.text(0.025, self.y_centers['dpl_M'], 'Current Dipole \n (nAm)', **labels_left) - self.f.text(0.025, self.y_centers['pgram_M'], 'Welch Spectral \n Power ((nAm)$^2$)', **labels_left) - self.f.text(0.025, self.y_centers['spk_M'], 'Cells', **labels_left) - self.f.text(0.025, self.y_centers['current_M'], 'Current \n ($\mu$A)', **labels_left) - - # bottom labels - self.f.text(self.ax_bounds['spec_M'][0], self.ax_bounds['spec_M'][1] - 0.05, 'Time (ms)', ha='left') - self.f.text(self.ax_bounds['pgram_M'][0], self.ax_bounds['pgram_M'][1] - 0.05, 'Frequency (Hz)', ha='left') - - # right labels - self.f.text(0.95, self.y_centers['spec_L'], 'Spectral Power \n ((nAm)$^2$)', rotation=270, ma='center', va='center') - - def set_axes_pingping(self): - self.ax['current_R'].set_ylim((-2000., 0.)) - -# strong ping and weak ping examples in Layer 5: Fig 2 -class FigL5PingExample(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(7, 8)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(7, 50), - 'right': gridspec.GridSpec(7, 50), - 'left_welch': gridspec.GridSpec(1, 50), - 'right_welch': gridspec.GridSpec(1, 50), - } - - # repositioning the gspec - l = np.arange(0.125, 0.90, 0.45) - r = l + 0.33 - - # create the gridspec - if runtype.startswith('pub'): - hspace_set = 0.30 - - else: - hspace_set = 0.30 - - self.gspec['left'].update(wspace=0, hspace=hspace_set, bottom=0.29, top=0.94, left=l[0], right=r[0]) - self.gspec['right'].update(wspace=0, hspace=hspace_set, bottom=0.29, top=0.94, left=l[1], right=r[1]) - self.gspec['left_welch'].update(wspace=0, hspace=0, bottom=0.1, top=0.2, left=l[0], right=r[0]) - self.gspec['right_welch'].update(wspace=0, hspace=0, bottom=0.1, top=0.2, left=l[1], right=r[1]) - - # create axes and handles - # spec_L will be conditional on debug or production - self.ax = { - 'raster_L': self.f.add_subplot(self.gspec['left'][:2, :40]), - 'hist_L': self.f.add_subplot(self.gspec['left'][2:3, :40]), - 'current_L': self.f.add_subplot(self.gspec['left'][3:4, :40]), - 'dpl_L': self.f.add_subplot(self.gspec['left'][4:5, :40]), - 'spec_L': None, - 'pgram_L': self.f.add_subplot(self.gspec['left_welch'][:, :]), - - 'raster_R': self.f.add_subplot(self.gspec['right'][:2, :40]), - 'hist_R': self.f.add_subplot(self.gspec['right'][2:3, :40]), - 'current_R': self.f.add_subplot(self.gspec['right'][3:4, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['right'][4:5, :40]), - 'spec_R': self.f.add_subplot(self.gspec['right'][5:7, :]), - 'pgram_R': self.f.add_subplot(self.gspec['right_welch'][:, :]), - } - - # different spec_L depending on mode - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][5:7, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][5:7, :40]) - - # print dir(self.ax['pgram_L'].get_position()) - # print self.ax['pgram_L'].get_position().get_points() - # print self.ax['pgram_L'].get_position().y0 - - # create twinx for the hist - for key in self.ax.keys(): - if key.startswith('hist_'): - # this creates an f.ax_twinx dict with the appropriate key names - self.create_axis_twinx(key) - - # pgram_axes - list_handles_pgram = [h for h in self.ax.keys() if h.startswith('pgram_')] - # self.fmt = self.set_notation_scientific(list_handles_pgram) - self.fmt = None - - # remove ytick labels for the rasters - for h in [ax for ax in self.ax if ax.startswith('raster_')]: - self.ax[h].set_yticklabels('') - - if runtype.startswith('pub'): - self.no_xaxis = ['raster', 'hist', 'dpl', 'current'] - self.no_yaxis = ['raster', 'spec', 'current'] - self.__remove_labels() - - self.__add_labels_subfig(l) - - self.list_sci = ['pgram_L', 'pgram_R'] - self.set_notation_scientific(self.list_sci) - - # add text labels - # ma is 'multialignment' for multiple lines - def __add_labels_subfig(self, l): - self.ax_pos = dict.fromkeys(self.ax) - - # create the ycenter dict - self.create_y_centers_dict() - - # first get all of the positions - for ax_h in self.ax.keys(): - self.ax_pos[ax_h] = self.return_axis_bounds(ax_h) - - # raster is on top - y_pad = 0.005 - label_y = self.ax_pos['raster_L'][-1] + y_pad - self.f.text(l[0], label_y, 'A.') - self.f.text(l[1], label_y, 'B.') - - # ylabel x pos - x_pos = 0.025 - label_props = { - 'fontsize': 7, - 'rotation': 90, - 'va': 'center', - 'ma': 'center', - } - - # cool, this uses a dict to fill in props. cool, cool, cool. - self.f.text(x_pos, self.y_centers['raster_L'], 'Cell no.', **label_props) - self.f.text(x_pos, self.y_centers['hist_L'], 'Spike\nhistogram\n(Left: I, Right: E)', **label_props) - self.f.text(x_pos, self.y_centers['spec_L'], 'Frequency (Hz)', **label_props) - self.f.text(x_pos, self.y_centers['dpl_L'], 'Current dipole \n (nAm)', **label_props) - self.f.text(x_pos, self.y_centers['pgram_L'], 'Welch Spectral \n Power ((nAm)$^2$)', **label_props) - self.f.text(x_pos, self.y_centers['current_L'], 'Total network \n GABA$_A$ current \n ($\mu$A)', **label_props) - - # xlabels - self.f.text(l[0], 0.05, 'Frequency (Hz)') - self.f.text(l[0], 0.25, 'Time (ms)') - - # find the spec_R coords and the associated center - # and create the text label there - coords_spec_R = self.return_axis_bounds('spec_R') - ycenter = coords_spec_R[1] + (coords_spec_R[-1] - coords_spec_R[1]) / 2. - self.f.text(0.97, ycenter, 'Spectral Power \n ((nAm)$^2$)', rotation=270, va='center', ha='center') - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - for label_prefix in self.no_xaxis: - # if ((ax.startswith('dpl')) or (ax.startswith('current'))): - if ax.startswith(label_prefix): - self.ax[ax].set_xticklabels('') - - if ax.endswith('_R'): - for label_prefix in self.no_yaxis: - if ax.startswith(label_prefix): - self.ax[ax].set_yticklabels('') - -# 3 examples of different phases and the aggregate spectral power as a function of delay -class FigDistalPhase(ac.FigBase): - def __init__(self): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(15, 4)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left0': gridspec.GridSpec(4, 50), - 'left1': gridspec.GridSpec(4, 50), - 'middle': gridspec.GridSpec(4, 50), - 'right': gridspec.GridSpec(1, 1), - } - - # number of cols are the number of gridspecs - n_cols = len(self.gspec.keys()) - - # find the start values by making a linspace from L margin to R margin - # and then remove the R margin's element - # this is why you need n_cols+1 - l = np.linspace(0.1, 0.95, n_cols+1)[:-1] - - # ensure first element of the unique on the diff of l to find - # the width of each panel - # remove the width of some margin - w_margin = 0.05 - w = np.unique(np.diff(l))[0] - w_margin - - # to find the right position, just add w to the l - r = l + w - - # create the gridspecs - self.gspec['left0'].update(wspace=0, hspace=0.15, bottom=0.1, top=0.91, left=l[0], right=r[0]) - self.gspec['left1'].update(wspace=0, hspace=0.15, bottom=0.1, top=0.91, left=l[1], right=r[1]) - self.gspec['middle'].update(wspace=0, hspace=0.15, bottom=0.1, top=0.91, left=l[2], right=r[2]) - self.gspec['right'].update(wspace=0, hspace=0.15, bottom=0.1, top=0.91, left=l[3], right=r[3]) - - # create axes and handles - self.ax = { - 'spec_L': self.f.add_subplot(self.gspec['left0'][:2, :]), - 'spec_M': self.f.add_subplot(self.gspec['left1'][:2, :]), - 'spec_R': self.f.add_subplot(self.gspec['middle'][:2, :]), - - 'dpl_L': self.f.add_subplot(self.gspec['left0'][2:3, :40]), - 'dpl_M': self.f.add_subplot(self.gspec['left1'][2:3, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['middle'][2:3, :40]), - - 'hist_L': self.f.add_subplot(self.gspec['left0'][3:, :40]), - 'hist_M': self.f.add_subplot(self.gspec['left1'][3:, :40]), - 'hist_R': self.f.add_subplot(self.gspec['middle'][3:, :40]), - - 'aggregate': self.f.add_subplot(self.gspec['right'][:, :]), - } - - self.__create_hist_twinx() - self.__add_labels_subfig(l) - - def __create_hist_twinx(self): - # ax_handles_hist = [ax for ax in self.ax.keys() if ax.startswith('hist')] - for ax in self.ax.keys(): - if ax.startswith('hist'): - self.create_axis_twinx(ax) - - # add text labels - def __add_labels_subfig(self, l): - self.f.text(l[0], 0.95, 'A.') - self.f.text(l[1], 0.95, 'B.') - self.f.text(l[2], 0.95, 'C.') - self.f.text(l[3], 0.95, 'D.') - -class FigStDev(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(8, 3)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(4, 50), - 'middle': gridspec.GridSpec(4, 50), - 'right': gridspec.GridSpec(4, 50), - 'farright': gridspec.GridSpec(4, 50), - } - - # reposition the gridspecs - l = np.arange(0.1, 0.9, 0.2) - # l = np.arange(0.05, 0.95, 0.3) - r = l + 0.175 - - # create the gridspecs - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[0], right=r[0]) - self.gspec['middle'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[1], right=r[1]) - self.gspec['right'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[2], right=r[2]) - self.gspec['farright'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[3], right=r[3]) - - self.ax = { - 'hist_L': self.f.add_subplot(self.gspec['left'][:1, :40]), - 'hist_M': self.f.add_subplot(self.gspec['middle'][:1, :40]), - 'hist_R': self.f.add_subplot(self.gspec['right'][:1, :40]), - 'hist_FR': self.f.add_subplot(self.gspec['farright'][:1, :40]), - - 'dpl_L': self.f.add_subplot(self.gspec['left'][1:2, :40]), - 'dpl_M': self.f.add_subplot(self.gspec['middle'][1:2, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['right'][1:2, :40]), - 'dpl_FR': self.f.add_subplot(self.gspec['farright'][1:2, :40]), - - # these are set differently depending on runtype, below - 'spec_L': None, - 'spec_M': None, - 'spec_R': None, - 'spec_FR': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][2:, :]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][2:, :]) - self.ax['spec_FR'] = self.f.add_subplot(self.gspec['farright'][2:, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :40]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][2:, :40]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][2:, :40]) - self.ax['spec_FR'] = self.f.add_subplot(self.gspec['farright'][2:, :]) - - if runtype.startswith('pub'): - self.__remove_labels() - - self.__create_twinx() - - # methods come from the FigBase() - self.create_ax_bounds_dict() - self.create_y_centers_dict() - self.__add_labels_subfig(l) - - def __create_twinx(self): - for ax_handle in self.ax.keys(): - if ax_handle.startswith('hist'): - self.create_axis_twinx(ax_handle) - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - if ax.startswith(('dpl', 'hist')): - self.ax[ax].set_xticklabels('') - - if ax.endswith(('_M', '_R', '_FR')): - self.ax[ax].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_subfig(self, l): - self.f.text(self.ax_bounds['hist_L'][0], 0.95, 'A.') - self.f.text(self.ax_bounds['hist_M'][0], 0.95, 'B.') - self.f.text(self.ax_bounds['hist_R'][0], 0.95, 'C.') - self.f.text(self.ax_bounds['hist_FR'][0], 0.95, 'D.') - - labels_left = { - 'va': 'center', - 'ma': 'center', - 'rotation': 90, - } - self.f.text(0.025, self.y_centers['spec_L'], 'Frequency \n (Hz)', **labels_left) - self.f.text(0.025, self.y_centers['dpl_L'], 'Current dipole \n (nAm)', **labels_left) - self.f.text(0.025, self.y_centers['hist_L'], 'EPSPs', **labels_left) - - self.f.text(self.ax_bounds['spec_L'][0], 0.025, 'Time (ms)', ha='left') - self.f.text(self.ax_bounds['spec_FR'][2] + 0.05, self.y_centers['spec_FR'], 'Power spectral density \n ((nAm)$^2$/Hz)', rotation=270, va='center', ma='center') - -class FigPanel4(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(8, 3)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(4, 50), - 'middle': gridspec.GridSpec(4, 50), - 'right': gridspec.GridSpec(4, 50), - 'farright': gridspec.GridSpec(4, 50), - } - - # reposition the gridspecs - l = np.arange(0.1, 0.9, 0.2) - # l = np.arange(0.05, 0.95, 0.3) - r = l + 0.175 - - # create the gridspecs - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[0], right=r[0]) - self.gspec['middle'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[1], right=r[1]) - self.gspec['right'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[2], right=r[2]) - self.gspec['farright'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[3], right=r[3]) - - self.ax = { - 'hist_L': self.f.add_subplot(self.gspec['left'][:1, :40]), - 'hist_M': self.f.add_subplot(self.gspec['middle'][:1, :40]), - 'hist_R': self.f.add_subplot(self.gspec['right'][:1, :40]), - 'hist_FR': self.f.add_subplot(self.gspec['farright'][:1, :40]), - - 'dpl_L': self.f.add_subplot(self.gspec['left'][1:2, :40]), - 'dpl_M': self.f.add_subplot(self.gspec['middle'][1:2, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['right'][1:2, :40]), - 'dpl_FR': self.f.add_subplot(self.gspec['farright'][1:2, :40]), - - # these are set differently depending on runtype, below - 'spec_L': None, - 'spec_M': None, - 'spec_R': None, - 'spec_FR': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][2:, :]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][2:, :]) - self.ax['spec_FR'] = self.f.add_subplot(self.gspec['farright'][2:, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :40]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][2:, :40]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][2:, :40]) - self.ax['spec_FR'] = self.f.add_subplot(self.gspec['farright'][2:, :]) - - if runtype.startswith('pub'): - self.__remove_labels() - - self.__create_twinx() - - # methods come from the FigBase() - self.create_ax_bounds_dict() - self.create_y_centers_dict() - self.__add_labels_subfig(l) - - def __create_twinx(self): - for ax_handle in self.ax.keys(): - if ax_handle.startswith('hist'): - self.create_axis_twinx(ax_handle) - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - if ax.startswith(('dpl', 'hist')): - self.ax[ax].set_xticklabels('') - - if ax.endswith(('_M', '_R', '_FR')): - self.ax[ax].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_subfig(self, l): - self.f.text(self.ax_bounds['hist_L'][0], 0.95, 'A.') - self.f.text(self.ax_bounds['hist_M'][0], 0.95, 'B.') - self.f.text(self.ax_bounds['hist_R'][0], 0.95, 'C.') - self.f.text(self.ax_bounds['hist_FR'][0], 0.95, 'D.') - - labels_left = { - 'va': 'center', - 'ma': 'center', - 'rotation': 90, - } - self.f.text(0.025, self.y_centers['spec_L'], 'Frequency \n (Hz)', **labels_left) - self.f.text(0.025, self.y_centers['dpl_L'], 'Current dipole \n (nAm)', **labels_left) - self.f.text(0.025, self.y_centers['hist_L'], 'EPSPs', **labels_left) - - self.f.text(self.ax_bounds['spec_L'][0], 0.025, 'Time (ms)', ha='left') - self.f.text(self.ax_bounds['spec_FR'][2] + 0.05, self.y_centers['spec_FR'], 'Power spectral density \n ((nAm)$^2$/Hz)', rotation=270, va='center', ma='center') - -class Fig3PanelPlusAgg(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(10, 3)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(6) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(5, 50), - 'middle': gridspec.GridSpec(5, 50), - 'right': gridspec.GridSpec(5, 50), - 'farright': gridspec.GridSpec(5, 50), - } - - # reposition the gridspecs - # l = np.arange(0.075, 0.8, 0.22) - l = np.array([0.075, 0.295, 0.515, 0.740]) - # l = np.arange(0.05, 0.95, 0.3) - r = l + 0.2 - # r[-1] += 0.025 - - # create the gridspecs - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[0], right=r[0]) - self.gspec['middle'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[1], right=r[1]) - self.gspec['right'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[2], right=r[2]) - self.gspec['farright'].update(wspace=0, hspace=0.30, bottom=0.15, top=0.94, left=l[3], right=r[3]) - - self.ax = { - 'hist_L': self.f.add_subplot(self.gspec['left'][:1, :40]), - 'hist_M': self.f.add_subplot(self.gspec['middle'][:1, :40]), - 'hist_R': self.f.add_subplot(self.gspec['right'][:1, :40]), - - 'dpl_L': self.f.add_subplot(self.gspec['left'][1:3, :40]), - 'dpl_M': self.f.add_subplot(self.gspec['middle'][1:3, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['right'][1:3, :40]), - - # aggregate welch - 'pgram': self.f.add_subplot(self.gspec['farright'][:, :]), - - # these are set differently depending on runtype, below - 'spec_L': None, - 'spec_M': None, - 'spec_R': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][3:, :]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][3:, :]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][3:, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][3:, :40]) - self.ax['spec_M'] = self.f.add_subplot(self.gspec['middle'][3:, :40]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][3:, :]) - - if runtype.startswith('pub'): - self.__remove_labels() - - # periodogram hold on - self.ax['pgram'].hold(True) - self.ax['pgram'].yaxis.tick_right() - - self.__create_twinx() - self.create_ax_bounds_dict() - self.create_y_centers_dict() - self.__add_labels_subfig(l) - - def __create_twinx(self): - for ax_handle in self.ax.keys(): - if ax_handle.startswith('hist'): - self.create_axis_twinx(ax_handle) - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - if ax.startswith(('dpl', 'hist')): - self.ax[ax].set_xticklabels('') - - if ax.endswith(('_M', '_R')): - self.ax[ax].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_subfig(self, l): - self.f.text(l[0], 0.95, 'A.') - self.f.text(l[1], 0.95, 'B.') - self.f.text(l[2], 0.95, 'C.') - self.f.text(l[3], 0.95, 'D.') - - ylabel_props = { - 'rotation': 90, - 'va': 'center', - 'ma': 'center', - } - - ylabel_right_props = { - 'rotation': 270, - 'va': 'center', - 'ma': 'center', - 'ha': 'center', - } - - xoffset = 0.0675 - - # y labels - self.f.text(l[0] - xoffset, self.y_centers['hist_L'], 'EPSP \n Count', **ylabel_props) - self.f.text(l[0] - xoffset, self.y_centers['dpl_L'], 'Current Dipole \n (nAm)', **ylabel_props) - self.f.text(l[0] - xoffset, self.y_centers['spec_L'], 'Frequency \n (Hz)', **ylabel_props) - - # self.ax['spec_L'].set_ylabel('Frequency (Hz)') - # self.ax['dpl_L'].set_ylabel('Current dipole (nAm)') - # self.ax['hist_L'].set_ylabel('EPSP count') - - self.f.text(l[0], 0.025, 'Time (ms)') - self.f.text(l[-1], 0.025, 'Frequency (Hz)') - self.f.text(0.975, self.y_centers['pgram'], 'Welch Spectral Power \n ((nAm)$^2$ x10$^{-7}$)', **ylabel_right_props) - self.f.text(self.ax_bounds['spec_R'][2] + 0.005, self.y_centers['spec_R'], 'Spectral Power \n ((nAm)$^2$)', fontsize=5, **ylabel_right_props) - # self.f.text(self.ax_bounds['hist_L'][-2]+0.05, self.y_centers['hist_L'], 'Distal EPSP Count', rotation=270, va='center', ma='center', ha='center') - # self.f.text(0.925, 0.40, 'Power spectral density ((nAm)$^2$/Hz)', rotation=270) - -class FigSubDistExample(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(6, 5)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(4, 50), - 'right': gridspec.GridSpec(4, 50), - } - - # reposition the gridspecs - l = np.array([0.1, 0.52]) - # l = np.arange(0.1, 0.9, 0.45) - r = l + 0.39 - - # create the gridspecs - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.1, top=0.94, left=l[0], right=r[0]) - self.gspec['right'].update(wspace=0, hspace=0.30, bottom=0.1, top=0.94, left=l[1], right=r[1]) - - self.ax = { - 'hist_L': self.f.add_subplot(self.gspec['left'][:1, :40]), - 'hist_R': self.f.add_subplot(self.gspec['right'][:1, :40]), - - 'dpl_L': self.f.add_subplot(self.gspec['left'][1:2, :40]), - 'dpl_R': self.f.add_subplot(self.gspec['right'][1:2, :40]), - - # these are set differently depending on runtype, below - 'spec_L': None, - 'spec_R': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][2:, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :40]) - self.ax['spec_R'] = self.f.add_subplot(self.gspec['right'][2:, :]) - - if runtype.startswith('pub'): - self.__remove_labels() - - self.__create_twinx() - self.create_ax_bounds_dict() - self.create_y_centers_dict() - self.__add_labels_subfig(l) - - def __create_twinx(self): - for ax_handle in self.ax.keys(): - if ax_handle.startswith('hist'): - self.create_axis_twinx(ax_handle) - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - if ax.startswith(('dpl', 'hist')): - self.ax[ax].set_xticklabels('') - - if ax.endswith('_R'): - self.ax[ax].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - if ax.endswith('_R'): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_subfig(self, l): - # left labels - labels_left = { - 'va': 'center', - 'ma': 'center', - 'rotation': 90, - } - self.f.text(0.02, self.y_centers['spec_L'], 'Frequency (Hz)', **labels_left) - self.f.text(0.02, self.y_centers['dpl_L'], 'Current Dipole \n (nAm)', **labels_left) - self.f.text(0.02, self.y_centers['hist_L'], 'Proximal EPSP Count', **labels_left) - # self.f.text(self.ax_bounds['spec_M'][0], self.ax_bounds['spec_M'][1] - 0.05, 'Time (ms)', ha='left') - self.f.text(self.ax_bounds['hist_L'][-2]+0.05, self.y_centers['hist_L'], 'Distal EPSP Count', rotation=270, va='center', ma='center', ha='center') - self.f.text(0.95, self.y_centers['spec_R'], 'Power spectral density \n ((nAm)$^2$/Hz)', rotation=270, ha='center', ma='center', va='center') - - self.f.text(l[0], 0.95, 'A.') - self.f.text(l[1], 0.95, 'B.') - - # self.ax['spec_L'].set_ylabel('Frequency (Hz)') - # self.ax['dpl_L'].set_ylabel('Current dipole (nAm)') - # self.ax['hist_L'].set_ylabel('Proximal EPSP count') - # self.ax_twinx['hist_L'].set_ylabel('Distal EPSP count') - - self.f.text(l[0], 0.025, 'Time (ms)') - -class FigPeaks(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(4, 5)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(4, 50), - 'right': gridspec.GridSpec(4, 50), - } - - # reposition the gridspecs - l = np.arange(0.1, 0.9, 0.45) - r = l + 0.8 - - # create the gridspecs - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.1, top=0.94, left=l[0], right=r[0]) - self.gspec['right'].update(wspace=0, hspace=0.30, bottom=0.1, top=0.94, left=l[1], right=r[1]) - - self.ax = { - 'dpl_L': self.f.add_subplot(self.gspec['left'][:1, :40]), - 'hist_L': self.f.add_subplot(self.gspec['left'][1:2, :40]), - - # these are set differently depending on runtype, below - 'spec_L': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][2:, :40]) - - if runtype.startswith('pub'): - self.__remove_labels() - - # self.__create_twinx() - # self.__add_labels_subfig(l) - - def __create_twinx(self): - for ax_handle in self.ax.keys(): - if ax_handle.startswith('hist'): - self.create_axis_twinx(ax_handle) - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - if ax.startswith(('dpl', 'hist')): - self.ax[ax].set_xticklabels('') - - if ax.endswith('_R'): - self.ax[ax].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_subfig(self, l): - self.f.text(l[0], 0.95, 'A.') - self.f.text(l[1], 0.95, 'B.') - - self.ax['spec_L'].set_ylabel('Frequency (Hz)') - self.ax['dpl_L'].set_ylabel('Current dipole (nAm)') - self.ax['hist_L'].set_ylabel('EPSP count') - - self.f.text(l[0], 0.025, 'Time (ms)') - self.f.text(0.95, 0.40, 'Power spectral density ((nAm)$^2$/Hz)', rotation=270) - -class FigHF(ac.FigBase): - def __init__(self, runtype='debug'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(5, 7)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(8) - - # various gridspecs - self.gspec = { - 'left': gridspec.GridSpec(5, 50), - } - - # reposition the gridspecs - # l = np.arange(0.1, 0.9, 0.28) - # l = np.arange(0.05, 0.95, 0.3) - # r = l + 0.275 - - # create the gridspecs - self.gspec['left'].update(wspace=0, hspace=0.30, bottom=0.1, top=0.94, left=0.2, right=0.95) - l = 0.1 - - self.ax = { - 'spk': self.f.add_subplot(self.gspec['left'][:1, :40]), - 'hist_L': self.f.add_subplot(self.gspec['left'][1:2, :40]), - 'dpl_L': self.f.add_subplot(self.gspec['left'][2:4, :40]), - - # these are set differently depending on runtype, below - 'spec_L': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][4:, :]) - - elif runtype == 'pub': - self.ax['spec_L'] = self.f.add_subplot(self.gspec['left'][4:, :40]) - - # if runtype.startswith('pub'): - # self.__remove_labels() - - self.__create_twinx() - self.__add_labels_subfig(l) - - def __create_twinx(self): - for ax_handle in self.ax.keys(): - if ax_handle.startswith('dpl'): - self.create_axis_twinx(ax_handle) - - # function to remove labels when not testing - def __remove_labels(self): - for ax in self.ax.keys(): - if ax.startswith(('dpl', 'hist')): - self.ax[ax].set_xticklabels('') - - if ax.endswith(('_M', '_R')): - self.ax[ax].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_subfig(self, l): - # self.f.text(l, 0.95, 'A.') - - self.ax['spk'].set_ylabel('Cells') - self.ax['spec_L'].set_ylabel('Frequency (Hz)') - self.ax['spec_L'].set_xlabel('Time (ms)') - self.ax['dpl_L'].set_ylabel('Current dipole (nAm)') - self.ax['hist_L'].set_ylabel('L5 Pyramidal Spikes') - self.ax_twinx['dpl_L'].set_ylabel('Current (nA)', rotation=270) - - # self.f.text(l, 0.025, 'Time (ms)') - # self.f.text(0.925, 0.40, 'Power spectral density ((nAm)$^2$/Hz)', rotation=270) - -# high frequency epochs fig -class FigHFEpochs(ac.FigBase): - def __init__(self, runtype='pub'): - ac.FigBase.__init__(self) - self.f = plt.figure(figsize=(9.5, 7)) - - # set_fontsize() is part of FigBase() - self.set_fontsize(9) - - # called L_gspec so the ax keys can be (a) sortably grouped and (b) congruent with gspec - # I want it to be called gspec_L, but you can't have everything you want in life. - self.L_gspec = gridspec.GridSpec(5, 50) - self.L_gspec.update(wspace=0, hspace=0.1, bottom=0.1, top=0.95, left=0.07, right=0.32) - - self.gspec_ex = [ - gridspec.GridSpec(7, 50), - gridspec.GridSpec(7, 50), - gridspec.GridSpec(7, 50), - gridspec.GridSpec(7, 50), - ] - - # reposition the gridspecs. there are cleverer ways of doing this - w = 0.25 - l_split = np.array([0.05, 0.36, 0.63]) - # l_split = np.arange(0.05, 0.95, 0.3) - self.l_ex = np.array([l_split[1], l_split[2], l_split[1], l_split[2]]) - self.r_ex = self.l_ex + w - - # bottom and tops - h = 0.4 - b_ex = np.array([0.55, 0.55, 0.10, 0.10]) - # b_ex = np.array([0.1, 0.1, 0.55, 0.55]) - t_ex = b_ex + h - - # l = np.arange(0.25, 0.9, 0.2) - # r = l + 0.1 - # r = l + 0.275 - - # create the gridspecs - for i in range(len(self.gspec_ex)): - self.gspec_ex[i].update(wspace=0, hspace=0.30, bottom=b_ex[i], top=t_ex[i], left=self.l_ex[i], right=self.r_ex[i]) - - self.ax = { - 'L_spk': self.f.add_subplot(self.L_gspec[:2, :40]), - 'L_dpl': self.f.add_subplot(self.L_gspec[2:3, :40]), - 'L_spec': self.f.add_subplot(self.L_gspec[3:, :40]), - 'hist': [ - self.f.add_subplot(self.gspec_ex[0][:1, :40]), - self.f.add_subplot(self.gspec_ex[1][:1, :40]), - self.f.add_subplot(self.gspec_ex[2][:1, :40]), - self.f.add_subplot(self.gspec_ex[3][:1, :40]), - ], - - 'spk': [ - self.f.add_subplot(self.gspec_ex[0][1:3, :40]), - self.f.add_subplot(self.gspec_ex[1][1:3, :40]), - self.f.add_subplot(self.gspec_ex[2][1:3, :40]), - self.f.add_subplot(self.gspec_ex[3][1:3, :40]), - ], - - 'dpl': [ - self.f.add_subplot(self.gspec_ex[0][3:5, :40]), - self.f.add_subplot(self.gspec_ex[1][3:5, :40]), - self.f.add_subplot(self.gspec_ex[2][3:5, :40]), - self.f.add_subplot(self.gspec_ex[3][3:5, :40]), - ], - - # these are set differently depending on runtype, below - 'spec': None, - } - - if runtype in ('debug', 'pub2'): - self.ax['spec'] = [ - self.f.add_subplot(self.gspec_ex[0][5:, :]), - self.f.add_subplot(self.gspec_ex[1][5:, :]), - self.f.add_subplot(self.gspec_ex[2][5:, :]), - self.f.add_subplot(self.gspec_ex[3][5:, :]), - ] - - elif runtype == 'pub': - self.ax['spec'] = [ - self.f.add_subplot(self.gspec_ex[0][5:, :40]), - self.f.add_subplot(self.gspec_ex[1][5:, :40]), - self.f.add_subplot(self.gspec_ex[2][5:, :40]), - self.f.add_subplot(self.gspec_ex[3][5:, :]), - ] - - self.__create_twinx() - - if runtype.startswith('pub'): - self.__remove_labels() - - self.create_ax_bounds_dict() - self.create_y_centers_dict() - self.__add_labels_left() - self.__add_labels_subfig() - - # takes a specific ax and a well-formed props_dict - def add_sine(self, ax, props_dict): - # will create fake sine waves with various properties - # t_center = props_dict['t_center'] - f = props_dict['f'] - # t_half = 0.5 * (1000. / f) - t0 = props_dict['t'][0] - T = props_dict['t'][1] - # t0 = t_center - t_half - # T = t_center + t_half - t = np.arange(t0, T, props_dict['dt']) - x = props_dict['A'] * np.sin(2 * np.pi * f * (t - t0) / 1000.) - # x = 0.01 * np.sin(2 * np.pi * f * (t - t0) / 1000.) - ax.plot(t, x, 'r--') - - # the general twinx function cannot be used, due to the structure of the axes here - def __create_twinx(self): - # just create the twinx list for the dipole - self.ax_twinx['dpl'] = [ax_h.twinx() for ax_h in self.ax['dpl']] - - # function to remove labels when not testing - def __remove_labels(self): - for ax_h in self.ax.keys(): - if ax_h.startswith('L_'): - if ax_h.endswith(('spk', 'dpl')): - self.ax[ax_h].set_xticklabels('') - - if ax_h.endswith('spk'): - self.ax[ax_h].set_yticklabels('') - - for ax_h in self.ax.keys(): - if not ax_h.startswith(('L', 'spec')): - for i in range(1, 4): - self.ax[ax_h][i].set_xticklabels('') - - if not ax_h.startswith('L'): - if not ax_h.startswith('spec'): - for i in range(0, 2): - self.ax[ax_h][i].set_xticklabels('') - - for i in range(1, 4): - self.ax[ax_h][i].set_yticklabels('') - - # remove the spk ones - if ax_h == 'spk': - self.ax[ax_h][0].set_yticklabels('') - - for i in range(1, 4): - self.ax_twinx['dpl'][i].set_yticklabels('') - - def remove_twinx_labels(self): - for ax in self.ax_twinx.keys(): - self.ax_twinx[ax].set_yticklabels('') - - # add text labels - def __add_labels_left(self): - label_props = { - 'rotation': 90, - 'va': 'center', - 'ma': 'center', - } - self.f.text(0.018, self.y_centers['L_spec'], 'Frequency (Hz)', **label_props) - self.f.text(0.018, self.y_centers['L_dpl'], 'Current Dipole (nAm)', **label_props) - self.f.text(0.018, self.y_centers['L_spk'], 'Cells', **label_props) - self.f.text(0.07, 0.04, 'Time (ms)', ha='left') - - x = self.ax_bounds['L_spk'][0] - y = self.ax_bounds['L_spk'][-1] + 0.005 - self.f.text(x, y, 'A.', ha='left') - - def __add_labels_subfig(self): - list_labels = ['B.', 'C.', 'D.', 'E.'] - - label_props = { - 'ha': 'left', - } - - # for ax_h, lbl in zip(self.ax['hist'], list_labels): - for i in range(len(self.ax['hist'])): - # get the x coord - x = self.ax_bounds['hist'][i][0] - y = self.ax_bounds['hist'][i][-1] + 0.005 - self.f.text(x, y, list_labels[i], **label_props) - - # for x_l, lbl in zip(self.l_ex, list_labels): - # self.f.text(x_l, 0.95, lbl+'.') - - ylabel_props = { - 'rotation': 90, - 'va': 'center', - 'ma': 'center', - } - - xoffset = 0.0675 - - # y labels - self.f.text(self.l_ex[0] - xoffset, self.y_centers['spec'][0], 'Frequency \n (Hz)', **ylabel_props) - self.f.text(self.l_ex[0] - xoffset, self.y_centers['dpl'][0], 'Current Dipole \n (nAm)', **ylabel_props) - self.f.text(self.l_ex[0] - xoffset, self.y_centers['hist'][0], 'L5 Pyramidal \n Spike Count', **ylabel_props) - self.f.text(self.l_ex[0] - xoffset, self.y_centers['spk'][0], 'Cells', **ylabel_props) - self.f.text(self.r_ex[0] - 0.02, self.y_centers['dpl'][0], 'Somatic current \n ($\mu$A)', rotation=270, ma='center', va='center') - self.f.text(0.925, self.y_centers['spec'][-1], 'Spectral Power \n ((nAm)$^2$)', rotation=270, ma='center', va='center') - - # time label - self.f.text(self.l_ex[0], 0.04, 'Time (ms)') - -if __name__ == '__main__': - x = np.random.rand(100) - - f_test = 'testing.png' - - print mpl.get_backend() - - # testfig for FigDipoleExp() - # testfig = ac.FigDipoleExp(ax_handles) - # testfig.create_colorbar_axis('spec') - # testfig.ax['spec'].plot(x) - - # testfig = FigTest() - # testfig = FigStDev() - # testfig = FigL5PingExample() - # testfig = FigSubDistExample() - # testfig = FigLaminarComparison() - # testfig = FigDistalPhase() - # testfig = FigPeaks() - testfig = FigSimpleSpec() - testfig.savepng(f_test) - testfig.close() diff --git a/averaging.py b/averaging.py deleted file mode 100644 index a7a74b9ac..000000000 --- a/averaging.py +++ /dev/null @@ -1,67 +0,0 @@ -# averaging.py - routine to perform averaging -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: using new return_data_dir()) -# last major: (SL: pushed for CSM and CB) - -import fileio as fio -import dipolefn -import matplotlib.pyplot as plt -import numpy as np -import os - -# routine to average the dipoles found in the dsim directory -def average_dipole(dsim): - dproj = fio.return_data_dir() - - ddata = fio.SimulationPaths() - ddata.read_sim(dproj, dsim) - - # grab the first experimental group - expmt_group = ddata.expmt_groups[0] - - flist = ddata.file_match(expmt_group, 'rawdpl') - N_dpl = len(flist) - - # grab the time and the length - dpl_time = dipolefn.Dipole(flist[0]).t - length_dpl = dipolefn.Dipole(flist[0]).N - - # preallocation of the total dipole - # dpl_agg = np.zeros((N_dpl, length_dpl)) - dpl_sum = np.zeros(length_dpl) - - # the specific dipole to use - dpl_specific = 'agg' - - for f in flist: - dpl_f = dipolefn.Dipole(f) - dpl_sum = dpl_sum + dpl_f.dpl[dpl_specific] - - dpl_scaled = dpl_sum * 1e-6 - dpl_mean = dpl_scaled / N_dpl - - print dpl_sum - print ' ' - print dpl_scaled - print ' ' - print dpl_mean - - figure_create(dpl_time, dpl_mean) - -# simple plot of the mean dipole -def figure_create(dpl_time, dpl_agg): - fig = plt.figure() - ax = { - 'dpl_agg': fig.add_subplot(1, 1, 1), - } - - # example - ax['dpl_agg'].plot(dpl_time, dpl_agg, linewidth=0.5, color='k') - fig.savefig('testing_dpl.png', dpi=200) - plt.close(fig) - -if __name__ == '__main__': - droot = fio.return_data_dir() - dsim = os.path.join(droot, '2015-12-02/tonic_L5Pyr-000') - average_dipole(dsim) diff --git a/axes_create.py b/axes_create.py deleted file mode 100644 index a92690e46..000000000 --- a/axes_create.py +++ /dev/null @@ -1,832 +0,0 @@ -# axes_create.py - simple axis creation -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: checked all the divides for compatibility) -# last major: (SL: toward python3) - -# usage: -# testfig = FigStd() -# testfig.ax0.plot(somedata) -# plt.savefig('testfig.png') -# testfig.close() - -import paramrw -import matplotlib as mpl -from matplotlib import ticker -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -import itertools as it -import numpy as np -import os, sys - -# Base figure class -class FigBase(): - def __init__(self): - # self.f is typically set by the super class - self.f = None - - # only use LaTeX (latex) if on Mac - # kind of kludgy temporary fix for now. - # if sys.platform.startswith('darwin'): - # mpl.rc('text', usetex=True) - # elif sys.platform.startswith('linux'): - # pass - - # axis dicts are guaranteed to exist at least, sheesh - self.ax = {} - self.ax_twinx = {} - - # creates a twinx axis for the specified axis - def create_axis_twinx(self, ax_name): - if ax_name in self.ax.keys(): - self.ax_twinx[ax_name] = self.ax[ax_name].twinx() - - # returns the index of most recently added element (now the length) - return ax_name - - else: - # returns valid axis name ONLY if it existed - # otherwise, None will break other code - return None - - # returns axis bounds for an arbitrary axis handle - # ax_h must be defined as a key in self.ax - def return_axis_bounds(self, ax_h): - if ax_h in self.ax.keys(): - # check to see if this axis handle is actually a list - if isinstance(self.ax[ax_h], list): - # create a list of coords - list_coords_bbox = [] - - # iterate through axes in the list - for ax_item in self.ax[ax_h]: - # get the coords for the axis - coords = ax_item.get_position().get_points() - - # append the cleaned up version - list_coords_bbox.append(np.reshape(coords, (1, 4))[0]) - - return list_coords_bbox - - else: - # these are *not* beatbox coordinates - coords_bbox = self.ax[ax_h].get_position().get_points() - - # reshape the coords - return np.reshape(coords_bbox, (1, 4))[0] - - else: - print("Axis not found by return_axis_bounds()") - return 0 - - # needs to be run externally, after self.ax is established - def create_ax_bounds_dict(self): - # make a dict - self.ax_bounds = dict.fromkeys(self.ax) - - # iterate through keys and use return_axis_bounds() to get the axis - # this is now working for lists - for ax_h in self.ax_bounds.keys(): - self.ax_bounds[ax_h] = self.return_axis_bounds(ax_h) - - # creates a dict of axes that gives the center y pos - # needs to be run externally, after self.ax is established - # can utilize create_ax_bounds_dict() in the future - def create_y_centers_dict(self): - self.y_centers = dict.fromkeys(self.ax) - - for ax_h in self.y_centers.keys(): - if isinstance(self.ax[ax_h], list): - list_ax_bounds = self.return_axis_bounds(ax_h) - list_y_top = [ax_bounds[-1] for ax_bounds in list_ax_bounds] - list_y_bot = [ax_bounds[1] for ax_bounds in list_ax_bounds] - self.y_centers[ax_h] = [y_bot + (y_top - y_bot) / 2. for y_top, y_bot in zip(list_y_top, list_y_bot)] - - else: - # get the axis bounds - ax_bounds = self.return_axis_bounds(ax_h) - y_top = ax_bounds[-1] - y_bot = ax_bounds[1] - - self.y_centers[ax_h] = y_bot + (y_top - y_bot) / 2. - - # function to set the scientific notation limits - def set_notation_scientific(self, list_ax_handles, n=3): - # set the formatter - fmt = ticker.ScalarFormatter() - fmt.set_powerlimits((-n, n)) - for h in list_ax_handles: - self.ax[h].yaxis.set_major_formatter(fmt) - - return fmt - - # generic function to take an axis handle and make the y-axis even - def ysymmetry(self, ax): - ylim = ax.get_ylim() - yabs_max = np.max(np.abs(ylim)) - ylim_new = (-yabs_max, yabs_max) - ax.set_ylim(ylim_new) - - return ylim_new - - # equalizes SIZE of the ylim but keeps the center of the axis - # whatever makes sense for the data - def equalize_ylim_size(self, list_handles): - list_ylim_size = [] - - # create a dict of ylims from list_handles - ylim = dict.fromkeys(list_handles) - - # grab the current sizes - for h in list_handles: - # outputs of tuples for dict entries in ylim - ylim[h] = f.ax[h].get_ylim() - ylim_size = np.abs(ylim[h][-1] - ylim[h][0]) - list_ylim_size.append(ylim_size) - - # figure out which was the biggest - ylim_size_max = np.max(list_ylim_size) - - # iterate through the handles again, if the size is less than the max size, - # then adjust it appropriately - - # checks on all yaxes and then sets them - def equalize_ylim(self, list_handles): - list_ylim = [] - - # assumes axes are in a self.ax dictionary, and the keys of the dict - # are the names given in list_handles - for h in list_handles: - ymin_local, ymax_local = self.ax[h].get_ylim() - - # append to list - list_ylim.extend([ymin_local, ymax_local]) - - # calculate the ylim - ylim = [np.min(list_ylim), np.max(list_ylim)] - - # now set for all handles - for h in list_handles: - self.ax[h].set_ylim(ylim) - - return ylim - - # equalizing the color lims is a slightly different process that requires the pc_dict to be passed - def equalize_speclim(self, pc_dict): - list_lim_spec = [] - - # assume that the handles have clims that will be assigned by the pc_dict - for h in pc_dict.keys(): - vmin, vmax = pc_dict[h].get_clim() - list_lim_spec.extend([vmin, vmax]) - - # create a ylim from list_lim_spec - ylim = (np.min(list_lim_spec), np.max(list_lim_spec)) - - for h in pc_dict.keys(): - # can this be done: set_clim(ylim) - pc_dict[h].set_clim(ylim[0], ylim[1]) - - return ylim - - # set the font size globally - def set_fontsize(self, s): - font_prop = { - 'size': s, - } - mpl.rc('font', **font_prop) - - # sets the FIRST line found to black for a given axis or list of axes - def set_linecolor(self, ax_key, str_color): - if ax_key in self.ax.keys(): - if isinstance(self.ax[ax_key], list): - for item in self.ax[ax_key]: - item.lines[0].set_color(str_color) - else: - self.ax[ax_key].lines[0].set_color(str_color) - - # creates title string based on params that change during simulation - # title_str = ac.create_title(blah) - # title_str = f.create_title(blah) - def set_title(self, fparam, key_types): - # get param dict - p_dict = paramrw.read(fparam)[1] - - # create_title() is external fn - title = create_title(p_dict, key_types) - self.f.suptitle(title) - - # turns off top and right frame of an axis - def set_frame_off(self, ax_handle): - self.ax[ax_handle].spines['right'].set_visible(False) - self.ax[ax_handle].spines['top'].set_visible(False) - - self.ax[ax_handle].xaxis.set_ticks_position('bottom') - self.ax[ax_handle].yaxis.set_ticks_position('left') - - # generic function to remove xticklabels from a bunch of axes based on handle - def remove_tick_labels(self, list_ax_handles, ax_xy='x'): - for ax_handle in list_ax_handles: - if ax_handle in self.ax.keys(): - if ax_xy == 'x': - self.ax[ax_handle].set_xticklabels('') - elif ax_xy == 'y': - self.ax[ax_handle].set_yticklabels('') - - # generic save png function to file_name at dpi=dpi_set - def savepng(self, file_name, dpi_set=300): - self.f.savefig(file_name, dpi=dpi_set) - - # new png save - def savepng_new(self, dpng, fprefix, dpi_set=300): - # add png - fname = os.path.join(dpng, fprefix+'.png') - self.f.savefig(fname, dpi=dpi_set) - - # generic save, works for png but supposed to be for eps - def saveeps(self, deps, fprefix): - fname = os.path.join(deps, fprefix+'.eps') - self.f.savefig(fname) - - # obligatory close function - def close(self): - plt.close(self.f) - -# Simple one axis window -class FigStd(FigBase): - def __init__(self): - FigBase.__init__(self) - - self.f = plt.figure(figsize=(12, 6)) - self.set_fontsize(8) - - gs0 = gridspec.GridSpec(1, 1) - self.ax = { - 'ax0': self.f.add_subplot(gs0[:]), - } - - # this is a bad way of ensuring backward compatibility - self.ax0 = self.ax['ax0'] - -class FigDplWithHist(FigBase): - def __init__(self): - self.f = plt.figure(figsize=(12, 6)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # dipole gridpec - self.gs0 = gridspec.GridSpec(1, 1, wspace=0.05, hspace=0, bottom=0.10, top=0.55, left = 0.1, right = 0.90) - - # hist gridspec - self.gs1 = gridspec.GridSpec(2, 1, hspace=0.14 , bottom=0.60, top=0.95, left = 0.1, right = 0.90) - - # create axes - self.ax = {} - self.ax['dipole'] = self.f.add_subplot(self.gs0[:, :]) - self.ax['feed_prox'] = self.f.add_subplot(self.gs1[1, :]) - self.ax['feed_dist'] = self.f.add_subplot(self.gs1[0, :]) - - # setting the properties of a histogram - def set_hist_props(self, hist_data): - for key in self.ax.keys(): - if 'feed' in key: - if hist_data[key] is not None: - max_n = max(hist_data[key][0]) - self.ax[key].set_yticks(np.arange(0, max_n+2, np.ceil((max_n+2.) / 4.))) - - if 'feed_dist' in key: - self.ax[key].set_xticklabels('') - - def save(self, file_name): - self.f.savefig(file_name) - -# spec plus dipole plus alpha feed histograms -class FigSpecWithHist(FigBase): - def __init__(self): - self.f = plt.figure(figsize=(8, 8)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # the right margin is a hack and NOT guaranteed! - # it's making space for the stupid colorbar that creates a new grid to replace gs1 - # when called, and it doesn't update the params of gs1 - self.gs0 = gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.05, top=0.45, left=0.1, right=1.) - self.gs1 = gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.50, top=0.70, left=0.1, right=0.82) - self.gs2 = gridspec.GridSpec(2, 1, hspace=0.14, bottom=0.75, top=0.95, left = 0.1, right = 0.82) - - self.ax = {} - self.ax['spec'] = self.f.add_subplot(self.gs0[:, :]) - self.ax['dipole'] = self.f.add_subplot(self.gs1[:, :]) - self.ax['feed_prox'] = self.f.add_subplot(self.gs2[1, :]) - self.ax['feed_dist'] = self.f.add_subplot(self.gs2[0, :]) - - # self.__set_hist_props() - - def set_hist_props(self, hist_data): - for key in self.ax.keys(): - if 'feed' in key: - if hist_data[key] is not None: - max_n = max(hist_data[key][0]) - self.ax[key].set_yticks(np.arange(0, max_n+2, np.ceil((max_n+2.) / 4.))) - - if 'feed_dist' in key: - self.ax[key].set_xticklabels('') - -# spec plus dipole plus alpha feed histograms -class FigPhase(FigBase): - def __init__(self): - self.f = plt.figure(figsize=(8, 12)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # the right margin is a hack and NOT guaranteed! - # it's making space for the stupid colorbar that creates a new grid to replace gs1 - # when called, and it doesn't update the params of gs1 - self.gs0 = gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.05, top=0.3, left=0.1, right=1.) - self.gs1 = gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.35, top=0.6, left=0.1, right=1.) - self.gs2 = gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.65, top=0.775, left=0.1, right=0.82) - self.gs3 = gridspec.GridSpec(2, 1, hspace=0.14, bottom=0.825, top=0.95, left = 0.1, right = 0.82) - - self.ax = {} - self.ax['phase'] = self.f.add_subplot(self.gs0[:, :]) - self.ax['spec'] = self.f.add_subplot(self.gs1[:, :]) - self.ax['dipole'] = self.f.add_subplot(self.gs2[:, :]) - self.ax['input'] = self.f.add_subplot(self.gs3[:, :]) - -# spec plus dipole -class FigSpec(FigBase): - def __init__(self): - self.f = plt.figure(figsize=(8, 6)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # the right margin is a hack and NOT guaranteed! - # it's making space for the stupid colorbar that creates a new grid to replace gs1 - # when called, and it doesn't update the params of gs1 - self.gspec = { - 'dpl': gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.85, top=0.95, left=0.1, right=0.82), - 'spec': gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.30, top=0.80, left=0.1, right=1.), - 'pgram': gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.05, top=0.25, left=0.1, right=0.82), - } - - self.ax = {} - self.ax['dipole'] = self.f.add_subplot(self.gspec['dpl'][:, :]) - self.ax['spec'] = self.f.add_subplot(self.gspec['spec'][:, :]) - self.ax['pgram'] = self.f.add_subplot(self.gspec['pgram'][:, :]) - -class FigInterval(FigBase): - def __init__(self, N_trials): - self.f = plt.figure(figsize=(4, N_trials)) - self.set_fontsize(12) - - self.gspec = gridspec.GridSpec(1, 1, right=0.5) - - self.ax = {} - self.ax['ts'] = self.f.add_subplot(self.gspec[:, :]) - self.ax['ts'].hold(True) - self.ax['ts'].set_yticklabels([]) - self.set_frame_off('ts') - -class FigFreqpwrWithHist(FigBase): - def __init__(self): - self.f = plt.figure(figsize = (12, 6)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # One gridspec for both plots - self.gs0 = gridspec.GridSpec(1, 2, bottom=0.20, top = 0.80, left=0.1, right=0.90, wspace = 0.1) - - self.ax = {} - self.ax['freqpwr'] = self.f.add_subplot(self.gs0[0, 1]) - self.ax['hist'] = self.f.add_subplot(self.gs0[0, 0]) - - def set_hist_props(self, hist_data): - max_n = max(hist_data) - self.ax['hist'].set_yticks(np.arange(0, max_n+2, np.ceil((max_n+2.) / 4.))) - - def save(self, file_name): - self.f.savefig(file_name) - -class FigRaster(FigBase): - def __init__(self, tstop): - self.tstop = tstop - self.f = plt.figure(figsize=(6, 8)) - - grid0 = gridspec.GridSpec(5, 1) - grid0.update(wspace=0.05, hspace=0., bottom=0.05, top=0.45) - - grid1 = gridspec.GridSpec(5, 1) - grid1.update(wspace=0.05, hspace=0., bottom=0.50, top=0.95) - - self.ax = {} - - self.__panel_create(grid1, 'L2') - self.__panel_create(grid0, 'L5') - - for key in self.ax.keys(): - if key == 'L5_extinput': - self.__bottom_panel_prop(self.ax[key]) - - else: - self.__raster_prop(self.ax[key]) - - def __panel_create(self, grid, layer): - self.ax[layer] = self.f.add_subplot(grid[:2, :]) - self.ax[layer+'_extgauss'] = self.f.add_subplot(grid[2:3, :]) - self.ax[layer+'_extpois'] = self.f.add_subplot(grid[3:4, :]) - self.ax[layer+'_extinput'] = self.f.add_subplot(grid[4:, :]) - - def __bottom_panel_prop(self, ax): - ax.set_yticklabels('') - ax.set_xlim(0, self.tstop) - - def __raster_prop(self, ax): - ax.set_yticklabels('') - ax.set_xticklabels('') - ax.set_xlim(0, self.tstop) - -class FigPSTH(FigBase): - def __init__(self, tstop): - self.tstop = tstop - self.f = plt.figure(figsize=(6, 5)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - grid0 = gridspec.GridSpec(6, 2) - grid0.update(wspace=0.05, hspace=0., bottom=0.05, top=0.95) - - self.ax = {} - - self.ax['L2'] = self.f.add_subplot(grid0[:2, :1], title='Layer 2') - self.ax['L2_psth'] = self.f.add_subplot(grid0[2:4, :1]) - self.ax['L2_extgauss'] = self.f.add_subplot(grid0[4:5, :1]) - self.ax['L2_extpois'] = self.f.add_subplot(grid0[5:, :1], xlabel='Time (ms)') - - self.ax['L5'] = self.f.add_subplot(grid0[:2, 1:], title='Layer 5') - self.ax['L5_psth'] = self.f.add_subplot(grid0[2:4, 1:]) - self.ax['L5_extgauss'] = self.f.add_subplot(grid0[4:5, 1:]) - self.ax['L5_extpois'] = self.f.add_subplot(grid0[5:, 1:], xlabel='Time (ms)') - - for key in self.ax.keys(): - if key.endswith('_extpois'): - self.__bottom_panel_prop(self.ax[key]) - - elif key.endswith('_psth'): - self.__psth_prop(self.ax[key]) - - else: - self.__raster_prop(self.ax[key]) - - grid0.tight_layout(self.f, rect=[0, 0, 1, 1], h_pad=0., w_pad=1) - - def __bottom_panel_prop(self, ax): - ax.set_yticklabels('') - ax.set_xlim(0, self.tstop) - ax.get_xticklabels() - # locs, labels = plt.xticks() - # plt.setp(labels, rotation=45) - - def __psth_prop(self, ax): - # ax.set_yticklabels('') - ax.set_xticklabels('') - ax.set_xlim(0, self.tstop) - - for tick in ax.yaxis.get_major_ticks(): - tick.label1On = False - tick.label2On = True - - def __raster_prop(self, ax): - ax.set_yticklabels('') - ax.set_xticklabels('') - ax.set_xlim(0, self.tstop) - -# create a grid of psth figures, and rasters(?) -class FigGrid(FigBase): - def __init__(self, N_rows, N_cols, tstop): - self.tstop = tstop - - # changes over rows and cols to inches (?) and scales - self.f = plt.figure(figsize=(2*N_cols, 2*N_rows)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - self.grid_list = [] - self.__create_grids(N_rows, N_cols) - - # axes are a list of lists here - self.ax = [] - self.__create_axes() - - def __create_grids(self, N_rows, gs_cols): - gs_rows = 3 - self.grid_list = [gridspec.GridSpec(gs_rows, gs_cols) for i in range(N_rows)] - ytop = 0.075 - ybottom = 0.05 - ypad = 0.02 - ypanel = (1 - ytop - ybottom - ypad*(N_rows-1)) / N_rows - # print ypanel - - i = 0 - ystart = 1-ytop - - # used to pre-calculate this, but whatever - for grid in self.grid_list: - # start at the top to order the rows down - grid.update(wspace=0.05, hspace=0., bottom=ystart-ypanel, top=ystart) - # grid.update(wspace=0.05, hspace=0., bottom=0.05, top=0.95) - ystart -= ypanel+ypad - i += 1 - - # creates a list of lists of axes - def __create_axes(self): - for grid in self.grid_list: - ax_list = [] - for i in range(grid._ncols): - ax_list.append(self.f.add_subplot(grid[:, i:i+1])) - ax_list[-1].set_yticks([0, 100., 200., 300., 400., 500.]) - - # clear y-tick labels for everyone but the bottom - for ax in ax_list: - ax.set_xticklabels('') - - # clear y-tick labels for everyone but the left side - for ax in ax_list[1:]: - ax.set_yticklabels('') - self.ax.append(ax_list) - - # set a timescale for just the last axis - self.ax[-1][-1].set_xticks([0., 250., 500.]) - self.ax[-1][-1].set_xticklabels([0., 250., 500.]) - - # testing usage of string in title - # self.ax[0][0].set_title(r'$\lambda_i$ = %d' % 0) - -class FigAggregateSpecWithHist(FigBase): - def __init__(self, N_rows, N_cols): - self.N_rows = N_rows - self.N_cols = N_cols - - self.f = plt.figure(figsize=(2+8*N_cols, 1+8*N_rows), dpi=300) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # margins - self.top_margin = 1. / (2 + 8 * N_rows) - self.left_margin = 2. / (2 + 8 * N_cols) - - # Height is measured from top of figure - # i.e. row at top of figure is considered row 0 - # This is the opposite of matplotlib conventions - # White space accounting is kind of wierd. Sorry. - self.gap_height = 0.1 / (N_rows + 1) - height = (0.9 - self.top_margin) / N_rows - top = 1. - self.top_margin - self.gap_height - bottom = top - height - - # Width is measured from left of figure - # This is inline with matplotlib conventions - # White space accounting it kind of wierd. Sorry - self.gap_width = 0.15 / (N_cols + 1.) - width = (0.85 - self.left_margin) / N_cols - left = self.left_margin + self.gap_width - right = left + width - - # Preallocate some lists - self.gs0_list = [] - self.gs1_list = [] - self.gs2_list = [] - self.ax_list = [] - - # iterate over all rows/cols and create axes for each location - for row, col in it.product(range(0, N_rows), range(0, N_cols)): - # left and right margins for this set of axes - tmp_left = left + width * col + self.gap_width * col - tmp_right = right + width * col + self.gap_width * col - - # top and bottom margins for this set of axes - bottom_spec = bottom - height * row - self.gap_height * row - top_spec = bottom_spec + (0.4 - self.top_margin / 5.) / N_rows - - bottom_dpl = top_spec + (0.05 - self.top_margin / 5.) / N_rows - top_dpl = bottom_dpl + (0.2 - self.top_margin / 5.) / N_rows - - bottom_hist = top_dpl + (0.05 - self.top_margin / 5.) / N_rows - top_hist = bottom_hist + (0.2 - self.top_margin / 5.) / N_rows - - # tmp_top = top - height * row - self.gap_height * row - # tmp_bottom = bottom - height * row - self.gap_height * row - - # Create gridspecs - self.gs0_list.append(gridspec.GridSpec(1, 4, wspace=0., hspace=0., bottom=bottom_spec, top=top_spec, left=tmp_left, right=tmp_right)) - self.gs1_list.append(gridspec.GridSpec(2, 1, bottom=bottom_dpl, top=top_dpl, left=tmp_left, right=tmp_right-0.18 / N_cols)) - self.gs2_list.append(gridspec.GridSpec(2, 1, hspace=0.14, bottom=bottom_hist, top=top_hist, left=tmp_left, right = tmp_right-0.18 / N_cols)) - - # create axes - ax = {} - ax['spec'] = self.f.add_subplot(self.gs0_list[-1][:, :]) - ax['dipole'] = self.f.add_subplot(self.gs1_list[-1][:, :]) - ax['feed_prox'] = self.f.add_subplot(self.gs2_list[-1][1, :]) - ax['feed_dist'] = self.f.add_subplot(self.gs2_list[-1][0, :]) - - # store axes - # SUPER IMPORTANT: this list iterates across rows!!!!! - self.ax_list.append(ax) - - def set_hist_props(self, ax, hist_data): - for key in ax.keys(): - if 'feed' in key: - max_n = max(hist_data[key][0]) - ax[key].set_yticks(np.arange(0, max_n+2, np.ceil((max_n+2.) / 4.))) - - if 'feed_dist' in key: - ax[key].set_xticklabels('') - - # def add_column_labels(self, param_list): - def add_column_labels(self, param_list, key): - # override = {'fontsize': 8*self.N_cols} - - gap = (0.85 - self.left_margin) / self.N_cols + self.gap_width - - for i in range(0, self.N_cols): - p_dict = paramrw.read(param_list[i])[1] - - x = self.left_margin + gap / 2. + gap * i - y = 1 - self.top_margin / 2. - - self.f.text(x, y, key+' :%2.1f' %p_dict[key], fontsize=36, horizontalalignment='center', verticalalignment='top') - - # self.ax_list[i]['feed_dist'].set_title(key + ': %2.1f' %p_dict[key], **override) - - def add_row_labels(self, param_list, key): - gap = (0.9 - self.top_margin) / self.N_rows + self.gap_height - - for i in range(0, self.N_rows): - ind = self.N_cols * i - p_dict = paramrw.read(param_list[ind])[1] - - # place text in middle of each row of axes - x = self.left_margin / 2. - y = 1. - self.top_margin - self.gap_height - gap / 2. - gap * i - - # self.f.text(x, y, key+': %s' %p_dict[key], fontsize=36, rotation='vertical', horizontalalignment='left', verticalalignment='center') - - # try using key as a key in param dict - try: - self.f.text(x, y, key+': %s' %p_dict[key], fontsize=36, rotation='vertical', horizontalalignment='left', verticalalignment='center') - - # if this doesn't work, use individual parts of key as labels - except: - # check to see if there are enough args in key - if len(key) == self.N_rows: - self.f.text(x, y, key[i], fontsize=36, rotation='vertical', horizontalalignment='left', verticalalignment='center') - - # if not, do nothing - else: - print("Dude, the number of labels don't match the number of rows. I can't do nothing now.") - return 0 - - def save(self, file_name): - self.f.savefig(file_name, dpi=250) - -# aggregate figures for the experiments -class FigDipoleExp(FigBase): - def __init__(self, ax_handles): - FigBase.__init__(self) - # ax_handles is a list of axis handles in order - # previously called N_expmt_groups for legacy reasons (original intention) - # now generally repurposed for arbitrary numbers of axes with these handle names - self.ax_handles = ax_handles - self.N_expmt_groups = len(ax_handles) - self.f = plt.figure(figsize=(8, 2*self.N_expmt_groups)) - font_prop = {'size': 8} - mpl.rc('font', **font_prop) - - # create a gridspec that has width of "50" - # there is some dark magic here whereby colorbars change the original axis by some - # unspecified dimension. Rescaling non-spec axes to 40/50 is not the same as rescaling - # non-spec axes 4/5, for reason that's unclear to me at the time of this writing - # 40/50 works though - # 'spec' must be specified in the name of the spec - self.gspec = gridspec.GridSpec(self.N_expmt_groups, 55) - self.__create_axes() - self.__set_ax_props() - - def __create_axes(self): - # self.ax = [self.f.add_subplot(self.gspec[i:(i+1)]) for i in range(self.N_expmt_groups)] - self.ax = dict.fromkeys(self.ax_handles) - - # iterating like this because indices are useful in defining the recursive gspec locations - i = 0 - for ax in self.ax_handles: - if 'spec' not in ax: - self.ax[ax] = self.f.add_subplot(self.gspec[i:(i+1), :20]) - self.ax[ax+'_L5'] = self.f.add_subplot(self.gspec[i:(i+1), 30:50]) - # self.ax[ax] = self.f.add_subplot(self.gspec[i:(i+1), :40]) - else: - self.ax[ax] = self.f.add_subplot(self.gspec[i:(i+1), :25]) - self.ax[ax+'_L5'] = self.f.add_subplot(self.gspec[i:(i+1), 30:]) - # self.ax[ax] = self.f.add_subplot(self.gspec[i:(i+1), :]) - - i += 1 - - # if ax_twinx keys exist, the keys will mirror those in self.ax - self.ax_twinx = {} - - # extern function to create a colorbar on an arbitrary axis - # creates and rescales the specified axis and then scales down the rest of the axes accordingly - # I hope - def create_colorbar_axis(self, ax_name): - # print self.ax[N_ax] - cax, kw = mpl.colorbar.make_axes_gridspec(self.ax[ax_name]) - # a = self.ax[N_ax].get_axes() - # for item in dir(self.ax[N_ax]): - # if not item.startswith('__'): - # print item - - # take an external list of dipoles and plot them - # such a list is created externally - def plot(self, t, dpl_list): - if len(dpl_list) == self.N_expmt_groups: - # list of max and min dipoles for each in dpl_list - dpl_max = [] - dpl_min = [] - - # check on all the mins and maxes - for dpl, ax_name in zip(dpl_list, self.ax_handles): - self.ax[ax_name].plot(t, dpl) - ylim_tmp = ax.get_ylim() - - dpl_min.append(ylim_tmp[0]) - dpl_max.append(ylim_tmp[1]) - - # find the overall min and max - ymin = np.min(dpl_min) - ymax = np.max(dpl_max) - - # set the ylims for all, the same - for ax_name in self.ax.keys(): - self.ax[ax_name].set_ylim((ymin, ymax)) - - def __set_ax_props(self): - # remove xtick labels for everyone but the last axis - for ax_name in self.ax_handles[:-1]: - self.ax[ax_name].set_xticklabels('') - # for ax in self.ax[:-1]: - # ax.set_xticklabels('') - -# creates title string based on params that change during simulation -# title_str = ac.create_title(blah) -# title_str = f.create_title(blah) -def create_title(p_dict, key_types): - title = [] - - for key in key_types['dynamic_keys']: - # Rules for when to use scientific notation - if p_dict[key] >= 0.1 or p_dict[key] == 0: - title.append(key + ': %2.1f' %p_dict[key]) - else: - title.append(key + ': %2.1e' %p_dict[key]) - - # Return string in alphabetical order - title.sort() - return title - -# just a quick test for running this function -def testfn(): - x = np.random.rand(100) - - # testfig = FigStd() - # testfig.ax0.plot(x) - - ax_handles = [ - 'spec', - 'test1', - 'test2', - ] - - # testfig = FigDipoleExp(ax_handles) - # testfig.create_colorbar_axis('spec') - # testfig.create_colorbar_axis('spectest') - # testfig.ax['spec'].plot(x) - - # testfig = FigSpecWithHist() - # testfig = FigAggregateSpecWithHist(3, 3) - # testfig.ax['spec'].plot(x) - - # testfig = FigSpecWithHist() - # testfig.ax['spec'].plot(x) - # testfig.ax0.plot(x) - - # testfig = FigGrid(3, 3, 100) - - # testfig = FigPSTH(100) - # testfig.ax['L5_extpois'].plot(x) - - testfig = FigPhase() - # testfig.ax['dipole'].plot(x) - - plt.savefig('testing.png', dpi=250) - testfig.close() - -if __name__ == '__main__': - testfn() diff --git a/cartesian.py b/cartesian.py deleted file mode 100644 index 509aa1007..000000000 --- a/cartesian.py +++ /dev/null @@ -1,70 +0,0 @@ -# cartesian.py - returns cartesian product of a list of np arrays -# from StackOverflow.com -# -# v 1.10.0-py35 -# rev 2015-05-01 (SL: deprecated, tested) -# last major: (SL: imported from Stack Overflow) - -import numpy as np -import itertools as it - -def cartesian(arrays, out=None): - """ Parameters - ---------- - arrays : list of array-like - 1-D arrays to form the cartesian product of. - out : ndarray - Array to place the cartesian product in. - - Returns - ------- - out : ndarray - 2-D array of shape (M, len(arrays)) containing cartesian products - formed of input arrays. - - Examples - -------- - >>> cartesian(([1, 2], [4, 5], [6, 7])) - array([[1, 4, 6], - [1, 4, 7], - [1, 5, 6], - [1, 5, 7], - [2, 4, 6], - [2, 4, 7], - [2, 5, 6], - [2, 5, 7]) - """ - try: - xrange - - except: - xrange = range - - arrays = [np.asarray(x) for x in arrays] - - n = np.prod([x.size for x in arrays]) - if out is None: - out = np.zeros([n, len(arrays)], dtype='float64') - - m = int(n / arrays[0].size) - - out[:, 0] = np.repeat(arrays[0], m) - - if arrays[1:]: - cartesian(arrays[1:], out=out[:m, 1:]) - - for j in xrange(1, arrays[0].size): - out[j*m:(j+1)*m,1:] = out[:m, 1:] - - return out - -if __name__ == '__main__': - arrs = [np.arange(3), np.array([0]), np.array([5, 6, 7]), np.array([0, 1])] - out = cartesian(arrs) - print('old way') - print(out) - - out_it = [item for item in it.product(*arrs)] - print('it way') - for item in out_it: - print(item) diff --git a/cell.py b/cell.py deleted file mode 100644 index be1839ce3..000000000 --- a/cell.py +++ /dev/null @@ -1,402 +0,0 @@ -# cell.py - establish class def for general cell features -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: python3 compatibility) -# last rev: (SL: added list_IClamp as a pre-defined variable) - -import numpy as np -from neuron import h - -# global variables, should be node-independent -h("dp_total_L2 = 0."); h("dp_total_L5 = 0.") # put here since these variables used in cells - -# Units for e: mV -# Units for gbar: S/cm^2 - -# Create a cell class -class Cell (): - - def __init__ (self, gid, soma_props): - self.gid = gid - self.pc = h.ParallelContext() # Parallel methods - # make L_soma and diam_soma elements of self - # Used in shape_change() b/c func clobbers self.soma.L, self.soma.diam - self.L = soma_props['L'] - self.diam = soma_props['diam'] - self.pos = soma_props['pos'] - # create soma and set geometry - self.soma = h.Section(cell=self, name=soma_props['name']+'_soma') - self.soma.L = soma_props['L'] - self.soma.diam = soma_props['diam'] - self.soma.Ra = soma_props['Ra'] - self.soma.cm = soma_props['cm'] - # variable for the list_IClamp - self.list_IClamp = None - # par: create arbitrary lists of connections FROM other cells - # TO this cell instantiation - # these lists are allowed to be empty - # this should be a dict - self.ncfrom_L2Pyr = [] - self.ncfrom_L2Basket = [] - self.ncfrom_L5Pyr = [] - self.ncfrom_L5Basket = [] - self.ncfrom_extinput = [] - self.ncfrom_extgauss = [] - self.ncfrom_extpois = [] - self.ncfrom_ev = [] - - def record_volt_soma (self): - self.vsoma = h.Vector() - self.vsoma.record(self.soma(0.5)._ref_v) - - def get_sections (self): return [self.soma] - - def get3dinfo (self): - ls = self.get_sections() - lx,ly,lz,ldiam=[],[],[],[] - for s in ls: - for i in range(s.n3d()): - lx.append(s.x3d(i)) - ly.append(s.y3d(i)) - lz.append(s.z3d(i)) - ldiam.append(s.diam3d(i)) - return lx,ly,lz,ldiam - - # get cell's bounding box - def getbbox (self): - lx,ly,lz,ldiam = self.get3dinfo() - minx,miny,minz = 1e9,1e9,1e9 - maxx,maxy,maxz = -1e9,-1e9,-1e9 - for x,y,z in zip(lx,ly,lz): - minx = min(x,minx) - miny = min(y,miny) - minz = min(z,minz) - maxx = max(x,maxx) - maxy = max(y,maxy) - maxz = max(z,maxz) - return ((minx,maxx), (miny,maxy), (minz,maxz)) - - def translate3d (self, dx, dy, dz): - #s = self.soma - #for i in range(s.n3d()): - # h.pt3dchange(i,s.x3d(i)+dx,s.y3d(i)+dy,s.z3d(i)+dz,s.diam3d(i),sec=s) - for s in self.get_sections(): - for i in range(s.n3d()): - #print(s,i,s.x3d(i)+dx,s.y3d(i)+dy,s.z3d(i)+dz,s.diam3d(i)) - h.pt3dchange(i,s.x3d(i)+dx,s.y3d(i)+dy,s.z3d(i)+dz,s.diam3d(i),sec=s) - - def translateto (self, x, y, z): - x0 = self.soma.x3d(0) - y0 = self.soma.y3d(0) - z0 = self.soma.z3d(0) - dx = x - x0 - dy = y - y0 - dz = z - z0 - # print('dx:',dx,'dy:',dy,'dz:',dz) - self.translate3d(dx,dy,dz) - - def movetopos (self): - self.translateto(self.pos[0]*100,self.pos[2],self.pos[1]*100) - - # two things need to happen here for h: - # 1. dipole needs to be inserted into each section - # 2. a list needs to be created with a Dipole (Point Process) in each section at position 1 - # In Cell() and not Pyr() for future possibilities - def dipole_insert (self, yscale): - # insert dipole into each section of this cell - # dends must have already been created!! - # it's easier to use wholetree here, this includes soma - seclist = h.SectionList() - seclist.wholetree(sec=self.soma) - # create a python section list list_all - self.list_all = [sec for sec in seclist] - for sect in self.list_all: - sect.insert('dipole') - # Dipole is defined in dipole_pp.mod - self.dipole_pp = [h.Dipole(1, sec=sect) for sect in self.list_all] - # setting pointers and ztan values - for sect, dpp in zip(self.list_all, self.dipole_pp): - # assign internal resistance values to dipole point process (dpp) - dpp.ri = h.ri(1, sec=sect) - # sets pointers in dipole mod file to the correct locations - # h.setpointer(ref, ptr, obj) - h.setpointer(sect(0.99)._ref_v, 'pv', dpp) - if self.celltype.startswith('L2'): - h.setpointer(h._ref_dp_total_L2, 'Qtotal', dpp) - elif self.celltype.startswith('L5'): - h.setpointer(h._ref_dp_total_L5, 'Qtotal', dpp) - # gives INTERNAL segments of the section, non-endpoints - # creating this because need multiple values simultaneously - loc = np.array([seg.x for seg in sect]) - # these are the positions, including 0 but not L - pos = np.array([seg.x for seg in sect.allseg()]) - # diff in yvals, scaled against the pos np.array. y_long as in longitudinal - y_scale = (yscale[sect.name()] * sect.L) * pos - # y_long = (h.y3d(1, sec=sect) - h.y3d(0, sec=sect)) * pos - # diff values calculate length between successive section points - y_diff = np.diff(y_scale) - # y_diff = np.diff(y_long) - # doing range to index multiple values of the same np.array simultaneously - for i in range(len(loc)): - # assign the ri value to the dipole - sect(loc[i]).dipole.ri = h.ri(loc[i], sec=sect) - # range variable 'dipole' - # set pointers to previous segment's voltage, with boundary condition - if i: - h.setpointer(sect(loc[i-1])._ref_v, 'pv', sect(loc[i]).dipole) - else: - h.setpointer(sect(0)._ref_v, 'pv', sect(loc[i]).dipole) - # set aggregate pointers - h.setpointer(dpp._ref_Qsum, 'Qsum', sect(loc[i]).dipole) - if self.celltype.startswith('L2'): - h.setpointer(h._ref_dp_total_L2, 'Qtotal', sect(loc[i]).dipole) - elif self.celltype.startswith('L5'): - h.setpointer(h._ref_dp_total_L5, 'Qtotal', sect(loc[i]).dipole) - # add ztan values - sect(loc[i]).dipole.ztan = y_diff[i] - # set the pp dipole's ztan value to the last value from y_diff - dpp.ztan = y_diff[-1] - - # Add IClamp to a segment - def insert_IClamp (self, sect_name, props_IClamp): - # def insert_iclamp(self, sect_name, seg_loc, tstart, tstop, weight): - # gather list of all sections - seclist = h.SectionList() - seclist.wholetree(sec=self.soma) - # find specified sect in section list, insert IClamp, set props - for sect in seclist: - if sect_name in sect.name(): - stim = h.IClamp(sect(props_IClamp['loc'])) - stim.delay = props_IClamp['delay'] - stim.dur = props_IClamp['dur'] - stim.amp = props_IClamp['amp'] - # stim.dur = tstop - tstart - # stim = h.IClamp(sect(seg_loc)) - # object must exist for NEURON somewhere and needs to be saved - return stim - - # simple function to record current - # for now only at the soma - def record_current_soma (self): - # a soma exists at self.soma - self.rec_i = h.Vector() - try: - # assumes that self.synapses is a dict that exists - list_syn_soma = [key for key in self.synapses.keys() if key.startswith('soma_')] - # matching dict from the list_syn_soma keys - self.dict_currents = dict.fromkeys(list_syn_soma) - # iterate through keys and record currents appropriately - for key in self.dict_currents: - self.dict_currents[key] = h.Vector() - self.dict_currents[key].record(self.synapses[key]._ref_i) - except: - print("Warning in Cell(): record_current_soma() was called, but no self.synapses dict was found") - pass - - # General fn that creates any Exp2Syn synapse type - # requires dictionary of synapse properties - def syn_create (self, secloc, p): - syn = h.Exp2Syn(secloc) - syn.e = p['e'] - syn.tau1 = p['tau1'] - syn.tau2 = p['tau2'] - return syn - - # For all synapses, section location 'secloc' is being explicitly supplied - # for clarity, even though they are (right now) always 0.5. Might change in future - # creates a RECEIVING inhibitory synapse at secloc - def syn_gabaa_create (self, secloc): - syn_gabaa = h.Exp2Syn(secloc) - syn_gabaa.e = -80 - syn_gabaa.tau1 = 0.5 - syn_gabaa.tau2 = 5. - return syn_gabaa - - # creates a RECEIVING slow inhibitory synapse at secloc - # called: self.soma_gabab = syn_gabab_create(self.soma(0.5)) - def syn_gabab_create (self, secloc): - syn_gabab = h.Exp2Syn(secloc) - syn_gabab.e = -80 - syn_gabab.tau1 = 1 - syn_gabab.tau2 = 20. - return syn_gabab - - # creates a RECEIVING excitatory synapse at secloc - # def syn_ampa_create(self, secloc, tau_decay, prng_obj): - def syn_ampa_create (self, secloc): - syn_ampa = h.Exp2Syn(secloc) - syn_ampa.e = 0. - syn_ampa.tau1 = 0.5 - syn_ampa.tau2 = 5. - return syn_ampa - - # creates a RECEIVING nmda synapse at secloc - # this is a pretty fast NMDA, no? - def syn_nmda_create (self, secloc): - syn_nmda = h.Exp2Syn(secloc) - syn_nmda.e = 0. - syn_nmda.tau1 = 1. - syn_nmda.tau2 = 20. - return syn_nmda - - # connect_to_target created for pc, used in Network() - # these are SOURCES of spikes - def connect_to_target (self, target, threshold): - nc = h.NetCon(self.soma(0.5)._ref_v, target, sec=self.soma) - nc.threshold = threshold - return nc - - # parallel receptor-centric connect FROM presyn TO this cell, based on GID - def parconnect_from_src (self, gid_presyn, nc_dict, postsyn): - # nc_dict keys are: {pos_src, A_weight, A_delay, lamtha} - nc = self.pc.gid_connect(gid_presyn, postsyn) - # calculate distance between cell positions with pardistance() - d = self.__pardistance(nc_dict['pos_src']) - # set props here - nc.threshold = nc_dict['threshold'] - nc.weight[0] = nc_dict['A_weight'] * np.exp(-(d**2) / (nc_dict['lamtha']**2)) - nc.delay = nc_dict['A_delay'] / (np.exp(-(d**2) / (nc_dict['lamtha']**2))) - # print("parconnect_from_src in cell.py, weight = ",nc.weight[0]) - #fp = open('delays.txt','a'); fp.write(str(d)+' '+str(nc_dict['A_delay'])+' ' +str(nc.delay)+'\n'); fp.close() - #fp = open('weights.txt','a'); fp.write(str(d)+' '+str(nc_dict['A_weight'])+' ' +str(nc.weight[0])+'\n'); fp.close() - #fp = open('prepostty.txt','a'); fp.write(nc_dict['type_src']+' '+self.celltype+'\n'); fp.close() - - return nc - - # pardistance function requires pre position, since it is calculated on POST cell - def __pardistance (self, pos_pre): - dx = self.pos[0] - pos_pre[0] - dy = self.pos[1] - pos_pre[1] - #dz = self.pos[2] - pos_pre[2] - return np.sqrt(dx**2 + dy**2) - - # Define 3D shape of soma -- is needed for gui representation of cell - # DO NOT need to call h.define_shape() explicitly!! - def shape_soma (self): - h.pt3dclear(sec=self.soma) - # h.ptdadd(x, y, z, diam) -- if this function is run, clobbers - # self.soma.diam set above - h.pt3dadd(0, 0, 0, self.diam, sec=self.soma) - h.pt3dadd(0, self.L, 0, self.diam, sec=self.soma) - -# Inhibitory cell class -class BasketSingle (Cell): - def __init__ (self, gid, pos, cell_name='Basket'): - self.props = self.__set_props(cell_name, pos) - Cell.__init__(self, gid, self.props) - # store cell name for later - self.name = cell_name - # set 3D shape - unused for now but a prototype - self.__shape_change() - - def __set_props (self, cell_name, pos): - return { - 'pos': pos, - 'L': 39., - 'diam': 20., - 'cm': 0.85, - 'Ra': 200., - 'name': cell_name, - } - - # Define 3D shape and position of cell. By default neuron uses xy plane for - # height and xz plane for depth. This is opposite for model as a whole, but - # convention is followed in this function ease use of gui. - def __shape_change (self): - self.shape_soma() - """ - s = self.soma - for i in range(int(s.n3d())): - h.pt3dchange(i, self.pos[0]*100 + s.x3d(i), -self.pos[2] + s.y3d(i), - self.pos[1] * 100 + s.z3d(i), s.diam3d(i), sec=s) - """ - -# General Pyramidal cell class -class Pyr (Cell): - def __init__ (self, gid, soma_props): - Cell.__init__(self, gid, soma_props) - # store cell_name as self variable for later use - self.name = soma_props['name'] - # preallocate dict to store dends - self.dends = {} - # for legacy use with L5Pyr - self.list_dend = [] - - # Create dictionary of section names with entries to scale section lengths to length along z-axis - def get_sectnames (self): - seclist = h.SectionList() - seclist.wholetree(sec=self.soma) - d = dict((sect.name(), 1.) for sect in seclist) - for key in d.keys(): - # basal_2 and basal_3 at 45 degree angle to z-axis. - if 'basal_2' in key: - d[key] = np.sqrt(2) / 2. - elif 'basal_3' in key: - d[key] = np.sqrt(2) / 2. - # apical_oblique at 90 perpendicular to z-axis - elif 'apical_oblique' in key: - d[key] = 0. - # All basalar dendrites extend along negative z-axis - if 'basal' in key: - d[key] = -d[key] - return d - - def create_dends (self, p_dend_props): - for key in p_dend_props: self.dends[key] = h.Section(name=self.name+'_'+key) # create dend - # apical: 0--4; basal: 5--7 - self.list_dend = [self.dends[key] for key in ['apical_trunk', 'apical_oblique', 'apical_1', 'apical_2', 'apical_tuft', 'basal_1', 'basal_2', 'basal_3'] if key in self.dends] - - def set_dend_props (self, p_dend_props): - # iterate over keys in p_dend_props. Create dend for each key. - for key in p_dend_props: - # set dend props - self.dends[key].L = p_dend_props[key]['L'] - self.dends[key].diam = p_dend_props[key]['diam'] - self.dends[key].Ra = p_dend_props[key]['Ra'] - self.dends[key].cm = p_dend_props[key]['cm'] - # set dend nseg - if p_dend_props[key]['L'] > 100.: - self.dends[key].nseg = int(p_dend_props[key]['L'] / 50.) - # make dend.nseg odd for all sections - if not self.dends[key].nseg % 2: - self.dends[key].nseg += 1 - - # Creates dendritic sections based only on dictionary of dendrite props - def create_dends_new (self, p_dend_props): - # iterate over keys in p_dend_props. Create dend for each key. - for key in p_dend_props: - # create dend - self.dends[key] = h.Section(name=self.name+'_'+key) - - # set dend props - self.dends[key].L = p_dend_props[key]['L'] - self.dends[key].diam = p_dend_props[key]['diam'] - self.dends[key].Ra = p_dend_props[key]['Ra'] - self.dends[key].cm = p_dend_props[key]['cm'] - - # set dend nseg - if p_dend_props[key]['L'] > 100.: - self.dends[key].nseg = int(p_dend_props[key]['L'] / 50.) - - # make dend.nseg odd for all sections - if not self.dends[key].nseg % 2: - self.dends[key].nseg += 1 - - # apical: 0--4 - # basal: 5--7 - self.list_dend = [self.dends[key] for key in ['apical_trunk', 'apical_oblique', 'apical_1', 'apical_2', 'apical_tuft', 'basal_1', 'basal_2', 'basal_3'] if key in self.dends] - - - def get_sections (self): - ls = [self.soma] - for key in ['apical_trunk', 'apical_1', 'apical_2', 'apical_tuft', 'apical_oblique', 'basal_1', 'basal_2', 'basal_3']: - if key in self.dends: - ls.append(self.dends[key]) - return ls - - def get_section_names (self): - ls = ['soma'] - for key in ['apical_trunk', 'apical_1', 'apical_2', 'apical_tuft', 'apical_oblique', 'basal_1', 'basal_2', 'basal_3']: - if key in self.dends: - ls.append(key) - return ls diff --git a/cfg.py b/cfg.py deleted file mode 100644 index ea5148d4b..000000000 --- a/cfg.py +++ /dev/null @@ -1,76 +0,0 @@ -# cfg.py - Simulation configuration -from netpyne import specs - -cfg = specs.SimConfig() - -cfg.checkErrors = False # True # leave as False to avoid extra printouts - - -############################################################################### -# -# SIMULATION CONFIGURATION -# -############################################################################### - -############################################################################### -# Run parameters -############################################################################### -cfg.duration = 1.0*1e3 -cfg.dt = 0.05 -cfg.seeds = {'conn': 4321, 'stim': 1234, 'loc': 4321} -cfg.hParams = {'celsius': 34, 'v_init': -80} -cfg.verbose = 0 -cfg.cvode_active = False -cfg.printRunTime = 0.1 -cfg.printPopAvgRates = True - - -############################################################################### -# Recording -############################################################################### -cfg.recordTraces = {'V_soma': {'sec': 'soma', 'loc': 0.5, 'var': 'v'}} -cfg.recordStims = False -cfg.recordStep = 0.1 - - -############################################################################### -# Saving -############################################################################### -cfg.simLabel = 'sim1' -cfg.saveFolder = 'data' -cfg.savePickle = False -cfg.saveJson = True -cfg.saveDataInclude = ['simData', 'simConfig', 'netParams', 'net'] - - -############################################################################### -# Analysis and plotting -############################################################################### -cfg.analysis['plotTraces'] = {'include': ['L2Pyr','L5Pyr'], 'oneFigPer': 'cell', 'saveFig': True, - 'showFig': False, 'figSize': (10,8), 'timeRange': [0,cfg.duration]} - - -############################################################################### -# Parameters -############################################################################### -cfg.dendNa = 0.0345117294903 -cfg.tau1NMDA = 15 - - -############################################################################### -# Current inputs -############################################################################### -cfg.addIClamp = 0 - -cfg.IClamp1 = {'pop': 'PT5B', 'sec': 'soma', 'loc': 0.5, 'start': 1, 'dur': 1000, 'amp': 0.0} - - -############################################################################### -# NetStim inputs -############################################################################### -cfg.addNetStim = 0 - -cfg.NetStim1 = {'pop': 'PT5B', 'sec': 'soma', 'loc': 0.5, 'synMech': 'NMDA', 'start': 500, - 'interval': 1000, 'noise': 0.0, 'number': 1, 'weight': 0.0, 'delay': 1} - - diff --git a/cli.py b/cli.py deleted file mode 100644 index 94dbb2e7f..000000000 --- a/cli.py +++ /dev/null @@ -1,1064 +0,0 @@ -# cli.py - routines for the command line interface console s1sh.py -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: return_data_dir()) -# last major: (SL: reorganized show and pngv) - -from cmd import Cmd -from datetime import datetime -import ast, multiprocessing, os, signal, subprocess, time -import readline as rl -import itertools as it - -import clidefs -import fileio as fio -import paramrw -import specfn -import dipolefn -import ppsth -import praw - -class Console(Cmd): - def __init__(self, file_input=""): - Cmd.__init__(self) - self.prompt = '\033[93m' + "[s1] " + '\033[0m' - self.intro = "\nThis is the SomatoSensory SHell\n" - self.dproj = fio.return_data_dir() - self.server_default = self.__check_server() - self.f_history = '.s1sh_hist_local' - self.ddate = '' - self.dlast = [] - self.dlist = [] - self.dsim = [] - self.expmts = [] - self.sim_list = [] - self.param_list = [] - self.var_list = [] - self.N_sims = 0 - - # check to see if file_input is legit - if os.path.isfile(file_input): - self.file_input = file_input - - else: - # use a default - self.file_input = 'param/debug.param' - - # get initial count of avail processors for subprocess/multiprocessing routines - self.nprocs = multiprocessing.cpu_count() - - # Create the initial datelist - self.datelist = clidefs.get_subdir_list(self.dproj) - - # create the initial paramfile list - self.__get_paramfile_list() - - # set the date, grabs a dlist - self.do_setdate(datetime.now().strftime("%Y-%m-%d")) - - # splits argstring in format of --opt0=val0 --opt1=val1 - def __split_args(self, args): - # split based on leading -- - args_tmp = args.split(' --') - - # only take the args that start with -- and include a = - # drop the leading -- - # args_opt = [arg[2:] for arg in args_tmp if arg.startswith('--')] - args_opt = [arg for arg in args_tmp if '=' in arg] - arg_list = [] - for arg in args_opt: - # getting rid of first case, ugh, hack! - if arg.startswith('--'): - arg_list.append(arg[2:].split('=')) - else: - arg_list.append(arg.split('=')) - - return arg_list - - def __create_dict_from_args(self, args): - # split based on leading -- - args_tmp = args.split(' --') - - # only take the args that start with -- and include a = - # drop the leading -- - # args_opt = [arg[2:] for arg in args_tmp if arg.startswith('--')] - args_opt = [arg for arg in args_tmp if '=' in arg] - arg_dict = {} - for arg in args_opt: - # getting rid of first case, ugh, hack! - if arg.startswith('--'): - args_tmp = arg[2:].split('=') - arg_dict[args_tmp[0]] = args_tmp[1] - - else: - args_tmp = arg.split('=') - arg_dict[args_tmp[0]] = args_tmp[1] - - return arg_dict - - # generalized function for checking and assigning args - def __check_args(self, dict_opts, list_opts): - # assumes list_opts comes from __split_args() - if len(list_opts): - # iterate through possible key vals in list_opts - keys_missing = [] - for key, val in list_opts: - # check to see if the possible keys are in dict_opts - if key in dict_opts.keys(): - # assign the key/val pair in place - # this operation acts IN PLACE on the supplied dict_opts!! - # therefore, no return value necessary - dict_opts[key] = ast.literal_eval(val) - else: - keys_missing.append(key) - - # if there are any keys missing - if keys_missing: - print "Options were not recognized: " - fio.prettyprint(keys_missing) - - # checks to see if a default server file is found, if not, ask for one - def __check_server(self): - f_server = os.path.join(self.dproj, '.server_default') - - if os.path.isfile(f_server): - # read the file and set the default server - lines_f = fio.clean_lines(f_server) - - # there should only be one thing in this file, so assume that's the server name - return lines_f[0] - else: - return '' - - # create a list of parameter files - def __get_paramfile_list(self): - dparam_default = os.path.join(os.getcwd(), 'param') - self.paramfile_list = [f for f in os.listdir(dparam_default) if f.endswith('.param')] - - def do_debug(self, args): - """Qnd function to test other functions - """ - self.do_setdate('2013-12-04') - self.do_load('ftremor-003') - clidefs.exec_pgamma_spec_fig() - # self.do_pgamma_sub_example2('') - # self.do_setdate('pub') - # self.do_spec_current("--runtype='debug' --f_max=250.") - # self.do_load('2013-06-28_gamma_weak_L5-000') - # self.do_pgamma_hf_epochs('') - # self.do_load('2013-08-12_gamma_sub_50Hz-001') - # self.do_pgamma_spikephase('') - # self.do_pgamma_prox_dist_new('') - # self.do_throwaway('--n_trial=-1') - # self.do_load('2013-08-07_gamma_sub_50Hz_stdev-000') - # self.do_pgamma_stdev_new('--f_max_welch=150.') - # self.do_load('2013-06-28_gamma_sub_f-000') - # self.do_pgamma_stdev_new('--f_max_welch=150.') - # self.do_load('2013-07-15_gamma_L5weak_L2weak-000') - # self.do_pgamma_laminar('') - # self.do_pgamma_compare_ping('') - # self.do_show_dpl_max('') - # self.do_pgamma_peaks('') - # self.do_welch_max('') - # self.do_pgamma_sub_examples('') - # self.do_pgamma_distal_phase('--spec0=5 --spec1=9 --spec2=15') - # self.do_specmax('--expmt_group="weak" --f_interval=[50., 75] --t_interval=[50., 550.]') - # self.do_spike_rates('') - # self.do_save('') - # self.do_calc_dpl_regression('') - # self.do_calc_dpl_mean("--t0=100. --tstop=1000. --layer='L2'") - # self.do_praw('') - # self.do_pdipole('grid') - # self.do_pngv('') - # self.do_show('testing in (0, 0)') - # self.do_calc_dipole_avg('') - # self.do_pdipole('evaligned') - # self.do_avgtrials('dpl') - # self.do_replot('') - # self.do_spec_regenerate('--f_max=50.') - # self.do_addalphahist('--xmin=0 --xmax=500') - # self.do_avgtrials('dpl') - # self.do_dipolemin('in (mu, 0, 2) on [400., 410.]') - # self.epscompress('spk') - # self.do_psthgrid() - - def do_throwaway(self, args): - ''' This is a throwaway dipole saving function. Usage: - [s1] throwaway {--n_sim=12} {--n_trial=3} - ''' - dict_opts = self.__create_dict_from_args(args) - - # run the throwaway save function! - clidefs.exec_throwaway(self.ddata, dict_opts) - # clidefs.exec_throwaway(self.ddata, opts['n_sim'], opts['n_trial']) - - def do_spike_rates(self, args): - opts = { - 'expmt_group': 'weak', - 'celltype': 'L5_pyramidal', - } - l_opts = self.__split_args(args) - self.__check_args(opts, l_opts) - - clidefs.exec_spike_rates(self.ddata, opts) - - def do_calc_dpl_mean(self, args): - '''Returns the mean dipole to screen. Usage: - [s1] calc_dpl_mean - ''' - opts = { - 't0': 50., - 'tstop': -1, - 'layer': 'agg', - } - - l_opts = self.__split_args(args) - self.__check_args(opts, l_opts) - - # run the function - clidefs.exec_calc_dpl_mean(self.ddata, opts) - - def do_calc_dpl_regression(self, args): - clidefs.exec_calc_dpl_regression(self.ddata) - - def do_show_dpl_max(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_show_dpl_max(self.ddata, dict_opts) - - def do_pgamma_peaks(self, args): - clidefs.exec_pgamma_peaks() - - def do_pgamma_sub_examples(self, args): - clidefs.exec_pgamma_sub_examples() - - def do_pgamma_sub_example2(self, args): - clidefs.exec_pgamma_sub_example2() - - def do_pgamma_hf(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_pgamma_hf(self.ddata, dict_opts) - - def do_pgamma_hf_epochs(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_pgamma_hf_epochs(self.ddata, dict_opts) - - def do_pgamma_distal_phase(self, args): - '''Generates gamma fig for distal phase. Requires spec data for layers to exist. Usage: - [s1] spec_current - [s1] pgamma_distal_phase {--spec0=0 --spec1=1, --spec2=2} - ''' - - opts = { - 'spec0': 0, - 'spec1': 1, - 'spec2': 1, - } - - l_opts = self.__split_args(args) - self.__check_args(opts, l_opts) - - clidefs.exec_pgamma_distal_phase(self.ddata, opts) - - def do_pgamma_stdev(self, args): - '''Generates gamma fig for standard deviation. Requires spec data for layers to exist. Usage: - [s1] spec_current - [s1] pgamma_stdev - ''' - clidefs.exec_pgamma_stdev(self.ddata) - - def do_pgamma_stdev_new(self, args): - '''Generates gamma fig for standard deviation. Requires spec data for layers to exist. Usage: - [s1] spec_current - [s1] pgamma_stdev_new - ''' - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_pgamma_stdev_new(self.ddata, dict_opts) - - def do_pgamma_prox_dist_new(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_pgamma_prox_dist_new(self.ddata, dict_opts) - - def do_pgamma_laminar(self, args): - '''Generates a comparison figure of aggregate, L2 and L5 data. Taken from 1 data set. Usage: - [s1] pgamma_laminar - ''' - clidefs.exec_pgamma_laminar(self.ddata) - - def do_pgamma_compare_ping(self, args): - '''Generates gamma fig for comparison of PING and weak PING. Will need 2 data sets! Usage: - [s1] pgamma_compare_ping - ''' - clidefs.exec_pgamma_compare_ping() - - def do_pgamma_spikephase(self, args): - clidefs.exec_pgamma_spikephase() - - def do_spec_current(self, args): - # parse list of opts - dict_opts = self.__create_dict_from_args(args) - - # actually run the analysis - self.spec_current_tmp = clidefs.exec_spec_current(self.ddata, dict_opts) - - def do_praw(self, args): - '''praw is a fully automated function to replace the dipole plots with aggregate dipole/spec/spikes plots. Usage: - [s1] praw - ''' - praw.praw(self.ddata) - - # update the dlist - def __update_dlist(self): - if os.path.exists(self.ddate): - self.dlist = [d for d in os.listdir(self.ddate) if os.path.isdir(os.path.join(self.ddate, d))] - - def do_setdate(self, args): - """Sets the date string to the specified date - """ - if args: - if args == 'today': - dcheck = os.path.join(self.dproj, datetime.now().strftime("%Y-%m-%d")) - else: - dcheck = os.path.join(self.dproj, args) - - if os.path.exists(dcheck): - self.ddate = dcheck - - else: - self.ddate = os.path.join(self.dproj, 'pub') - - self.__update_dlist() - - print "Date set to", self.ddate - - def complete_setdate(self, text, line, j0, J): - """complete function for setdate - """ - if text: - print text - x = [item for item in self.datelist if item.startswith(text)] - if x: - return x - else: - return self.datelist - - def do_load(self, args): - """Load parameter file and regens all vars - Date needs to be set correctly for this to work. See 'help setdate' - Usage example: - [s1sh] setdate 2013-01-01 - [s1sh] load mucomplex-000 - - Running without arguments will load the last modified directory found in the date dir: - [s1] load - """ - if not args: - # attempt to load the most recent in the dproj/ddate - # find the most recent directory in this folder - list_d = [] - - for dsim_short in os.listdir(self.ddate): - # check to see if dsim_tmp is actually a dir - dsim_tmp = os.path.join(self.ddate, dsim_short) - - # append to list along with its modified time (mtime) - if os.path.isdir(dsim_tmp): - list_d.append((dsim_tmp, time.ctime(os.path.getmtime(dsim_tmp)))) - - # sort by mtime - list_d.sort(key=lambda x: x[1]) - - # grab the directory name of the most recent dir - dcheck = list_d[-1][0] - - else: - # dir_check is the attempt at creating this directory - dcheck = os.path.join(self.dproj, self.ddate, args) - - # check existence of the path - if os.path.exists(dcheck): - # create blank ddata structure from SimPaths - self.ddata = fio.SimulationPaths() - - # set dsim after using ddata's readsim method - self.dsim = self.ddata.read_sim(self.dproj, dcheck) - self.p_exp = paramrw.ExpParams(self.ddata.fparam) - print self.ddata.fparam - self.var_list = paramrw.changed_vars(self.ddata.fparam) - - else: - print dcheck - print "Could not find that dir, maybe check your date?" - - def complete_load(self, text, line, j0, J): - """complete function for load - """ - if text: - return [item for item in self.dlist if item.startswith(text)] - - else: - return self.dlist - - def do_sync(self, args): - """Sync with specified remote server. If 'exclude' is unspecified, by default will use the exclude_eps.txt file in the data dir. If exclude is specified, it will look in the root data dir. Usage examples: - [s1] sync 2013-03-25 - [s1] sync 2013-03-25 --exclude=somefile.txt - """ - try: - fshort_exclude = '' - list_args = args.split('--') - - # expect first arg to be the dsubdir - dsubdir = list_args.pop(0) - - for arg in list_args: - if arg: - opt, val = arg.split('=') - - if opt == 'exclude': - fshort_exclude = val - - if not self.server_default: - server_remote = raw_input("Server address: ") - else: - server_remote = self.server_default - print "Attempting to use default server ..." - - # run the command - if fshort_exclude: - clidefs.exec_sync(self.dproj, server_remote, dsubdir, fshort_exclude) - else: - clidefs.exec_sync(self.dproj, server_remote, dsubdir) - - # path - newdir = os.path.join('from_remote', dsubdir) - self.do_setdate(newdir) - - except: - print "Something went wrong here." - - def do_giddict(self, args): - pass - - def do_welch_max(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_welch_max(self.ddata, dict_opts) - - def do_specmax(self, args): - """Find the max spectral power, report value and time. - Usage: specmax {--expmt_group=0 --simrun=0 --trial=0 --t_interval=[0, 1000] --f_interval=[0, 100.]} - """ - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_specmax(self.ddata, dict_opts) - - def do_specmax_dpl_match(self, args): - """Plots dpl around max spectral power over specified time and frequency intervals - usage: specmax_dpl_match --t_interval=[0, 1000] --f_interval=[0, 100] --f_sorted=[0, 100] - """ - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_specmax_dpl_match(self.ddata, dict_opts) - - def do_specmax_dpl_tmpl(self, args): - """Isolates dpl waveforms producing specified spectral frequencies - across trails and averages them to produce a stereotypical waveform - Usage: specmax_dpl_tmpl --expmt_group --n_sim --trials --t_interval - --f_interval --f_sort - """ - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_specmax_dpl_tmpl(self.ddata, dict_opts) - - def do_plot_dpl_tmpl(self, args): - """Plots stereotypical waveforms produced by do_specmax_dpl_tmpl - usage: plot_dpl_tmpl --expmt_group - """ - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_plot_dpl_tmpl(self.ddata, dict_opts) - - def do_dipolemin(self, args): - """Find the minimum of a particular dipole - Usage: dipolemin in (, , ) on [interval] - """ - # look for first keyword - if args.startswith("in"): - try: - # split by 'in' to get the interval - s = args.split(" on ") - - # values are then in first part of s - # yeah, this is gross, sorry. just parsing between parens for params - expmt_group, n_sim_str, n_trial_str = s[0][s[0].find("(")+1:s[0].find(")")].split(", ") - n_sim = int(n_sim_str) - n_trial = int(n_trial_str) - - t_interval = ast.literal_eval(s[-1]) - clidefs.exec_dipolemin(self.ddata, expmt_group, n_sim, n_trial, t_interval) - - except ValueError: - self.do_help('dipolemin') - - else: - self.do_help('dipolemin') - - def do_file(self, args): - """Attempts to open a new file of params - """ - if not args: - print self.file_input - elif os.path.isfile(args): - self.file_input = args - print "New file is:", self.file_input - else: - # try searching specifcally in param dir - f_tmp = os.path.join('param', args) - if os.path.isfile(f_tmp): - self.file_input = f_tmp - else: - print "Does not appear to exist" - return 0 - - # tab complete rules for file - def complete_file(self, text, line, j0, J): - return [item for item in self.paramfile_list if item.startswith(text)] - - def do_diff(self, args): - """Runs a diff on various data types - """ - pass - - def do_testls(self, args): - # file_list = fio.file_match('../param', '*.param') - print "dlist is:", self.dlist - print "datelist is:", self.datelist - print "expmts is:", self.expmts - - def do_expmts(self, args): - """Show list of experiments for active directory. - """ - try: - clidefs.prettyprint(self.ddata.expmt_groups) - except AttributeError: - self.do_help('expmts') - print "No active directory?" - - def do_vars(self, args): - """Show changed variables in loaded simulation and their values. vars comes from p_exp. Usage: - [s1] vars - """ - print "\nVars changed in this simulation:" - - # iterate through params and print them raw - for var in self.var_list: - print " %s: %s" % (var[0], var[1]) - - # also print experimental groups - print "\nExperimental groups:" - self.do_expmts('') - - # cheap newline - print "" - - # this is an old function obsolete for this project - def do_view(self, args): - """Views the changes in the .params file. Use like 'load' - but does not commit variables to workspace - """ - dcheck = os.path.join(self.dproj, self.ddate, args) - - if os.path.exists(dcheck): - # get a list of the .params files - sim_list = fio.gen_sim_list(dcheck) - expmts = gen_expmts(sim_list[0]) - var_list = changed_vars(sim_list) - - clidefs.prettyprint(sim_list) - clidefs.prettyprint(expmts) - for var in var_list: - print var[0]+":", var[1] - - def complete_view(self, text, line, j0, J): - """complete function for view - """ - if text: - x = [item for item in self.dlist if item.startswith(text)] - if x: - return x - else: - return 0 - else: - return self.dlist - - def do_list(self, args): - """Lists simulations on a given date - 'args' is a date - """ - if not args: - dcheck = os.path.join(self.dproj, self.ddate) - - else: - dcheck = os.path.join(self.dproj, args) - - if os.path.exists(dcheck): - self.__update_dlist() - - # dir_list = [name for name in os.listdir(dcheck) if os.path.isdir(os.path.join(dcheck, name))] - clidefs.prettyprint(self.dlist) - - else: - print "Cannot find directory" - return 0 - - def do_pngoptimize(self, args): - """Optimizes png figures based on current directory - """ - fio.pngoptimize(self.simpaths.dsim) - - def do_avgtrials(self, args): - """Averages raw data over all trials for each simulation. - Usage: - [s1] avgtrials - where is either dpl or spec - """ - if not args: - print "You did not specify whether to avgerage dpl or spec data. Try again." - - else: - datatype = args - clidefs.exec_avgtrials(self.ddata, datatype) - - def do_spec_regenerate(self, args): - """Regenerates spec data and saves it to proper expmt directories. Usage: - [s1] spec_regenerate {--f_max=80.} - """ - - # use __split_args() - l_opts = self.__split_args(args) - - # these are the opts for which we are looking - opts = { - 'f_max': None, - } - - # parse the opts - self.__check_args(opts, l_opts) - - # use exec_spec_regenerate to regenerate spec data - clidefs.exec_spec_regenerate(self.ddata, opts['f_max']) - # self.spec_results = clidefs.exec_spec_regenerate(self.ddata, opts['f_max']) - - def do_spec_stationary_avg(self, args): - """Averages spec power over time and plots freq vs power. Fn can act per expmt or over entire simulation. If maxpwr supplied as arg, also plots freq at which max avg pwr occurs v.s input freq - """ - if args == 'maxpwr': - clidefs.exec_spec_stationary_avg(self.ddata, self.dsim, maxpwr=1) - - else: - clidefs.exec_spec_stationary_avg(self.ddata, self.dsim, maxpwr=0) - - def do_spec_avg_stationary_avg(self, args): - """Performs time-averaged stationarity analysis on avg'ed spec data. - Sorry for the terrible name... - """ - # parse args - l_opts = self.__split_args(args) - - # "default" opts - opts = { - 'errorbars': None - } - - # parse opts - self.__check_args(opts, l_opts) - - clidefs.exec_spec_avg_stationary_avg(self.ddata, self.dsim, opts) - - def do_freqpwrwithhist(self, args): - clidefs.freqpwr_with_hist(self.ddata, self.dsim) - - def do_calc_dipole_avg(self, args): - """Calculates average dipole using dipolefn.calc_avgdpl_stimevoked: - Usage: [s1] calc_dipole_avg - """ - dipolefn.calc_avgdpl_stimevoked(self.ddata) - - def do_pdipole(self, args): - """Regenerates plots in given directory. Usage: - To run on current working directory and regenerate each individual plot: 'pdipole' - To run aggregates for all simulations (across all trials/conditions) in a directory: 'pdipole exp' - To run aggregates with lines marking evoked times, run: 'pdipole evoked' - """ - # temporary arg split - arg_tmp = args.split(' ') - - # list of acceptable runtypes - runtype_list = [ - 'exp', - 'exp2', - 'evoked', - 'evaligned', - 'avg', - 'grid', - ] - - # minimal checks in this function - # assume that no ylim argument was specified - if len(arg_tmp) == 1: - runtype = arg_tmp[0] - ylim = [] - - else: - # set the runtype to the first - if arg_tmp[0] in runtype_list: - runtype = arg_tmp[0] - - # get the list of optional args - arg_list = self.__split_args(args) - - # default values for various params - # i_ctrl = 0 - for opt, val in arg_list: - # currently not being used - if opt == 'i_ctrl': - i_ctrl = int(val) - - # assume the first arg is correct, split on that - # arg_ylim_tmp = args.split(runtype) - - # if len(arg_ylim_tmp) == 2: - # ylim_read = ast.literal_eval(arg_ylim_tmp[-1].strip()) - # ylim = ylim_read - - # else: - # ylim = [] - - if runtype == 'exp': - # run the avg dipole per experiment (across all trials/simulations) - # using simpaths (ddata) - dipolefn.pdipole_exp(self.ddata, ylim) - - elif runtype == 'exp2': - dipolefn.pdipole_exp2(self.ddata) - # dipolefn.pdipole_exp2(self.ddata, i_ctrl) - - elif runtype == 'evoked': - # add the evoked lines to the pdipole individual simulations - clidefs.exec_pdipole_evoked(self.ddata, ylim) - - elif runtype == 'evaligned': - dipolefn.pdipole_evoked_aligned(self.ddata) - - elif runtype == 'avg': - # plot average over all TRIALS of a param regime - # requires that avg dipole data exist - clidefs.exec_plotaverages(self.ddata, ylim) - - elif runtype == 'grid': - dipolefn.pdipole_grid(self.ddata) - - def do_replot(self, args): - """Regenerates plots in given directory. Usage: - Usage: replot --xlim=[0, 1000] --ylim=[0, 100] - xlim is a time interval - ylim is a frequency interval - """ - # preallocate variables so they always exist - # xmin = 0. - # xmax = 'tstop' - - # # Parse args if they exist - # if args: - # arg_list = [arg for arg in args.split('--') if arg is not ''] - - # # Assign value to above variables if the value exists as input - # for arg in arg_list: - # if arg.startswith('xmin'): - # xmin = float(arg.split('=')[-1]) - - # elif arg.startswith('xmax'): - # xmax = float(arg.split('=')[-1]) - - # else: - # print "Did not recognize argument %s. Not doing anything with it" % arg - - # # Check to ensure xmin less than xmax - # if xmin and xmax: - # if xmin > xmax: - # print "xmin greater than xmax. Defaulting to sim parameters" - # xmin = 0. - # xmax = 'tstop' - - dict_opts = self.__create_dict_from_args(args) - - # check for spec data, create it if didn't exist, and then run the plots - clidefs.exec_replot(self.ddata, dict_opts) - # clidefs.regenerate_plots(self.ddata, [xmin, xmax]) - - def do_addalphahist(self, args): - """Adds histogram of alpha feed input times to dpl and spec plots. Usage: - [s1] addalphahist {--xlim=[0, 1000] --ylim=[0, 100]} - xlim is a time interval - ylim is a frequency interval - """ - # # preallocate variables so they always exist - # xmin = 0. - # xmax = 'tstop' - - # # Parse args if they exist - # if args: - # arg_list = [arg for arg in args.split('--') if arg is not ''] - - # # Assign value to above variables if the value exists as input - # for arg in arg_list: - # if arg.startswith('xmin'): - # xmin = float(arg.split('=')[-1]) - - # elif arg.startswith('xmax'): - # xmax = float(arg.split('=')[-1]) - - # else: - # print "Did not recognize argument %s. Not doing anything with it" %arg - - # # Check to ensure xmin less than xmax - # if xmin and xmax: - # if xmin > xmax: - # print "xmin greater than xmax. Defaulting to sim parameters" - # xmin = 0. - # xmax = 'tstop' - - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_addalphahist(self.ddata, dict_opts) - # clidefs.exec_addalphahist(self.ddata, [xmin, xmax]) - - def do_aggregatespec(self, args): - """Creates aggregates all spec data with histograms into one massive fig. - Must supply column label and row label as --row_label:param --column_label:param" - row_label should be param that changes only over experiments - column_label should be a param that changes trial to trial - """ - arg_list = [arg for arg in args.split('--') if arg is not ''] - - # Parse args - for arg in arg_list: - if arg.startswith('row'): - row_label = arg.split(':')[-1] - - # See if a list is being passed in - if row_label.startswith('['): - row_label = arg.split('[')[-1].split(']')[0].split(', ') - - else: - row_label = arg.split(':')[-1].split(' ')[0] - - elif arg.startswith('column'): - column_label = arg.split(':')[-1].split(' ')[0] - - else: - print "Did not recongnize argument. Going to break now." - - clidefs.exec_aggregatespec(self.ddata, [row_label, column_label]) - - def do_plotaverages(self, args): - """Creates plots of averaged dipole or spec data. Automatically checks if data exists. Usage: - 'plotaverages' - """ - - clidefs.exec_plotaverages(self.ddata) - - def do_phaselock(self, args): - """Calculates phaselock values between dipole and inputs - """ - args_dict = self.__create_dict_from_args(args) - clidefs.exec_phaselock(self.ddata, args_dict) - - def do_epscompress(self, args): - """Runs the eps compress utils on the specified fig type (currently either spk or spec) - """ - for expmt_group in self.ddata.expmt_groups: - if args == 'figspk': - d_eps = self.ddata.dfig[expmt_group]['figspk'] - elif args == 'figspec': - d_eps = self.ddata.dfig[expmt_group]['figspec'] - - try: - fio.epscompress(d_eps, '.eps') - except UnboundLocalError: - print "oy, this is embarrassing." - - def do_psthgrid(self, args): - """Aggregate plot of psth - """ - ppsth.ppsth_grid(self.simpaths) - - # save currently fails when no dir is loaded - def do_save(self, args): - """Copies the entire current directory over to the cppub directory - """ - clidefs.exec_save(self.dproj, self.ddate, self.dsim) - - # currently doesn't work with mpi interfacing - # def do_runsim(self, args): - # """Run the simulation code - # """ - # try: - # cmd_list = [] - # cmd_list.append('mpiexec -n %i ./s1run.py %s' % (self.nprocs, self.file_input)) - - # for cmd in cmd_list: - # subprocess.call(cmd, shell=True) - - # except (KeyboardInterrupt): - # print "Caught a break" - - def do_hist(self, args): - """Print a list of commands that have been entered""" - print self._hist - - def do_pwd(self, args): - """Displays active dir_data""" - print self.dsim - - def do_ls(self, args): - """Displays active param list""" - clidefs.prettyprint(self.param_list) - - def do_show(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_show(self.ddata, dict_opts) - - def complete_show(self, text, line, j0, J): - """Completion function for show - """ - if text: - return [expmt for expmt in self.expmts if expmt.startswith(text)] - else: - return self.expmts - - def do_showf(self, args): - """Show frequency information from rate files - """ - vars = args.split(' in ') - expmt = vars[0] - n = int(vars[1]) - - if n < self.N_sims: - drates = os.path.join(self.dsim, expmt, 'rates') - ratefile_list = fio.file_match(drates, '*.rates') - - with open(ratefile_list[n]) as frates: - lines = (line.rstrip() for line in frates) - lines = [line for line in lines if line] - - clidefs.prettyprint(lines) - else: - print "In do_showf in cli: out of range?" - return 0 - - def complete_showf(self, text, line, j0, J): - """Completion function for showf - """ - if text: - return [expmt for expmt in self.expmts if expmt.startswith(text)] - else: - return self.expmts - - def do_Nsims(self, args): - """Show number of simulations in each 'experiment' - """ - print self.N_sims - - def do_pngv(self, args): - dict_opts = self.__create_dict_from_args(args) - clidefs.exec_pngv(self.ddata, dict_opts) - - def complete_pngv(self, text, line, j0, J): - if text: - return [expmt for expmt in self.expmts if expmt.startswith(text)] - else: - return self.expmts - - ## Command definitions to support Cmd object functionality ## - def do_exit(self, args): - """Exits from the console - """ - return -1 - - def do_EOF(self, args): - """Exit on system end of file character - """ - return self.do_exit(args) - - def do_shell(self, args): - """Pass command to a system shell when line begins with '!' - """ - os.system(args) - - def do_help(self, args): - """Get help on commands - 'help' or '?' with no arguments prints a list of commands for which help is available - 'help ' or '? ' gives help on - """ - ## The only reason to define this method is for the help text in the doc string - Cmd.do_help(self, args) - - ## Override methods in Cmd object ## - def preloop(self): - """Initialization before prompting user for commands. - Despite the claims in the Cmd documentaion, Cmd.preloop() is not a stub. - """ - Cmd.preloop(self) ## sets up command completion - self._hist = self.load_history() - self._locals = {} ## Initialize execution namespace for user - self._globals = {} - - def postloop(self): - """Take care of any unfinished business. - Despite the claims in the Cmd documentaion, Cmd.postloop() is not a stub. - """ - self.write_history() - Cmd.postloop(self) ## Clean up command completion - print "Exiting..." - - def precmd(self, line): - """ This method is called after the line has been input but before - it has been interpreted. If you want to modify the input line - before execution (for example, variable substitution) do it here. - """ - self._hist += [ line.strip() ] - return line - - def postcmd(self, stop, line): - """If you want to stop the console, return something that evaluates to true. - If you want to do some post command processing, do it here. - """ - return stop - - def emptyline(self): - """Do nothing on empty input line""" - pass - - def default(self, line): - """Called on an input line when the command prefix is not recognized. - In that case we execute the line as Python code. - """ - try: - exec(line) in self._locals, self._globals - except Exception, e: - print e.__class__, ":", e - - # Function to read the history file - def load_history(self): - with open(self.f_history) as f_in: - lines = (line.rstrip() for line in f_in) - lines = [line for line in lines if line] - - return lines - - def history_remove_dupes(self): - unique_set = set() - return [x for x in self._hist if x not in unique_set and not unique_set.add(x)] - - # function to write the history file - def write_history(self): - # first we will clean the list of dupes - unique_history = self.history_remove_dupes() - with open(self.f_history, 'w') as f_out: - for line in unique_history[-100:]: - f_out.write(line+'\n') diff --git a/clidefs.py b/clidefs.py deleted file mode 100644 index 71a1c23f9..000000000 --- a/clidefs.py +++ /dev/null @@ -1,1499 +0,0 @@ -# clidefs.py - these are all of the function defs for the cli -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: python3 compatibility) -# last major: (SL: minor) - -# Standard modules -import fnmatch, os, re, sys -import numpy as np -from scipy import stats -from multiprocessing import Pool -from subprocess import call -from glob import iglob -from time import time -import ast -import matplotlib.pyplot as plt -import matplotlib as mpl - -# local modules -import spikefn -import plotfn -import fileio as fio -import paramrw -import specfn -import pspec -import dipolefn -import axes_create as ac -import pmanu_gamma as pgamma -import subprocess - -# Returns length of any list -def number_of_sims(some_list): - return len(some_list) - -# Just a simple thing to print parts of a list -def prettyprint(lines): - for line in lines: - print line - -# gets a subdir list -def get_subdir_list(dcheck): - if os.path.exists(dcheck): - return [name for name in os.listdir(dcheck) if os.path.isdir(os.path.join(dcheck, name))] - - else: - return [] - -# generalized function for checking and assigning args -def args_check(dict_default, dict_check): - if len(dict_check): - keys_missing = [] - - # iterate through possible key vals in dict_check - for key, val in dict_check.items(): - # check to see if the possible keys are in dict_default - if key in dict_default.keys(): - # assign the key/val pair in place - # this operation acts IN PLACE on the supplied dict_default!! - # therefore, no return value necessary - try: - dict_default[key] = ast.literal_eval(val) - - except ValueError: - dict_default[key] = val - - else: - keys_missing.append(key) - - # if there are any keys missing - if keys_missing: - print "Options were not recognized: " - fio.prettyprint(keys_missing) - -def exec_pngv(ddata, dict_opts={}): - """Attempt to find the PNGs and open them - [aushnew] pngv {--run=0 --expmt_group='testing' --type='fig_spec'} - """ - file_viewer(ddata, dict_opts) - -# returns average spike data -def exec_spike_rates(ddata, opts): - # opts should be: - # opts_default = { - # expmt_group: 'something', - # celltype: 'L5_pyramidal', - # } - expmt_group = opts['expmt_group'] - celltype = opts['celltype'] - - list_f_spk = ddata.file_match(expmt_group, 'rawspk') - list_f_param = ddata.file_match(expmt_group, 'param') - - # note! this is NOT ignoring first 50 ms - for fspk, fparam in zip(list_f_spk, list_f_param): - s_all = spikefn.spikes_from_file(fparam, fspk) - _, p_dict = paramrw.read(fparam) - T = p_dict['tstop'] - - # check if the celltype is in s_all - if celltype in s_all.keys(): - s = s_all[celltype].spike_list - n_cells = len(s) - - # grab all the sp_counts - sp_counts = np.array([len(spikes_cell) for spikes_cell in s]) - - # calc mean and stdev - sp_count_mean = np.mean(sp_counts) - sp_count_stdev = np.std(sp_counts) - - # calc rate in Hz, assume T in ms - sp_rates = sp_counts * 1000. / T - sp_rate_mean = np.mean(sp_rates) - sp_rate_stdev = np.std(sp_rates) - - # direct - sp_rate = sp_count_mean * 1000. / T - - print "Sim No. %i, Trial %i, celltype is %s:" % (p_dict['Sim_No'], p_dict['Trial'], celltype) - print " spike count mean is: %4.3f" % sp_count_mean - print " spike count stdev is: %4.3f" % sp_count_stdev - print " spike rate over %4.3f ms is %4.3f Hz +/- %4.3f" % (T, sp_rate_mean, sp_rate_stdev) - print " spike rate over %4.3f ms is %4.3f Hz" % (T, sp_rate) - -def exec_welch_max(ddata, opts): - p = { - 'f_min': 0., - } - - args_check(p, opts) - - # assume first expmt_group for now - expmt_group = ddata.expmt_groups[0] - - # grab list of dipoles - list_dpl = ddata.file_match(expmt_group, 'rawdpl') - list_param = ddata.file_match(expmt_group, 'param') - - # iterate through dipoles - for fdpl, fparam in zip(list_dpl, list_param): - # grab the dt (needed for the Welch) - dt = paramrw.find_param(fparam, 'dt') - - # grab the dipole - dpl = dipolefn.Dipole(fdpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - # create empty pgram - pgram = dict.fromkeys(dpl.dpl) - pgram_max = dict.fromkeys(dpl.dpl) - - # perform stationary Welch, since we're not saving this data yet - for key in pgram.keys(): - pgram[key] = specfn.Welch(dpl.t, dpl.dpl[key], dt) - - # create a mask based on f min - fmask = (pgram[key].f > p['f_min']) - P_cut = pgram[key].P[fmask] - f_cut = pgram[key].f[fmask] - - p_max = np.max(P_cut) - f_max = f_cut[P_cut == p_max] - # p_max = np.max(pgram[key].P) - # f_max = pgram[key].f[pgram[key].P == p_max] - - # not clear why saving for now - pgram_max[key] = (f_max, p_max) - print "Max power for %s was %.3e at %4.2f Hz, with f min set to %4.2f" % (key, p_max, f_max, p['f_min']) - -# throwaway save method for now -# trial is currently undefined -# function is broken for N_trials > 1 -def exec_throwaway(ddata, opts): - p = { - 'n_sim': 0, - 'n_trial': 0, - } - args_check(p, opts) - - p_exp = paramrw.ExpParams(ddata.fparam) - N_trials = p_exp.N_trials - print opts, p - - if p['n_sim'] == -1: - for i in range(p_exp.N_sims): - if p['n_trial'] == -1: - for j in range(N_trials): - dipolefn.dpl_convert_and_save(ddata, i, j) - else: - j = p['n_trial'] - dipolefn.dpl_convert_and_save(ddata, i, j) - - else: - i = p['n_sim'] - if p['n_trial'] == -1: - for j in range(N_trials): - dipolefn.dpl_convert_and_save(ddata, i, j) - else: - j = p['n_trial'] - dipolefn.dpl_convert_and_save(ddata, i, j) - - # # take the ith sim, jth trial, do some stuff to it, resave it - # # only uses first expmt_group - # expmt_group = ddata.expmt_groups[0] - - # # need n_trials - # p_exp = paramrw.ExpParams(ddata.fparam) - # if not p_exp.N_trials: - # N_trials = 1 - # else: - # N_trials = p_exp.N_trials - - # # absolute number - # n = i*N_trials + j - - # # grab the correct files - # f_dpl = ddata.file_match(expmt_group, 'rawdpl')[n] - # f_param = ddata.file_match(expmt_group, 'param')[n] - - # # print ddata.sim_prefix, ddata.dsim - # f_name_short = '%s-%03d-T%02d-dpltest.txt' % (ddata.sim_prefix, i, j) - # f_name = os.path.join(ddata.dsim, expmt_group, f_name_short) - # print f_name - - # dpl = dipolefn.Dipole(f_dpl) - # dpl.baseline_renormalize(f_param) - # print "baseline renormalized" - - # dpl.convert_fAm_to_nAm() - # print "converted to nAm" - - # dpl.write(f_name) - -def exec_show(ddata, dict_opts): - dict_opts_default = { - 'run': 0, - 'trial': 0, - 'expmt_group': '', - 'key': 'changed', - 'var_list': [], - } - - # hack for now to get backward compatibility with this original function - var_list = dict_opts_default['var_list'] - - exclude_list = [ - 'sim_prefix', - 'N_trials', - 'Run_Date', - ] - - args_check(dict_opts_default, dict_opts) - if dict_opts_default['expmt_group'] not in ddata.expmt_groups: - # print "Warning: expmt_group %s not found" % dict_opts_default['expmt_group'] - dict_opts_default['expmt_group'] = ddata.expmt_groups[0] - - # output the expmt group used - print "expmt_group: %s" % dict_opts_default['expmt_group'] - - # find the params - p_exp = paramrw.ExpParams(ddata.fparam) - - if dict_opts_default['key'] == 'changed': - print "Showing changed ... \n" - # create a list - var_list = [val[0] for val in paramrw.changed_vars(ddata.fparam)] - - elif dict_opts_default['key'] in p_exp.keys(): - # create a list with just this element - var_list = [dict_opts_default['key']] - - else: - key_part = dict_opts_default['key'] - var_list = [key for key in p_exp.keys() if key_part in key] - - if not var_list: - print "Keys were not found by exec_show()" - return 0 - - # files - fprefix = ddata.trial_prefix_str % (dict_opts_default['run'], dict_opts_default['trial']) - fparam = ddata.create_filename(dict_opts_default['expmt_group'], 'param', fprefix) - - list_param = ddata.file_match(dict_opts_default['expmt_group'], 'param') - - if fparam in list_param: - # this version of read returns the gid dict as well ... - _, p = paramrw.read(fparam) - - # use var_list to print values - for key in var_list: - if key not in exclude_list: - try: - print '%s: %s' % (key, p[key]) - - except KeyError: - print "Value %s not found in file %s!" % (key, fparam) - -def exec_show_dpl_max(ddata, opts={}): - p = { - 'layer': 'L5', - 'n_sim': 0, - 'n_trial': 0, - } - args_check(p, opts) - - expmt_group = ddata.expmt_groups[0] - - n = p['n_sim'] + p['n_sim']*p['n_trial'] - - fdpl = ddata.file_match(expmt_group, 'rawdpl')[n] - fparam = ddata.file_match(expmt_group, 'param')[n] - - T = paramrw.find_param(fparam, 'tstop') - xlim = (50., T) - - dpl = dipolefn.Dipole(fdpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - # add this data to the dict for the string output mapping - p['dpl_max'] = dpl.lim(p['layer'], xlim)[1] - p['units'] = dpl.units - - print "The maximal value for the dipole is %(dpl_max)4.3f %(units)s for sim=%(n_sim)i, trial=%(n_trial)i in layer %(layer)s" % (p) - # print "The maximal value for the dipole is %4.3f %s for sim=%i, trial=%i" % (dpl_max, dpl.units, n_sim, n_trial) - -# calculates the mean dipole over a specified range -def exec_calc_dpl_mean(ddata, opts={}): - for expmt_group in ddata.expmt_groups: - list_fdpl = ddata.file_match(expmt_group, 'rawdpl') - - # order of l_dpl is same as list_fdpl - l_dpl = [dipolefn.Dipole(f) for f in list_fdpl] - - for dpl in l_dpl: - print dpl.mean_stationary(opts) - -# calculates the linear regression, shows values of slope (m) and int (b) -# and plots line to dipole fig (in place) -def exec_calc_dpl_regression(ddata, opts={}): - for expmt_group in ddata.expmt_groups: - list_fdpl = ddata.file_match(expmt_group, 'rawdpl') - list_figdpl = ddata.file_match(expmt_group, 'figdpl') - - # this is to overwrite the fig - for f, ffig in zip(list_fdpl, list_figdpl): - dipolefn.plinear_regression(ffig, f) - -def exec_pdipole_evoked(ddata, ylim=[]): - # runtype = 'parallel' - runtype = 'debug' - - expmt_group = ddata.expmt_groups[0] - - # grab just the first element of the dipole list - dpl_list = ddata.file_match(expmt_group, 'rawdpl') - param_list = ddata.file_match(expmt_group, 'param') - spk_list = ddata.file_match(expmt_group, 'rawspk') - - # fig dir will be that of the original dipole - dfig = ddata.dfig[expmt_group]['figdpl'] - - # first file names - f_dpl = dpl_list[0] - f_spk = spk_list[0] - f_param = param_list[0] - - if runtype == 'parallel': - pl = Pool() - for f_dpl, f_spk, f_param in zip(dpl_list, spk_list, param_list): - pl.apply_async(dipolefn.pdipole_evoked, (dfig, f_dpl, f_spk, f_param, ylim)) - - pl.close() - pl.join() - - elif runtype == 'debug': - for f_dpl, f_spk, f_param in zip(dpl_list, spk_list, param_list): - dipolefn.pdipole_evoked(dfig, f_dpl, f_spk, f_param, ylim) - -# timer function wrapper returns WALL CLOCK time (more or less) -def timer(fn, args): - t0 = time() - x = eval(fn + args) - t1 = time() - - print "%s took %4.4f s" % (fn, t1-t0) - - return x - -def exec_pcompare(ddata, cli_args): - vars = cli_args.split(" ") - - # find any expmt and just take the first one. (below) - expmt = [arg.split("=")[1] for arg in vars if arg.startswith("expmt")] - sim0 = int([arg.split("=")[1] for arg in vars if arg.startswith("sim0")][0]) - sim1 = int([arg.split("=")[1] for arg in vars if arg.startswith("sim1")][0]) - - sims = [sim0, sim1] - - labels = ['A. Control E$_g$-I$_s$', 'B. Increased E$_g$-I$_s$'] - - if expmt: - psum.pcompare2(ddata, sims, labels, [expmt[0], expmt[0]]) - else: - psum.pcompare2(ddata, sims, labels) - # print "not found" - -def exec_pcompare3(ddata, cli_args): - # the args will be the 3 sim numbers. - # these will be strings out of the split! - vars = cli_args.split(' ') - sim_no = int(vars[0]) - # expmt_last = int(vars[1]) - - psum.pcompare3(ddata, sim_no) - -# executes the function plotvar in psummary -# At some point, replace 'vars' with a non-standard variable name -def exec_plotvars(cli_args, ddata): - # split the cli args based on options - vars = cli_args.split(' --') - - # first part is always the first 2 options (required, no checks) - vars_to_plot = vars[0].split() - - # grab the experiment handle - # vars_expmt = [arg.split()[1] for arg in vars if arg.startswith('expmt')] - vars_opts = [arg.split()[1:] for arg in vars if arg.startswith('opts')] - - # just pass the first of these - if vars_opts: - psum.plotvars(ddata, vars_to_plot[0], vars_opts[0]) - # psum.plotvars(ddata, vars_to_plot[0], vars_to_plot[1], vars_opts[0]) - # else: - # run the plotvar function with the cli args - # psum.plotvars(ddata, vars_to_plot[0]) - # psum.plotvars(ddata, vars_to_plot[0], vars_to_plot[1]) - -def exec_pphase(ddata, args): - args_split = args.split(" ") - expmt = args_split[0] - N_sim = int(args_split[1]) - - N_bins = 20 - - psum.pphase(ddata, expmt, N_sim, N_bins) - -# do_phist -def exec_phist(ddata, args): - # somehow create these plots - args_split = args.split(" ") - N_sim = args_split[0] - N_bins = int(args_split[1]) - psum.pphase_hist(ddata, N_sim, N_bins) - -# find the spectral max over an interval, for a particular sim -def exec_specmax(ddata, opts): - p = { - 'expmt_group': '', - 'n_sim': 0, - 'n_trial': 0, - 't_interval': None, - 'f_interval': None, - 'f_sort': None, - # 't_interval': [0., -1], - # 'f_interval': [0., -1], - } - - args_check(p, opts) - - p_exp = paramrw.ExpParams(ddata.fparam) - # trial_prefix = p_exp.trial_prefix_str % (p['n_sim'], p['n_trial']) - - if not p['expmt_group']: - p['expmt_group'] = ddata.expmt_groups[0] - - # Get the associated dipole and spec file - fspec = ddata.return_specific_filename(p['expmt_group'], 'rawspec', p['n_sim'], p['n_trial']) - - # Load the spec data - spec = specfn.Spec(fspec) - - # get max data - data_max = spec.max('agg', p['t_interval'], p['f_interval'], p['f_sort']) - - if data_max: - print "Max power of %4.2e at f of %4.2f Hz at %4.3f ms" % (data_max['pwr'], data_max['f_at_max'], data_max['t_at_max']) - - # # data_max = specfn.specmax(fspec, p) - # data = specfn.read(fspec) - # print data.keys() - - # # grab the min and max f - # f_min, f_max = p['f_interval'] - - # # set f_max - # if f_max < 0: - # f_max = data['freq'][-1] - - # # create an f_mask for the bounds of f, inclusive - # f_mask = (data['freq']>=f_min) & (data['freq']<=f_max) - - # # do the same for t - # t_min, t_max = p['t_interval'] - # if t_max < 0: - # t_max = data['time'][-1] - - # t_mask = (data['time']>=t_min) & (data['time']<=t_max) - - # # use the masks truncate these appropriately - # TFR_key = 'TFR' - - # if p['layer'] in ('L2', 'L5'): - # TFR_key += '_%s' % p['layer'] - - # TFR_fcut = data[TFR_key][f_mask, :] - # # TFR_fcut = data['TFR'][f_mask, :] - # TFR_tfcut = TFR_fcut[:, t_mask] - - # f_fcut = data['freq'][f_mask] - # t_tcut = data['time'][t_mask] - - # # find the max power over this new range - # # the max_mask is for the entire TFR - # pwr_max = TFR_tfcut.max() - # max_mask = (TFR_tfcut==pwr_max) - - # # find the t and f at max - # # these are slightly crude and do not allow for the possibility of multiple maxes (rare?) - # t_at_max = t_tcut[max_mask.sum(axis=0)==1] - # f_at_max = f_fcut[max_mask.sum(axis=1)==1] - - # # friendly printout - # print "Max power of %4.2e at f of %4.2f Hz at %4.3f ms" % (pwr_max, f_at_max, t_at_max) - - # pd_at_max = 1000./f_at_max - # t_start = t_at_max - pd_at_max/2. - # t_end = t_at_max + pd_at_max/2. - - # print "Symmetric interval at %4.2f Hz (T=%4.3f ms) about %4.3f ms is (%4.3f, %4.3f)" % (f_at_max, pd_at_max, t_at_max, t_start, t_end) - - # # output structure - # data_max = { - # 'pwr': pwr_max, - # 't': t_at_max, - # 'f': f_at_max, - # } - -def exec_specmax_dpl_match(ddata, opts): - p = { - 'expmt_group': '', - 'n_sim': 0, - 'trials': [0, -1], - 't_interval': None, - 'f_interval': None, - 'f_sort': None, - } - - args_check(p, opts) - - # set expmt group - if not p['expmt_group']: - p['expmt_group'] = ddata.expmt_groups[0] - - # set directory to save fig in and check that it exists - dir_fig = os.path.join(ddata.dsim, p['expmt_group'], 'figint') - fio.dir_create(dir_fig) - - # if p['trials'][1] is -1, assume all trials are wanted - # 1 is subtracted from N_trials to be consistent with manual entry of trial range - if p['trials'][1] == -1: - p_exp = paramrw.ExpParams(ddata.fparam) - p['trials'][1] = p_exp.N_trials - 1 - - # Get spec, dpl, and param files - # Sorry for lack of readability - spec_list = [ddata.return_specific_filename(p['expmt_group'], 'rawspec', p['n_sim'], i) for i in range(p['trials'][0], p['trials'][1]+1)] - dpl_list = [ddata.return_specific_filename(p['expmt_group'], 'rawdpl', p['n_sim'], i) for i in range(p['trials'][0], p['trials'][1]+1)] - param_list = [ddata.return_specific_filename(p['expmt_group'], 'param', p['n_sim'], i) for i in range(p['trials'][0], p['trials'][1]+1)] - - # Get max spectral data - data_max_list = [] - - for fspec in spec_list: - spec = specfn.Spec(fspec) - data_max_list.append(spec.max('agg', p['t_interval'], p['f_interval'], p['f_sort'])) - - # create fig name - if p['f_sort']: - fname_short = "sim-%03i-T%03i-T%03d-sort-%i-%i" %(p['n_sim'], p['trials'][0], p['trials'][1], p['f_sort'][0], p['f_sort'][1]) - - else: - fname_short = "sim-%03i-T%03i-T%03i" %(p['n_sim'], p['trials'][0], p['trials'][1]) - - fname = os.path.join(dir_fig, fname_short) - - # plot time-series over proper intervals - dipolefn.plot_specmax_interval(fname, dpl_list, param_list, data_max_list) - -def exec_specmax_dpl_tmpl(ddata, opts): - p = { - 'expmt_group': '', - 'n_sim': 0, - 'trials': [0, -1], - 't_interval': None, - 'f_interval': None, - 'f_sort': None, - } - - args_check(p, opts) - - # set expmt group - if not p['expmt_group']: - p['expmt_group'] = ddata.expmt_groups[0] - - # set directory to save template in and check that it exists - dir_out = os.path.join(ddata.dsim, p['expmt_group'], 'tmpldpl') - fio.dir_create(dir_out) - - # if p['trials'][1] is -1, assume all trials are wanted - # 1 is subtracted from N_trials to be consistent with manual entry of trial range - if p['trials'][1] == -1: - p_exp = paramrw.ExpParams(ddata.fparam) - p['trials'][1] = p_exp.N_trials - 1 - - # Get spec, dpl, and param files - # Sorry for lack of readability - spec_list = [ddata.return_specific_filename(p['expmt_group'], 'rawspec', p['n_sim'], i) for i in range(p['trials'][0], p['trials'][1]+1)] - dpl_list = [ddata.return_specific_filename(p['expmt_group'], 'rawdpl', p['n_sim'], i) for i in range(p['trials'][0], p['trials'][1]+1)] - param_list = [ddata.return_specific_filename(p['expmt_group'], 'param', p['n_sim'], i) for i in range(p['trials'][0], p['trials'][1]+1)] - - # Get max spectral data - data_max_list = [] - - for fspec in spec_list: - spec = specfn.Spec(fspec) - data_max_list.append(spec.max('agg', p['t_interval'], p['f_interval'], p['f_sort'])) - - # Get time intervals of max spectral pwr - t_interval_list = [dmax['t_int'] for dmax in data_max_list if dmax is not None] - - # truncate dpl_list to include only sorted trials - # kind of crazy that this works. Just sayin'... - dpl_list = [fdpl for fdpl, dmax in zip(dpl_list, data_max_list) if dmax is not None] - - # create file name - if p['f_sort']: - fname_short = "sim-%03i-T%03i-T%03d-sort-%i-%i-tmpldpl.txt" %(p['n_sim'], p['trials'][0], p['trials'][1], p['f_sort'][0], p['f_sort'][1]) - - else: - fname_short = "sim-%03i-T%03i-T%03i-tmpldpl.txt" %(p['n_sim'], p['trials'][0], p['trials'][1]) - - fname = os.path.join(dir_out, fname_short) - - # Create dpl template - dipolefn.create_template(fname, dpl_list, param_list, t_interval_list) - -def exec_plot_dpl_tmpl(ddata, opts): - p = { - 'expmt_group': '', - } - - args_check(p, opts) - - # set expmt group - if not p['expmt_group']: - p['expmt_group'] = ddata.expmt_groups[0] - - # set directory to save template in and check that it exists - dir_out = os.path.join(ddata.dsim, p['expmt_group'], 'figtmpldpl') - fio.dir_create(dir_out) - - # get template dpl data - dpl_list = fio.file_match(os.path.join(ddata.dsim, p['expmt_group']), '-tmpldpl.txt') - - # create file name list - # prefix_list = [fdpl.split('/')[-1].split('-tmpldpl')[0] for fdpl in dpl_list] - # fname_list = [os.path.join(dir_out, prefix+'-tmpldpl.png') for prefix in prefix_list] - - plot_dict = { - 'xlim': None, - 'ylim': None, - } - - for fdpl in dpl_list: - print fdpl - dipolefn.pdipole(fdpl, dir_out, plot_dict) - -# search for the min in a dipole over specified interval -def exec_dipolemin(ddata, expmt_group, n_sim, n_trial, t_interval): - p_exp = paramrw.ExpParams(ddata.fparam) - trial_prefix = p_exp.trial_prefix_str % (n_sim, n_trial) - - # list of all the dipoles - dpl_list = ddata.file_match(expmt_group, 'rawdpl') - - # load the associated dipole file - # find the specific file - # assume just the first file - fdpl = [file for file in dpl_list if trial_prefix in file][0] - - data = np.loadtxt(open(fdpl, 'r')) - t_vec = data[:, 0] - data_dpl = data[:, 1] - - data_dpl_range = data_dpl[(t_vec >= t_interval[0]) & (t_vec <= t_interval[1])] - dpl_min_range = data_dpl_range.min() - t_min_range = t_vec[data_dpl == dpl_min_range] - - print "Minimum value over t range %s was %4.4f at %4.4f." % (str(t_interval), dpl_min_range, t_min_range) - -# averages raw dipole or raw spec over all trials -def exec_avgtrials(ddata, datatype): - # create the relevant key for the data - datakey = 'raw' + datatype - datakey_avg = 'avg' + datatype - - # assumes N_Trials are the same in both - p_exp = paramrw.ExpParams(ddata.fparam) - sim_prefix = p_exp.sim_prefix - N_trials = p_exp.N_trials - - # fix for N_trials=0 - if not N_trials: - N_trials = 1 - - # prefix strings - exp_prefix_str = p_exp.exp_prefix_str - trial_prefix_str = p_exp.trial_prefix_str - - # Averaging must be done per expmt - for expmt_group in ddata.expmt_groups: - ddatatype = ddata.dfig[expmt_group][datakey] - dparam = ddata.dfig[expmt_group]['param'] - - param_list = ddata.file_match(expmt_group, 'param') - rawdata_list = ddata.file_match(expmt_group, datakey) - - # if nothing in the raw data list, then generate it for spec - if datakey == 'rawspec': - if not len(rawdata_list): - # generate the data! - exec_spec_regenerate(ddata) - rawdata_list = ddata.file_match(expmt_group, datakey) - - # simple length check, but will proceed bluntly anyway. - # this will result in truncated lists, per zip function - if len(param_list) != len(rawdata_list): - print "warning, some weirdness detected in list length in exec_avgtrials. Check yo' lengths!" - - # number of unique simulations, per trial - # this had better be equivalent as an integer or a float! - N_unique = len(param_list) / N_trials - - # go through the unique simulations - for i in range(N_unique): - # fills in the correct int for the experimental prefix string formatter 'exp_prefix_str' - prefix_unique = exp_prefix_str % i - fprefix_long = os.path.join(ddatatype, prefix_unique) - fprefix_long_param = os.path.join(dparam, prefix_unique) - - # create the sublist of just these trials - unique_list = [rawdatafile for rawdatafile in rawdata_list if rawdatafile.startswith(fprefix_long)] - unique_param_list = [pfile for pfile in param_list if pfile.startswith(fprefix_long_param)] - - # one filename per unique - # length of the unique list is the number of trials for this sim, should match N_trials - fname_unique = ddata.create_filename(expmt_group, datakey_avg, prefix_unique) - - # Average data for each trial - # average dipole data - if datakey == 'rawdpl': - for f_dpl, f_param in zip(unique_list, unique_param_list): - dpl = dipolefn.Dipole(f_dpl) - # dpl = dipolefn.Dipole(f_dpl, f_param) - - # ah, this is required becaused the dpl *file* still contains the raw, un-normalized data - dpl.baseline_renormalize(f_param) - - # initialize and use x_dpl - if f_dpl is unique_list[0]: - # assume time vec stays the same throughout - t_vec = dpl.t - x_dpl_agg = dpl.dpl['agg'] - x_dpl_L2 = dpl.dpl['L2'] - x_dpl_L5 = dpl.dpl['L5'] - - else: - x_dpl_agg += dpl.dpl['agg'] - x_dpl_L2 += dpl.dpl['L2'] - x_dpl_L5 += dpl.dpl['L5'] - - # poor man's mean - x_dpl_agg /= len(unique_list) - x_dpl_L2 /= len(unique_list) - x_dpl_L5 /= len(unique_list) - - # write this data to the file - # np.savetxt(fname_unique, avg_data, '%5.4f') - with open(fname_unique, 'w') as f: - for t, x_agg, x_L2, x_L5 in zip(t_vec, x_dpl_agg, x_dpl_L2, x_dpl_L5): - f.write("%03.3f\t%5.4f\t%5.4f\t%5.4f\n" % (t, x_agg, x_L2, x_L5)) - - # average spec data - elif datakey == 'rawspec': - specfn.average(fname_unique, unique_list) - # # load TFR data into np array and avg by summing and dividing by n_trials - # data_for_avg = np.array([np.load(file)['TFR'] for file in unique_list]) - # spec_avg = data_for_avg.sum(axis=0)/data_for_avg.shape[0] - - # # load time and freq vectors from the first item on the list, assume all same - # timevec = np.load(unique_list[0])['time'] - # freqvec = np.load(unique_list[0])['freq'] - - # # save the aggregate info - # np.savez_compressed(fname_unique, time=timevec, freq=freqvec, TFR=spec_avg) - -# run the spectral analyses on the somatic current time series -def exec_spec_current(ddata, opts_in=None): - # p_exp = paramrw.ExpParams(ddata.fparam) - - opts = { - 'type': 'dpl_laminar', - 'f_max': 150., - 'save_data': 1, - 'runtype': 'parallel', - } - - if opts_in: - args_check(opts, opts_in) - - specfn.analysis_typespecific(ddata, opts) - -# this function can now use specfn.generate_missing_spec(ddata, f_max) -def exec_spec_regenerate(ddata, f_max=None): - # regenerate and save spec data - opts = { - 'type': 'dpl_laminar', - 'f_max': 60., - 'save_data': 1, - 'runtype': 'parallel', - } - - # set f_max if provided - if f_max: - opts['f_max'] = f_max - - specfn.analysis_typespecific(ddata, opts) - -# Time-averaged stationarity analysis - averages spec power over time and plots it -def exec_spec_stationary_avg(ddata, dsim, maxpwr): - - # Prompt user for type of analysis (per expmt or whole sim) - analysis_type = raw_input('Would you like analysis per expmt or for whole sim? (expmt or sim): ') - - fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - fparam_list = fio.file_match(ddata.dsim, '-param.txt') - # fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - # fparam_list = fio.file_match(ddata.dsim, '-param.txt') - - p_exp = paramrw.ExpParams(ddata.fparam) - key_types = p_exp.get_key_types() - - # If no saved spec results exist, redo spec analysis - if not fspec_list: - print "No saved spec data found. Performing spec analysis...", - exec_spec_regenerate(ddata) - fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - # spec_results = exec_spec_regenerate(ddata) - - print "now doing spec freq-pwr analysis" - - # perform time-averaged stationary analysis - # specpwr_results = [specfn.specpwr_stationary_avg(fspec) for fspec in fspec_list] - specpwr_results = [] - - for fspec in fspec_list: - spec = specfn.Spec(fspec) - specpwr_results.append(spec.stationary_avg()) - - # plot for whole simulation - if analysis_type == 'sim': - - file_name = os.path.join(dsim, 'specpwr.eps') - pspec.pspecpwr(file_name, specpwr_results, fparam_list, key_types) - - # if maxpwr plot indicated - if maxpwr: - f_name = os.path.join(dsim, 'maxpwr.png') - specfn.pmaxpwr(f_name, specpwr_results, fparam_list) - - # plot per expmt - if analysis_type == 'expmt': - for expmt_group in ddata.expmt_groups: - # create name for figure. Figure saved to expmt directory - file_name = os.path.join(dsim, expmt_group, 'specpwr.png') - - # compile list of freqpwr results and param pathways for expmt - partial_results_list = [result for result in specpwr_results if result['expmt']==expmt_group] - partial_fparam_list = [fparam for fparam in fparam_list if expmt_group in fparam] - - # plot results - pspec.pspecpwr(file_name, partial_results_list, partial_fparam_list, key_types) - - # if maxpwr plot indicated - if maxpwr: - f_name = os.path.join(dsim, expmt_group, 'maxpwr.png') - specfn.pmaxpwr(f_name, partial_results_list, partial_fparam_list) - -# Time-averaged Spectral-power analysis/plotting of avg spec data -def exec_spec_avg_stationary_avg(ddata, dsim, opts): - - # Prompt user for type of analysis (per expmt or whole sim) - analysis_type = raw_input('Would you like analysis per expmt or for whole sim? (expmt or sim): ') - - spec_results_avged = fio.file_match(ddata.dsim, '-specavg.npz') - fparam_list = fio.file_match(ddata.dsim, '-param.txt') - - p_exp = paramrw.ExpParams(ddata.fparam) - key_types = p_exp.get_key_types() - - # If no avg spec data found, generate it. - if not spec_results_avged: - exec_avgtrials(ddata, 'spec') - spec_results_avged = fio.file_match(ddata.dsim, '-specavg.npz') - - # perform time-averaged stationarity analysis - # specpwr_results = [specfn.specpwr_stationary_avg(dspec) for dspec in spec_results_avged] - specpwr_results = [] - - for fspec in spec_results_avged: - spec = specfn.Spec(fspec) - specpwr_results.append(spec.stationary_avg()) - - # create fparam list to match avg'ed data - N_trials = p_exp.N_trials - nums = np.arange(0, len(fparam_list), N_trials) - fparam_list = [fparam_list[num] for num in nums] - - # plot for whole simulation - if analysis_type == 'sim': - - # if error bars indicated - if opts['errorbars']: - # get raw (non avg'ed) spec data - raw_spec_data = fio.file_match(ddata.dsim, '-spec.npz') - - # perform freqpwr analysis on raw data - # raw_specpwr = [specfn.specpwr_stationary_avg(dspec)['p_avg'] for dspec in raw_spec_data] - raw_specpwr = [] - - for fspec in raw_spec_data: - spec = specfn.Spec(fspec) - raw_specpwr.append(spec.stationary_avg()['p_avg']) - - # calculate standard error - error_vec = specfn.calc_stderror(raw_specpwr) - - else: - error_vec = [] - - file_name = os.path.join(dsim, 'specpwr-avg.eps') - pspec.pspecpwr(file_name, specpwr_results, fparam_list, key_types, error_vec) - - # # if maxpwr plot indicated - # if maxpwr: - # f_name = os.path.join(dsim, 'maxpwr-avg.png') - # specfn.pmaxpwr(f_name, freqpwr_results_list, fparam_list) - - # plot per expmt - if analysis_type == 'expmt': - for expmt_group in ddata.expmt_groups: - # if error bars indicated - if opts['errorbars']: - # get exmpt group raw spec data - raw_spec_data = ddata.file_match(expmt_group, 'rawspec') - - # perform stationary analysis on raw data - raw_specpwr = [specfn.specpwr_stationary_avg(dspec)['p_avg'] for dspec in raw_spec_data] - - # calculate standard error - error_vec = specfn.calc_stderror(raw_specpwr) - - else: - error_vec = [] - - # create name for figure. Figure saved to expmt directory - file_name = os.path.join(dsim, expmt_group, 'specpwr-avg.png') - - # compile list of specpwr results and param pathways for expmt - partial_results_list = [result for result in specpwr_results if result['expmt']==expmt_group] - partial_fparam_list = [fparam for fparam in fparam_list if expmt_group in fparam] - - # plot results - pspec.pspecpwr(file_name, partial_results_list, partial_fparam_list, key_types, error_vec) - - # # if maxpwr plot indicated - # if maxpwr: - # f_name = os.path.join(dsim, expmt_group, 'maxpwr-avg.png') - # specfn.pmaxpwr(f_name, partial_results_list, partial_fparam_list) - -# Averages spec pwr over time and plots it with histogram of alpha feeds per simulation -# Currently not completed -def freqpwr_with_hist(ddata, dsim): - fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - spk_list = fio.file_match(ddata.dsim, '-spk.txt') - fparam_list = fio.file_match(ddata.dsim, '-param.txt') - - p_exp = paramrw.ExpParams(ddata.fparam) - key_types = p_exp.get_key_types() - - # If no save spec reslts exist, redo spec analysis - if not fspec_list: - print "No saved spec data found. Performing spec analysis...", - exec_spec_regenerate(ddata) - fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - # spec_results = exec_spec_regenerate(ddata) - - print "now doing spec freq-pwr analysis" - - # perform freqpwr analysis - freqpwr_results_list = [specfn.freqpwr_analysis(fspec) for fspec in fspec_list] - - # Plot - for freqpwr_result, f_spk, fparam in zip(freqpwr_results_list, spk_list, fparam_list): - gid_dict, p_dict = paramrw.read(fparam) - file_name = 'freqpwr.png' - - specfn.pfreqpwr_with_hist(file_name, freqpwr_result, f_spk, gid_dict, p_dict, key_types) - -# runs plotfn.pall *but* checks to make sure there are spec data -def exec_replot(ddata, opts): -# def regenerate_plots(ddata, xlim=[0, 'tstop']): - p = { - 'xlim': None, - 'ylim': None, - } - - args_check(p, opts) - - # recreate p_exp ... don't like this - # ** should be guaranteed to be identical ** - p_exp = paramrw.ExpParams(ddata.fparam) - - # grab the list of spec results that exists - # there is a method in SimulationPaths/ddata for this specifically, this should be deprecated - # fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - - # generate data if no spec exists here - if not fio.file_match(ddata.dsim, '-spec.npz'): - # if not fspec_list: - print "No saved spec data found. Performing spec anaylsis ... " - exec_spec_regenerate(ddata) - # spec_results = exec_spec_regenerate(ddata) - - # run our core pall plot - plotfn.pall(ddata, p_exp, p['xlim'], p['ylim']) - -# function to add alpha feed hists -def exec_addalphahist(ddata, opts): -# def exec_addalphahist(ddata, xlim=[0, 'tstop']): - p = { - 'xlim': None, - 'ylim': None, - } - - args_check(p, opts) - - p_exp = paramrw.ExpParams(ddata.fparam) - - # generate data if no spec exists here - if not fio.file_match(ddata.dsim, '-spec.npz'): - print "No saved spec data found. Performing spec anaylsis ... " - exec_spec_regenerate(ddata) - - plotfn.pdpl_pspec_with_hist(ddata, p_exp, p['xlim'], p['ylim']) - # plotfn.pdpl_pspec_with_hist(ddata, p_exp, spec_list, xlim) - -def exec_aggregatespec(ddata, labels): - p_exp = paramrw.ExpParams(ddata.fparam) - - fspec_list = fio.file_match(ddata.dsim, '-spec.npz') - - # generate data if no spec exists here - if not fspec_list: - print "No saved spec data found. Performing spec anaylsis ... " - exec_spec_regenerate(ddata) - - plotfn.aggregate_spec_with_hist(ddata, p_exp, labels) - -def exec_pgamma_spec_fig(): - pgamma.spec_fig() - -def exec_pgamma_spikephase(): - # the directory here is hardcoded for now, inside the function - pgamma.spikephase() - -def exec_pgamma_peaks(): - pgamma.peaks() - -def exec_pgamma_sub_examples(): - pgamma.sub_dist_examples() - -def exec_pgamma_sub_example2(): - pgamma.sub_dist_example2() - -def exec_phaselock(ddata, opts): - p = { - 't_interval': [50, 1000], - 'f_max': 60., - } - args_check(p, opts) - - # Do this per expmt group - for expmt_group in ddata.expmt_groups: - # Get paths to relevant files - list_dpl = ddata.file_match(expmt_group, 'rawdpl') - list_spk = ddata.file_match(expmt_group, 'rawspk') - list_param = ddata.file_match(expmt_group, 'param') - - avg_spec = ddata.file_match(expmt_group, 'avgspec')[0] - - tmp_array_dpl = [] - tmp_array_spk = [] - - for f_dpl, f_spk, f_param in zip(list_dpl, list_spk, list_param): - # load Dpl data, do stuff, and store it - print f_dpl - dpl = dipolefn.Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - t, dp = dpl.truncate_ext(p['t_interval'][0], p['t_interval'][1]) - dp = dp['agg'] - tmp_array_dpl.append(dp) - - # Load extinput data, do stuff, and store it - try: - extinput = spikefn.ExtInputs(f_spk, f_param) - except ValueError: - print("Error: could not load spike timings from %s" % f_spk) - return - - extinput.add_delay_times() - extinput.get_envelope(dpl.t, feed='dist', bins=150) - inputs, t = extinput.truncate_ext('env', p['t_interval']) - tmp_array_spk.append(inputs) - - # Convert tmp arrays (actually lists) to numpy nd arrays - array_dpl = np.array(tmp_array_dpl) - array_spk = np.array(tmp_array_spk) - - # Phase-locking analysis - phase = specfn.PhaseLock(array_dpl, array_spk, list_param[0], p['f_max']) - - fname_d = os.path.join(ddata.dsim, expmt_group, 'phaselock-%iHz.npz' %p['f_max']) - np.savez_compressed(fname_d, t=phase.data['t'], f=phase.data['f'], B=phase.data['B']) - - # Plotting - # Should be moved elsewhere - avg_dpl = np.mean(array_dpl, axis=0) - avg_spk = np.mean(array_spk, axis=0) - - f = ac.FigPhase() - - extent_xy = [t[0], t[-1], phase.data['f'][-1], 0] - pc1 = f.ax['phase'].imshow(phase.data['B'], extent=extent_xy, aspect='auto', origin='upper',cmap=plt.get_cmap('jet')) - pc1.set_clim([0, 1]) - cb1 = f.f.colorbar(pc1, ax=f.ax['phase']) - # cb1.set_clim([0, 1]) - - spec = specfn.Spec(avg_spec) - pc2 = spec.plot_TFR(f.ax['spec'], xlim=[t[0], t[-1]]) - pc2.set_clim([0, 3.8e-7]) - cb2 = f.f.colorbar(pc2, ax=f.ax['spec']) - # cb2.set_clim([0, 3.6e-7]) - - f.ax['dipole'].plot(t, avg_dpl) - f.ax['dipole'].set_xlim([t[0], t[-1]]) - f.ax['dipole'].set_ylim([-0.0015, 0.0015]) - - f.ax['input'].plot(t, avg_spk) - f.ax['input'].set_xlim([t[0], t[-1]]) - f.ax['input'].set_ylim([-1, 5]) - f.ax['input'].invert_yaxis() - - f.ax['phase'].set_xlabel('Time (ms)') - f.ax['phase'].set_ylabel('Frequency (Hz)') - - fname = os.path.join(ddata.dsim, expmt_group, 'phaselock-%iHz.png' %p['f_max']) - print fname - - f.savepng(fname) - -# runs the gamma plot for a comparison of the high frequency -def exec_pgamma_hf(ddata, opts): - p = { - 'xlim_window': [0., -1], - 'n_sim': 0, - 'n_trial': 0, - } - args_check(p, opts) - pgamma.hf(ddata, p['xlim_window'], p['n_sim'], p['n_trial']) - -def exec_pgamma_hf_epochs(ddata, opts): - p = {} - args_check(p, opts) - pgamma.hf_epochs(ddata) - -# comparison of all layers and aggregate data -def exec_pgamma_laminar(ddata): - pgamma.laminar(ddata) - -# comparison between a PING (ddata0) and a weak PING (ddata1) data set -def exec_pgamma_compare_ping(): - # def exec_pgamma_compare_ping(ddata0, ddata1, opts): - pgamma.compare_ping() - -# plot for gamma stdev on a given ddata -def exec_pgamma_stdev(ddata): - pgamma.pgamma_stdev(ddata) - -def exec_pgamma_prox_dist_new(ddata, opts): - p = { - 'f_max_welch': 80., - } - - args_check(p, opts) - pgamma.prox_dist_new(ddata, p) - -def exec_pgamma_stdev_new(ddata, opts): - p = { - 'f_max_welch': 80., - } - - args_check(p, opts) - pgamma.pgamma_stdev_new(ddata, p) - -# plot for gamma distal phase on a given ddata -def exec_pgamma_distal_phase(ddata, opts): - pgamma.pgamma_distal_phase(ddata, opts['spec0'], opts['spec1'], opts['spec2']) - -# plot data averaged over trials -# dipole and spec should be split up at some point (soon) -# ylim specified here is ONLY for the dipole -def exec_plotaverages(ddata, ylim=[]): - # runtype = 'parallel' - runtype = 'debug' - - # this is a qnd check to create the fig dir if it doesn't already exist - # backward compatibility check for sims that didn't auto-create these dirs - for expmt_group in ddata.expmt_groups: - dfig_avgdpl = ddata.dfig[expmt_group]['figavgdpl'] - dfig_avgspec = ddata.dfig[expmt_group]['figavgspec'] - - # create them if they did not previously exist - fio.dir_create(dfig_avgdpl) - fio.dir_create(dfig_avgspec) - - # presumably globally true information - p_exp = paramrw.ExpParams(ddata.fparam) - key_types = p_exp.get_key_types() - - # empty lists to be used/appended - dpl_list = [] - spec_list = [] - dfig_list = [] - dfig_dpl_list = [] - dfig_spec_list = [] - pdict_list = [] - - # by doing all file operations sequentially by expmt_group in this iteration - # trying to guarantee order better than before - for expmt_group in ddata.expmt_groups: - # print expmt_group, ddata.dfig[expmt_group] - - # avgdpl and avgspec data paths - # fio.file_match() returns lists sorted - # dpl_list_expmt is so i can iterate through them in a sec - dpl_list_expmt = fio.file_match(ddata.dfig[expmt_group]['avgdpl'], '-dplavg.txt') - dpl_list += dpl_list_expmt - spec_list += fio.file_match(ddata.dfig[expmt_group]['avgspec'], '-specavg.npz') - - # create redundant list of avg dipole dirs and avg spec dirs - # unique parts are expmt group names - # create one entry for each in dpl_list - dfig_list_expmt = [ddata.dfig[expmt_group] for path in dpl_list_expmt] - dfig_list += dfig_list_expmt - dfig_dpl_list += [dfig['figavgdpl'] for dfig in dfig_list_expmt] - dfig_spec_list += [dfig['figavgspec'] for dfig in dfig_list_expmt] - - # param list to match avg data lists - fparam_list = fio.fparam_match_minimal(ddata.dfig[expmt_group]['param'], p_exp) - pdict_list += [paramrw.read(f_param)[1] for f_param in fparam_list] - - if dpl_list: - # new input to dipolefn - pdipole_dict = { - 'xlim': None, - 'ylim': None, - # 'xmin': 0., - # 'xmax': None, - # 'ymin': None, - # 'ymax': None, - } - - # if there is a length, assume it's 2 (it should be!) - if len(ylim): - pdipole_dict['ymin'] = ylim[0] - pdipole_dict['ymax'] = ylim[1] - - if runtype == 'debug': - for f_dpl, f_param, dfig_dpl in zip(dpl_list, fparam_list, dfig_dpl_list): - dipolefn.pdipole(f_dpl, dfig_dpl, pdipole_dict, f_param, key_types) - - elif runtype == 'parallel': - pl = Pool() - for f_dpl, f_param, dfig_dpl in zip(dpl_list, fparam_list, dfig_dpl_list): - pl.apply_async(dipolefn.pdipole, (f_dpl, f_param, dfig_dpl, key_types, pdipole_dict)) - - pl.close() - pl.join() - - else: - print "No avg dipole data found." - return 0 - - # if avg spec data exists - if spec_list: - if runtype == 'debug': - for f_spec, f_dpl, f_param, dfig_spec, pdict in zip(spec_list, dpl_list, fparam_list, dfig_spec_list, pdict_list): - pspec.pspec_dpl(f_spec, f_dpl, dfig_spec, pdict, key_types, f_param=f_param) - - elif runtype == 'parallel': - pl = Pool() - for f_spec, f_dpl, dfig_spec, pdict in zip(spec_list, dpl_list, dfig_spec_list, pdict_list): - pl.apply_async(pspec.pspec_dpl, (f_spec, f_dpl, dfig_spec, pdict, key_types)) - - pl.close() - pl.join() - - else: - print "No averaged spec data found. Run avgtrials()." - return 0 - -# rsync command with excludetype input -def exec_sync(droot, server_remote, dsubdir, fshort_exclude='exclude_eps.txt'): - # make up the local exclude file name - # f_exclude = os.path.join(droot, 'exclude_eps.txt') - f_exclude = os.path.join(droot, fshort_exclude) - - # create remote and local directories, they should look similar - dremote = os.path.join(droot, dsubdir) - dlocal = os.path.join(droot, 'from_remote') - - # creat the rsync command - cmd_rsync = "rsync -ruv --exclude-from '%s' -e ssh %s:%s %s" % (f_exclude, server_remote, dremote, dlocal) - - call(cmd_rsync, shell=True) - -# save to cppub -def exec_save(dproj, ddate, dsim): - if fio.dir_check(dsim): - dsave_root = os.path.join(dproj, 'pub') - - # check to see if this dir exists or not, and create it if not - fio.dir_create(dsave_root) - - dsave_short = '%s_%s' % (ddate.split('/')[-1], dsim.split('/')[-1]) - dsave = os.path.join(dsave_root, dsave_short) - - # use fileio routine to non-destructively copy dir - fio.dir_copy(dsim, dsave) - - else: - print "Not sure I can find that directory." - return 1 - -# Creates a pdf from a file list and saves it generically to ddata -def pdf_create(ddata, fprefix, flist): - file_out = os.path.join(ddata, fprefix + '-summary.pdf') - - # create the beginning of the call to ghostscript - gscmd = 'gs -dNumRenderingThreads=8 -dBATCH -dNOPAUSE -sDEVICE=pdfwrite -sOutputFile=' + file_out + ' -f ' - - for file in flist: - gscmd += file + ' ' - - # print gscmd - call(gscmd, shell=True) - - return file_out - -# PDF Viewer -def view_pdf(pdffile): - if sys.platform.startswith('darwin'): - app_pdf = 'open -a skim ' - elif sys.platform.startswith('linux'): - app_pdf = 'evince ' - - call([app_pdf + pdffile + ' &'], shell=True) - -# PDF finder ... (this is starting to get unwieldy) -def find_pdfs(ddata, expmt): - if expmt == 'all': - # This is recursive - # find the ONE pdf in the root dir - # all refers to the aggregated pdf file - pdf_list = [f for f in iglob(os.path.join(ddata, '*.pdf'))] - - elif expmt == 'each': - # get each and every one of these (syntax matches below) - pdf_list = fio.file_match(ddata, '*.pdf') - else: - # do this non-recursively (i.e. just for this directory) - dexpmt = os.path.join(ddata, expmt, '*.pdf') - pdf_list = [f for f in iglob(dexpmt)] - - # Check the length of pdf_list - if len(pdf_list) > 3: - print "There are", len(pdf_list), "files here." - str_open = raw_input("Do you want to open them all? ") - else: - # just set to open the files if fewer than 3 - str_open = 'y' - - # now check for a yes and go - if str_open == 'y': - for file in pdf_list: - view_pdf(file) - else: - print "Okay, good call. Here's the consolation prize:\n" - prettyprint(pdf_list) - -# Cross-platform file viewing using eog or xee, cmd is pngv in cli.py -def view_img(dir_data, ext): - # platform and extension specific opening - if sys.platform.startswith('darwin'): - ext_img = '/*' + ext - app_img = 'open -a xee ' - - elif sys.platform.startswith('linux'): - if ext == 'png': - app_img = 'eog ' - elif ext == 'eps': - app_img = 'evince ' - ext_img = '/*' + ext + '&' - - call([app_img + os.path.join(dir_data, 'spec') + ext_img], shell=True) - -# Cross platform file viewing over all exmpts -def file_viewer(ddata, dict_opts): - opts_default = { - 'expmt_group': ddata.expmt_groups[0], - 'type': 'figspec', - 'run': 'all', - } - args_check(opts_default, dict_opts) - - # return a list of files by run - if opts_default['run'] == 'all': - flist = ddata.file_match(opts_default['expmt_group'], opts_default['type']) - - else: - flist = ddata.file_match_by_run(**opts_default) - - # sort the list in place - flist.sort() - - # create a list of files for the argument to the program - files_arg = ' '.join(flist) - - if sys.platform.startswith('darwin'): - app_img = 'open -a preview ' - subprocess.call([app_img + files_arg], shell=True) - - elif sys.platform.startswith('linux'): - app_img = 'eog ' - subprocess.call([app_img + files_arg + '&'], shell=True) - -# a really simple image viewer, views images in dimg -def png_viewer_simple(dimg): - list_fpng = fio.file_match(dimg, '*.png') - - # Create an empty file argument - files_arg = '' - for file in list_fpng: - files_arg += file + ' ' - - # uses xee - if sys.platform.startswith('darwin'): - app_img = 'open -a xee ' - call([app_img + files_arg], shell=True) - - # uses eye of gnome (eog) - elif sys.platform.startswith('linux'): - app_img = 'eog ' - call([app_img + files_arg + '&'], shell=True) diff --git a/conf.py b/conf.py deleted file mode 100644 index 009468e10..000000000 --- a/conf.py +++ /dev/null @@ -1,188 +0,0 @@ -from configparser import ConfigParser -import io -import pickle -import os -import sys -from fileio import safemkdir -from collections import OrderedDict - -try: - from StringIO import StringIO -except ImportError: - from io import StringIO - -# default config as string -def_config = """ -[run] -dorun = 1 -doquit = 1 -debug = 0 -testlfp = 0 -testlaminarlfp = 0 -nsgrun = 0 -[paths] -paramindir = param -homeout = 1 -[sim] -simf = run.py -paramf = param/default.param -[draw] -drawindivdpl = 1 -drawindivrast = 1 -fontsize = 0 -drawavgdpl = 0 -[tips] -tstop = Simulation duration; Evoked response simulations typically take 170 ms while ongoing rhythms are run for longer. -dt = Simulation timestep - shorter timesteps mean more accuracy but longer runtimes. -[opt] -decay_multiplier = 1.6 -""" - -# parameter used for optimization -class param: - def __init__ (self, origval, minval, maxval, bounded, var, bestval=None): - self.origval = origval - self.minval = minval - self.maxval = maxval - self.bounded = bounded - self.bestval = bestval - if var.count(',') > 0: self.var = var.split(',') - else: self.var = var - def __str__ (self): - sout = '' - for s in [self.var, self.minval, self.maxval, self.origval, self.bounded, self.bestval]: - sout += str(s) - sout += ' ' - return sout - def assignstr (self, val): # generates string for execution - if type(self.var) == list: - astr = '' - for var in self.var: astr += var + ' = ' + str(val) + ';' - return astr - else: - return self.var + ' = ' + str(val) - def inbounds (self,val): # check if value is within bounds - if not bounded: return True - return val >= self.minval and val <= self.maxval - # only return assignstr if val is within bounds - def checkassign (self,val): - if self.inbounds(val): - return self.assignstr(val) - else: - return None - -# write config file starting with defaults and new entries -# specified in section (sec) , option (opt), and value (val) -# saves to output filepath fn -def writeconf (fn,sec,opt,val): - conf = ConfigParser() - conf.readfp(io.BytesIO(def_config)) # start with defaults - # then change entries by user-specs - for i in range(len(sec)): conf.set(sec[i],opt[i],val[i]) - # write config file - with open(fn, 'wb') as cfile: conf.write(cfile) - -def str2bool (v): return v.lower() in ("true", "t", "1") - -# read config file -def readconf (fn="hnn.cfg",nohomeout=False): - config = ConfigParser() - config.optionxform = str - - with open(fn, 'r') as cfg_file: - cfg_txt = os.path.expandvars(cfg_file.read()) - - config.readfp(StringIO(cfg_txt)) - - def conffloat (base,var,defa): # defa is default value - val = defa - try: val=config.getfloat(base,var) - except: pass - return val - - def confint (base,var,defa): - val = defa - try: val=config.getint(base,var) - except: pass - return val - - def confstr (base,var,defa): - val = defa - try: val = config.get(base,var) - except: pass - return val - - def confbool (base,var,defa): - return str2bool(confstr(base,var,defa)) - - def readtips (d): - if not config.has_section('tips'): return None - ltips = config.options('tips') - for i,prm in enumerate(ltips): - d[prm] = config.get('tips',prm).strip() - - d = {} - - d['homeout'] = confint("paths","homeout",1) # whether user home directory for output - if nohomeout: d['homeout'] = 0 # override config file with commandline - - d['simf'] = confstr('sim','simf','run.py') - d['paramf'] = confstr('sim','paramf',os.path.join('param','default.param')) - - - # dbase - optional config setting to change base output directory - if config.has_option('paths','dbase'): - dbase = config.get('paths','dbase').strip() - if not safemkdir(dbase): sys.exit(1) # check existence of base hnn output dir - else: - if d['homeout']: # user home directory for output - if 'SYSTEM_USER_DIR' in os.environ: - dbase = os.path.join(os.environ["SYSTEM_USER_DIR"],'hnn_out') # user home directory - else: - dbase = os.path.join(os.path.expanduser('~'),'hnn_out') # user home directory - if not safemkdir(dbase): sys.exit(1) # check existence of base hnn output dir - else: # cwd for output - dbase = os.getcwd() # use os.getcwd instead for better compatability with NSG - - d['dbase'] = dbase - d['datdir'] = os.path.join(dbase,'data') # data output directory - d['paramoutdir'] = os.path.join(dbase, 'param') - d['paramindir'] = confstr('paths','paramindir','param') # this depends on hnn install location - d['dataf'] = confstr('paths','dataf','') - - for k in ['datdir', 'paramindir', 'paramoutdir']: # need these directories - if not safemkdir(d[k]): sys.exit(1) - - d['dorun'] = confint("run","dorun",1) - d['doquit'] = confint("run","doquit",1) - d['debug'] = confint("run","debug",0) - d['testlfp'] = confint("run","testlfp",0) - d['testlaminarlfp'] = confint("run","testlaminarlfp",0) - d['nsgrun'] = confint("run","nsgrun",0) - - d['drawindivdpl'] = confint("draw","drawindivdpl",1) - d['drawavgdpl'] = confint("draw","drawavgdpl",0) - d['drawindivrast'] = confint("draw","drawindivrast",1) - d['fontsize'] = confint("draw","fontsize",0) - - d['decay_multiplier'] = conffloat('opt','decay_multiplier',1.6) - - readtips(d) # read tooltips for parameters - - return d - -# determine config file name -def setfcfg (): - fcfg = "hnn.cfg" # default config file name - for i in range(len(sys.argv)): - if sys.argv[i].endswith(".cfg") and os.path.exists(sys.argv[i]): - fcfg = sys.argv[i] - # print("hnn config file is " , fcfg) - return fcfg - -fcfg = setfcfg() # config file name -nohomeout = False -for i in range(len(sys.argv)): # override homeout option through commandline flag - if sys.argv[i] == '-nohomeout' or sys.argv[i] == 'nohomeout': nohomeout = True -dconf = readconf(fcfg,nohomeout) - diff --git a/ctune.py b/ctune.py deleted file mode 100644 index 42ce68c04..000000000 --- a/ctune.py +++ /dev/null @@ -1,361 +0,0 @@ -from math import log, exp - -""" -from neuron import h -# h.load_file("stdrun.hoc") -import numpy -from pylab import * -from time import time, clock -import os -from conf import dconf -import pickle - -dprm = dconf['params'] -sampr = dconf['sampr'] # 10KHz sampling rate in txt,npy file (data/15jun12_BS0284_subthreshandspikes_A0.npy) - -vtime = h.Vector() -vtime.record(h._ref_t) - -tinit = 0.0 -tstop = h.tstop - -gtmp=h.Vector() - -# -def myrun (reconfig=True,inj=0.0,prtime=False): - if reconfig: safereconfig() # makes sure params set within cell - stim.amp = inj - if prtime: clockStart = time() - # h.run() - if prtime: - clockEnd = time() - print('\nsim runtime:',str(round(clockEnd-clockStart,2)),'secs') - -y = h.Vector() -drawOut = False - -# -def readdat (sampr=10e3): - dat = numpy.load(dconf['evolts']) #numpy.load('data/15jun12_BS0284_subthreshandspikes_A0.npy') - etime = numpy.linspace(0,dat.shape[0]*1e3/sampr,dat.shape[0]) - return dat,etime - -dat,etime = readdat(sampr) # dim 1 is voltage - -intert = 3000 # 3000 ms in between clamps -offt = 500 # 500 ms before start of first clamp -durt = 1000 # 1000 ms current clamp -padt = 500 # pad around clamp - -# -def getindices (tdx): - sidx = sampr*(offt/1e3+tdx*(intert+durt)/1e3) - sampr * padt / 1e3 - eidx = sidx + durt * sampr / 1e3 + 2 * sampr * padt / 1e3 - return sidx,eidx - -# -def cuttrace (dat,tdx): - sidx,eidx = getindices(tdx) - if catdat: - return dat[sidx:eidx,1], etime[sidx:eidx] - else: - return dat[:,tdx],etime - -Vector = h.Vector -Iexp = lstimamp = numpy.load(dconf['lstimamp']) # [-0.15+j*0.05 for j in xrange(nstims_fi)] -nstims_fi = len(Iexp) # was 7 -alltrace = [i for i in xrange(nstims_fi)] -targSpikes = numpy.load(dconf['spiket']) # just using this for spike frequency - not timing!! -Fexp = [len(arr) for arr in targSpikes] # assumes 1 s stimulus duration -ltracedxsubth = [i for i in xrange(len(Fexp)) if Fexp[i] <= 0.0 and Iexp[i] != 0.0] -ltrace=ltracedxsubth - -def issubth (tdx): return ltracedxsubth.count(tdx) > 0 -def issuperth (tdx): return not issubth(tdx) - -# simple normalization - with a maximum cap -def getnormval (val,maxval,scale=1.0): - if val > maxval: return scale - return scale * val / maxval - - -# interpolate voltage recorded in simulation to a fixed grid (dt millisecond spacing) -# output time,voltage is returned -def interpvolt (tsrc=vtime,vsrc=vsoma,dt=0.1,tshift=tinit,tend=tstop): - tdest = h.Vector(); tdest.indgen(tshift,tend,dt) - vdest = h.Vector(); vdest.interpolate(tdest,tsrc,vsrc) - tdest.sub(tshift) - return tdest, vdest - -tracedx = 0 # which trace to fit (trace index) - -# -def plotinterp (vtime,vval,clr): - it,ival = interpvolt(vtime,vval,1e3/sampr) - plot(it.as_numpy(),ival.as_numpy(),clr) - -# -def voltcompare (tdx,interponly=True,dcurr=None,xl=None): - if dcurr is not None: subplot(2,1,1) - dd,tt = cuttrace(dat,tdx) - tt = linspace(0,tt[-1]-tt[0],len(tt)) - plot(tt,dd,'b') - if not interponly: plot(vtime.as_numpy(),vsoma.as_numpy(),'r') - it,iv = interpvolt(vtime,vsoma,1e3/sampr) - plot(it.as_numpy(),iv.as_numpy(),'r') - legend(['experiment','simulation'],loc='best') - xlabel('Time (ms)',fontsize=16); ylabel('Vm',fontsize=16); - if xl is not None: xlim(xl) - if dcurr is not None: - subplot(2,1,2) - plotinterp(vtime,dcurr['ina'],'r') - plotinterp(vtime,dcurr['ik'],'b') - plotinterp(vtime,dcurr['ica'],'g') - plotinterp(vtime,dcurr['cai'],'y') - plotinterp(vtime,dcurr['ih'],'k') - legend(['ina','ik','ica','cai','ih'],loc='best') - if xl is not None: xlim(xl) - -prtime = True # print simulation duration? - -lparam = [p for p in dprm.values()]; - -def lparamindex (lp,s): - for i,p in enumerate(lp): - if p.var == s: return i - return -1 - -# -def prmnames (): return [prm.var for prm in lparam] - -# clamps nval (which is between 0,1) to valid param range -def clampval (prm, nval): - if nval < 0.0: return prm.minval - elif nval > 1.0: return prm.maxval - else: return prm.minval + (prm.maxval - prm.minval) * nval - -# -def clampvals (vec,lparam): return [clampval(prm,x) for prm,x in zip(lparam,vec)] -""" - -# exponentiates value -def expval (prm, val): - if prm.minval > 0: return exp(val) - elif prm.maxval < 0: return -exp(val) - else: return val - -# -def expvals (vec,lparam): return [expval(prm,x) for prm,x in zip(lparam,vec)] - -# -def logval (prm, val): - if prm.minval > 0: return log(val) - elif prm.maxval < 0: return log(-val) - else: return val - -# -def logvals (vec,lparam): return [logval(prm,x) for prm,x in zip(lparam,vec)] - - -""" -# -def assignparams (vparam,lparam,useExp=False): - if useExp: - for prm,val in zip(lparam,expvals(vparam,lparam)): # set parameters - exec(prm.assignstr(val)) - else: - for prm,val in zip(lparam,vparam): # set parameters - exec(prm.assignstr(val)) - -# -def assignrow (nqp, row): - if row < 0 or row >= nqp.v[0].size(): return None - nprm = int(nqp.m[0]) - 2 # -2 for idx,err - vprm = [] - for col in xrange(nprm): vprm.append(nqp.v[col].x[row]) - assignparams(vprm,lparam) - safereconfig() - return vprm - -# -def printparams (vparam,lparam,useExp=False): - if useExp: - for prm,val in zip(lparam,expvals(vparam,lparam)): print(prm.var, ' = ' , val) # set parameters - else: - for prm,val in zip(lparam,vparam): print(prm.var, ' = ' , val) # set parameters - -myerrfunc = None # error function - -# create an empty NQS with parameter and error columns -def makeprmnq (): - lp = prmnames() - nqp = h.NQS() - for s in lp: - if type(s) == list: - nqp.resize(s[0]) - else: - nqp.resize(s) - nqp.resize('idx'); nqp.resize('err'); nqp.clear(1e3) - return nqp - -# append parameter values and error to the NQS -def appendprmnq (nqp,vprm,err): - for i,x in enumerate(vprm): nqp.v[i].append(x) - sz = nqp.v[0].size() - nqp.getcol('idx').append(nqp.v[0].size()-1) - nqp.getcol('err').append(err) - -nqp = makeprmnq() - -# -def traceerr (): - toterr = 0.0 # total error across traces - for tracedx in ltrace: - print('stim.amp is ', lstimamp[tracedx]) - myrun(reconfig=False,inj=lstimamp[tracedx]) # - if drawOut: voltcompare(tracedx) - err = myerrfunc(tracedx) - print('err is ' , round(err,6)) - toterr += err - return toterr - -# errwrap - assigns params (xp are param values), evaluates and returns error (uses traceerr) -def errwrap (xp): - assignparams(xp,lparam,useExp=False) - printparams(xp,lparam,useExp=False) - safereconfig() - toterr = traceerr() - print('toterr is ' , toterr) - return toterr - -# optimization run - for an individual set of params specified in vparam -# NB: vparam contains the log of actual param values, & the meaning of params is specified in global lparam -def optrun (vparam): - if prtime: clockStart = time() - global tracedx, ltrace - for prm,val in zip(lparam,expvals(vparam,lparam)): # set parameters - if val >= prm.minval and val <= prm.maxval: - exec(prm.assignstr(val)) - else: - print(val, 'out of bounds for ' , prm.var, prm.minval, prm.maxval) - appendprmnq(nqp,expvals(vparam,lparam),1e9) - return 1e9 - if type(vparam)==list: print('set params:', vparam) - else: print('set params:', vparam.as_numpy()) - safereconfig() # make sure parameters are set in cell - toterr = traceerr() # total error across traces - if prtime: - clockEnd = time() - print('\nsim runtime:',str(round(clockEnd-clockStart,2)),'secs') - print('toterr is ' , round(toterr/len(ltrace),6)) - appendprmnq(nqp,expvals(vparam,lparam),toterr / len(ltrace)) - return toterr / len(ltrace) # average - -# run sims specified in ltrace and plot comparison of voltages -def voltcomprun (ltrace=None,prtime=False): - if ltrace is None: ltrace = ltracedxsubth - for tracedx in ltrace: - myrun(reconfig=False,inj=lstimamp[tracedx],prtime=prtime) - voltcompare(tracedx) - -# mean squared error of voltage -lvoltwin = [] # can use to specify time ranges for volterr -lvoltscale = [] # can use to scale errors (matches to lvoltwin indices) - -# -def volterr (tdx): - it,iv = interpvolt(vtime,vsoma,1e3/sampr) - dd,tt = cuttrace(dat,tdx) - npt = len(dd) - if it.size() > npt: it.resize(npt) - err = 0; ivnp = iv.as_numpy() - if len(lvoltwin) > 0: - if len(lvoltscale) > 0: - npt = 0; - for voltwin,fctr in zip(lvoltwin,lscale): - sidx,eidx = int(voltwin[0]*sampr/1e3),int(voltwin[1]*sampr/1e3) - npt += (eidx-sidx+1) - for idx in xrange(sidx,eidx+1,1): err += fctr * (ivnp[idx] - dd[idx])**2 - else: - npt = 0; - for voltwin in lvoltwin: - sidx,eidx = int(voltwin[0]*sampr/1e3),int(voltwin[1]*sampr/1e3) - npt += (eidx-sidx+1) - for idx in xrange(sidx,eidx+1,1): err += (ivnp[idx] - dd[idx])**2 - else: - for v1,v2 in zip(ivnp,dd): err += (v1-v2)**2 - return sqrt(err/npt) - -# scaled error, scale individual functions, then combine -useV = False; -scaleV = zeros((len(lstimamp),)) - -myerrfunc = volterr - -# randomized optimization - search random points in param space -def randopt (lparam,nstep,errfunc,saveevery=0,fout=None): - global myerrfunc, nqp - myerrfunc = errfunc - for i in xrange(nstep): - print('step ' , i+1 , ' of ' , nstep) - vparam = [p.minval + random.uniform() * (p.maxval-p.minval) for p in lparam] - vplog = [logval(p,x) for p,x in zip(lparam,vparam)] - optrun(vplog) - if fout is not None and saveevery > 0 and i%saveevery==0: nqp.sv(fout) - if fout is not None: nqp.sv(fout) - -# performs praxis optimization using specified params and error function (errfunc) -def praxismatch (vparam,nstep,tol,stepsz,errfunc): - global myerrfunc, nqp - h.nqsdel(nqp) - nqp = makeprmnq() - myerrfunc = errfunc - print('using these traces:', ltrace) - h.attr_praxis(tol, stepsz, 3) - h.stop_praxis(nstep) # - return h.fit_praxis(optrun, vparam) - -# use praxis to match voltage traces -def voltmatch (vparam,nstep=10,tol=0.001,stepsz=0.5): - global tstop - if len(lvoltwin) > 0: - tstop = tinit + amax(lvoltwin) - print('reset tstop to ' , tstop) - return praxismatch(vparam,nstep,tol,stepsz,volterr) - -# get the original param values (stored in lparam) -def getparamorig (): - vparam = h.Vector() - for p in lparam: vparam.append( logval(p,p.origval) ) - return vparam - -# get random param values (from the set stored in lparam) -def getparamrand (seed): - rdm = h.Random() - rdm.ACG(seed) - vparam = h.Vector() - for p in lparam: vparam.append( logval(p, rdm.uniform(p.minval,p.maxval)) ) - return vparam - -# get 'best' param values found (from opt) -def getparambest (): - vparam = h.Vector() - for p in lparam: vparam.append( logval(p,p.bestval) ) - return vparam - -def runsaveopt (): - global ltrace,nqp,vparam - vparam = getparambest(); - assignparams(vparam,lparam,useExp=True); - safereconfig(); - # lvoltwin = [[495.0,750.0]] - ltrace=ltracedxsubth - voltmatch(vparam,nstep=dconf['nstep'],tol=dconf['tol']); - nqp.sv(dconf['nqp']) - dconf['vparam'] = vparam.to_python() # output parameter values - dconf['lparam'] = lparam - pickle.dump(dconf,open(dconf['dout'],'w')) # save everything - -if __name__ == '__main__': - if dconf['runopt']: runsaveopt() -""" diff --git a/currentfn.py b/currentfn.py deleted file mode 100644 index a484d72a5..000000000 --- a/currentfn.py +++ /dev/null @@ -1,47 +0,0 @@ -# currentfn.py - current-based analysis functions -# -# v 1.8.22 -# rev 2013-11-19 (SL: added simple convert function) -# last major: (SL: added layers for plot to axis command) - -import numpy as np - -class SynapticCurrent(): - def __init__(self, fcurrent): - self.__parse_f(fcurrent) - - # parses the input file - def __parse_f(self, fcurrent): - x = np.loadtxt(open(fcurrent, 'r')) - self.t = x[:, 0] - - # this really should be a dictionary - self.I_soma_L2Pyr = x[:, 1] - self.I_soma_L5Pyr = x[:, 2] - self.units = 'nA' - - # ext fn to convert to uA - def convert_nA_to_uA(self): - self.I_soma_L2Pyr *= 1e-3 - self.I_soma_L5Pyr *= 1e-3 - self.units = 'uA' - - # external plot function - def plot_to_axis(self, a, layer=None): - # layer=None is redundant with L5Pyr, but it might be temporary - if layer is None: - a.plot(self.t, -self.I_soma_L5Pyr) - - elif layer is 'L2': - a.plot(self.t, -self.I_soma_L2Pyr) - - elif layer is 'L5': - a.plot(self.t, -self.I_soma_L5Pyr) - - # set the xlim - a.set_xlim((50., self.t[-1])) - -# external function to use SynapticCurrent() and plot to axis a -def pcurrent(a, fcurrent): - I_syn = SynapticCurrent(fcurrent) - I_syn.plot_to_axis(a) diff --git a/dipolefn.py b/dipolefn.py deleted file mode 100644 index c398f7abc..000000000 --- a/dipolefn.py +++ /dev/null @@ -1,1123 +0,0 @@ -# dipolefn.py - dipole-based analysis functions -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: itertools and return data dir) -# last major: (SL: toward python3) -import fileio as fio -import numpy as np -import ast -import os -import paramrw -import spikefn -import specfn -import matplotlib.pyplot as plt -import axes_create as ac -from math import ceil -from filt import boxfilt, hammfilt, emptyfilt - -# class Dipole() is for a single set of f_dpl and f_param -class Dipole(): - def __init__(self, f_dpl): # fix to allow init from data in memory (not disk) - """ some usage: dpl = Dipole(file_dipole, file_param) - this gives dpl.t and dpl.dpl - """ - self.units = None - self.N = None - self.__parse_f(f_dpl) - - # opens the file and sets units - def __parse_f(self, f_dpl): - x = np.loadtxt(open(f_dpl, 'r')) - # better implemented as a dict - self.t = x[:, 0] - self.dpl = { - 'agg': x[:, 1], - 'L2': x[:, 2], - 'L5': x[:, 3], - } - self.N = self.dpl['agg'].shape[-1] - # string that holds the units - self.units = 'fAm' - - # truncate to a length and save here - def truncate(self, t0, T): - """ this is independent of the other stuff - moved to an external function so as to not disturb the delicate genius of this object - """ - self.t, self.dpl = self.truncate_ext(t0, T) - - # just return the values, do not modify the class internally - def truncate_ext(self, t0, T): - # only do this if the limits make sense - if (t0 >= self.t[0]) & (T <= self.t[-1]): - dpl_truncated = dict.fromkeys(self.dpl) - # do this for each dpl - for key in self.dpl.keys(): - dpl_truncated[key] = self.dpl[key][(self.t >= t0) & (self.t <= T)] - t_truncated = self.t[(self.t >= t0) & (self.t <= T)] - return t_truncated, dpl_truncated - - # conversion from fAm to nAm - def convert_fAm_to_nAm (self): - """ must be run after baseline_renormalization() - """ - for key in self.dpl.keys(): self.dpl[key] *= 1e-6 - # change the units string - self.units = 'nAm' - - def scale (self, fctr): - for key in self.dpl.keys(): self.dpl[key] *= fctr - return fctr - - def smooth (self, winsz): - if winsz <= 1: return - #for key in self.dpl.keys(): self.dpl[key] = boxfilt(self.dpl[key],winsz) - for key in self.dpl.keys(): self.dpl[key] = hammfilt(self.dpl[key],winsz) - - # average stationary dipole over a time window - def mean_stationary(self, opts_input={}): - # opts is default AND input to below, can be modified by opts_input - opts = { - 't0': 50., - 'tstop': self.t[-1], - 'layer': 'agg', - } - # attempt to override the keys in opts - for key in opts_input.keys(): - # check for each of the keys in opts - if key in opts.keys(): - # special rule for tstop - if key == 'tstop': - # if value in tstop is -1, then use end to T - if opts_input[key] == -1: - opts[key] = self.t[-1] - else: - opts[key] = opts_input[key] - # check for layer in keys - if opts['layer'] in self.dpl.keys(): - # get the dipole that matches the xlim - x_dpl = self.dpl[opts['layer']][(self.t > opts['t0']) & (self.t < opts['tstop'])] - # directly return the average - return np.mean(x_dpl, axis=0) - else: - print("Layer not found. Try one of %s" % self.dpl.keys()) - - # finds the max value within a specified xlim - # def max(self, layer, xlim): - def lim(self, layer, xlim): - # better implemented as a dict - if layer is None: - dpl_tmp = self.dpl['agg'] - elif layer in self.dpl.keys(): - dpl_tmp = self.dpl[layer] - # set xmin and xmax - if xlim is None: - xmin = self.t[0] - xmax = self.t[-1] - else: - xmin, xmax = xlim - if xmin < 0.: xmin = 0. - if xmax < 0.: xmax = self.f[-1] - dpl_tmp = dpl_tmp[(self.t > xmin) & (self.t < xmax)] - return (np.min(dpl_tmp), np.max(dpl_tmp)) - - # simple layer-specific plot function - def plot(self, ax, xlim, layer='agg'): - # plot the whole thing and just change the xlim and the ylim - # if layer is None: - # ax.plot(self.t, self.dpl['agg']) - # ymax = self.max(None, xlim) - # ylim = (-ymax, ymax) - # ax.set_ylim(ylim) - if layer in self.dpl.keys(): - ax.plot(self.t, self.dpl[layer]) - ylim = self.lim(layer, xlim) - # force ymax to be something sane - # commenting this out for now, but - # we can change if absolutely necessary. - # ax.set_ylim(top=ymax*1.2) - # set the lims here, as a default - ax.set_ylim(ylim) - ax.set_xlim(xlim) - else: - print("raise some error") - return ax.get_xlim() - - # ext function to renormalize - # this function changes in place but does NOT write the new values to the file - def baseline_renormalize(self, f_param): - # only baseline renormalize if the units are fAm - if self.units == 'fAm': - N_pyr_x = paramrw.find_param(f_param, 'N_pyr_x') - N_pyr_y = paramrw.find_param(f_param, 'N_pyr_y') - # N_pyr cells in grid. This is PER LAYER - N_pyr = N_pyr_x * N_pyr_y - # dipole offset calculation: increasing number of pyr cells (L2 and L5, simultaneously) - # with no inputs resulted in an aggregate dipole over the interval [50., 1000.] ms that - # eventually plateaus at -48 fAm. The range over this interval is something like 3 fAm - # so the resultant correction is here, per dipole - # dpl_offset = N_pyr * 50.207 - dpl_offset = { - # these values will be subtracted - 'L2': N_pyr * 0.0443, - 'L5': N_pyr * -49.0502 - # 'L5': N_pyr * -48.3642, - # will be calculated next, this is a placeholder - # 'agg': None, - } - # L2 dipole offset can be roughly baseline shifted over the entire range of t - self.dpl['L2'] -= dpl_offset['L2'] - # L5 dipole offset should be different for interval [50., 500.] and then it can be offset - # slope (m) and intercept (b) params for L5 dipole offset - # uncorrected for N_cells - # these values were fit over the range [37., 750.) - m = 3.4770508e-3 - b = -51.231085 - # these values were fit over the range [750., 5000] - t1 = 750. - m1 = 1.01e-4 - b1 = -48.412078 - # piecewise normalization - self.dpl['L5'][self.t <= 37.] -= dpl_offset['L5'] - self.dpl['L5'][(self.t > 37.) & (self.t < t1)] -= N_pyr * (m * self.t[(self.t > 37.) & (self.t < t1)] + b) - self.dpl['L5'][self.t >= t1] -= N_pyr * (m1 * self.t[self.t >= t1] + b1) - # recalculate the aggregate dipole based on the baseline normalized ones - self.dpl['agg'] = self.dpl['L2'] + self.dpl['L5'] - else: - print("Warning, no dipole renormalization done because units were in %s" % (self.units)) - - # function to write to a file! - # f_dpl must be fully specified - def write(self, f_dpl): - with open(f_dpl, 'w') as f: - for t, x_agg, x_L2, x_L5 in zip(self.t, self.dpl['agg'], self.dpl['L2'], self.dpl['L5']): - f.write("%03.3f\t" % t) - f.write("%9.8f\t" % x_agg) - f.write("%9.8f\t" % x_L2) - f.write("%9.8f\n" % x_L5) - -# throwaway save method for now - see note -def dpl_convert_and_save(ddata, i=0, j=0): - """ trial is currently undefined - function is broken for N_trials > 1 - """ - # take the ith sim, jth trial, do some stuff to it, resave it - # only uses first expmt_group - expmt_group = ddata.expmt_groups[0] - - # need n_trials - p_exp = paramrw.ExpParams(ddata.fparam) - if not p_exp.N_trials: - N_trials = 1 - else: - N_trials = p_exp.N_trials - - # absolute number - n = i*N_trials + j - - # grab the correct files - f_dpl = ddata.file_match(expmt_group, 'rawdpl')[n] - f_param = ddata.file_match(expmt_group, 'param')[n] - - # print ddata.sim_prefix, ddata.dsim - f_name_short = '%s-%03d-T%02d-dpltest.txt' % (ddata.sim_prefix, i, j) - f_name = os.path.join(ddata.dsim, expmt_group, f_name_short) - print(f_name) - - dpl = Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - print("baseline renormalized") - - dpl.convert_fAm_to_nAm() - print("converted to nAm") - - dpl.write(f_name) - -# ddata is a fio.SimulationPaths() object -def calc_aggregate_dipole(ddata): - for expmt_group in ddata.expmt_groups: - # create the filename - dexp = ddata.dexpmt_dict[expmt_group] - fname_short = '%s-%s-dpl' % (ddata.sim_prefix, expmt_group) - fname_data = os.path.join(dexp, fname_short + '.txt') - - # grab the list of raw data dipoles and assoc params in this expmt - dpl_list = ddata.file_match(expmt_group, 'rawdpl') - param_list = ddata.file_match(expmt_group, 'param') - - for f_dpl, f_param in zip(dpl_list, param_list): - dpl = Dipole(f_dpl) - # dpl.baseline_renormalize(f_param) - - # initialize and use x_dpl - if f_dpl is dpl_list[0]: - # assume time vec stays the same throughout - t_vec = dpl.t - x_dpl = dpl.dpl['agg'] - - else: - # guaranteed to exist after dpl_list[0] - x_dpl += dpl.dpl['agg'] - - # poor man's mean - x_dpl /= len(dpl_list) - - # write this data to the file - with open(fname_data, 'w') as f: - for t, x in zip(t_vec, x_dpl): - f.write("%03.3f\t%5.4f\n" % (t, x)) - -# calculate stimulus evoked dipole -def calc_avgdpl_stimevoked(ddata): - for expmt_group in ddata.expmt_groups: - # create the filename - dexp = ddata.dexpmt_dict[expmt_group] - fname_short = '%s-%s-dpl' % (ddata.sim_prefix, expmt_group) - fname_data = os.path.join(dexp, fname_short + '.txt') - - # grab the list of raw data dipoles and assoc params in this expmt - fdpl_list = ddata.file_match(expmt_group, 'rawdpl') - param_list = ddata.file_match(expmt_group, 'param') - spk_list = ddata.file_match(expmt_group, 'rawspk') - - # actual list of Dipole() objects - dpl_list = [Dipole(fdpl) for fdpl in fdpl_list] - t_truncated = [] - - # iterate through the lists, grab the spike time, phase align the signals, - # cut them to length, and then mean the dipoles - for dpl, f_spk, f_param in zip(dpl_list, spk_list, param_list): - _, p = paramrw.read(f_param) - - # grab the corresponding relevant starting spike time - s = spikefn.spikes_from_file(f_param, f_spk) - s = spikefn.alpha_feed_verify(s, p) - s = spikefn.add_delay_times(s, p) - - # t_evoked is the same for all of the cells in these simulations - t_evoked = s['evprox0'].spike_list[0][0] - - # attempt to give a 50 ms buffer - if t_evoked > 50.: - t0 = t_evoked - 50. - else: - t0 = t_evoked - - # truncate the dipole related vectors - dpl.t = dpl.t[dpl.t > t0] - dpl.dpl['agg'] = dpl.dpl['agg'][dpl.t > t0] - t_truncated.append(dpl.t[0]) - - # find the t0_max value to compare on other dipoles - t_truncated -= np.max(t_truncated) - - for dpl, t_adj in zip(dpl_list, t_truncated): - # negative numbers mean that this vector needs to be shortened by that many ms - T_new = dpl.t[-1] + t_adj - dpl.dpl['agg'] = dpl.dpl['agg'][dpl.t < T_new] - dpl.t = dpl.t[dpl.t < T_new] - - if dpl is dpl_list[0]: - dpl_total = dpl.dpl['agg'] - - else: - dpl_total += dpl.dpl['agg'] - - dpl_mean = dpl_total / len(dpl_list) - t_dpl = dpl_list[0].t - - # write this data to the file - with open(fname_data, 'w') as f: - for t, x in zip(t_dpl, dpl_mean): - f.write("%03.3f\t%5.4f\n" % (t, x)) - -# Creates a template of dpl activity by averaging dpl data over specified time intervals -# Assumes t_intervals are all the same length -def create_template(fname, dpl_list, param_list, t_interval_list): - # iterate over lists, load dpl data and average - for fdpl, fparam, t_int in zip(dpl_list, param_list, t_interval_list): - # load ts data - dpl = Dipole(fdpl) - dpl.baseline_renormalize(fparam) - # dpl.convert_fAm_to_nAm() - - # truncate data based on time ranges specified in dmax - t_cut, dpl_tcut = dpl.truncate_ext(t_int[0], t_int[1]) - - if fdpl is dpl_list[0]: - x_dpl_agg = dpl_tcut['agg'] - x_dpl_L2 = dpl_tcut['L2'] - x_dpl_L5 = dpl_tcut['L5'] - - else: - x_dpl_agg += dpl_tcut['agg'] - x_dpl_L2 += dpl_tcut['L2'] - x_dpl_L5 += dpl_tcut['L5'] - - # poor man's mean - x_dpl_agg /= len(dpl_list) - x_dpl_L2 /= len(dpl_list) - x_dpl_L5 /= len(dpl_list) - - # create a tvec that is symmetric about zero and of proper length - # assume time intervals are identical length for all data - t_range = t_interval_list[0][1] - t_interval_list[0][0] - t_start = - t_range / 2. - t_end = t_range / 2. - tvec = np.linspace(t_start, t_end, x_dpl_agg.shape[0]) - # tvec = np.linspace(0, t_range, x_dpl_agg.shape[0]) - - # save to file - with open(fname, 'w') as f: - for t, x_agg, x_L2, x_L5 in zip(tvec, x_dpl_agg, x_dpl_L2, x_dpl_L5): - f.write("%03.3f\t%5.4f\t%5.4f\t%5.4f\n" % (t, x_agg, x_L2, x_L5)) - -# one off function to plot linear regression -def plinear_regression(ffig_dpl, fdpl): - dpl = Dipole(fdpl) - layer = 'L5' - t0 = 750. - - # dipole for the given layer, truncated - # order matters here - x_dpl = dpl.dpl[layer][(dpl.t > t0)] - t = dpl.t[dpl.t > t0] - - # take the transpose (T) of a vector of the times and ones for each element - A = np.vstack([t, np.ones(len(t))]).T - - # find the slope and the y-int of the line fit with least squares method (min. of Euclidean 2-norm) - m, c = np.linalg.lstsq(A, x_dpl)[0] - print(m, c) - - # plot me - f = ac.FigStd() - f.ax0.plot(t, x_dpl) - f.ax0.hold(True) - f.ax0.plot(t, m*t + c, 'r') - - # save over the original - f.savepng(ffig_dpl) - f.close() - -# plot a dipole to an axis from corresponding dipole and param files -def pdipole_ax(a, f_dpl, f_param): - dpl = Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - - a.plot(dpl.t, dpl.dpl['agg']) - - # any further xlim sets can be done by whoever wants to do them later - a.set_xlim((0., dpl.t[-1])) - - # at least make the ylim symmetrical about 0 - ylim = a.get_ylim() - abs_y_max = np.max(np.abs(ylim)) - ylim = (-abs_y_max, abs_y_max) - a.set_ylim(ylim) - - # return the actual time in form of xlim. ain't pretty but works - return a.get_xlim() - -# pdipole is for a single dipole file, should be for a -def pdipole(f_dpl, dfig, plot_dict, f_param=None, key_types={}): - """ single dipole file combination (incl. param file) - this should be done with an axis input too - two separate functions, a pdipole kernel function and a specific function for this simple plot - """ - # dpl is an obj of Dipole() class - dpl = Dipole(f_dpl) - - if f_param: - dpl.baseline_renormalize(f_param) - - dpl.convert_fAm_to_nAm() - - # split to find file prefix - file_prefix = f_dpl.split('/')[-1].split('.')[0] - - - # parse xlim from plot_dict - if plot_dict['xlim'] is None: - xmin = dpl.t[0] - xmax = dpl.t[-1] - - else: - xmin, xmax = plot_dict['xlim'] - - if xmin < 0.: - xmin = 0. - - if xmax < 0.: - xmax = self.f[-1] - - # # get xmin and xmax from the plot_dict - # if plot_dict['xmin'] is None: - # xmin = 0. - # else: - # xmin = plot_dict['xmin'] - - # if plot_dict['xmax'] is None: - # xmax = p_dict['tstop'] - # else: - # xmax = plot_dict['xmax'] - - # truncate them using logical indexing - t_range = dpl.t[(dpl.t >= xmin) & (dpl.t <= xmax)] - dpl_range = dpl.dpl['agg'][(dpl.t >= xmin) & (dpl.t <= xmax)] - - f = ac.FigStd() - f.ax0.plot(t_range, dpl_range) - - # sorry about the parity between vars here and above with xmin/xmax - if plot_dict['ylim'] is None: - # if plot_dict['ymin'] is None or plot_dict['ymax'] is None: - pass - else: - f.ax0.set_ylim(plot_dict['ylim'][0], plot_dict['ylim'][1]) - # f.ax0.set_ylim(plot_dict['ymin'], plot_dict['ymax']) - - # Title creation - if f_param and key_types: - # grabbing the p_dict from the f_param - _, p_dict = paramrw.read(f_param) - - # useful for title strings - title_str = ac.create_title(p_dict, key_types) - f.f.suptitle(title_str) - - # create new fig name - fig_name = os.path.join(dfig, file_prefix+'.png') - - # savefig - plt.savefig(fig_name, dpi=300) - f.close() - -# plot vertical lines corresponding to the evoked input times -def pdipole_evoked(dfig, f_dpl, f_spk, f_param, ylim=[]): - """ for each individual simulation/trial - """ - gid_dict, p_dict = paramrw.read(f_param) - - # get the spike dict from the files - s_dict = spikefn.spikes_from_file(f_param, f_spk) - s = s_dict.keys() - s.sort() - - # create an empty dict 'spk_unique' - spk_unique = dict.fromkeys([key for key in s_dict.keys() if key.startswith(('evprox', 'evdist'))]) - - for key in spk_unique: - spk_unique[key] = s_dict[key].unique_all(0) - - # draw vertical lines for each item in this - - # x_dipole is dipole data - # x_dipole = np.loadtxt(open(f_dpl, 'r')) - dpl = Dipole(f_dpl) - - # split to find file prefix - file_prefix = f_dpl.split('/')[-1].split('.')[0] - - # # set xmin value - # xmin = xlim[0] / p_dict['dt'] - - # # set xmax value - # if xlim[1] == 'tstop': - # xmax = p_dict['tstop'] / p_dict['dt'] - # else: - # xmax = xlim[1] / p_dict['dt'] - - # these are the vectors for now, but this is going to change - t_vec = dpl.t - dp_total = dpl.dpl['agg'] - - f = ac.FigStd() - - # hold on - f.ax0.hold(True) - - f.ax0.plot(t_vec, dp_total) - - lines_spk = dict.fromkeys(spk_unique) - - print(spk_unique) - - # plot the lines - for key in spk_unique: - print(key, spk_unique[key]) - x_val = spk_unique[key][0] - lines_spk[key] = plt.axvline(x=x_val, linewidth=0.5, color='r') - - # title_txt = [key + ': {:.2e}' % p_dict[key] for key in key_types['dynamic_keys']] - title_txt = 'test' - f.ax0.set_title(title_txt) - - if ylim: - f.ax0.set_ylim(ylim) - - fig_name = os.path.join(dfig, file_prefix+'.png') - - plt.savefig(fig_name, dpi=300) - f.close() - -# Plots dipole with histogram of alpha feed inputs - slightly deprecated, see note -def pdipole_with_hist(f_dpl, f_spk, dfig, f_param, key_types, plot_dict): - """ this function has not been converted to use the Dipole() class yet - """ - # dpl is an obj of Dipole() class - dpl = Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - # split to find file prefix - file_prefix = f_dpl.split('/')[-1].split('.')[0] - # grabbing the p_dict from the f_param - _, p_dict = paramrw.read(f_param) - # get xmin and xmax from the plot_dict - if plot_dict['xmin'] is None: - xmin = 0. - else: - xmin = plot_dict['xmin'] - if plot_dict['xmax'] is None: - xmax = p_dict['tstop'] - else: - xmax = plot_dict['xmax'] - # truncate tvec and dpl data using logical indexing - t_range = dpl.t[(dpl.t >= xmin) & (dpl.t <= xmax)] - dpl_range = dpl.dpl['agg'][(dpl.t >= xmin) & (dpl.t <= xmax)] - # Plotting - f = ac.FigDplWithHist() - # dipole - f.ax['dipole'].plot(t_range, dpl_range) - # set new xlim based on dipole plot - xlim_new = f.ax['dipole'].get_xlim() - # Get extinput data and account for delays - try: - extinputs = spikefn.ExtInputs(f_spk, f_param) - except ValueError: - print("Error: could not load spike timings from %s" % f_spk) - f.close() - return - - extinputs.add_delay_times() - # set number of bins (150 bins per 1000ms) - bins = ceil(150. * (xlim_new[1] - xlim_new[0]) / 1000.) # bins needs to be an int - # plot histograms - hist = {} - hist['feed_prox'] = extinputs.plot_hist(f.ax['feed_prox'], 'prox', dpl.t, bins, xlim_new, color='red') - hist['feed_dist'] = extinputs.plot_hist(f.ax['feed_dist'], 'dist', dpl.t, bins, xlim_new, color='green') - # Invert dist histogram - f.ax['feed_dist'].invert_yaxis() - # for now, set the xlim for the other one, force it! - f.ax['dipole'].set_xlim(xlim_new) - f.ax['feed_prox'].set_xlim(xlim_new) - f.ax['feed_dist'].set_xlim(xlim_new) - # set hist axis properties - f.set_hist_props(hist) - # Add legend to histogram - for key in f.ax.keys(): - if 'feed' in key: - f.ax[key].legend() - # force xlim on histograms - f.ax['feed_prox'].set_xlim((xmin, xmax)) - f.ax['feed_dist'].set_xlim((xmin, xmax)) - title_str = ac.create_title(p_dict, key_types) - f.f.suptitle(title_str) - fig_name = os.path.join(dfig, file_prefix+'.png') - plt.savefig(fig_name) - f.close() - -# For a given ddata (SimulationPaths object), find the mean dipole -def pdipole_exp(ddata, ylim=[]): - """ over ALL trials in ALL conditions in EACH experiment - """ - # sim_prefix - fprefix = ddata.sim_prefix - - # create the figure name - fname_exp = '%s_dpl' % (fprefix) - fname_exp_fig = os.path.join(ddata.dsim, fname_exp + '.png') - - # create one figure comparing across all - N_expmt_groups = len(ddata.expmt_groups) - f_exp = ac.FigDipoleExp(ddata.expmt_groups) - - # empty list for the aggregate dipole data - dpl_exp = [] - - # go through each expmt - for expmt_group in ddata.expmt_groups: - # create the filename - dexp = ddata.dexpmt_dict[expmt_group] - fname_short = '%s-%s-dpl' % (fprefix, expmt_group) - fname_data = os.path.join(dexp, fname_short + '.txt') - fname_fig = os.path.join(ddata.dfig[expmt_group]['figdpl'], fname_short + '.png') - - # grab the list of raw data dipoles and assoc params in this expmt - dpl_list = ddata.file_match(expmt_group, 'rawdpl') - param_list = ddata.file_match(expmt_group, 'param') - - for f_dpl, f_param in zip(dpl_list, param_list): - dpl = Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - # x_tmp = np.loadtxt(open(file, 'r')) - - # initialize and use x_dpl - if f_dpl is dpl_list[0]: - - # assume time vec stays the same throughout - t_vec = dpl.t - x_dpl = dpl.dpl['agg'] - - else: - # guaranteed to exist after dpl_list[0] - x_dpl += dpl.dpl['agg'] - - # poor man's mean - x_dpl /= len(dpl_list) - - # save this in a list to do comparison figure - # order is same as ddata.expmt_groups - dpl_exp.append(x_dpl) - - # write this data to the file - with open(fname_data, 'w') as f: - for t, x in zip(t_vec, x_dpl): - f.write("%03.3f\t%5.4f\n" % (t, x)) - - # create the plot I guess? - f = ac.FigStd() - f.ax0.plot(t_vec, x_dpl) - - if len(ylim): - f.ax0.set_ylim(ylim) - - f.savepng(fname_fig) - f.close() - - # plot the aggregate data using methods defined in FigDipoleExp() - f_exp.plot(t_vec, dpl_exp) - - # attempt at setting titles - for ax, expmt_group in zip(f_exp.ax, ddata.expmt_groups): - ax.set_title(expmt_group) - - f_exp.savepng(fname_exp_fig) - f_exp.close() - -# For a given ddata (SimulationPaths object), find the mean dipole -def pdipole_exp2(ddata): - """ over ALL trials in ALL conditions in EACH experiment - appears to be an iteration on pdipole_exp() - """ - # grab the original dipole from a specific dir - dproj = fio.return_data_dir() - - runtype = 'somethingotherthandebug' - # runtype = 'debug' - - # really shoddy testing code! sorry! - if runtype == 'debug': - ddate = '2013-04-08' - dsim = 'mubaseline-15-000' - i_ctrl = 0 - else: - ddate = raw_input('Short date directory? ') - dsim = raw_input('Sim name? ') - i_ctrl = ast.literal_eval(raw_input('Sim number: ')) - dcheck = os.path.join(dproj, ddate, dsim) - - # create a blank ddata structure - ddata_ctrl = fio.SimulationPaths() - dsim = ddata_ctrl.read_sim(dproj, dcheck) - - # find the mu_low and mu_high in the expmtgroup names - # this means the group names must be well formed - for expmt_group in ddata_ctrl.expmt_groups: - if 'mu_low' in expmt_group: - mu_low_group = expmt_group - elif 'mu_high' in expmt_group: - mu_high_group = expmt_group - - # choose the first [0] from the list of the file matches for mu_low - fdpl_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawdpl')[i_ctrl] - fparam_mu_low = ddata_ctrl.file_match(mu_low_group, 'param')[i_ctrl] - fspk_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawspk')[i_ctrl] - fspec_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawspec')[i_ctrl] - - # choose the first [0] from the list of the file matches for mu_high - fdpl_mu_high = ddata_ctrl.file_match(mu_high_group, 'rawdpl')[i_ctrl] - fparam_mu_high = ddata_ctrl.file_match(mu_high_group, 'param')[i_ctrl] - # fspk_mu_high = ddata_ctrl.file_match(mu_high_group, 'rawspk')[i_ctrl] - - # grab the relevant dipole and renormalize it for mu_low - dpl_mu_low = Dipole(fdpl_mu_low) - dpl_mu_low.baseline_renormalize(fparam_mu_low) - - # grab the relevant dipole and renormalize it for mu_high - dpl_mu_high = Dipole(fdpl_mu_high) - dpl_mu_high.baseline_renormalize(fparam_mu_high) - - # input feed information - s = spikefn.spikes_from_file(fparam_mu_low, fspk_mu_low) - _, p_ctrl = paramrw.read(fparam_mu_low) - s = spikefn.alpha_feed_verify(s, p_ctrl) - s = spikefn.add_delay_times(s, p_ctrl) - - # hard coded bin count for now - tstop = paramrw.find_param(fparam_mu_low, 'tstop') - bins = spikefn.bin_count(150., tstop) - - # sim_prefix - fprefix = ddata.sim_prefix - - # create the figure name - fname_exp = '%s_dpl' % (fprefix) - fname_exp_fig = os.path.join(ddata.dsim, fname_exp + '.png') - - # create one figure comparing across all - N_expmt_groups = len(ddata.expmt_groups) - ax_handles = [ - 'spec', - 'input', - 'dpl_mu_low', - 'dpl_mu_high', - ] - f_exp = ac.FigDipoleExp(ax_handles) - - # plot the ctrl dipoles - f_exp.ax['dpl_mu_low'].plot(dpl_mu_low.t, dpl_mu_low.dpl['agg'], color='k') - f_exp.ax['dpl_mu_low'].hold(True) - f_exp.ax['dpl_mu_high'].plot(dpl_mu_high.t, dpl_mu_high.dpl['agg'], color='k') - f_exp.ax['dpl_mu_high'].hold(True) - - # function creates an f_exp.ax_twinx list and returns the index of the new feed - ax_twin_name = f_exp.create_axis_twinx('input') - if not ax_twin_name: - print("You've got bigger problems, I'm afraid") - - # input hist information: predicated on the fact that the input histograms - # should be identical for *all* of the inputs represented in this figure - spikefn.pinput_hist(f_exp.ax['input'], f_exp.ax_twinx['input'], s['alpha_feed_prox'][0].spike_list, s['alpha_feed_dist'][0].spike_list, n_bins) - - # grab the max counts for both hists - # the [0] item of hist are the counts - max_hist = np.max([np.max(hist[key][0]) for key in hist.keys()]) - ymax = 2 * max_hist - - # plot the spec here - pc = specfn.pspec_ax(f_exp.ax['spec'], fspec_mu_low) - print(f_exp.ax[0].get_xlim()) - - # deal with the axes here - f_exp.ax_twinx['input'].set_ylim((ymax, 0)) - f_exp.ax['input'].set_ylim((0, ymax)) - - f_exp.ax['input'].set_xlim((50., tstop)) - f_exp.ax_twinx['input'].set_xlim((50., tstop)) - - # empty list for the aggregate dipole data - dpl_exp = [] - - # go through each expmt - # calculation is extremely redundant - for expmt_group in ddata.expmt_groups: - # a little sloppy, just find the param file - # this param file was for the baseline renormalization and - # assumes it's the same in all for this expmt_group - # also for getting the gid_dict, also assumed to be the same - fparam = ddata.file_match(expmt_group, 'param')[0] - - # general check to see if the aggregate dipole data exists - if 'mu_low' in expmt_group or 'mu_high' in expmt_group: - # check to see if these files exist - flist = ddata.find_aggregate_file(expmt_group, 'dpl') - - # if no file exists, then find one - if not len(flist): - calc_aggregate_dipole(ddata) - flist = ddata.find_aggregate_file(expmt_group, 'dpl') - - # testing the first file - list_spk = ddata.file_match(expmt_group, 'rawspk') - list_s_dict = [spikefn.spikes_from_file(fparam, fspk) for fspk in list_spk] - list_evoked = [s_dict['evprox0'].spike_list[0][0] for s_dict in list_s_dict] - lines_spk = [f_exp.ax[2].axvline(x=x_val, linewidth=0.5, color='r') for x_val in list_evoked] - lines_spk = [f_exp.ax[3].axvline(x=x_val, linewidth=0.5, color='r') for x_val in list_evoked] - - # handle mu_low and mu_high separately - if 'mu_low' in expmt_group: - dpl_mu_low_ev = Dipole(flist[0]) - dpl_mu_low_ev.baseline_renormalize(fparam) - f_exp.ax['dpl_mu_low'].plot(dpl_mu_low_ev.t, dpl_mu_low_ev.dpl['agg']) - - elif 'mu_high' in expmt_group: - dpl_mu_high_ev = Dipole(flist[0]) - dpl_mu_high_ev.baseline_renormalize(fparam) - f_exp.ax['dpl_mu_high'].plot(dpl_mu_high_ev.t, dpl_mu_high_ev.dpl['agg']) - - f_exp.ax['dpl_mu_low'].set_xlim(50., tstop) - f_exp.ax['dpl_mu_high'].set_xlim(50., tstop) - - f_exp.savepng(fname_exp_fig) - f_exp.close() - -# For a given ddata (SimulationPaths object), find the mean dipole -def pdipole_evoked_aligned(ddata): - """ over ALL trials in ALL conditions in EACH experiment - appears to be iteration over pdipole_exp2() - """ - # grab the original dipole from a specific dir - dproj = fio.return_data_dir() - - runtype = 'somethingotherthandebug' - # runtype = 'debug' - - if runtype == 'debug': - ddate = '2013-04-08' - dsim = 'mubaseline-04-000' - i_ctrl = 0 - else: - ddate = raw_input('Short date directory? ') - dsim = raw_input('Sim name? ') - i_ctrl = ast.literal_eval(raw_input('Sim number: ')) - dcheck = os.path.join(dproj, ddate, dsim) - - # create a blank ddata structure - ddata_ctrl = fio.SimulationPaths() - dsim = ddata_ctrl.read_sim(dproj, dcheck) - - # find the mu_low and mu_high in the expmtgroup names - # this means the group names must be well formed - for expmt_group in ddata_ctrl.expmt_groups: - if 'mu_low' in expmt_group: - mu_low_group = expmt_group - elif 'mu_high' in expmt_group: - mu_high_group = expmt_group - - # choose the first [0] from the list of the file matches for mu_low - fdpl_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawdpl')[i_ctrl] - fparam_mu_low = ddata_ctrl.file_match(mu_low_group, 'param')[i_ctrl] - fspk_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawspk')[i_ctrl] - fspec_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawspec')[i_ctrl] - - # choose the first [0] from the list of the file matches for mu_high - fdpl_mu_high = ddata_ctrl.file_match(mu_high_group, 'rawdpl')[i_ctrl] - fparam_mu_high = ddata_ctrl.file_match(mu_high_group, 'param')[i_ctrl] - - # grab the relevant dipole and renormalize it for mu_low - dpl_mu_low = Dipole(fdpl_mu_low) - dpl_mu_low.baseline_renormalize(fparam_mu_low) - - # grab the relevant dipole and renormalize it for mu_high - dpl_mu_high = Dipole(fdpl_mu_high) - dpl_mu_high.baseline_renormalize(fparam_mu_high) - - # input feed information - s = spikefn.spikes_from_file(fparam_mu_low, fspk_mu_low) - _, p_ctrl = paramrw.read(fparam_mu_low) - s = spikefn.alpha_feed_verify(s, p_ctrl) - s = spikefn.add_delay_times(s, p_ctrl) - - # find tstop, assume same over all. grab the first param file, get the tstop - tstop = paramrw.find_param(fparam_mu_low, 'tstop') - - # hard coded bin count for now - n_bins = spikefn.bin_count(150., tstop) - - # sim_prefix - fprefix = ddata.sim_prefix - - # create the figure name - fname_exp = '%s_dpl_align' % (fprefix) - fname_exp_fig = os.path.join(ddata.dsim, fname_exp + '.png') - - # create one figure comparing across all - N_expmt_groups = len(ddata.expmt_groups) - ax_handles = [ - 'spec', - 'input', - 'dpl_mu', - 'spk', - ] - f_exp = ac.FigDipoleExp(ax_handles) - - # plot the ctrl dipoles - f_exp.ax['dpl_mu'].plot(dpl_mu_low.t, dpl_mu_low.dpl, color='k') - f_exp.ax['dpl_mu'].hold(True) - f_exp.ax['dpl_mu'].plot(dpl_mu_high.t, dpl_mu_high.dpl) - - # function creates an f_exp.ax_twinx list and returns the index of the new feed - f_exp.create_axis_twinx('input') - - # input hist information: predicated on the fact that the input histograms - # should be identical for *all* of the inputs represented in this figure - # places 2 histograms on two axes (meant to be one axis flipped) - hists = spikefn.pinput_hist(f_exp.ax['input'], f_exp.ax_twinx['input'], s['alpha_feed_prox'].spike_list, s['alpha_feed_dist'].spike_list, n_bins) - - # grab the max counts for both hists - # the [0] item of hist are the counts - max_hist = np.max([np.max(hists[key][0]) for key in hists.keys()]) - ymax = 2 * max_hist - - # plot the spec here - pc = specfn.pspec_ax(f_exp.ax['spec'], fspec_mu_low) - - # deal with the axes here - f_exp.ax['input'].set_ylim((0, ymax)) - f_exp.ax_twinx['input'].set_ylim((ymax, 0)) - # f_exp.ax[1].set_ylim((0, ymax)) - - # f_exp.ax[1].set_xlim((50., tstop)) - - # turn hold on - f_exp.ax[dpl_mu].hold(True) - - # empty list for the aggregate dipole data - dpl_exp = [] - - # go through each expmt - # calculation is extremely redundant - for expmt_group in ddata.expmt_groups: - # a little sloppy, just find the param file - # this param file was for the baseline renormalization and - # assumes it's the same in all for this expmt_group - # also for getting the gid_dict, also assumed to be the same - fparam = ddata.file_match(expmt_group, 'param')[0] - - # general check to see if the aggregate dipole data exists - if 'mu_low' in expmt_group or 'mu_high' in expmt_group: - # check to see if these files exist - flist = ddata.find_aggregate_file(expmt_group, 'dpl') - - # if no file exists, then find one - if not len(flist): - calc_aggregate_dipole(ddata) - flist = ddata.find_aggregate_file(expmt_group, 'dpl') - - # testing the first file - list_spk = ddata.file_match(expmt_group, 'rawspk') - list_s_dict = [spikefn.spikes_from_file(fparam, fspk) for fspk in list_spk] - list_evoked = [s_dict['evprox0'].spike_list[0][0] for s_dict in list_s_dict] - lines_spk = [f_exp.ax['dpl_mu'].axvline(x=x_val, linewidth=0.5, color='r') for x_val in list_evoked] - lines_spk = [f_exp.ax['spk'].axvline(x=x_val, linewidth=0.5, color='r') for x_val in list_evoked] - - # handle mu_low and mu_high separately - if 'mu_low' in expmt_group: - dpl_mu_low_ev = Dipole(flist[0]) - dpl_mu_low_ev.baseline_renormalize(fparam) - f_exp.ax['spk'].plot(dpl_mu_low_ev.t, dpl_mu_low_ev.dpl, color='k') - - # get xlim stuff - t0 = dpl_mu_low_ev.t[0] - T = dpl_mu_low_ev.t[-1] - - elif 'mu_high' in expmt_group: - dpl_mu_high_ev = Dipole(flist[0]) - dpl_mu_high_ev.baseline_renormalize(fparam) - f_exp.ax['spk'].plot(dpl_mu_high_ev.t, dpl_mu_high_ev.dpl, color='b') - - f_exp.ax['input'].set_xlim(50., tstop) - - for ax_name in f_exp.ax_handles[2:]: - ax.set_xlim((t0, T)) - - f_exp.savepng(fname_exp_fig) - f_exp.close() - -# create a grid of all dipoles in this dir -def pdipole_grid(ddata): - # iterate through expmt_groups - for expmt_group in ddata.expmt_groups: - fname_short = "%s-%s-dpl.png" % (ddata.sim_prefix, expmt_group) - fname = os.path.join(ddata.dsim, expmt_group, fname_short) - - # simple usage, just checks how many dipole files (total in an expmt) - # and then plots dumbly to a grid - dpl_list = ddata.file_match(expmt_group, 'rawdpl') - param_list = ddata.file_match(expmt_group, 'param') - - # assume tstop is the same everywhere - tstop = paramrw.find_param(param_list[0], 'tstop') - - # length of the dpl list - N_dpl = len(dpl_list) - - # make a 5-col figure - N_cols = 5 - - # force int arithmetic - # this is the BASE number of rows, one might be added! - N_rows = int(N_dpl) // int(N_cols) - - # if the mod is not 0, add a row - if (N_dpl % N_cols): - N_rows += 1 - - # print(N_dpl, N_cols, N_rows) - f = ac.FigGrid(N_rows, N_cols, tstop) - - l = [] - r = 0 - for ax_list in f.ax: - l.extend([(r,c) for c in range(len(ax_list))]) - r += 1 - - # automatically truncates the loc list to the size of dpl_list - for loc, fdpl, fparam in zip(l, dpl_list, param_list): - r = loc[0] - c = loc[1] - pdipole_ax(f.ax[r][c], fdpl, fparam) - - f.savepng(fname) - f.close() - -def plot_specmax_interval(fname, dpl_list, param_list, specmax_list): - N_trials = len([d for d in specmax_list if d is not None]) - - # instantiate figure - f = ac.FigInterval(N_trials+1) - - # set spacing between plots - spacers = np.arange(0.5e-4, N_trials*1e-4, 1e-4) - - # invert order of spacers so first trial is at top of plot - spacers = spacers[::-1] - - # iterate over various lists and plot to axis - i = 0 - - for fdpl, fparam, dmax in zip(dpl_list, param_list, specmax_list): - # for fdpl, dmax, space in zip(dpl_list, specmax_list, spacers): - if dmax is not None: - # load ts data - dpl = Dipole(fdpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - # truncate data based on time ranges specified in dmax - t_cut, dpl_tcut = dpl.truncate_ext(dmax['t_int'][0], dmax['t_int'][1]) - - # create a tvec that is symmetric about zero and of proper length - t_range = dmax['t_int'][1] - dmax['t_int'][0] - t_start = 0 - t_range / 2. - t_end = 0 + t_range / 2. - tvec = np.linspace(t_start, t_end, dpl_tcut['agg'].shape[0]) - - # plot to proper height - f.ax['ts'].plot(tvec, dpl_tcut['agg']+spacers[i]) - - # add text with pertinent information - x_offset = f.ax['ts'].get_xlim()[1] + 25 - f.ax['ts'].text(x_offset, spacers[i], 'freq: %s Hz\ntime: %s ms\n%s' %(dmax['f_at_max'],dmax['t_at_max'], dmax['fname']), fontsize=12, verticalalignment='center') - - i += 1 - - # force xlim for now - # f.ax['ts'].set_xlim(-100, 100) - - # save fig - f.savepng(fname+'.png') - - # close fig - f.close() diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..29c996385 --- /dev/null +++ b/environment.yml @@ -0,0 +1,12 @@ +name: hnn +channels: +- defaults +dependencies: +- python=3.7 +- pip +- numpy +- scipy +- matplotlib +- psutil +- pip: + - PyQt5 \ No newline at end of file diff --git a/example_analysis.py b/example_analysis.py deleted file mode 100644 index b1d6b2ce4..000000000 --- a/example_analysis.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -# example_analysis.py - Example for an analysis workflow -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: changed it.izip()) -# last major: (RL: added dipole loading) - -import numpy as np -import fileio as fio -import paramrw -import dipolefn -import PT_example - -# example simulation -def example_analysis_for_simulation(): - # from these two directories - droot = fio.return_data_dir() - dsim = os.path.join(droot, '2016-02-03/beta-sweep-000') - - # create the SimulationPaths() object ddata and read the simulation - ddata = fio.SimulationPaths() - ddata.read_sim(droot, dsim) - - # print dir(ddata) - # print type(np.zeros(5)) - - # print ddata.expmt_groups - # print ddata.fparam - # for key, val in ddata.dfig['testing'].items(): - # print key, val - # print dir({}) - - # p_exp = paramrw.ExpParams(ddata.fparam) - # print p_exp.p_all['dt'] - # # print p_exp.p_all - - # iterate through experimental groups and do the analysis on individual files, etc. - for expmt_group in ddata.expmt_groups: - print "experiment group is: %s" % (expmt_group) - # print ddata.dfig[expmt_group] - flist_param = ddata.file_match(expmt_group, 'param') - flist_dpl = ddata.file_match(expmt_group, 'rawdpl') - # flist_spk = ddata.file_match(expmt_group, 'rawspk') - # fio.prettyprint(flist_spk) - - # iterate through files in the lists - for fparam, fdpl in zip(flist_param, flist_dpl): - # print fparam, fdpl - gid_dict, p_tr = paramrw.read(fparam) - - # for key, val in p_tr.items(): - # print key, val - # fio.prettyprint(p_tr.keys()) - - # create and load dipole data structure - d = dipolefn.Dipole(fdpl) - - # more or less analysis goes here. - # generate a filename for a dipole plot - fname_png = ddata.return_filename_example('figdpl', expmt_group, p_tr['Sim_No'], tr=p_tr['Trial']) - # print p_tr['Trial'], p_tr['Sim_No'], fname_png - - # example figure for this pair of files - fig = PT_example.FigExample() - - # plot dipole - fig.ax['dipole'].plot(d.t, d.dpl['agg']) - fig.ax['dipole'].plot(d.t, d.dpl['L2']) - fig.ax['dipole'].plot(d.t, d.dpl['L5']) - fig.savepng(fname_png) - fig.close() - -if __name__ == '__main__': - example_analysis_for_simulation() diff --git a/feed.py b/feed.py deleted file mode 100644 index 75324b960..000000000 --- a/feed.py +++ /dev/null @@ -1,224 +0,0 @@ -# feed.py - establishes FeedExt(), ParFeedAll() -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: updated for python3) -# last major: (SL: toward python3) - -import numpy as np -import itertools as it # this used? -from neuron import h - -class ParFeedAll (): - # p_ext has a different structure for the extinput - # usually, p_ext is a dict of cell types - def __init__ (self, ty, celltype, p_ext, gid): - #print("ParFeedAll __init__") - # VecStim setup - self.eventvec = h.Vector() - self.vs = h.VecStim() - # self.p_unique = p_unique[type] - self.p_ext = p_ext - self.celltype = celltype - self.ty = ty # feed type - self.gid = gid - self.set_prng() # sets seeds for random num generator - self.set_event_times() # sets event times into self.eventvec and plays into self.vs (VecStim) - - # inc random number generator seeds - def inc_prng (self, inc): - self.seed += inc - self.prng = np.random.RandomState(self.seed) - if hasattr(self,'seed2'): - self.seed2 += inc - self.prng2 = np.random.RandomState(self.seed2) - - def set_prng (self, seed = None): - if seed is None: # no seed specified then use p_ext to determine seed - # random generator for this instance - # qnd hack to make the seeds the same across all gids - # for just evoked - if self.ty.startswith(('evprox', 'evdist')): - if self.p_ext['sync_evinput']: - self.seed = self.p_ext['prng_seedcore'] - else: - self.seed = self.p_ext['prng_seedcore'] + self.gid - elif self.ty.startswith('extinput'): - self.seed = self.p_ext['prng_seedcore'] + self.gid # seed for events assuming a given start time - self.seed2 = self.p_ext['prng_seedcore'] # separate seed for start times - else: - self.seed = self.p_ext['prng_seedcore'] + self.gid - else: # if seed explicitly specified use it - self.seed = seed - if hasattr(self,'seed2'): self.seed2 = seed - self.prng = np.random.RandomState(self.seed) - if hasattr(self,'seed2'): self.prng2 = np.random.RandomState(self.seed2) - #print('ty,seed:',self.ty,self.seed) - - def set_event_times (self, inc_evinput = 0.0): - # print('self.p_ext:',self.p_ext) - # each of these methods creates self.eventvec for playback - if self.ty == 'extpois': - self.__create_extpois() - elif self.ty.startswith(('evprox', 'evdist')): - self.__create_evoked(inc_evinput) - elif self.ty == 'extgauss': - self.__create_extgauss() - elif self.ty == 'extinput': - self.__create_extinput() - # load eventvec into VecStim object - self.vs.play(self.eventvec) - - # based on cdf for exp wait time distribution from unif [0, 1) - # returns in ms based on lamtha in Hz - def __t_wait (self, lamtha): - return -1000. * np.log(1. - self.prng.rand()) / lamtha - - # new external pois designation - def __create_extpois (self): - #print("__create_extpois") - if self.p_ext[self.celltype][0] <= 0.0 and \ - self.p_ext[self.celltype][1] <= 0.0: return False # 0 ampa and 0 nmda weight - # check the t interval - t0 = self.p_ext['t_interval'][0] - T = self.p_ext['t_interval'][1] - lamtha = self.p_ext[self.celltype][3] # index 3 is frequency (lamtha) - # values MUST be sorted for VecStim()! - # start the initial value - if lamtha > 0.: - val_pois = np.array([]) - t_gen = t0 - lamtha * 2 # start before t0 to remove artifact of lower event rate at start of period - while t_gen < T: - t_gen += self.__t_wait(lamtha) # move forward by the wait time (so as to not clobber base off of t_gen) - if t_gen >= t0 and t_gen < T: # make sure event time is within the specified interval - # vals are guaranteed to be monotonically increasing, no need to sort - val_pois = np.append(val_pois, t_gen) - else: - val_pois = np.array([]) - # checks the distribution stats - # if len(val_pois): - # xdiff = np.diff(val_pois/1000) - # print(lamtha, np.mean(xdiff), np.var(xdiff), 1/lamtha**2) - # Convert array into nrn vector - # if len(val_pois)>0: print('val_pois:',val_pois) - self.eventvec.from_python(val_pois) - return self.eventvec.size() > 0 - - # mu and sigma vals come from p - def __create_evoked (self, inc=0.0): - #print("__create_evoked", self.p_ext) - if self.celltype in self.p_ext.keys(): - # assign the params - mu = self.p_ext['t0'] + inc - sigma = self.p_ext[self.celltype][3] # index 3 is sigma_t_ (stdev) - numspikes = int(self.p_ext['numspikes']) - # print('mu:',mu,'sigma:',sigma,'inc:',inc) - # if a non-zero sigma is specified - if sigma: - val_evoked = self.prng.normal(mu, sigma, numspikes) - else: - # if sigma is specified at 0 - val_evoked = np.array([mu] * numspikes) - val_evoked = val_evoked[val_evoked > 0] - # vals must be sorted - val_evoked.sort() - # print('__create_evoked val_evoked:',val_evoked) - self.eventvec.from_python(val_evoked) - else: - # return an empty eventvec list - self.eventvec.from_python([]) - return self.eventvec.size() > 0 - - def __create_extgauss (self): - # print("__create_extgauss") - # assign the params - if self.p_ext[self.celltype][0] <= 0.0 and \ - self.p_ext[self.celltype][1] <= 0.0: return False # 0 ampa and 0 nmda weight - # print('gauss params:',self.p_ext[self.celltype]) - mu = self.p_ext[self.celltype][3] - sigma = self.p_ext[self.celltype][4] - # mu and sigma values come from p - # one single value from Gaussian dist. - # values MUST be sorted for VecStim()! - val_gauss = self.prng.normal(mu, sigma, 50) - # val_gauss = np.random.normal(mu, sigma, 50) - # remove non-zero values brute force-ly - val_gauss = val_gauss[val_gauss > 0] - # sort values - critical for nrn - val_gauss.sort() - # if len(val_gauss)>0: print('val_gauss:',val_gauss) - # Convert array into nrn vector - self.eventvec.from_python(val_gauss) - return self.eventvec.size() > 0 - - def __create_extinput (self): # creates the ongoing external inputs (rhythmic) - #print("__create_extinput") - # store f_input as self variable for later use if it exists in p - # t0 is always defined - t0 = self.p_ext['t0'] - # If t0 is -1, randomize start time of inputs - if t0 == -1: - t0 = self.prng.uniform(25., 125.) - #print(self.ty,'t0 was -1; now', t0,'seed:',self.seed) - elif self.p_ext['t0_stdev'] > 0.0: # randomize start time based on t0_stdev - t0 = self.prng2.normal(t0, self.p_ext['t0_stdev']) # start time uses different prng - #print(self.ty,'t0 is', t0, 'seed:',self.seed,'seed2:',self.seed2) - f_input = self.p_ext['f_input'] - stdev = self.p_ext['stdev'] - events_per_cycle = self.p_ext['events_per_cycle'] - distribution = self.p_ext['distribution'] - # events_per_cycle = 1 - if events_per_cycle > 2 or events_per_cycle <= 0: - print("events_per_cycle should be either 1 or 2, trying 2") - events_per_cycle = 2 - # If frequency is 0, create empty vector if input times - if not f_input: - t_input = [] - elif distribution == 'normal': - # array of mean stimulus times, starts at t0 - isi_array = np.arange(t0, self.p_ext['tstop'], 1000. / f_input) - # array of single stimulus times -- no doublets - if stdev: - t_array = self.prng.normal(np.repeat(isi_array, self.p_ext['repeats']), stdev) - else: - t_array = isi_array - if events_per_cycle == 2: # spikes/burst in GUI - # Two arrays store doublet times - t_array_low = t_array - 5 - t_array_high = t_array + 5 - # Array with ALL stimulus times for input - # np.append concatenates two np arrays - t_input = np.append(t_array_low, t_array_high) - elif events_per_cycle == 1: - t_input = t_array - # brute force remove zero times. Might result in fewer vals than desired - t_input = t_input[t_input > 0] - t_input.sort() - # Uniform Distribution - elif distribution == 'uniform': - n_inputs = self.p_ext['repeats'] * f_input * (self.p_ext['tstop'] - t0) / 1000. - t_array = self.prng.uniform(t0, self.p_ext['tstop'], n_inputs) - if events_per_cycle == 2: - # Two arrays store doublet times - t_input_low = t_array - 5 - t_input_high = t_array + 5 - # Array with ALL stimulus times for input - # np.append concatenates two np arrays - t_input = np.append(t_input_low, t_input_high) - elif events_per_cycle == 1: - t_input = t_array - # brute force remove non-zero times. Might result in fewer vals than desired - t_input = t_input[t_input > 0] - t_input.sort() - else: - print("Indicated distribution not recognized. Not making any alpha feeds.") - t_input = [] - # Convert array into nrn vector - self.eventvec.from_python(t_input) - return self.eventvec.size() > 0 - - # for parallel, maybe be that postsyn for this is just nil (None) - def connect_to_target (self, threshold): - #print("connect_to_target") - nc = h.NetCon(self.vs, None) # why is target always nil?? - nc.threshold = threshold - return nc diff --git a/fileio.py b/fileio.py deleted file mode 100644 index 1fd9b337e..000000000 --- a/fileio.py +++ /dev/null @@ -1,368 +0,0 @@ -# fileio.py - general file input/output functions -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: return_data_dir() instead of hardcoded everywhere, etc.) -# last rev: (SL: toward python3) - -import datetime, fnmatch, os, shutil, sys -import subprocess, multiprocessing -import numpy as np -import paramrw - -# creates data dirs and a dictionary of useful types -# self.dfig is a dictionary of experiments, which is each a dictionary of data type -# keys and the specific directories that contain them. -class SimulationPaths (): - def __init__ (self, dbase=None): - # hard coded data types - # fig extensions are not currently being used as well as they could be - # add new directories here to be automatically created for every simulation - self.__datatypes = {'rawspk': 'spk.txt', - 'rawdpl': 'rawdpl.txt', - 'normdpl': 'dpl.txt', # same output name - do not need both raw and normalized dipole - unless debugging - 'rawcurrent': 'i.txt', - 'rawspec': 'spec.npz', - 'rawspeccurrent': 'speci.npz', - 'avgdpl': 'dplavg.txt', - 'avgspec': 'specavg.npz', - 'figavgdpl': 'dplavg.png', - 'figavgspec': 'specavg.png', - 'figdpl': 'dpl.png', - 'figspec': 'spec.png', - 'figspk': 'spk.png', - 'param': 'param.txt', - } - # empty until a sim is created or read - self.fparam = None - self.sim_prefix = None - self.trial_prefix_str = None - self.expmt_groups = [] - self.dproj = None - self.ddate = None - self.dsim = None - self.dexpmt_dict = {} - self.dfig = {} - if dbase is None: - self.dbase = os.path.join(os.path.expanduser('~'),'hnn') - else: - self.dbase = dbase - - # reads sim information based on sim directory and param files - def read_sim (self, dproj, dsim): - self.dproj = dproj - self.dsim = dsim - # match the param from this sim - self.fparam = file_match(dsim, '.param')[0] - self.expmt_groups = paramrw.read_expmt_groups(self.fparam) - self.sim_prefix = paramrw.read_sim_prefix(self.fparam) - # this should somehow be supplied by the ExpParams() class, but doing it here - self.trial_prefix_str = self.sim_prefix + "-%03d-T%02d" - self.dexpmt_dict = self.__create_dexpmt(self.expmt_groups) - # create dfig - self.dfig = self.__read_dirs() - return self.dsim - - # only run for the creation of a new simulation - def create_new_sim (self, dproj, expmt_groups, sim_prefix='test'): - self.dproj = dproj - self.expmt_groups = expmt_groups - # prefix for these simulations in both filenames and directory in ddate - self.sim_prefix = sim_prefix - # create date and sim directories if necessary - self.ddate = self.__datedir() - self.dsim = self.__simdir() - self.dexpmt_dict = self.__create_dexpmt(self.expmt_groups) - # dfig is just a record of all the fig directories, per experiment - # will only be written to at time of creation, by create_dirs - # dfig is a terrible variable name, sorry! - self.dfig = self.__ddata_dict_template() - - # this is a hack - # checks root expmt_group directory for any files i've thrown there - def find_aggregate_file (self, expmt_group, datatype): - # file name is in format: '%s-%s-%s' % (sim_prefix, expmt_group, datatype-ish) - fname = '%s-%s-%s.txt' % (self.sim_prefix, expmt_group, datatype) - # get a list of txt files in the expmt_group - # local=1 forces the search to be local to this directory and not recursive - local = 1 - flist = file_match(self.dexpmt_dict[expmt_group], fname, local) - return flist - - # returns a filename for an example type of data - def return_filename_example (self, datatype, expmt_group, sim_no=None, tr=None, ext='png'): - fname_short = "%s-%s" % (self.sim_prefix, expmt_group) - if sim_no is not None: fname_short += "-%03i" % (sim_no) - if tr is not None: fname_short += "-T%03i" % (tr) - # add the extension - fname_short += ".%s" % (ext) - fname = os.path.join(self.dfig[expmt_group][datatype], fname_short) - return fname - - # creates a dict of dicts for each experiment and all the datatype directories - # this is the empty template that gets filled in later. - def __ddata_dict_template (self): - dfig = dict.fromkeys(self.expmt_groups) - for key in dfig: dfig[key] = dict.fromkeys(self.__datatypes) - return dfig - - # read directories for an already existing sim - def __read_dirs (self): - dfig = self.__ddata_dict_template() - for expmt_group, dexpmt in self.dexpmt_dict.items(): - for key in self.__datatypes.keys(): - ddatatype = os.path.join(dexpmt, key) - dfig[expmt_group][key] = ddatatype - return dfig - - # create the data directory for the sim - def create_datadir (self): - dout = self.__simdir() - print('making dout:',dout) - safemkdir(dout) - - # extern function to create directories - def create_dirs (self): - # create expmt directories - for expmt_group, dexpmt in self.dexpmt_dict.items(): - dir_create(dexpmt) - for key in self.__datatypes.keys(): - ddatatype = os.path.join(dexpmt, key) - self.dfig[expmt_group][key] = ddatatype - dir_create(ddatatype) - - # Returns date directory - # this is NOT safe for midnight - def __datedir (self): - self.str_date = datetime.datetime.now().strftime("%Y-%m-%d") - ddate = os.path.join(self.dproj, self.str_date) - return ddate - - # returns the directory for the sim - def __simdir (self): - return os.path.join(self.dbase,'data',self.sim_prefix) - # return os.path.join(os.path.expanduser('~'),'hnn','data',self.sim_prefix) - - # creates all the experimental directories based on dproj - def __create_dexpmt (self, expmt_groups): - d = dict.fromkeys(expmt_groups) - for expmt_group in d: d[expmt_group] = os.path.join(self.dsim, expmt_group) - return d - - # dictionary creation - # this is specific to a expmt_group - def create_dict (self, expmt_group): - fileinfo = dict.fromkeys(self.__datatypes) - for key in self.__datatypes.keys(): - # join directory name - dtype = os.path.join(self.dexpmt_dict(expmt_group), key) - fileinfo[key] = (self.__datatypes[key], dtype) - return fileinfo - - def return_specific_filename(self, expmt_group, datatype, n_sim, n_trial): - f_list = self.file_match(expmt_group, datatype) - trial_prefix = self.trial_prefix_str % (n_sim, n_trial) - # assume there is only one match (this should be true) - f_datatype = [f for f in f_list if trial_prefix in f][0] - return f_datatype - - # requires dict lookup - def create_filename (self, expmt_group, key): - d = self.__simdir() - # some kind of if key in self.fileinfo.keys() catch - file_name_raw = self.__datatypes[key] - return os.path.join(d,file_name_raw) - # grab the whole experimental directory - dexpmt = self.dexpmt_dict[expmt_group] - # create the full path name for the file - file_path_full = os.path.join(dexpmt, key, file_name_raw) - return file_path_full - - # Get the data files matching file_ext in this directory - # functionally the same as the previous function but with a local scope - def file_match (self, expmt_group, key): - # grab the relevant fext - fext = self.__datatypes[key] - file_list = [] - ddata = self.__simdir() # - # search the sim directory for all relevant files - if os.path.exists(ddata): - for root, dirnames, filenames in os.walk(ddata): - for fname in fnmatch.filter(filenames, '*'+fext): file_list.append(os.path.join(root, fname)) - # sort file list? untested - file_list.sort() - return file_list - - def exp_files_of_type (self, datatype): - # create dict of experiments - d = dict.fromkeys(self.expmt_groups) - # create file lists that match the dict keys for only files for this experiment - # this all would be nicer with a freaking folder - for key in d: d[key] = [file for file in self.filelists[datatype] if key in file.split("/")[-1]] - return d - -# Cleans input files -def clean_lines (file): - with open(file) as f_in: - lines = (line.rstrip() for line in f_in) - lines = [line for line in lines if line] - return lines - -# this make a little more sense in fileio -def prettyprint (iterable_items): - for item in iterable_items: print(item) - -# create gid dict from a file -def gid_dict_from_file (fparam): - l = ['L2_pyramidal', 'L5_pyramidal', 'L2_basket', 'L5_basket', 'extinput'] - d = dict.fromkeys(l) - plist = clean_lines(fparam) - for param in plist: print(param) - -# create file name for temporary spike file -# that every processor is aware of -def file_spike_tmp (dproj): - filename_spikes = 'spikes_tmp.spk' - file_spikes = os.path.join(dproj, filename_spikes) - return file_spikes - -# this is ugly, potentially. sorry, future -# i.e will change when the file name format changes -def strip_extprefix (filename): - f_raw = filename.split("/")[-1] - f = f_raw.split(".")[0].split("-")[:-1] - ext_prefix = f.pop(0) - for part in f: ext_prefix += "-%s" % part - return ext_prefix - -# Get the data files matching file_ext in this directory -# this function traverses ALL directories -# local=1 makes the search local and not recursive -def file_match (dsearch, file_ext, local=0): - file_list = [] - if not local: - if os.path.exists(dsearch): - for root, dirnames, filenames in os.walk(dsearch): - for fname in fnmatch.filter(filenames, '*'+file_ext): - file_list.append(os.path.join(root, fname)) - else: - file_list = [os.path.join(dsearch, file) for file in os.listdir(dsearch) if file.endswith(file_ext)] - # sort file list? untested - file_list.sort() - return file_list - -# Get minimum list of param dicts (i.e. excludes duplicates due to N_trials > 1) -def fparam_match_minimal (dsim, p_exp): - # Complete list of all param dicts used in simulation - fparam_list_complete = file_match(dsim, '-param.txt') - # List of indices from which to pull param dicts from fparam_list_complete - N_trials = p_exp.N_trials - if not N_trials: N_trials = 1 - indexes = np.arange(0, len(fparam_list_complete), N_trials) - # Pull unique param dicts from fparam_list_complete - fparam_list_minimal = [fparam_list_complete[ind] for ind in indexes] - return fparam_list_minimal - -# check any directory -def dir_check (d): - if not os.path.isdir(d): return 0 - else: return os.path.isdir(d) - -# only create if check comes back 0 -def dir_create (d): - if not dir_check(d): os.makedirs(d) - -# non-destructive copy routine -def dir_copy (din, dout): - # this command should work on most posix systems - cmd_cp = 'cp -R %s %s' % (din, dout) - # if the dir doesn't already exist, copy it over - if not dir_check(dout): - # print the actual command when successful - print(cmd_cp) - # use call to run the command - subprocess.call(cmd_cp, shell=True) - return 0 - else: - print("Directory already exists.") - -# Finds and moves files to created subdirectories. -def subdir_move (dir_out, name_dir, file_pattern): - dir_name = os.path.join(dir_out, name_dir) - # create directories that do not exist - if not os.path.isdir(dir_name): os.mkdir(dir_name) - for filename in glob.iglob(os.path.join(dir_out, file_pattern)): shutil.move(filename, dir_name) - -# currently used only minimally in epscompress -# need to figure out how to change argument list in cmd as below -def cmds_runmulti (cmdlist): - n_threads = multiprocessing.cpu_count() - list_runs = [cmdlist[i:i+n_threads] for i in range(0, len(cmdlist), n_threads)] - # open devnull for writing extraneous output - with open(os.devnull, 'w') as devnull: - for sublist in list_runs: - procs = [subprocess.Popen(cmd, stdout=devnull, stderr=devnull) for cmd in sublist] - for proc in procs: proc.wait() - -# small kernel for png optimization based on fig directory -def pngoptimize (dfig): - local = 0 - pnglist = file_match(dfig, '.png', local) - cmds_opti = [('optipng', pngfile) for pngfile in pnglist] - cmds_runmulti(cmds_opti) - -# list spike raster eps files and then rasterize them to HQ png files, lossless compress, -# reencapsulate as eps, and remove backups when successful -def epscompress (dfig_spk, fext_figspk, local=0): - cmds_gs = [] - cmds_opti = [] - cmds_encaps = [] - n_threads = multiprocessing.cpu_count() - # lists of eps files and corresponding png files - # fext_figspk, dfig_spk = fileinfo['figspk'] - epslist = file_match(dfig_spk, fext_figspk, local) - pnglist = [f.replace('.eps', '.png') for f in epslist] - epsbackuplist = [f.replace('.eps', '.bak.eps') for f in epslist] - # create command lists for gs, optipng, and convert - for pngfile, epsfile in zip(pnglist, epslist): - cmds_gs.append(('gs -r300 -dEPSCrop -dTextAlphaBits=4 -sDEVICE=png16m -sOutputFile=%s -dBATCH -dNOPAUSE %s' % (pngfile, epsfile))) - cmds_opti.append(('optipng', pngfile)) - cmds_encaps.append(('convert %s eps3:%s' % (pngfile, epsfile))) - # create procs list of manageable lists and run - runs_gs = [cmds_gs[i:i+n_threads] for i in range(0, len(cmds_gs), n_threads)] - # run each sublist differently - for sublist in runs_gs: - procs_gs = [subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) for cmd in sublist] - for proc in procs_gs: proc.wait() - # create optipng procs list and run - cmds_runmulti(cmds_opti) - # backup original eps files temporarily - for epsfile, epsbakfile in zip(epslist, epsbackuplist): shutil.move(epsfile, epsbakfile) - # recreate original eps files, now encapsulated, optimized rasters - # cmds_runmulti(cmds_encaps) - runs_encaps = [cmds_encaps[i:i+n_threads] for i in range(0, len(cmds_encaps), n_threads)] - for sublist in runs_encaps: - procs_encaps = [subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) for cmd in sublist] - for proc in procs_encaps: proc.wait() - # remove all of the backup files - for epsbakfile in epsbackuplist: os.remove(epsbakfile) - -# make dir, catch exceptions -def safemkdir (dn): - try: - os.mkdir(dn) - except FileExistsError: - pass - except OSError: - print('ERR: incorrect permissions for creating', dn) - raise - - return True - -# returns the data dir -def return_data_dir (): - dfinal = os.path.join('.','data') - if not safemkdir(dfinal): sys.exit(1) - return dfinal - -if __name__ == '__main__': - return_data_dir() diff --git a/filt.py b/filt.py deleted file mode 100644 index e592a20ce..000000000 --- a/filt.py +++ /dev/null @@ -1,444 +0,0 @@ -from pylab import convolve -from numpy import hamming -import numpy as np - -# box filter -def boxfilt (x, winsz): - win = [1.0/winsz for i in range(int(winsz))] - return convolve(x,win,'same') - -# convolve with a hamming window -def hammfilt (x, winsz): - win = hamming(winsz) - win /= sum(win) - return convolve(x,win,'same') - -# returns x -def emptyfilt (x, winsz): - return np.array(x) - -# following from -#------------------------------------------------------------------- -# Purpose: Various Seismogram Filtering Functions -# Author: Tobias Megies, Moritz Beyreuther, Yannik Behr -# Email: tobias.megies@geophysik.uni-muenchen.de -# -# Copyright (C) 2009 Tobias Megies, Moritz Beyreuther, Yannik Behr -#--------------------------------------------------------------------- -""" -Various Seismogram Filtering Functions - -:copyright: - The ObsPy Development Team (devs@obspy.org) -:license: - GNU Lesser General Public License, Version 3 - (http://www.gnu.org/copyleft/lesser.html) -""" - -from numpy import array, where, fft -from scipy.fftpack import hilbert -from scipy.signal import iirfilter, lfilter, remez, convolve, get_window, butter -import math -from numpy import vstack, hstack, eye, ones, zeros, linalg, \ -newaxis, r_, flipud, convolve, matrix, array, vectorize, angle - -def bandpass(data, freqmin, freqmax, df=200, corners=4, zerophase=False): - """ - Butterworth-Bandpass Filter. - - Filter data from freqmin to freqmax using - corners corners. - - :param data: Data to filter, type numpy.ndarray. - :param freqmin: Pass band low corner frequency. - :param freqmax: Pass band high corner frequency. - :param df: Sampling rate in Hz; Default 200. - :param corners: Filter corners. Note: This is twice the value of PITSA's - filter sections - :param zerophase: If True, apply filter once forwards and once backwards. - This results in twice the number of corners but zero phase shift in - the resulting filtered trace. - :return: Filtered data. - """ - fe = 0.5 * df - [b, a] = iirfilter(corners, [freqmin / fe, freqmax / fe], btype='band', - ftype='butter', output='ba') - if zerophase: - firstpass = lfilter(b, a, data) - return lfilter(b, a, firstpass[::-1])[::-1] - else: - return lfilter(b, a, data) - - -def bandpassZPHSH(data, freqmin, freqmax, df=200, corners=2): - """ - DEPRECATED. Use :func:`~obspy.signal.filter.bandpass` instead. - """ - return bandpass(data, freqmin, freqmax, df, corners, zerophase=True) - - -def bandstop(data, freqmin, freqmax, df=200, corners=4, zerophase=False): - """ - Butterworth-Bandstop Filter. - - Filter data removing data between frequencies freqmin and freqmax using - corners corners. - - :param data: Data to filter, type numpy.ndarray. - :param freqmin: Stop band low corner frequency. - :param freqmax: Stop band high corner frequency. - :param df: Sampling rate in Hz; Default 200. - :param corners: Filter corners. Note: This is twice the value of PITSA's - filter sections - :param zerophase: If True, apply filter once forwards and once backwards. - This results in twice the number of corners but zero phase shift in - the resulting filtered trace. - :return: Filtered data. - """ - fe = 0.5 * df - [b, a] = iirfilter(corners, [freqmin / fe, freqmax / fe], - btype='bandstop', ftype='butter', output='ba') - if zerophase: - firstpass = lfilter(b, a, data) - return lfilter(b, a, firstpass[::-1])[::-1] - else: - return lfilter(b, a, data) - - -def bandstopZPHSH(data, freqmin, freqmax, df=200, corners=2): - """ - DEPRECATED. Use :func:`~obspy.signal.filter.bandstop` instead. - """ - return bandstop(data, freqmin, freqmax, df, corners, zerophase=True) - - -def lowpass(data, freq, df=200, corners=4, zerophase=False): - """ - Butterworth-Lowpass Filter. - - Filter data removing data over certain frequency freq using corners - corners. - - :param data: Data to filter, type numpy.ndarray. - :param freq: Filter corner frequency. - :param df: Sampling rate in Hz; Default 200. - :param corners: Filter corners. Note: This is twice the value of PITSA's - filter sections - :param zerophase: If True, apply filter once forwards and once backwards. - This results in twice the number of corners but zero phase shift in - the resulting filtered trace. - :return: Filtered data. - """ - fe = 0.5 * df - [b, a] = iirfilter(corners, freq / fe, btype='lowpass', ftype='butter', - output='ba') - if zerophase: - firstpass = lfilter(b, a, data) - return lfilter(b, a, firstpass[::-1])[::-1] - else: - return lfilter(b, a, data) - - -def lowpassZPHSH(data, freq, df=200, corners=2): - """ - DEPRECATED. Use :func:`~obspy.signal.filter.lowpass` instead. - """ - return lowpass(data, freq, df, corners, zerophase=True) - - -def highpass(data, freq, df=200, corners=4, zerophase=False): - """ - Butterworth-Highpass Filter. - - Filter data removing data below certain frequency freq using corners. - - :param data: Data to filter, type numpy.ndarray. - :param freq: Filter corner frequency. - :param df: Sampling rate in Hz; Default 200. - :param corners: Filter corners. Note: This is twice the value of PITSA's - filter sections - :param zerophase: If True, apply filter once forwards and once backwards. - This results in twice the number of corners but zero phase shift in - the resulting filtered trace. - :return: Filtered data. - """ - fe = 0.5 * df - [b, a] = iirfilter(corners, freq / fe, btype='highpass', ftype='butter', - output='ba') - if zerophase: - firstpass = lfilter(b, a, data) - return lfilter(b, a, firstpass[::-1])[::-1] - else: - return lfilter(b, a, data) - - -def highpassZPHSH(data, freq, df=200, corners=2): - """ - DEPRECATED. Use :func:`~obspy.signal.filter.highpass` instead. - """ - return highpass(data, freq, df, corners, zerophase=True) - - -def envelope(data): - """ - Envelope of a function. - - Computes the envelope of the given function. The envelope is determined by - adding the squared amplitudes of the function and it's Hilbert-Transform - and then taking the squareroot. - (See Kanasewich: Time Sequence Analysis in Geophysics) - The envelope at the start/end should not be taken too seriously. - - :param data: Data to make envelope of, type numpy.ndarray. - :return: Envelope of input data. - """ - hilb = hilbert(data) - data = pow(pow(data, 2) + pow(hilb, 2), 0.5) - return data - - -def remezFIR(data, freqmin, freqmax, samp_rate=200): - """ - The minimax optimal bandpass using Remez algorithm. Zerophase bandpass? - - Finite impulse response (FIR) filter whose transfer function minimizes - the maximum error between the desired gain and the realized gain in the - specified bands using the remez exchange algorithm - """ - # Remez filter description - # ======================== - # - # So, let's go over the inputs that you'll have to worry about. - # First is numtaps. This parameter will basically determine how good your - # filter is and how much processor power it takes up. If you go for some - # obscene number of taps (in the thousands) there's other things to worry - # about, but with sane numbers (probably below 30-50 in your case) that is - # pretty much what it affects (more taps is better, but more expensive - # processing wise). There are other ways to do filters as well - # which require less CPU power if you really need it, but I doubt that you - # will. Filtering signals basically breaks down to convolution, and apple - # has DSP libraries to do lightning fast convolution I'm sure, so don't - # worry about this too much. Numtaps is basically equivalent to the number - # of terms in the convolution, so a power of 2 is a good idea, 32 is - # probably fine. - # - # bands has literally your list of bands, so you'll break it up into your - # low band, your pass band, and your high band. Something like [0, 99, 100, - # 999, 1000, 22049] should work, if you want to pass frequencies between - # 100-999 Hz (assuming you are sampling at 44.1 kHz). - # - # desired will just be [0, 1, 0] as you want to drop the high and low - # bands, and keep the middle one without modifying the amplitude. - # - # Also, specify Hz = 44100 (or whatever). - # - # That should be all you need; run the function and it will spit out a list - # of coefficients [c0, ... c(N-1)] where N is your tap count. When you run - # this filter, your output signal y[t] will be computed from the input x[t] - # like this (t-N means N samples before the current one): - # - # y[t] = c0*x[t] + c1*x[t-1] + ... + c(N-1)*x[t-(N-1)] - # - # After playing around with remez for a bit, it looks like numtaps should - # be above 100 for a solid filter. See what works out for you. Eventually, - # take those coefficients and then move them over and do the convolution - # in C or whatever. Also, note the gaps between the bands in the call to - # remez. You have to leave some space for the transition in frequency - # response to occur, otherwise the call to remez will complain. - # - # SRC: # http://episteme.arstechnica.com/eve/forums/a/tpc/f/6330927813/m/175006289731 - # See also: - # http://aspn.activestate.com/ASPN/Mail/Message/scipy-dev/1592174 - # http://aspn.activestate.com/ASPN/Mail/Message/scipy-dev/1592172 - - #take 10% of freqmin and freqmax as """corners""" - flt = freqmin - 0.1 * freqmin - fut = freqmax + 0.1 * freqmax - #bandpass between freqmin and freqmax - filt = remez(50, array([0, flt, freqmin, freqmax, fut, samp_rate / 2 - 1]), - array([0, 1, 0]), Hz=samp_rate) - return convolve(filt, data) - - -def lowpassFIR(data, freq, samp_rate=200, winlen=2048): - """ - FIR-Lowpass Filter - - Filter data by passing data only below a certain frequency. - - :param data: Data to filter, type numpy.ndarray. - :param freq: Data below this frequency pass. - :param samprate: Sampling rate in Hz; Default 200. - :param winlen: Window length for filter in samples, must be power of 2; - Default 2048 - :return: Filtered data. - """ - # There is not currently an FIR-filter design program in SciPy. One - # should be constructed as it is not hard to implement (of course making - # it generic with all the options you might want would take some time). - # - # What kind of window are you currently using? - # - # For your purposes this is what I would do: - # SRC: Travis Oliphant - # http://aspn.activestate.com/ASPN/Mail/Message/scipy-user/2009409] - # - #winlen = 2**11 #2**10 = 1024; 2**11 = 2048; 2**12 = 4096 - #give frequency bins in Hz and sample spacing - w = fft.fftfreq(winlen, 1 / float(samp_rate)) - #cutoff is low-pass filter - myfilter = where((abs(w) < freq), 1., 0.) - #ideal filter - h = fft.ifft(myfilter) - beta = 11.7 - #beta implies Kaiser - myh = fft.fftshift(h) * get_window(beta, winlen) - return convolve(abs(myh), data)[winlen / 2:-winlen / 2] - -def lfilter_zi (b,a): - #compute the zi state from the filter parameters. see [Gust96]. - - #Based on: - # [Gust96] Fredrik Gustafsson, Determining the initial states in forward-backward - # filtering, IEEE Transactions on Signal Processing, pp. 988--992, April 1996, - # Volume 44, Issue 4 - - n=max(len(a),len(b)) - - zin = ( eye(n-1) - hstack( (-a[1:n,newaxis], - vstack((eye(n-2), zeros(n-2)))))) - - zid= b[1:n] - a[1:n]*b[0] - - zi_matrix=linalg.inv(zin)*(matrix(zid).transpose()) - zi_return=[] - - #convert the result into a regular array (not a matrix) - for i in range(len(zi_matrix)): - zi_return.append(float(zi_matrix[i][0])) - - return array(zi_return) - -# http://www.scipy.org/Cookbook/FiltFilt -def filtfilt (b,a,x): - #For now only accepting 1d arrays - ntaps=max(len(a),len(b)) - edge=ntaps*3 - - if x.ndim != 1: - raise ValueError("Filiflit is only accepting 1 dimension arrays.") - - #x must be bigger than edge - if x.size < edge: - raise ValueError("Input vector needs to be bigger than 3 * max(len(a),len(b).") - - if len(a) < ntaps: - a=r_[a,zeros(len(b)-len(a))] - - if len(b) < ntaps: - b=r_[b,zeros(len(a)-len(b))] - - zi=lfilter_zi(b,a) - - #Grow the signal to have edges for stabilizing - #the filter with inverted replicas of the signal - s=r_[2*x[0]-x[edge:1:-1],x,2*x[-1]-x[-1:-edge:-1]] - #in the case of one go we only need one of the extrems - # both are needed for filtfilt - - (y,zf)=lfilter(b,a,s,-1,zi*s[0]) - - (y,zf)=lfilter(b,a,flipud(y),-1,zi*y[-1]) - - return flipud(y[edge-1:-edge+1]) - - -# bandfilt - bandpass filter signal x using filtfilt -# sampr = sampling rate in Hz -# lohz = low frequency cutoff -# hihz = high frequency cutoff -# bord = order for butterworth filter -def bandfilt (x,sampr,lohz,hihz,bord=2): - fr = [lohz/(sampr/2.0), hihz/(sampr/2.0)] - [b,a] = butter(bord,fr,btype='band') - y=filtfilt(b,a,x) - return y - -# bandfiltlist - apply bandpass filter on x using -# frequencies centered on lfreq +/- fwidths/2 -# return filtered signals as a list -def bandfiltlist (x,sampr,lfreq,fwidths): - lxf = zeros((len(lfreq),len(x))); i=0 - for f in lfreq: - lxf[i,:] = bandfilt(x,sampr,f - fwidths[i]/2.0,f + fwidths[i]/2.0) - i += 1 - return lxf - -#hilb - returns 2 element list with amplitude,phase numpy arrays -# from hilbert transform of x -# before running hilb, should bandpass filter signal with bandfilt -# typical sequence on signal x: -# xf = bandfilt(x,sampr,lohz,hihz) <- bandpass filter x into xf -# xfH = hilb(xf) -> xfH[0] = amplitude, xfH[1] = phase -def hilb (x): - xsz=x.size - lxsz=math.log(xsz)/math.log(2) - if (lxsz!=int(lxsz)): - x=x.copy() # to resize must make x a local copy of the externally referenced x - x.resize(2**(int(lxsz)+1)) - H = hilbert(x) - H.resize(xsz) # only needed if stretched - return [abs(H),angle(H)] - -# hilblist - apply hilbert transform on x using -# frequencies centered on lfreq +/- fwidths/2 -# return amplitude and phase for each frequency as 2 separate lists -# cuts specifies number of seconds to cut off start/end of filtered signal -# before applying the hilbert transform -def hilblist (x,sampr,lfreq,fwidths,cuts): - lhamp,lhang = [], []; i = 0; - cutsamps = int(cuts*sampr); # number of samples to cut off filtered signal - sz = len(x); # length of original signal - lhamp,lhang=zeros((len(lfreq), sz-2*cutsamps)), zeros((len(lfreq), sz-2*cutsamps)) - for f in lfreq: - xf = bandfilt(x,sampr,f - fwidths[i]/2.0, f + fwidths[i]/2.0) # bandpass filter - hamp,hang = hilb(xf) # hilbert transform - del xf # free some memory - lhamp[i,:] = hamp[cutsamps:sz-cutsamps] - lhang[i,:] = hang[cutsamps:sz-cutsamps] - del hamp,hang; - i += 1 - return lhamp,lhang - -#gethilbd - returns a dict with amplitude,phase,filtered signal -def gethilbd (x,sampr,lohz,hihz): - xf = bandfilt(x,sampr,lohz,hihz) # filter the signal - xfh = hilb(xf) # do the hilbert transform - d = {'amp':xfh[0], 'phase':xfh[1], 'filt':xf} - return d - - -if __name__=='__main__': - - from scipy import sin, arange, pi, randn - from pylab import plot, legend, show, hold - - t=arange(-1,1,.01) - x=sin(2*pi*t*.5+2) - #xn=x + sin(2*pi*t*10)*.1 - xn=x+randn(len(t))*0.05 - - [b,a]=butter(3,0.05) - - z=lfilter(b,a,xn) - y=filtfilt(b,a,xn) - - plot(x,'c') - hold(True) - plot(xn,'k') - plot(z,'r') - plot(y,'g') - - legend(('original','noisy signal','lfilter - butter 3 order','filtfilt - butter 3 order')) - hold(False) - show() diff --git a/gutils.py b/gutils.py deleted file mode 100644 index 4ed0ffd71..000000000 --- a/gutils.py +++ /dev/null @@ -1,59 +0,0 @@ -from PyQt5.QtCore import QCoreApplication -from PyQt5 import QtGui - -# some graphics utilities - -# use pyqt5 to get screen resolution -def getscreengeom (): - width,height = 2880, 1620 # default width,height - used for development - app = QCoreApplication.instance() # can only have 1 instance of qtapp; get that instance - app.setDesktopSettingsAware(True) - if len(app.screens()) > 0: - screen = app.screens()[0] - geom = screen.geometry() - return geom.width(), geom.height() - else: - return width, height - -# check if display has low resolution -def lowresdisplay (): - w, h = getscreengeom() - return w < 1400 or h < 700 - -# get DPI for use in matplotlib figures (part of simulation output canvas - in simdat.py) -def getmplDPI (): - if lowresdisplay(): return 60 - return 120 - -# get new window width, height scaled by current screen resolution relative to original development resolution -def scalegeom (width, height): - devwidth, devheight = 2880.0, 1620.0 # resolution used for development - used to scale window height/width - screenwidth, screenheight = getscreengeom() - widthnew = int((screenwidth / devwidth) * width) - heightnew = int((screenheight / devheight) * height) - if widthnew > 1000 or heightnew > 850: - widthnew = 1000 - heightnew = 850 - return widthnew, heightnew - -# set dialog's position (x,y) and rescale geometry based on original width and height and development resolution -def setscalegeom (dlg, x, y, origw, origh): - nw, nh = scalegeom(origw, origh) - # print('origw,origh:',origw, origh,'nw,nh:',nw, nh) - dlg.setGeometry(x, y, int(nw), int(nh)) - return int(nw), int(nh) - -# set dialog in center of screen width and rescale size based on original width and height and development resolution -def setscalegeomcenter (dlg, origw, origh): - nw, nh = scalegeom(origw, origh) - # print('origw,origh:',origw, origh,'nw,nh:',nw, nh) - sw, sh = getscreengeom() - x = (sw-nw)/2 - y = 0 - dlg.setGeometry(x, y, int(nw), int(nh)) - return int(nw), int(nh) - -# scale font size -def scalefont (fsize): - pass # devfont - diff --git a/hnn.cfg b/hnn.cfg deleted file mode 100644 index d0616bc6b..000000000 --- a/hnn.cfg +++ /dev/null @@ -1,61 +0,0 @@ -[run] -dorun = 1 -doquit = 1 -debug = 0 - -[paths] -paramindir = param - -[sim] -simf = run.py - -[opt] -decay_multiplier = 1.6 - -[draw] -drawindivdpl = 1 -drawavgdpl = 1 -drawindivrast = 1 - -[tips] -tstop = Simulation duration; Evoked response simulations typically take 170 ms while ongoing rhythms are run for longer. -dt = Simulation integration timeste; shorter timesteps mean more accuracy but longer runtimes. Default value: 0.025 ms. -N_trials = How many times a simulation is run using the specified parameters. Note that there is randomization across trials, e.g. for the input times, leading to differences in outputs. -NumCores = Specifies how many cores to use for running the simulation. Best to use default value, which is automatically determined to match your CPU capacity. -save_figs = Whether to save figures of model activity when the simulation is run; if set to 1, figures are saved in simulation output directory. -save_spec_data = Whether to save spectral simulation spectral data - time/frequency/power; if set to 1, saved to simulation output directory. -f_max_spec = Maximum frequency used in dipole spectral analysis. -dipole_scalefctr = Scaling used to match simulation dipole signal to data; implicitly estimates number of cells contributing to dipole signal. -dipole_smooth_win = Window size (ms) used for Hamming filtering of dipole signal (0 means no smoothing); for analysis of ongoing rhythms (alpha/beta/gamma), best to avoid smoothing, while for evoked responses, best to smooth with 15-30 ms window. -prng_seedcore_opt = Random number generator seed used for parameter estimation (optimization). -prng_seedcore_input_prox = Random number generator seed used for rhythmic proximal inputs. -prng_seedcore_input_dist = Random number generator seed used for rhythmic distal inputs. -prng_seedcore_extpois = Random number generator seed used for Poisson inputs. -prng_seedcore_extgauss = Random number generator seed used for Gaussian inputs. -prng_seedcore_evprox_1 = Random number generator seed used for evoked proximal input 1. -prng_seedcore_evdist_1 = Random number generator seed used for evoked distal input 1. -prng_seedcore_evprox_2 = Random number generator seed used for evoked proximal input 2. -prng_seedcore_evdist_2 = Random number generator seed used for evoked distal input 2. -# evoked inputs -t_evprox_1 = Average start time of first evoked proximal input. -sigma_t_evprox_1 = Standard deviation of start time of first evoked proximal input. -gbar_evprox_1_L2Pyr = Weight of first evoked proximal input to L2/3 Pyramidal cells. -gbar_evprox_1_L2Basket = Weight of first evoked proximal input to L2/3 Basket cells. -gbar_evprox_1_L5Pyr = Weight of first evoked proximal input to L5 Pyramidal cells. -gbar_evprox_1_L5Basket = Weight of first evoked proximal input to L5 Basket cells. -t_evprox_2 = Average start time of second evoked proximal input. -sigma_t_evprox_2 = Standard deviation of start time of second evoked proximal input. -gbar_evprox_2_L2Pyr = Weight of second evoked proximal input to L2/3 Pyramidal cells. -gbar_evprox_2_L2Basket = Weight of second evoked proximal input to L2/3 Basket cells. -gbar_evprox_2_L5Pyr = Weight of second evoked proximal input to L5 Pyramidal cells. -gbar_evprox_2_L5Basket = Weight of second evoked proximal input to L5 Basket cells. -t_evdist_1 = Average start time of evoked distal input. -sigma_t_evdist_1 = Standard deviation of start time of evoked distal input. -gbar_evdist_1_L2Pyr = Weight of evoked distal input to L2/3 Pyramidal cells. -gbar_evdist_1_L2Basket = Weight of evoked distal input to L2/3 Basket cells. -gbar_evdist_1_L5Pyr = Weight of evoked distal input to L5 Pyramidal cells. -sync_evinput = Whether to provide synchronous inputs to all cells receiving evoked inputs or to set input timing of each cell independently (the same distribution is used either way). -numspikes_evprox_1 = Number of synaptic input events. -numspikes_evdist_1 = Number of synaptic input events. - -[params] diff --git a/hnn.py b/hnn.py index 93df0360b..eda7c7bf0 100755 --- a/hnn.py +++ b/hnn.py @@ -1,40 +1,19 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- +"""Main file to launch HNN GUI""" + +# Authors: Sam Neymotin +# Blake Caldwell import sys -import os -import shlex -from subprocess import Popen, PIPE, call -from time import sleep +from PyQt5 import QtWidgets + +from hnn import HNNGUI -from hnn_qt5 import * -# -def runnrnui (): - pnrnui = Popen(shlex.split(os.path.join('NEURON-UI','NEURON-UI')),cwd=os.getcwd()) - sleep(5) # make sure NEURON-UI had a chance to start - pjup = Popen(shlex.split('jupyter run loadmodel_nrnui.py --existing'),cwd=os.getcwd()) - lproc = [pnrnui, pjup]; done = [False, False] - while pnrnui.poll() is None or pjup.poll() is None: sleep(1) - for p in lproc: - try: p.communicate() - except: pass - return lproc +def runqt5(): + app = QtWidgets.QApplication(sys.argv) + HNNGUI() + sys.exit(app.exec_()) -def runqt5 (): - app = QApplication(sys.argv) - ex = HNNGUI() - sys.exit(app.exec_()) - # app.exec_() - # print('\n'.join(repr(w) for w in app.allWidgets())) if __name__ == '__main__': - useqt5 = True - for s in sys.argv: - if s == 'nrnui': - useqt5 = False - if useqt5: runqt5() - else: - lproc = runnrnui() - diff --git a/hnn b/hnn.sh similarity index 100% rename from hnn rename to hnn.sh diff --git a/hnn/DataViewGUI.py b/hnn/DataViewGUI.py new file mode 100644 index 000000000..4a2f18dc2 --- /dev/null +++ b/hnn/DataViewGUI.py @@ -0,0 +1,172 @@ +""" GUI for viewing data from individual/all trials""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import os + +from PyQt5.QtWidgets import QMainWindow, QAction, QWidget, QComboBox +from PyQt5.QtWidgets import QGridLayout, QInputDialog +from PyQt5.QtGui import QIcon + +from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT +import matplotlib.pyplot as plt + +from .qt_lib import getmplDPI +from .paramrw import get_output_dir + +fontsize = plt.rcParams['font.size'] = 10 + + +class DataViewGUI(QMainWindow): + def __init__(self, CanvasType, params, sim_data, title): + super().__init__() + + global fontsize + + self.fontsize = fontsize + self.linewidth = plt.rcParams['lines.linewidth'] = 1 + self.markersize = plt.rcParams['lines.markersize'] = 5 + self.CanvasType = CanvasType + self.ntrial = params['N_trials'] + self.params = params + self.title = title + self.m = None + self.toolbar = None + self.sim_data = sim_data + self.initUI() + + def initMenu(self): + exitAction = QAction(QIcon.fromTheme('close'), 'Close', self) + exitAction.setShortcut('Ctrl+W') + exitAction.setStatusTip('Close ' + self.title + '.') + exitAction.triggered.connect(self.close) + + menubar = self.menuBar() + self.fileMenu = menubar.addMenu('&File') + menubar.setNativeMenuBar(False) + self.fileMenu.addAction(exitAction) + + self.viewMenu = menubar.addMenu('&View') + changeFontSizeAction = QAction('Change Font Size', self) + changeFontSizeAction.setStatusTip('Change Font Size.') + changeFontSizeAction.triggered.connect(self.changeFontSize) + self.viewMenu.addAction(changeFontSizeAction) + changeLineWidthAction = QAction('Change Line Width', self) + changeLineWidthAction.setStatusTip('Change Line Width.') + changeLineWidthAction.triggered.connect(self.changeLineWidth) + self.viewMenu.addAction(changeLineWidthAction) + changeMarkerSizeAction = QAction('Change Marker Size', self) + changeMarkerSizeAction.setStatusTip('Change Marker Size.') + changeMarkerSizeAction.triggered.connect(self.changeMarkerSize) + self.viewMenu.addAction(changeMarkerSizeAction) + + def changeFontSize(self): + global fontsize + + i, okPressed = QInputDialog.getInt(self, "Set Font Size", + "Font Size:", + plt.rcParams['font.size'], 1, 100, + 1) + if okPressed: + self.fontsize = plt.rcParams['font.size'] = fontsize = i + self.initCanvas() + self.m.plot() + + def changeLineWidth(self): + i, okPressed = QInputDialog.getInt(self, "Set Line Width", + "Line Width:", + plt.rcParams['lines.linewidth'], 1, + 20, 1) + if okPressed: + self.linewidth = plt.rcParams['lines.linewidth'] = i + self.initCanvas() + self.m.plot() + + def changeMarkerSize(self): + i, okPressed = QInputDialog.getInt(self, "Set Marker Size", + "Font Size:", self.markersize, 1, + 100, 1) + if okPressed: + self.markersize = plt.rcParams['lines.markersize'] = i + self.initCanvas() + self.m.plot() + + def printStat(self, s): + print(s) + self.statusBar().showMessage(s) + + def initCanvas(self, first_init=False): + """Initialize canvas + + Parameters + ---------- + first_init: bool | None + Whether canvas is being initialized for the first time + or if member objects have already been instantiated. If None, + then False is assumed. + """ + + if not first_init: + self.m.setParent(None) + self.toolbar.setParent(None) + self.grid.removeWidget(self.m) + self.grid.removeWidget(self.toolbar) + + self.grid.removeWidget(self.toolbar) + self.m = self.CanvasType(self.params, self.sim_data, self.index, + parent=self, width=12, height=10, + dpi=getmplDPI()) + # this is the Navigation widget + # it takes the Canvas widget and a parent + self.toolbar = NavigationToolbar2QT(self.m, self) + self.grid.addWidget(self.toolbar, 0, 0, 1, 4) + self.grid.addWidget(self.m, 1, 0, 1, 4) + + def updateCB(self): + self.cb.clear() + if self.ntrial > 1: + self.cb.addItem('Show All Trials') + for i in range(self.ntrial): + self.cb.addItem('Show Trial ' + str(i + 1)) + else: + self.cb.addItem('All Trials') + self.cb.activated[int].connect(self.onActivated) + + def initUI(self): + self.initMenu() + self.statusBar() + self.setGeometry(300, 300, 1300, 1100) + self.setWindowTitle(self.title + ' - ' + + os.path.join(get_output_dir(), 'data', + self.params['sim_prefix'] + + '.param')) + self.grid = grid = QGridLayout() + self.index = 0 + self.initCanvas(True) + self.cb = QComboBox(self) + self.grid.addWidget(self.cb, 2, 0, 1, 4) + + self.updateCB() + + # need a separate widget to put grid on + widget = QWidget(self) + widget.setLayout(grid) + self.setCentralWidget(widget) + + self.setWindowIcon(QIcon(os.path.join('res', 'icon.png'))) + + self.show() + + def onActivated(self, idx): + if idx != self.index: + self.index = idx + if self.index == 0: + self.statusBar().showMessage('Loading data from all trials.') + else: + self.statusBar().showMessage('Loading data from trial ' + + str(self.index) + '.') + self.m.index = self.index + self.initCanvas() + self.m.plot() + self.statusBar().showMessage('') diff --git a/hnn/__init__.py b/hnn/__init__.py new file mode 100644 index 000000000..f368b4e7d --- /dev/null +++ b/hnn/__init__.py @@ -0,0 +1,3 @@ +__version__ = "1.4.0" + +from .qt_main import HNNGUI diff --git a/hnn/paramrw.py b/hnn/paramrw.py new file mode 100644 index 000000000..6323f426d --- /dev/null +++ b/hnn/paramrw.py @@ -0,0 +1,323 @@ +# paramrw.py - routines for reading the param files +# +# v 1.10.0-py35 +# rev 2016-05-01 (SL: removed dependence on cartesian, updated for python3) +# last major: (SL: cleanup of self.p_all) + +import os +import numpy as np + + +def get_output_dir(): + """Return the base directory for storing output files""" + + try: + base_dir = os.environ["SYSTEM_USER_DIR"] + except KeyError: + base_dir = os.path.expanduser('~') + + return os.path.join(base_dir, 'hnn_out') + + +def get_fname(sim_dir, key, trial=None): + """Build the file names using the old HNN scheme + + Parameters + ---------- + sim_dir : str + The base data directory where simulation result files are stored + key : str + A string describing the type of file (HNN specific) + trial : int | None + Trial number for which to generate files (separate files per trial). + If None is given, then will use filename with trial suffix + + Returns + ---------- + fname : str + A string with the correct filename + """ + + datatypes = {'rawspk': ('spk', '.txt'), + 'rawdpl': ('rawdpl', '.txt'), + 'normdpl': ('dpl', '.txt'), + 'rawcurrent': ('i', '.txt'), + 'rawspec': ('rawspec', '.npz'), + 'rawspeccurrent': ('speci', '.npz'), + 'avgdpl': ('dplavg', '.txt'), + 'avgspec': ('specavg', '.npz'), + 'figavgdpl': ('dplavg', '.png'), + 'figavgspec': ('specavg', '.png'), + 'figdpl': ('dpl', '.png'), + 'figspec': ('rawspec', '.png'), + 'figspk': ('spk', '.png'), + 'param': ('param', '.txt'), + 'vsoma': ('vsoma', '.pkl')} + + if trial is None or key == 'param': + # param file currently identical for all trials + fname = os.path.join(sim_dir, datatypes[key][0] + datatypes[key][1]) + else: + fname = os.path.join(sim_dir, datatypes[key][0] + '_' + str(trial) + + datatypes[key][1]) + + return fname + + +def get_inputs(params): + """ get a dictionary of input types used in simulation + with distal/proximal specificity for evoked,ongoing inputs + """ + + dinty = {'evoked': usingEvokedInputs(params), + 'ongoing': usingOngoingInputs(params), + 'tonic': usingTonicInputs(params), + 'pois': usingPoissonInputs(params), + 'evdist': usingEvokedInputs(params, lsuffty=['_evdist_']), + 'evprox': usingEvokedInputs(params, lsuffty=['_evprox_']), + 'dist': usingOngoingInputs(params, lty=['_dist']), + 'prox': usingOngoingInputs(params, lty=['_prox'])} + + return dinty + +# Cleans input files + + +def clean_lines(file): + with open(file) as f_in: + lines = (line.rstrip() for line in f_in) + lines = [line for line in lines if line] + return lines + +# check if using ongoing inputs + + +def usingOngoingInputs(params, lty=['_prox', '_dist']): + if params is None: + return False + + try: + tstop = float(params['tstop']) + except KeyError: + return False + + dpref = {'_prox': 'input_prox_A_', '_dist': 'input_dist_A_'} + for postfix in lty: + if float(params['t0_input'+postfix]) <= tstop and \ + float(params['tstop_input'+postfix]) >= float(params['t0_input' + postfix]) and float(params['f_input'+postfix]) > 0.: # noqa: E501 + for k in ['weight_L2Pyr_ampa', 'weight_L2Pyr_nmda', + 'weight_L5Pyr_ampa', 'weight_L5Pyr_nmda', + 'weight_inh_ampa', 'weight_inh_nmda']: + if float(params[dpref[postfix]+k]) > 0.: + # print('usingOngoingInputs:',params[dpref[postfix]+k]) + return True + + return False + +# return number of evoked inputs (proximal, distal) +# using dictionary d (or if d is a string, first load the dictionary from +# filename d) + + +def countEvokedInputs(params): + nprox = ndist = 0 + if params is not None: + for k, v in params.items(): + if k.startswith('t_'): + if k.count('evprox') > 0: + nprox += 1 + elif k.count('evdist') > 0: + ndist += 1 + return nprox, ndist + +# check if using any evoked inputs + + +def usingEvokedInputs(params, lsuffty=['_evprox_', '_evdist_']): + nprox, ndist = countEvokedInputs(params) + if nprox == 0 and ndist == 0: + return False + + try: + tstop = float(params['tstop']) + except KeyError: + return False + + lsuff = [] + if '_evprox_' in lsuffty: + for i in range(1, nprox+1, 1): + lsuff.append('_evprox_'+str(i)) + if '_evdist_' in lsuffty: + for i in range(1, ndist+1, 1): + lsuff.append('_evdist_'+str(i)) + for suff in lsuff: + k = 't' + suff + if k not in params: + continue + if float(params[k]) > tstop: + continue + k = 'gbar' + suff + for k1 in params.keys(): + if k1.startswith(k): + if float(params[k1]) > 0.0: + return True + return False + +# check if using any poisson inputs + + +def usingPoissonInputs(params): + if params is None: + return False + + try: + tstop = float(params['tstop']) + + if 't0_pois' in params and 'T_pois' in params: + t0_pois = float(params['t0_pois']) + if t0_pois > tstop: + return False + T_pois = float(params['T_pois']) + if t0_pois > T_pois and T_pois != -1.0: + return False + except KeyError: + return False + + for cty in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: + for sy in ['ampa', 'nmda']: + k = cty+'_Pois_A_weight_'+sy + if k in params: + if float(params[k]) != 0.0: + return True + + return False + +# check if using any tonic (IClamp) inputs + + +def usingTonicInputs(d): + if d is None: + return False + + tstop = float(d['tstop']) + for cty in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: + k = 'Itonic_A_' + cty + '_soma' + if k in d: + amp = float(d[k]) + if amp != 0.0: + print(k, 'amp != 0.0', amp) + k = 'Itonic_t0_' + cty + t0, t1 = 0.0, -1.0 + if k in d: + t0 = float(d[k]) + k = 'Itonic_T_' + cty + if k in d: + t1 = float(d[k]) + if t0 > tstop: + continue + # print('t0:',t0,'t1:',t1) + if t0 < t1 or t1 == -1.0: + return True + return False + + +def read_gids_param(fparam): + lines = clean_lines(fparam) + gid_dict = {} + for line in lines: + if line.startswith('#'): + continue + keystring, val = line.split(": ") + key = keystring.strip() + if val[0] == '[': + val_range = val[1:-1].split(', ') + if len(val_range) == 2: + ind_start = int(val_range[0]) + ind_end = int(val_range[1]) + 1 + gid_dict[key] = np.arange(ind_start, ind_end) + else: + gid_dict[key] = np.array([]) + + return gid_dict + + +def legacy_param_str_to_dict(param_str): + boolean_params = ['sync_evinput', 'record_vsoma', 'save_spec_data', + 'save_figs'] + string_params = ['sim_prefix', 'spec_cmap', 'distribution_prox', + 'distribution_dist'] + param_dict = {} + for line in param_str.splitlines(): + keystring, val = line.split(': ') + key = keystring.strip() + if key == 'expmt_groups': + continue + elif key in string_params: + param_dict[key] = val + elif key.startswith('N_') or key.startswith('numspikes_') or \ + key.startswith('events_per_cycle_') or \ + key.startswith('repeats_') or \ + key.startswith('prng_seedcore_'): + param_dict[key] = int(val) + elif key in boolean_params: + param_dict[key] = int(val) + else: + param_dict[key] = float(val) + + return param_dict + + +# write the params to a filename +def write_legacy_paramf(fparam, p): + """ now sorting + """ + + p_keys = [key for key, val in p.items()] + p_sorted = [(key, p[key]) for key in p_keys] + with open(fparam, 'w') as f: + pstring = '%26s: ' + # do the params in p_sorted + for param in p_sorted: + key, val = param + f.write(pstring % key) + if key.startswith('N_'): + f.write('%i\n' % val) + else: + f.write(str(val)+'\n') + + +def write_gids_param(fparam, gid_list): + with open(fparam, 'w') as f: + pstring = '%26s: ' + # write the gid info + for key in gid_list.keys(): + f.write(pstring % key) + if len(gid_list[key]): + f.write('[%4i, %4i] ' % (gid_list[key][0], gid_list[key][-1])) + else: + f.write('[]') + f.write('\n') + +def hnn_core_compat_params(params): + boolean_params = ['sync_evinput', 'record_vsoma', 'save_spec_data', + 'save_figs'] + string_params = ['sim_prefix', 'spec_cmap', 'distribution_prox', + 'distribution_dist'] + + param_dict = {} + for key, val in params.items(): + if key == 'expmt_groups': + continue + elif key in string_params: + param_dict[key] = val + elif key.startswith('N_') or key.startswith('numspikes_') or \ + key.startswith('events_per_cycle_') or \ + key.startswith('repeats_') or \ + key.startswith('prng_seedcore_'): + param_dict[key] = int(val) + elif key in boolean_params: + param_dict[key] = bool(val) + else: + param_dict[key] = float(val) + + return param_dict \ No newline at end of file diff --git a/hnn/qt_dialog.py b/hnn/qt_dialog.py new file mode 100644 index 000000000..98d280678 --- /dev/null +++ b/hnn/qt_dialog.py @@ -0,0 +1,1197 @@ +"""Classes for creating the dialog boxes""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import os +from collections import OrderedDict + +from PyQt5.QtWidgets import (QDialog, QToolTip, QTabWidget, QWidget, + QPushButton, QMessageBox, QComboBox, QLabel, + QLineEdit, QTextEdit, QFormLayout, + QVBoxLayout, QHBoxLayout, QGridLayout) +from PyQt5.QtGui import QFont, QPixmap, QIcon + +from hnn_core import read_params, Params + +from .paramrw import (usingOngoingInputs, usingEvokedInputs, get_output_dir, + legacy_param_str_to_dict) +from .qt_lib import (setscalegeom, setscalegeomcenter, lookupresource, + ClickLabel) +from .qt_evoked import EvokedInputParamDialog, OptEvokedInputParamDialog + + +def bringwintotop(win): + # bring a pyqt5 window to the top (parents still stay behind children) + # from https://www.programcreek.com/python/example/ + # 101663/PyQt5.QtCore.Qt.WindowActive + # win.show() + # win.setWindowState(win.windowState() & ~Qt.WindowMinimized | + # Qt.WindowActive) + # win.raise_() + win.showNormal() + win.activateWindow() + # win.setWindowState((win.windowState() & ~Qt.WindowMinimized) | + # Qt.WindowActive) + # win.activateWindow() + # win.raise_() + # win.show() + + +class DictDialog(QDialog): + """dictionary-based dialog with tabs + + should make all dialogs specifiable via cfg file format - + then can customize gui without changing py code + and can reduce code explosion / overlap between dialogs + """ + + def __init__(self, parent, din): + super(DictDialog, self).__init__(parent) + self.ldict = [] # subclasses should override + self.ltitle = [] + # for translating model variable name to more human-readable form + self.dtransvar = {} + self.stitle = '' + self.initd() + self.initUI() + self.initExtra() + self.setfromdin(din) # set values from input dictionary + # self.addtips() + + # TODO: add back tooltips + # def addtips (self): + # for ktip in dconf.keys(): + # if ktip in self.dqline: + # self.dqline[ktip].setToolTip(dconf[ktip]) + # elif ktip in self.dqextra: + # self.dqextra[ktip].setToolTip(dconf[ktip]) + + def __str__(self): + s = '' + for k, v in self.dqline.items(): + s += k + ': ' + v.text().strip() + os.linesep + return s + + def saveparams(self): + self.hide() + + def initd(self): + pass # implemented in subclass + + def getval(self, k): + if k in self.dqline.keys(): + return self.dqline[k].text().strip() + + def lines2val(self, ksearch, val): + for k in self.dqline.keys(): + if k.count(ksearch) > 0: + self.dqline[k].setText(str(val)) + + def setfromdin(self, din): + if not din: + return + for k, v in din.items(): + if k in self.dqline: + self.dqline[k].setText(str(v).strip()) + + def transvar(self, k): + if k in self.dtransvar: + return self.dtransvar[k] + return k + + def addtransvar(self, k, strans): + self.dtransvar[k] = strans + self.dtransvar[strans] = k + + def initExtra(self): + # extra items not written to param file + self.dqextra = {} + + def initUI(self): + self.layout = QVBoxLayout(self) + + # Add stretch to separate the form layout from the button + self.layout.addStretch(1) + + # Initialize tab screen + self.ltabs = [] + self.tabs = QTabWidget() + self.layout.addWidget(self.tabs) + + for _ in range(len(self.ldict)): + self.ltabs.append(QWidget()) + + self.tabs.resize(575, 200) + + # create tabs and their layouts + for tab, s in zip(self.ltabs, self.ltitle): + self.tabs.addTab(tab, s) + tab.layout = QFormLayout() + tab.setLayout(tab.layout) + + self.dqline = {} # QLineEdits dict; key is model variable + for d, tab in zip(self.ldict, self.ltabs): + for k, v in d.items(): + self.dqline[k] = QLineEdit(self) + self.dqline[k].setText(str(v)) + # add label,QLineEdit to the tab + tab.layout.addRow(self.transvar(k), self.dqline[k]) + + # Add tabs to widget + self.layout.addWidget(self.tabs) + self.setLayout(self.layout) + self.setWindowTitle(self.stitle) + + def TurnOff(self): + pass + + def addOffButton(self): + """Create a horizontal box layout to hold the button""" + self.button_box = QHBoxLayout() + self.btnoff = QPushButton('Turn Off Inputs', self) + self.btnoff.resize(self.btnoff.sizeHint()) + self.btnoff.clicked.connect(self.TurnOff) + self.btnoff.setToolTip('Turn Off Inputs') + self.button_box.addWidget(self.btnoff) + self.layout.addLayout(self.button_box) + + def addHideButton(self): + self.bbhidebox = QHBoxLayout() + self.btnhide = QPushButton('Hide Window', self) + self.btnhide.resize(self.btnhide.sizeHint()) + self.btnhide.clicked.connect(self.hide) + self.btnhide.setToolTip('Hide Window') + self.bbhidebox.addWidget(self.btnhide) + self.layout.addLayout(self.bbhidebox) + + +class OngoingInputParamDialog (DictDialog): + """widget to specify ongoing input params (proximal, distal)""" + + def __init__(self, parent, inty, din=None): + self.inty = inty + if self.inty.startswith('Proximal'): + self.prefix = 'input_prox_A_' + self.postfix = '_prox' + self.isprox = True + else: + self.prefix = 'input_dist_A_' + self.postfix = '_dist' + self.isprox = False + super(OngoingInputParamDialog, self).__init__(parent, din) + self.addOffButton() + self.addImages() + self.addHideButton() + + def addImages(self): + """add png cartoons to tabs""" + if self.isprox: + self.pix = QPixmap(lookupresource('proxfig')) + else: + self.pix = QPixmap(lookupresource('distfig')) + for tab in self.ltabs: + pixlbl = ClickLabel(self) + pixlbl.setPixmap(self.pix) + tab.layout.addRow(pixlbl) + + def TurnOff(self): + """ turn off by setting all weights to 0.0""" + self.lines2val('weight', 0.0) + + def initd(self): + self.dtiming = OrderedDict([('t0_input' + self.postfix, 1000.), + ('t0_input_stdev' + self.postfix, 0.), + ('tstop_input' + self.postfix, 250.), + ('f_input' + self.postfix, 10.), + ('f_stdev' + self.postfix, 20.), + ('events_per_cycle' + self.postfix, 2), + ('repeats' + self.postfix, 10)]) + + self.dL2 = OrderedDict([(self.prefix + 'weight_L2Pyr_ampa', 0.), + (self.prefix + 'weight_L2Pyr_nmda', 0.), + (self.prefix + 'weight_L2Basket_ampa', 0.), + (self.prefix + 'weight_L2Basket_nmda', 0.), + (self.prefix + 'delay_L2', 0.1)]) + + self.dL5 = OrderedDict([(self.prefix + 'weight_L5Pyr_ampa', 0.), + (self.prefix + 'weight_L5Pyr_nmda', 0.)]) + + if self.isprox: + self.dL5[self.prefix + 'weight_L5Basket_ampa'] = 0.0 + self.dL5[self.prefix + 'weight_L5Basket_nmda'] = 0.0 + self.dL5[self.prefix + 'delay_L5'] = 0.1 + + self.ldict = [self.dtiming, self.dL2, self.dL5] + self.ltitle = ['Timing', 'Layer 2/3', 'Layer 5'] + self.stitle = 'Set Rhythmic ' + self.inty + ' Inputs' + + dtmp = {'L2': 'L2/3 ', 'L5': 'L5 '} + for d in [self.dL2, self.dL5]: + for k in d.keys(): + lk = k.split('_') + if k.count('weight') > 0: + self.addtransvar(k, dtmp[lk[-2][0:2]] + lk[-2][2:] + ' ' + + lk[-1].upper() + u' weight (µS)') + else: + self.addtransvar(k, 'Delay (ms)') + + self.addtransvar('t0_input' + self.postfix, 'Start time mean (ms)') + self.addtransvar('t0_input_stdev' + self.postfix, + 'Start time stdev (ms)') + self.addtransvar('tstop_input' + self.postfix, 'Stop time (ms)') + self.addtransvar('f_input' + self.postfix, 'Burst frequency (Hz)') + self.addtransvar('f_stdev' + self.postfix, 'Burst stdev (ms)') + self.addtransvar('events_per_cycle' + self.postfix, 'Spikes/burst') + self.addtransvar('repeats' + self.postfix, 'Number bursts') + + +class EvokedOrRhythmicDialog (QDialog): + def __init__(self, parent, distal, evwin, rhythwin): + super(EvokedOrRhythmicDialog, self).__init__(parent) + if distal: + self.prefix = 'Distal' + else: + self.prefix = 'Proximal' + self.evwin = evwin + self.rhythwin = rhythwin + self.initUI() + # TODO: add back tooltips + # self.addtips() + + def initUI(self): + self.layout = QVBoxLayout(self) + # Add stretch to separate the form layout from the button + self.layout.addStretch(1) + + self.btnrhythmic = QPushButton('Rhythmic ' + self.prefix + ' Inputs', + self) + self.btnrhythmic.resize(self.btnrhythmic.sizeHint()) + self.btnrhythmic.clicked.connect(self.showrhythmicwin) + self.layout.addWidget(self.btnrhythmic) + + self.btnevoked = QPushButton('Evoked Inputs', self) + self.btnevoked.resize(self.btnevoked.sizeHint()) + self.btnevoked.clicked.connect(self.showevokedwin) + self.layout.addWidget(self.btnevoked) + + self.addHideButton() + + setscalegeom(self, 150, 150, 270, 120) + self.setWindowTitle("Pick Input Type") + + def showevokedwin(self): + bringwintotop(self.evwin) + self.hide() + + def showrhythmicwin(self): + bringwintotop(self.rhythwin) + self.hide() + + def addHideButton(self): + self.bbhidebox = QHBoxLayout() + self.btnhide = QPushButton('Hide Window', self) + self.btnhide.resize(self.btnhide.sizeHint()) + self.btnhide.clicked.connect(self.hide) + self.btnhide.setToolTip('Hide Window') + self.bbhidebox.addWidget(self.btnhide) + self.layout.addLayout(self.bbhidebox) + + +class SynGainParamDialog(QDialog): + def __init__(self, parent, netparamwin): + super(SynGainParamDialog, self).__init__(parent) + self.netparamwin = netparamwin + self.initUI() + + def scalegain(self, k, fctr): + oldval = float(self.netparamwin.dqline[k].text().strip()) + newval = oldval * fctr + self.netparamwin.dqline[k].setText(str(newval)) + return newval + + def isE(self, ty): + return ty.count('Pyr') > 0 + + def isI(self, ty): + return ty.count('Basket') > 0 + + def tounity(self): + for k in self.dqle.keys(): + self.dqle[k].setText('1.0') + + def scalegains(self): + for _, k in enumerate(self.dqle.keys()): + fctr = float(self.dqle[k].text().strip()) + if fctr < 0.: + fctr = 0. + self.dqle[k].setText(str(fctr)) + elif fctr == 1.0: + continue + for k2 in self.netparamwin.dqline.keys(): + types = k2.split('_') + ty1, ty2 = types[1], types[2] + if self.isE(ty1) and self.isE(ty2) and k == 'E -> E': + self.scalegain(k2, fctr) + elif self.isE(ty1) and self.isI(ty2) and k == 'E -> I': + self.scalegain(k2, fctr) + elif self.isI(ty1) and self.isE(ty2) and k == 'I -> E': + self.scalegain(k2, fctr) + elif self.isI(ty1) and self.isI(ty2) and k == 'I -> I': + self.scalegain(k2, fctr) + + # go back to unity since pressed OK - next call to this dialog will + # reset new values + self.tounity() + self.hide() + + def initUI(self): + grid = QGridLayout() + grid.setSpacing(10) + + self.dqle = {} + for row, k in enumerate(['E -> E', 'E -> I', 'I -> E', 'I -> I']): + lbl = QLabel(self) + lbl.setText(k) + lbl.adjustSize() + grid.addWidget(lbl, row, 0) + qle = QLineEdit(self) + qle.setText('1.0') + grid.addWidget(qle, row, 1) + self.dqle[k] = qle + + row += 1 + self.btnok = QPushButton('OK', self) + self.btnok.resize(self.btnok.sizeHint()) + self.btnok.clicked.connect(self.scalegains) + grid.addWidget(self.btnok, row, 0, 1, 1) + self.btncancel = QPushButton('Cancel', self) + self.btncancel.resize(self.btncancel.sizeHint()) + self.btncancel.clicked.connect(self.hide) + grid.addWidget(self.btncancel, row, 1, 1, 1) + + self.setLayout(grid) + setscalegeom(self, 150, 150, 270, 180) + self.setWindowTitle("Synaptic Gains") + + +# widget to specify tonic inputs +class TonicInputParamDialog(DictDialog): + def __init__(self, parent, din): + super(TonicInputParamDialog, self).__init__(parent, din) + self.addOffButton() + self.addHideButton() + + # turn off by setting all weights to 0.0 + def TurnOff(self): + self.lines2val('A', 0.0) + + def initd(self): + self.dL2 = OrderedDict([ + # IClamp params for L2Pyr + ('Itonic_A_L2Pyr_soma', 0.), + ('Itonic_t0_L2Pyr_soma', 0.), + ('Itonic_T_L2Pyr_soma', -1.), + # IClamp param for L2Basket + ('Itonic_A_L2Basket', 0.), + ('Itonic_t0_L2Basket', 0.), + ('Itonic_T_L2Basket', -1.)]) + + self.dL5 = OrderedDict([ + # IClamp params for L5Pyr + ('Itonic_A_L5Pyr_soma', 0.), + ('Itonic_t0_L5Pyr_soma', 0.), + ('Itonic_T_L5Pyr_soma', -1.), + # IClamp param for L5Basket + ('Itonic_A_L5Basket', 0.), + ('Itonic_t0_L5Basket', 0.), + ('Itonic_T_L5Basket', -1.)]) + + # temporary dictionary for string translation + dtmp = {'L2': 'L2/3 ', 'L5': 'L5 '} + for d in [self.dL2, self.dL5]: + for k in d.keys(): + cty = k.split('_')[2] # cell type + tcty = dtmp[cty[0:2]] + cty[2:] # translated cell type + if k.count('A') > 0: + self.addtransvar(k, tcty + ' amplitude (nA)') + elif k.count('t0') > 0: + self.addtransvar(k, tcty + ' start time (ms)') + elif k.count('T') > 0: + self.addtransvar(k, tcty + ' stop time (ms)') + + self.ldict = [self.dL2, self.dL5] + self.ltitle = ['Layer 2/3', 'Layer 5'] + self.stitle = 'Set Tonic Inputs' + + +# widget to specify ongoing poisson inputs +class PoissonInputParamDialog(DictDialog): + def __init__(self, parent, din): + super(PoissonInputParamDialog, self).__init__(parent, din) + self.addOffButton() + self.addHideButton() + + def TurnOff(self): + """turn off by setting all weights to 0.0""" + self.lines2val('weight', 0.0) + + def initd(self): + self.dL2, self.dL5 = {}, {} + ld = [self.dL2, self.dL5] + + for i, lyr in enumerate(['L2', 'L5']): + d = ld[i] + for ty in ['Pyr', 'Basket']: + for sy in ['ampa', 'nmda']: + d[lyr + ty + '_Pois_A_weight' + '_' + sy] = 0. + d[lyr + ty + '_Pois_lamtha'] = 0. + + self.dtiming = OrderedDict([('t0_pois', 0.), + ('T_pois', -1)]) + + self.addtransvar('t0_pois', 'Start time (ms)') + self.addtransvar('T_pois', 'Stop time (ms)') + + # temporary dictionary for string translation + dtmp = {'L2': 'L2/3 ', 'L5': 'L5 '} + for d in [self.dL2, self.dL5]: + for k in d.keys(): + ks = k.split('_') + cty = ks[0] # cell type + tcty = dtmp[cty[0:2]] + cty[2:] # translated cell type + if k.count('weight'): + self.addtransvar(k, tcty + ' ' + ks[-1].upper() + + u' weight (µS)') + elif k.endswith('lamtha'): + self.addtransvar(k, tcty + ' freq (Hz)') + + self.ldict = [self.dL2, self.dL5, self.dtiming] + self.ltitle = ['Layer 2/3', 'Layer 5', 'Timing'] + self.stitle = 'Set Poisson Inputs' + + +# widget to specify run params (tstop, dt, etc.) -- not many params here +class RunParamDialog(DictDialog): + def __init__(self, parent, mainwin, din=None): + self.mainwin = mainwin + super(RunParamDialog, self).__init__(parent, din) + self.addHideButton() + self.parent = parent + + def initd(self): + + self.drun = OrderedDict([('tstop', 250.), # simulation end time (ms) + ('dt', 0.025), # timestep + ('celsius', 37.0), # temperature + ('N_trials', 1), # number of trials + ('threshold', 0.0)]) # firing threshold + # cvode - not currently used by simulation + + # analysis + self.danalysis = OrderedDict([('save_figs', 0), + ('save_spec_data', 0), + ('f_max_spec', 40), + ('dipole_scalefctr', 30e3), + ('dipole_smooth_win', 15.0), + ('record_vsoma', 0)]) + + self.drand = OrderedDict([('prng_seedcore_opt', + self.mainwin.prng_seedcore_opt), + ('prng_seedcore_input_prox', 0), + ('prng_seedcore_input_dist', 0), + ('prng_seedcore_extpois', 0), + ('prng_seedcore_extgauss', 0), + ('prng_seedcore_evprox_1', 0), + ('prng_seedcore_evdist_1', 0), + ('prng_seedcore_evprox_2', 0), + ('prng_seedcore_evdist_2', 0)]) + + self.ldict = [self.drun, self.danalysis, self.drand] + self.ltitle = ['Run', 'Analysis', 'Randomization Seeds'] + self.stitle = 'Run Parameters' + + self.addtransvar('tstop', 'Duration (ms)') + self.addtransvar('dt', 'Integration Timestep (ms)') + self.addtransvar('celsius', 'Temperature (C)') + self.addtransvar('threshold', 'Firing Threshold (mV)') + self.addtransvar('N_trials', 'Trials') + self.addtransvar('save_spec_data', 'Save Spectral Data') + self.addtransvar('save_figs', 'Save Figures') + self.addtransvar('f_max_spec', 'Max Spectral Frequency (Hz)') + self.addtransvar('spec_cmap', 'Spectrogram Colormap') + self.addtransvar('dipole_scalefctr', 'Dipole Scaling') + self.addtransvar('dipole_smooth_win', 'Dipole Smooth Window (ms)') + self.addtransvar('record_vsoma', 'Record Somatic Voltages') + self.addtransvar('prng_seedcore_opt', 'Parameter Optimization') + self.addtransvar('prng_seedcore_input_prox', 'Ongoing Proximal Input') + self.addtransvar('prng_seedcore_input_dist', 'Ongoing Distal Input') + self.addtransvar('prng_seedcore_extpois', 'External Poisson') + self.addtransvar('prng_seedcore_extgauss', 'External Gaussian') + self.addtransvar('prng_seedcore_evprox_1', 'Evoked Proximal 1') + self.addtransvar('prng_seedcore_evdist_1', 'Evoked Distal 1 ') + self.addtransvar('prng_seedcore_evprox_2', 'Evoked Proximal 2') + self.addtransvar('prng_seedcore_evdist_2', 'Evoked Distal 2') + + def selectionchange(self, i): + self.spec_cmap = self.cmaps[i] + self.parent.update_gui_params({}) + + def initExtra(self): + DictDialog.initExtra(self) + self.dqextra['NumCores'] = QLineEdit(self) + self.dqextra['NumCores'].setText(str(self.mainwin.defncore)) + self.addtransvar('NumCores', 'Number Cores') + self.ltabs[0].layout.addRow('NumCores', self.dqextra['NumCores']) + + self.spec_cmap_cb = None + + self.cmaps = ['jet', + 'viridis', + 'plasma', + 'inferno', + 'magma', + 'cividis'] + + self.spec_cmap_cb = QComboBox() + for cmap in self.cmaps: + self.spec_cmap_cb.addItem(cmap) + self.spec_cmap_cb.currentIndexChanged.connect(self.selectionchange) + self.ltabs[1].layout.addRow( + self.transvar('spec_cmap'), self.spec_cmap_cb) + + def getntrial(self): + ntrial = int(self.dqline['N_trials'].text().strip()) + if ntrial < 1: + self.dqline['N_trials'].setText(str(1)) + ntrial = 1 + return ntrial + + def getncore(self): + ncore = int(self.dqextra['NumCores'].text().strip()) + if ncore < 1: + self.dqline['NumCores'].setText(str(1)) + ncore = 1 + + # update value in HNNGUI for persistence + self.mainwin.defncore = ncore + + return ncore + + def get_prng_seedcore_opt(self): + prng_seedcore_opt = self.dqline['prng_seedcore_opt'].text().strip() + + # update value in HNNGUI for persistence + self.mainwin.prng_seedcore_opt = int(prng_seedcore_opt) + + return int(self.mainwin.prng_seedcore_opt) + + def setfromdin(self, din): + if not din: + return + + # number of cores may have changed if the configured number failed + self.dqextra['NumCores'].setText(str(self.mainwin.defncore)) + + # update ordered dict of QLineEdit objects with new parameters + for k, v in din.items(): + if k in self.dqline: + self.dqline[k].setText(str(v).strip()) + elif k == 'spec_cmap': + self.spec_cmap = v + + # for spec_cmap we want the user to be able to change + # (e.g. 'viridis'), but the default is 'jet' to be consistent with + # prior publications on HNN + if 'spec_cmap' not in din: + self.spec_cmap = 'jet' + + # update the spec_cmap dropdown menu + self.spec_cmap_cb.setCurrentIndex(self.cmaps.index(self.spec_cmap)) + + def __str__(self): + s = '' + for k, v in self.dqline.items(): + s += k + ': ' + v.text().strip() + os.linesep + s += 'spec_cmap: ' + self.spec_cmap + os.linesep + return s + +# widget to specify (pyramidal) cell parameters (geometry, synapses, +# biophysics) + + +class CellParamDialog (DictDialog): + def __init__(self, parent=None, din=None): + super(CellParamDialog, self).__init__(parent, din) + self.addHideButton() + + def initd(self): + + self.dL2PyrGeom = OrderedDict([('L2Pyr_soma_L', 22.1), # Soma + ('L2Pyr_soma_diam', 23.4), + ('L2Pyr_soma_cm', 0.6195), + ('L2Pyr_soma_Ra', 200.), + # Dendrites + ('L2Pyr_dend_cm', 0.6195), + ('L2Pyr_dend_Ra', 200.), + ('L2Pyr_apicaltrunk_L', 59.5), + ('L2Pyr_apicaltrunk_diam', 4.25), + ('L2Pyr_apical1_L', 306.), + ('L2Pyr_apical1_diam', 4.08), + ('L2Pyr_apicaltuft_L', 238.), + ('L2Pyr_apicaltuft_diam', 3.4), + ('L2Pyr_apicaloblique_L', 340.), + ('L2Pyr_apicaloblique_diam', 3.91), + ('L2Pyr_basal1_L', 85.), + ('L2Pyr_basal1_diam', 4.25), + ('L2Pyr_basal2_L', 255.), + ('L2Pyr_basal2_diam', 2.72), + ('L2Pyr_basal3_L', 255.), + ('L2Pyr_basal3_diam', 2.72)]) + + self.dL2PyrSyn = OrderedDict([('L2Pyr_ampa_e', 0.), # Synapses + ('L2Pyr_ampa_tau1', 0.5), + ('L2Pyr_ampa_tau2', 5.), + ('L2Pyr_nmda_e', 0.), + ('L2Pyr_nmda_tau1', 1.), + ('L2Pyr_nmda_tau2', 20.), + ('L2Pyr_gabaa_e', -80.), + ('L2Pyr_gabaa_tau1', 0.5), + ('L2Pyr_gabaa_tau2', 5.), + ('L2Pyr_gabab_e', -80.), + ('L2Pyr_gabab_tau1', 1.), + ('L2Pyr_gabab_tau2', 20.)]) + + self.dL2PyrBiophys = OrderedDict([ # Biophysics soma + ('L2Pyr_soma_gkbar_hh2', 0.01), + ('L2Pyr_soma_gnabar_hh2', 0.18), + ('L2Pyr_soma_el_hh2', -65.), + ('L2Pyr_soma_gl_hh2', 4.26e-5), + ('L2Pyr_soma_gbar_km', 250.), + # Biophysics dends + ('L2Pyr_dend_gkbar_hh2', 0.01), + ('L2Pyr_dend_gnabar_hh2', 0.15), + ('L2Pyr_dend_el_hh2', -65.), + ('L2Pyr_dend_gl_hh2', 4.26e-5), + ('L2Pyr_dend_gbar_km', 250.)]) + + self.dL5PyrGeom = OrderedDict([('L5Pyr_soma_L', 39.), # Soma + ('L5Pyr_soma_diam', 28.9), + ('L5Pyr_soma_cm', 0.85), + ('L5Pyr_soma_Ra', 200.), + # Dendrites + ('L5Pyr_dend_cm', 0.85), + ('L5Pyr_dend_Ra', 200.), + ('L5Pyr_apicaltrunk_L', 102.), + ('L5Pyr_apicaltrunk_diam', 10.2), + ('L5Pyr_apical1_L', 680.), + ('L5Pyr_apical1_diam', 7.48), + ('L5Pyr_apical2_L', 680.), + ('L5Pyr_apical2_diam', 4.93), + ('L5Pyr_apicaltuft_L', 425.), + ('L5Pyr_apicaltuft_diam', 3.4), + ('L5Pyr_apicaloblique_L', 255.), + ('L5Pyr_apicaloblique_diam', 5.1), + ('L5Pyr_basal1_L', 85.), + ('L5Pyr_basal1_diam', 6.8), + ('L5Pyr_basal2_L', 255.), + ('L5Pyr_basal2_diam', 8.5), + ('L5Pyr_basal3_L', 255.), + ('L5Pyr_basal3_diam', 8.5)]) + + self.dL5PyrSyn = OrderedDict([('L5Pyr_ampa_e', 0.), # Synapses + ('L5Pyr_ampa_tau1', 0.5), + ('L5Pyr_ampa_tau2', 5.), + ('L5Pyr_nmda_e', 0.), + ('L5Pyr_nmda_tau1', 1.), + ('L5Pyr_nmda_tau2', 20.), + ('L5Pyr_gabaa_e', -80.), + ('L5Pyr_gabaa_tau1', 0.5), + ('L5Pyr_gabaa_tau2', 5.), + ('L5Pyr_gabab_e', -80.), + ('L5Pyr_gabab_tau1', 1.), + ('L5Pyr_gabab_tau2', 20.)]) + + self.dL5PyrBiophys = OrderedDict([ # Biophysics soma + ('L5Pyr_soma_gkbar_hh2', 0.01), + ('L5Pyr_soma_gnabar_hh2', 0.16), + ('L5Pyr_soma_el_hh2', -65.), + ('L5Pyr_soma_gl_hh2', 4.26e-5), + ('L5Pyr_soma_gbar_ca', 60.), + ('L5Pyr_soma_taur_cad', 20.), + ('L5Pyr_soma_gbar_kca', 2e-4), + ('L5Pyr_soma_gbar_km', 200.), + ('L5Pyr_soma_gbar_cat', 2e-4), + ('L5Pyr_soma_gbar_ar', 1e-6), + # Biophysics dends + ('L5Pyr_dend_gkbar_hh2', 0.01), + ('L5Pyr_dend_gnabar_hh2', 0.14), + ('L5Pyr_dend_el_hh2', -71.), + ('L5Pyr_dend_gl_hh2', 4.26e-5), + ('L5Pyr_dend_gbar_ca', 60.), + ('L5Pyr_dend_taur_cad', 20.), + ('L5Pyr_dend_gbar_kca', 2e-4), + ('L5Pyr_dend_gbar_km', 200.), + ('L5Pyr_dend_gbar_cat', 2e-4), + ('L5Pyr_dend_gbar_ar', 1e-6)]) + + dtrans = {'gkbar': 'Kv', 'gnabar': 'Na', 'km': 'Km', 'gl': 'leak', + 'ca': 'Ca', 'kca': 'KCa', 'cat': 'CaT', 'ar': 'HCN', + 'cad': 'Ca decay time', 'dend': 'Dendrite', 'soma': 'Soma', + 'apicaltrunk': 'Apical Dendrite Trunk', + 'apical1': 'Apical Dendrite 1', + 'apical2': 'Apical Dendrite 2', + 'apical3': 'Apical Dendrite 3', + 'apicaltuft': 'Apical Dendrite Tuft', + 'apicaloblique': 'Oblique Apical Dendrite', + 'basal1': 'Basal Dendrite 1', 'basal2': 'Basal Dendrite 2', + 'basal3': 'Basal Dendrite 3'} + + for d in [self.dL2PyrGeom, self.dL5PyrGeom]: + for k in d.keys(): + lk = k.split('_') + if lk[-1] == 'L': + self.addtransvar( + k, dtrans[lk[1]] + ' ' + r'length (micron)') + elif lk[-1] == 'diam': + self.addtransvar( + k, dtrans[lk[1]] + ' ' + r'diameter (micron)') + elif lk[-1] == 'cm': + self.addtransvar( + k, dtrans[lk[1]] + ' ' + r'capacitive density (F/cm2)') + elif lk[-1] == 'Ra': + self.addtransvar( + k, dtrans[lk[1]] + ' ' + r'resistivity (ohm-cm)') + + for d in [self.dL2PyrSyn, self.dL5PyrSyn]: + for k in d.keys(): + lk = k.split('_') + if k.endswith('e'): + self.addtransvar(k, lk[1].upper() + ' ' + ' reversal (mV)') + elif k.endswith('tau1'): + self.addtransvar(k, lk[1].upper() + + ' ' + ' rise time (ms)') + elif k.endswith('tau2'): + self.addtransvar(k, lk[1].upper() + + ' ' + ' decay time (ms)') + + for d in [self.dL2PyrBiophys, self.dL5PyrBiophys]: + for k in d.keys(): + lk = k.split('_') + if lk[2].count('g') > 0: + if lk[3] == 'km' or lk[3] == 'ca' or lk[3] == 'kca' \ + or lk[3] == 'cat' or lk[3] == 'ar': + nv = dtrans[lk[1]] + ' ' + \ + dtrans[lk[3]] + ' ' + ' channel density ' + else: + nv = dtrans[lk[1]] + ' ' + \ + dtrans[lk[2]] + ' ' + ' channel density ' + if lk[3] == 'hh2' or lk[3] == 'cat' or lk[3] == 'ar': + nv += '(S/cm2)' + else: + nv += '(pS/micron2)' + elif lk[2].count('el') > 0: + nv = dtrans[lk[1]] + ' leak reversal (mV)' + elif lk[2].count('taur') > 0: + nv = dtrans[lk[1]] + ' ' + dtrans[lk[3]] + ' (ms)' + self.addtransvar(k, nv) + + self.ldict = [self.dL2PyrGeom, self.dL2PyrSyn, self.dL2PyrBiophys, + self.dL5PyrGeom, self.dL5PyrSyn, self.dL5PyrBiophys] + self.ltitle = ['L2/3 Pyr Geometry', 'L2/3 Pyr Synapses', + 'L2/3 Pyr Biophysics', 'L5 Pyr Geometry', + 'L5 Pyr Synapses', 'L5 Pyr Biophysics'] + self.stitle = 'Cell Parameters' + + +# widget to specify network parameters (number cells, weights, etc.) +class NetworkParamDialog (DictDialog): + def __init__(self, parent=None, din=None): + super(NetworkParamDialog, self).__init__(parent, din) + self.addHideButton() + + def initd(self): + # number of cells + self.dcells = OrderedDict([('N_pyr_x', 10), + ('N_pyr_y', 10)]) + + # max conductances TO L2Pyr + self.dL2Pyr = OrderedDict([('gbar_L2Pyr_L2Pyr_ampa', 0.), + ('gbar_L2Pyr_L2Pyr_nmda', 0.), + ('gbar_L2Basket_L2Pyr_gabaa', 0.), + ('gbar_L2Basket_L2Pyr_gabab', 0.)]) + + # max conductances TO L2Baskets + self.dL2Bas = OrderedDict([('gbar_L2Pyr_L2Basket', 0.), + ('gbar_L2Basket_L2Basket', 0.)]) + + # max conductances TO L5Pyr + self.dL5Pyr = OrderedDict([('gbar_L2Pyr_L5Pyr', 0.), + ('gbar_L2Basket_L5Pyr', 0.), + ('gbar_L5Pyr_L5Pyr_ampa', 0.), + ('gbar_L5Pyr_L5Pyr_nmda', 0.), + ('gbar_L5Basket_L5Pyr_gabaa', 0.), + ('gbar_L5Basket_L5Pyr_gabab', 0.)]) + + # max conductances TO L5Baskets + self.dL5Bas = OrderedDict([('gbar_L2Pyr_L5Basket', 0.), + ('gbar_L5Pyr_L5Basket', 0.), + ('gbar_L5Basket_L5Basket', 0.)]) + + self.ldict = [self.dcells, self.dL2Pyr, + self.dL5Pyr, self.dL2Bas, self.dL5Bas] + self.ltitle = ['Cells', 'Layer 2/3 Pyr', + 'Layer 5 Pyr', 'Layer 2/3 Bas', 'Layer 5 Bas'] + self.stitle = 'Local Network Parameters' + + self.addtransvar('N_pyr_x', 'Num Pyr Cells (X direction)') + self.addtransvar('N_pyr_y', 'Num Pyr Cells (Y direction)') + + dtmp = {'L2': 'L2/3 ', 'L5': 'L5 '} + + for d in [self.dL2Pyr, self.dL5Pyr, self.dL2Bas, self.dL5Bas]: + for k in d.keys(): + lk = k.split('_') + sty1 = dtmp[lk[1][0:2]] + lk[1][2:] + sty2 = dtmp[lk[2][0:2]] + lk[2][2:] + if len(lk) == 3: + self.addtransvar(k, sty1 + ' -> ' + sty2 + u' weight (µS)') + else: + self.addtransvar(k, sty1 + ' -> ' + sty2 + ' ' + + lk[3].upper() + u' weight (µS)') + + +class HelpDialog (QDialog): + def __init__(self, parent): + super(HelpDialog, self).__init__(parent) + self.initUI() + + def initUI(self): + self.layout = QVBoxLayout(self) + # Add stretch to separate the form layout from the button + self.layout.addStretch(1) + + setscalegeom(self, 100, 100, 300, 100) + self.setWindowTitle('Help') + + +class SchematicDialog (QDialog): + # class for holding model schematics (and parameter shortcuts) + def __init__(self, parent): + super(SchematicDialog, self).__init__(parent) + self.initUI() + + def initUI(self): + + self.setWindowTitle('Model Schematics') + QToolTip.setFont(QFont('SansSerif', 10)) + + self.grid = grid = QGridLayout() + grid.setSpacing(10) + + gRow = 0 + + self.locbtn = QPushButton( + 'Local Network' + os.linesep + 'Connections', self) + self.locbtn.setIcon(QIcon(lookupresource('connfig'))) + self.locbtn.clicked.connect(self.parent().shownetparamwin) + self.grid.addWidget(self.locbtn, gRow, 0, 1, 1) + + self.proxbtn = QPushButton( + 'Proximal Drive' + os.linesep + 'Thalamus', self) + self.proxbtn.setIcon(QIcon(lookupresource('proxfig'))) + self.proxbtn.clicked.connect(self.parent().showproxparamwin) + self.grid.addWidget(self.proxbtn, gRow, 1, 1, 1) + + self.distbtn = QPushButton( + 'Distal Drive NonLemniscal' + os.linesep + + 'Thal./Cortical Feedback', self) + self.distbtn.setIcon(QIcon(lookupresource('distfig'))) + self.distbtn.clicked.connect(self.parent().showdistparamwin) + self.grid.addWidget(self.distbtn, gRow, 2, 1, 1) + + gRow = 1 + + # for schematic dialog box + self.pixConn = QPixmap(lookupresource('connfig')) + self.pixConnlbl = ClickLabel(self) + self.pixConnlbl.setScaledContents(True) + # self.pixConnlbl.resize(self.pixConnlbl.size()) + self.pixConnlbl.setPixmap(self.pixConn) + # self.pixConnlbl.clicked.connect(self.shownetparamwin) + self.grid.addWidget(self.pixConnlbl, gRow, 0, 1, 1) + + self.pixProx = QPixmap(lookupresource('proxfig')) + self.pixProxlbl = ClickLabel(self) + self.pixProxlbl.setScaledContents(True) + self.pixProxlbl.setPixmap(self.pixProx) + # self.pixProxlbl.clicked.connect(self.showproxparamwin) + self.grid.addWidget(self.pixProxlbl, gRow, 1, 1, 1) + + self.pixDist = QPixmap(lookupresource('distfig')) + self.pixDistlbl = ClickLabel(self) + self.pixDistlbl.setScaledContents(True) + self.pixDistlbl.setPixmap(self.pixDist) + # self.pixDistlbl.clicked.connect(self.showdistparamwin) + self.grid.addWidget(self.pixDistlbl, gRow, 2, 1, 1) + + self.setLayout(grid) + + +class BaseParamDialog (QDialog): + """Base widget for specifying params + + The params dictionary is stored within this class. Other Dialogs access it + here. + """ + def __init__(self, parent, paramfn): + super(BaseParamDialog, self).__init__(parent) + self.proxparamwin = None + self.distparamwin = None + self.netparamwin = None + self.syngainparamwin = None + self.runparamwin = RunParamDialog(self, parent) + self.cellparamwin = CellParamDialog(self) + self.netparamwin = NetworkParamDialog(self) + self.syngainparamwin = SynGainParamDialog(self, self.netparamwin) + self.proxparamwin = OngoingInputParamDialog(self, 'Proximal') + self.distparamwin = OngoingInputParamDialog(self, 'Distal') + self.evparamwin = EvokedInputParamDialog(self, None) + self.optparamwin = OptEvokedInputParamDialog(self, parent) + self.poisparamwin = PoissonInputParamDialog(self, None) + self.tonicparamwin = TonicInputParamDialog(self, None) + self.lsubwin = [self.runparamwin, self.cellparamwin, self.netparamwin, + self.proxparamwin, self.distparamwin, self.evparamwin, + self.poisparamwin, self.tonicparamwin, + self.optparamwin] + self.paramfn = paramfn + self.parent = parent + + self.params = read_params(self.paramfn) + self.initUI() # requires self.params + self.updateDispParam(self.params) + + def updateDispParam(self, params=None): + global drawavgdpl + + if params is None: + try: + params = read_params(self.paramfn) + except ValueError: + QMessageBox.information(self, "HNN", "WARNING: could not" + "retrieve parameters from %s" % + self.paramfn) + return + + self.params = params + + if usingEvokedInputs(self.params): + # default for evoked is to show average dipole + drawavgdpl = True + elif usingOngoingInputs(self.params): + # default for ongoing is NOT to show average dipole + drawavgdpl = False + + for dlg in self.lsubwin: + dlg.setfromdin(self.params) # update to values from file + self.qle.setText(self.params['sim_prefix']) # update simulation name + + def setrunparam(self): + bringwintotop(self.runparamwin) + + def setcellparam(self): + bringwintotop(self.cellparamwin) + + def setnetparam(self): + bringwintotop(self.netparamwin) + + def setsyngainparam(self): + bringwintotop(self.syngainparamwin) + + def setproxparam(self): + bringwintotop(self.proxparamwin) + + def setdistparam(self): + bringwintotop(self.distparamwin) + + def setevparam(self): + bringwintotop(self.evparamwin) + + def setpoisparam(self): + bringwintotop(self.poisparamwin) + + def settonicparam(self): + bringwintotop(self.tonicparamwin) + + def initUI(self): + grid = QGridLayout() + grid.setSpacing(10) + + row = 1 + + self.lbl = QLabel(self) + self.lbl.setText('Simulation Name:') + self.lbl.adjustSize() + self.lbl.setToolTip( + 'Simulation Name used to save parameter file and simulation data') + grid.addWidget(self.lbl, row, 0) + self.qle = QLineEdit(self) + self.qle.setText(self.params['sim_prefix']) + grid.addWidget(self.qle, row, 1) + row += 1 + + self.btnrun = QPushButton('Run', self) + self.btnrun.resize(self.btnrun.sizeHint()) + self.btnrun.setToolTip('Set Run Parameters') + self.btnrun.clicked.connect(self.setrunparam) + grid.addWidget(self.btnrun, row, 0, 1, 1) + + self.btncell = QPushButton('Cell', self) + self.btncell.resize(self.btncell.sizeHint()) + self.btncell.setToolTip( + 'Set Cell (Geometry, Synapses, Biophysics) Parameters') + self.btncell.clicked.connect(self.setcellparam) + grid.addWidget(self.btncell, row, 1, 1, 1) + row += 1 + + self.btnnet = QPushButton('Local Network', self) + self.btnnet.resize(self.btnnet.sizeHint()) + self.btnnet.setToolTip('Set Local Network Parameters') + self.btnnet.clicked.connect(self.setnetparam) + grid.addWidget(self.btnnet, row, 0, 1, 1) + + self.btnsyngain = QPushButton('Synaptic Gains', self) + self.btnsyngain.resize(self.btnsyngain.sizeHint()) + self.btnsyngain.setToolTip('Set Local Network Synaptic Gains') + self.btnsyngain.clicked.connect(self.setsyngainparam) + grid.addWidget(self.btnsyngain, row, 1, 1, 1) + + row += 1 + + self.btnprox = QPushButton('Rhythmic Proximal Inputs', self) + self.btnprox.resize(self.btnprox.sizeHint()) + self.btnprox.setToolTip('Set Rhythmic Proximal Inputs') + self.btnprox.clicked.connect(self.setproxparam) + grid.addWidget(self.btnprox, row, 0, 1, 2) + row += 1 + + self.btndist = QPushButton('Rhythmic Distal Inputs', self) + self.btndist.resize(self.btndist.sizeHint()) + self.btndist.setToolTip('Set Rhythmic Distal Inputs') + self.btndist.clicked.connect(self.setdistparam) + grid.addWidget(self.btndist, row, 0, 1, 2) + row += 1 + + self.btnev = QPushButton('Evoked Inputs', self) + self.btnev.resize(self.btnev.sizeHint()) + self.btnev.setToolTip('Set Evoked Inputs') + self.btnev.clicked.connect(self.setevparam) + grid.addWidget(self.btnev, row, 0, 1, 2) + row += 1 + + self.btnpois = QPushButton('Poisson Inputs', self) + self.btnpois.resize(self.btnpois.sizeHint()) + self.btnpois.setToolTip('Set Poisson Inputs') + self.btnpois.clicked.connect(self.setpoisparam) + grid.addWidget(self.btnpois, row, 0, 1, 2) + row += 1 + + self.btntonic = QPushButton('Tonic Inputs', self) + self.btntonic.resize(self.btntonic.sizeHint()) + self.btntonic.setToolTip('Set Tonic (Current Clamp) Inputs') + self.btntonic.clicked.connect(self.settonicparam) + grid.addWidget(self.btntonic, row, 0, 1, 2) + row += 1 + + self.btnsave = QPushButton('Save Parameters To File', self) + self.btnsave.resize(self.btnsave.sizeHint()) + self.btnsave.setToolTip( + 'Save All Parameters to File (Specified by Simulation Name)') + self.btnsave.clicked.connect(self.saveparams) + grid.addWidget(self.btnsave, row, 0, 1, 2) + row += 1 + + self.btnhide = QPushButton('Hide Window', self) + self.btnhide.resize(self.btnhide.sizeHint()) + self.btnhide.clicked.connect(self.hide) + self.btnhide.setToolTip('Hide Window') + grid.addWidget(self.btnhide, row, 0, 1, 2) + + self.setLayout(grid) + + self.setWindowTitle('Set Parameters') + + def saveparams(self, checkok=True): + param_dir = os.path.join(get_output_dir(), 'param') + tmpf = os.path.join(param_dir, self.qle.text() + '.param') + + oktosave = True + if os.path.isfile(tmpf) and checkok: + self.show() + msg = QMessageBox() + ret = msg.warning(self, 'Over-write file(s)?', + tmpf + ' already exists. Over-write?', + QMessageBox.Ok | QMessageBox.Cancel, + QMessageBox.Ok) + if ret == QMessageBox.Cancel: + oktosave = False + + if oktosave: + # update params dict with values from GUI + self.params = Params(legacy_param_str_to_dict(str(self))) + + os.makedirs(param_dir, exist_ok=True) + with open(tmpf, 'w') as fp: + fp.write(str(self)) + + self.paramfn = tmpf + data_dir = os.path.join(get_output_dir(), 'data') + sim_dir = os.path.join(data_dir, self.qle.text()) + os.makedirs(sim_dir, exist_ok=True) + + return oktosave + + def update_gui_params(self, dtest): + """ Update parameter values in GUI + + So user can see and so GUI will save these param values + """ + for win in self.lsubwin: + win.setfromdin(dtest) + + def __str__(self): + s = 'sim_prefix: ' + self.qle.text() + os.linesep + s += 'expmt_groups: {' + self.qle.text() + '}' + os.linesep + for win in self.lsubwin: + s += str(win) + return s + + +class WaitSimDialog (QDialog): + def __init__(self, parent): + super(WaitSimDialog, self).__init__(parent) + self.initUI() + self.txt = '' # text for display + + def updatetxt(self, txt): + self.qtxt.append(txt) + + def initUI(self): + self.layout = QVBoxLayout(self) + self.layout.addStretch(1) + + self.qtxt = QTextEdit(self) + self.layout.addWidget(self.qtxt) + + self.stopbtn = stopbtn = QPushButton('Stop All Simulations', self) + stopbtn.setToolTip('Stop All Simulations') + stopbtn.resize(stopbtn.sizeHint()) + stopbtn.clicked.connect(self.stopsim) + self.layout.addWidget(stopbtn) + + setscalegeomcenter(self, 500, 250) + self.setWindowTitle("Simulation Log") + + def stopsim(self): + self.parent().stopsim() + self.hide() diff --git a/hnn/qt_dipole.py b/hnn/qt_dipole.py new file mode 100644 index 000000000..c1d6907c1 --- /dev/null +++ b/hnn/qt_dipole.py @@ -0,0 +1,115 @@ +"""Class for Dipole viewing window""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import numpy as np + +from PyQt5.QtWidgets import QSizePolicy + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure + +fontsize = plt.rcParams['font.size'] = 10 +random_label = np.random.rand(100) + + +class DipoleCanvas(FigureCanvasQTAgg): + """Class for displaying Dipole Viewer + + Required parameters in params dict: N_trials, tstop, dipole_scalefctr + """ + def __init__(self, params, sim_data, index, parent=None, width=12, + height=10, dpi=120, title='Dipole Viewer'): + FigureCanvasQTAgg.__init__(self, Figure(figsize=(width, height), + dpi=dpi)) + self.title = title + self.setParent(parent) + self.gui = parent + self.index = index + self.sim_data = sim_data + FigureCanvasQTAgg.setSizePolicy(self, QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + + self.params = params + self.scalefctr = self.params['dipole_scalefctr'] + if type(self.scalefctr) != float and type(self.scalefctr) != int: + self.scalefctr = 30e3 + self.tstop = self.params['tstop'] + self.ntrial = self.params['N_trials'] + + self.plot() + + def plot(self): + global random_label + + gdx = 311 + + ltitle = ['Layer 2/3', 'Layer 5', 'Aggregate'] + dipole_keys = ['L2', 'L5', 'agg'] + + white_patch = mpatches.Patch(color='white', label='Average') + gray_patch = mpatches.Patch(color='gray', label='Individual') + lpatch = [] + + if len(self.sim_data['dpls']) > 0: + lpatch = [white_patch, gray_patch] + + yl = [1e9, -1e9] + for key in dipole_keys: + yl[0] = min(yl[0], np.amin(self.sim_data['avg_dpl'].data[key])) + yl[1] = max(yl[1], np.amax(self.sim_data['avg_dpl'].data[key])) + + # plot dipoles from individual trials + if len(self.sim_data['dpls']) > 0: + for dpltrial in self.sim_data['dpls']: + yl[0] = min(yl[0], np.amin(dpltrial.data[key])) + yl[1] = max(yl[1], np.amax(dpltrial.data[key])) + yl = tuple(yl) + + for key, title in zip(dipole_keys, ltitle): + ax = self.figure.add_subplot(gdx, label=random_label) + random_label += 1 + + if key == 'agg': + ax.set_xlabel('Time (ms)') + + lw = self.gui.linewidth + if self.index != 0: + lw = self.gui.linewidth + 2 + + # plot dipoles from individual trials + if len(self.sim_data['dpls']) > 0: + for ddx, dpltrial in enumerate(self.sim_data['dpls']): + if self.index == 0 or (self.index > 0 and + ddx == (self.index - 1)): + ax.plot(dpltrial.times, dpltrial.data[key], + color='gray', linewidth=lw) + + # average dipole (across trials) + if self.index == 0: + ax.plot(self.sim_data['avg_dpl'].times, + self.sim_data['avg_dpl'].data[key], 'w', + linewidth=self.gui.linewidth + 2) + + ax.set_ylabel(r'(nAm $\times$ ' + str(self.scalefctr) + ')') + if self.tstop != -1: + ax.set_xlim((0, self.tstop)) + ax.set_ylim(yl) + + if key == 'L2' and len(self.sim_data['dpls']) > 0: + ax.legend(handles=lpatch) + + ax.set_facecolor('k') + ax.grid(True) + ax.set_title(title) + + gdx += 1 + + self.figure.subplots_adjust(bottom=0.06, left=0.06, right=1.0, + top=0.97, wspace=0.1, hspace=0.09) + + self.draw() diff --git a/hnn/qt_evoked.py b/hnn/qt_evoked.py new file mode 100644 index 000000000..7f9684285 --- /dev/null +++ b/hnn/qt_evoked.py @@ -0,0 +1,1520 @@ +"""Class for creating the optimization configuration window""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import os +import numpy as np +from math import isclose +from copy import deepcopy + +from PyQt5.QtWidgets import QPushButton, QTabWidget, QWidget, QDialog +from PyQt5.QtWidgets import QGridLayout, QLabel, QFrame, QSpacerItem +from PyQt5.QtWidgets import QCheckBox, QSizePolicy, QLineEdit +from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QFormLayout +from PyQt5.QtGui import QPixmap +from PyQt5.QtCore import Qt + +from .qt_lib import QRangeSlider, MyLineEdit, ClickLabel, setscalegeom +from .qt_lib import lookupresource +from .paramrw import countEvokedInputs + +decay_multiplier = 1.6 + + +def _consolidate_chunks(input_dict): + # MOVE to hnn-core + # get a list of sorted chunks + sorted_inputs = sorted( + input_dict.items(), key=lambda x: x[1]['user_start']) + + consolidated_chunks = [] + for one_input in sorted_inputs: + if 'opt_start' not in one_input[1]: + continue + + # extract info from sorted list + input_dict = {'inputs': [one_input[0]], + 'chunk_start': one_input[1]['user_start'], + 'chunk_end': one_input[1]['user_end'], + 'opt_start': one_input[1]['opt_start'], + 'opt_end': one_input[1]['opt_end'], + 'weights': one_input[1]['weights'], + } + + if (len(consolidated_chunks) > 0) and \ + (input_dict['chunk_start'] <= consolidated_chunks[-1]['chunk_end']): + # update previous chunk + consolidated_chunks[-1]['inputs'].extend(input_dict['inputs']) + consolidated_chunks[-1]['chunk_end'] = input_dict['chunk_end'] + consolidated_chunks[-1]['opt_end'] = max( + consolidated_chunks[-1]['opt_end'], input_dict['opt_end']) + # average the weights + consolidated_chunks[-1]['weights'] = ( + consolidated_chunks[-1]['weights'] + one_input[1]['weights'])/2 + else: + # new chunk + consolidated_chunks.append(input_dict) + + return consolidated_chunks + + +def _combine_chunks(input_chunks): + # MOVE to hnn-core + # Used for creating the opt params of the last step with all inputs + + final_chunk = {'inputs': [], + 'opt_start': 0.0, + 'opt_end': 0.0, + 'chunk_start': 0.0, + 'chunk_end': 0.0} + + for evinput in input_chunks: + final_chunk['inputs'].extend(evinput['inputs']) + if evinput['opt_end'] > final_chunk['opt_end']: + final_chunk['opt_end'] = evinput['opt_end'] + if evinput['chunk_end'] > final_chunk['chunk_end']: + final_chunk['chunk_end'] = evinput['chunk_end'] + + # wRMSE with weights of 1's is the same as regular RMSE. + final_chunk['weights'] = np.ones(len(input_chunks[-1]['weights'])) + return final_chunk + + +def _chunk_evinputs(opt_params, sim_tstop, sim_dt): + # MOVE to hnn-core + """ + Take dictionary (opt_params) sorted by input and + return a sorted list of dictionaries describing + chunks with inputs consolidated as determined the + range between 'user_start' and 'user_end'. + + The keys of the chunks in chunk_list dictionary + returned are: + 'weights' + 'chunk_start' + 'chunk_end' + 'opt_start' + 'opt_end' + """ + + import scipy.stats as stats + from math import ceil, floor + + num_step = ceil(sim_tstop / sim_dt) + 1 + times = np.linspace(0, sim_tstop, num_step) + + # input_dict will be passed to consolidate_chunks, so it has + # keys 'user_start' and 'user_end' instead of chunk_start and + # 'chunk_start' that will be returned in the dicts returned + # in chunk_list + input_dict = {} + cdfs = {} + + for input_name in opt_params.keys(): + if opt_params[input_name]['user_start'] > sim_tstop or \ + opt_params[input_name]['user_end'] < 0: + # can't optimize over this input + continue + + # calculate cdf using start time (minival of optimization range) + cdf = stats.norm.cdf(times, opt_params[input_name]['user_start'], + opt_params[input_name]['sigma']) + cdfs[input_name] = cdf.copy() + + for input_name in opt_params.keys(): + if opt_params[input_name]['user_start'] > sim_tstop or \ + opt_params[input_name]['user_end'] < 0: + # can't optimize over this input + continue + input_dict[input_name] = {'weights': cdfs[input_name].copy(), + 'user_start': opt_params[input_name]['user_start'], + 'user_end': opt_params[input_name]['user_end']} + + for other_input in opt_params: + if opt_params[other_input]['user_start'] > sim_tstop or \ + opt_params[other_input]['user_end'] < 0: + # not optimizing over that input + continue + if input_name == other_input: + # don't subtract our own cdf(s) + continue + if opt_params[other_input]['mean'] < \ + opt_params[input_name]['mean']: + # check ordering to only use inputs after us + continue + else: + decay_factor = opt_params[input_name]['decay_multiplier']*(opt_params[other_input]['mean'] - + opt_params[input_name]['mean']) / \ + sim_tstop + input_dict[input_name]['weights'] -= cdfs[other_input] * \ + decay_factor + + # weights should not drop below 0 + input_dict[input_name]['weights'] = np.clip( + input_dict[input_name]['weights'], a_min=0, a_max=None) + + # start and stop optimization where the weights are insignificant + good_indices = np.where(input_dict[input_name]['weights'] > 0.01) + if len(good_indices[0]) > 0: + input_dict[input_name]['opt_start'] = min( + opt_params[input_name]['user_start'], times[good_indices][0]) + input_dict[input_name]['opt_end'] = max( + opt_params[input_name]['user_end'], times[good_indices][-1]) + else: + input_dict[input_name]['opt_start'] = opt_params[other_input]['user_start'] + input_dict[input_name]['opt_end'] = opt_params[other_input]['user_end'] + + # convert to multiples of dt + input_dict[input_name]['opt_start'] = floor( + input_dict[input_name]['opt_start']/sim_dt)*sim_dt + input_dict[input_name]['opt_end'] = ceil( + input_dict[input_name]['opt_end']/sim_dt)*sim_dt + + # combined chunks that have overlapping ranges + # opt_params is a dict, turn into a list + chunk_list = _consolidate_chunks(input_dict) + + # add one last chunk to the end + if len(chunk_list) > 1: + chunk_list.append(_combine_chunks(chunk_list)) + + return chunk_list + + +def _get_param_inputs(params): + import re + input_list = [] + + # first pass through all params to get mu and sigma for each + for k in params.keys(): + input_mu = re.match('^t_ev(prox|dist)_([0-9]+)', k) + if input_mu: + id_str = 'ev' + input_mu.group(1) + '_' + input_mu.group(2) + input_list.append(id_str) + + return input_list + + +def _trans_input(input_var): + import re + + input_str = input_var + input_match = re.match('^ev(prox|dist)_([0-9]+)', input_var) + if input_match: + if input_match.group(1) == "prox": + input_str = 'Proximal ' + input_match.group(2) + if input_match.group(1) == "dist": + input_str = 'Distal ' + input_match.group(2) + + return input_str + + +def _format_range_str(value): + if value == 0: + value_str = "0.000" + elif value < 0.1: + value_str = ("%6f" % value) + else: + value_str = ("%.3f" % value) + + return value_str + + +def _get_prox_dict(nprox): + # evprox feed strength + + dprox = { + 't_evprox_' + str(nprox): 0., + 'sigma_t_evprox_' + str(nprox): 2.5, + 'numspikes_evprox_' + str(nprox): 1, + 'gbar_evprox_' + str(nprox) + '_L2Pyr_ampa': 0., + 'gbar_evprox_' + str(nprox) + '_L2Pyr_nmda': 0., + 'gbar_evprox_' + str(nprox) + '_L2Basket_ampa': 0., + 'gbar_evprox_' + str(nprox) + '_L2Basket_nmda': 0., + 'gbar_evprox_' + str(nprox) + '_L5Pyr_ampa': 0., + 'gbar_evprox_' + str(nprox) + '_L5Pyr_nmda': 0., + 'gbar_evprox_' + str(nprox) + '_L5Basket_ampa': 0., + 'gbar_evprox_' + str(nprox) + '_L5Basket_nmda': 0. + } + return dprox + + +def _get_dist_dict(ndist): + # evdist feed strength + + ddist = { + 't_evdist_' + str(ndist): 0., + 'sigma_t_evdist_' + str(ndist): 6., + 'numspikes_evdist_' + str(ndist): 1, + 'gbar_evdist_' + str(ndist) + '_L2Pyr_ampa': 0., + 'gbar_evdist_' + str(ndist) + '_L2Pyr_nmda': 0., + 'gbar_evdist_' + str(ndist) + '_L2Basket_ampa': 0., + 'gbar_evdist_' + str(ndist) + '_L2Basket_nmda': 0., + 'gbar_evdist_' + str(ndist) + '_L5Pyr_ampa': 0., + 'gbar_evdist_' + str(ndist) + '_L5Pyr_nmda': 0., + } + return ddist + + +class EvokedInputBaseDialog(QDialog): + def __init__(self): + super(EvokedInputBaseDialog, self).__init__() + + self.nprox = self.ndist = 0 # number of proximal,distal inputs + self.ld = [] # list of dictionaries for proximal/distal inputs + self.dqline = {} + # for translating model variable name to more human-readable form + self.dtransvar = {} + # TODO: add back tooltips + # self.addtips() + + def transvar(self, k): + if k in self.dtransvar: + return self.dtransvar[k] + return k + + def addtransvarfromdict(self, d): + dtmp = {'L2': 'L2/3 ', 'L5': 'L5 '} + for k in d.keys(): + if k.startswith('gbar'): + ks = k.split('_') + stmp = ks[-2] + self.addtransvar(k, dtmp[stmp[0:2]] + stmp[2:] + ' ' + + ks[-1].upper() + u' weight (µS)') + elif k.startswith('t'): + self.addtransvar(k, 'Start time mean (ms)') + elif k.startswith('sigma'): + self.addtransvar(k, 'Start time stdev (ms)') + elif k.startswith('numspikes'): + self.addtransvar(k, 'Number spikes') + + def addtransvar(self, k, strans): + self.dtransvar[k] = strans + self.dtransvar[strans] = k + + def IsProx(self, idx): + d = self.ld[idx] + for k in d.keys(): + if k.count('evprox'): + return True + + return False + + def getInputID(self, idx): + """get evoked input number associated with idx""" + d = self.ld[idx] + for k in d.keys(): + lk = k.split('_') + if len(lk) >= 3: + return int(lk[2]) + + def downShift(self, idx): + """downshift the evoked input ID, keys, values""" + d = self.ld[idx] + dnew = {} # new dictionary + newidx = 0 # new evoked input ID + for k, v in d.items(): + lk = k.split('_') + if len(lk) >= 3: + if lk[0] == 'sigma': + newidx = int(lk[3]) - 1 + lk[3] = str(newidx) + else: + newidx = int(lk[2]) - 1 + lk[2] = str(newidx) + newkey = '_'.join(lk) + dnew[newkey] = v + if k in self.dqline: + self.dqline[newkey] = self.dqline[k] + del self.dqline[k] + self.ld[idx] = dnew + currtxt = self.tabs.tabText(idx) + newtxt = currtxt.split(' ')[0] + ' ' + str(newidx) + self.tabs.setTabText(idx, newtxt) + + def removeInput(self, idx): + # remove the evoked input specified by idx + if idx < 0 or idx > len(self.ltabs): + return + self.tabs.removeTab(idx) + tab = self.ltabs[idx] + self.ltabs.remove(tab) + d = self.ld[idx] + + isprox = self.IsProx(idx) # is it a proximal input? + isdist = not isprox # is it a distal input? + + # what's the proximal/distal input number? + inputID = self.getInputID(idx) + + for k in d.keys(): + if k in self.dqline: + del self.dqline[k] + self.ld.remove(d) + tab.setParent(None) + + # now downshift the evoked inputs (only proximal or only distal) that + # came after this one. + # first get the IDs of the evoked inputs to downshift + lds = [] # list of inputs to downshift + for jdx in range(len(self.ltabs)): + if isprox and self.IsProx(jdx) and self.getInputID(jdx) > inputID: + lds.append(jdx) + elif isdist and not self.IsProx(jdx): + if self.getInputID(jdx) > inputID: + lds.append(jdx) + for jdx in lds: + self.downShift(jdx) # then do the downshifting + + def removeCurrentInput(self): + """ removes currently selected input""" + idx = self.tabs.currentIndex() + if idx < 0: + return + self.removeInput(idx) + + def removeAllInputs(self): + for _ in range(len(self.ltabs)): + self.removeCurrentInput() + self.nprox = self.ndist = 0 + + +class EvokedInputParamDialog (EvokedInputBaseDialog): + """ Evoked Input Dialog + allows adding/removing arbitrary number of evoked inputs""" + + def __init__(self, parent, din): + super(EvokedInputParamDialog, self).__init__() + self.initUI() + self.setfromdin(din) + + def transvar(self, k): + if k in self.dtransvar: + return self.dtransvar[k] + return k + + def set_qline_float(self, key_str, value): + try: + new_value = float(value) + except ValueError: + print("WARN: bad value for param %s: %s. Unable to convert" + " to a floating point number" % (key_str, value)) + return + + # Enforce no sci. not. + limit field len + remove trailing 0's + self.dqline[key_str].setText( + ("%7f" % new_value).rstrip('0').rstrip('.')) + + def setfromdin(self, din): + if not din: + return + + if 'dt' in din: + + # Optimization feature introduces the case where din just contains + # optimization relevant parameters. In that case, we don't want to + # remove all inputs, just modify existing inputs. + self.removeAllInputs() # turn off any previously set inputs + + nprox, ndist = countEvokedInputs(din) + for i in range(nprox+ndist): + if i % 2 == 0: + if self.nprox < nprox: + self.addProx() + elif self.ndist < ndist: + self.addDist() + else: + if self.ndist < ndist: + self.addDist() + elif self.nprox < nprox: + self.addProx() + + for k, v in din.items(): + if k == 'sync_evinput': + try: + new_value = bool(int(v)) + except ValueError: + print("WARN: bad value for param %s: %s. Unable to convert" + " to a boolean value" % (k, v)) + continue + if new_value: + self.chksync.setChecked(True) + else: + self.chksync.setChecked(False) + elif k == 'inc_evinput': + try: + new_value = float(v) + except ValueError: + print("WARN: bad value for param %s: %s. Unable to convert" + " to a floating point number" % (k, v)) + continue + self.incedit.setText(str(new_value).strip()) + elif k in self.dqline: + if k.startswith('numspikes'): + try: + new_value = int(v) + except ValueError: + print("WARN: bad value for param %s: %s. Unable to convert" + " to a integer" % (k, v)) + continue + self.dqline[k].setText(str(new_value)) + else: + self.set_qline_float(k, v) + elif k.count('gbar') > 0 and \ + (k.count('evprox') > 0 or + k.count('evdist') > 0): + # NOTE: will be deprecated in future release + # for back-compat with old-style specification which didn't have ampa,nmda in evoked gbar + lks = k.split('_') + eloc = lks[1] + enum = lks[2] + base_key_str = 'gbar_' + eloc + '_' + enum + '_' + if eloc == 'evprox': + for ct in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: + # ORIGINAL MODEL/PARAM: only ampa for prox evoked + # inputs + key_str = base_key_str + ct + '_ampa' + self.set_qline_float(key_str, v) + elif eloc == 'evdist': + for ct in ['L2Pyr', 'L2Basket', 'L5Pyr']: + # ORIGINAL MODEL/PARAM: both ampa and nmda for distal + # evoked inputs + key_str = base_key_str + ct + '_ampa' + self.set_qline_float(key_str, v) + key_str = base_key_str + ct + '_nmda' + self.set_qline_float(key_str, v) + + def initUI(self): + self.layout = QVBoxLayout(self) + + # Add stretch to separate the form layout from the button + self.layout.addStretch(1) + + self.ltabs = [] + self.tabs = QTabWidget() + self.layout.addWidget(self.tabs) + + self.button_box = QVBoxLayout() + self.btnprox = QPushButton('Add Proximal Input', self) + self.btnprox.resize(self.btnprox.sizeHint()) + self.btnprox.clicked.connect(self.addProx) + self.btnprox.setToolTip('Add Proximal Input') + self.button_box.addWidget(self.btnprox) + + self.btndist = QPushButton('Add Distal Input', self) + self.btndist.resize(self.btndist.sizeHint()) + self.btndist.clicked.connect(self.addDist) + self.btndist.setToolTip('Add Distal Input') + self.button_box.addWidget(self.btndist) + + self.chksync = QCheckBox('Synchronous Inputs', self) + self.chksync.resize(self.chksync.sizeHint()) + self.chksync.setChecked(True) + self.button_box.addWidget(self.chksync) + + self.incbox = QHBoxLayout() + self.inclabel = QLabel(self) + self.inclabel.setText('Increment start time (ms)') + self.inclabel.adjustSize() + self.inclabel.setToolTip( + 'Increment mean evoked input start time(s) by this amount on each trial.') + self.incedit = QLineEdit(self) + self.incedit.setText('0.0') + self.incbox.addWidget(self.inclabel) + self.incbox.addWidget(self.incedit) + + self.layout.addLayout(self.button_box) + self.layout.addLayout(self.incbox) + + self.tabs.resize(425, 200) + + # Add tabs to widget + self.layout.addWidget(self.tabs) + self.setLayout(self.layout) + + self.setWindowTitle('Evoked Inputs') + + self.addRemoveInputButton() + self.addHideButton() + # self.addtips() + + def lines2val(self, ksearch, val): + for k in self.dqline.keys(): + if k.count(ksearch) > 0: + self.dqline[k].setText(str(val)) + + def __str__(self): + s = '' + for k, v in self.dqline.items(): + s += k + ': ' + v.text().strip() + os.linesep + if self.chksync.isChecked(): + s += 'sync_evinput: 1' + os.linesep + else: + s += 'sync_evinput: 0' + os.linesep + s += 'inc_evinput: ' + self.incedit.text().strip() + os.linesep + return s + + def addRemoveInputButton(self): + self.bbremovebox = QHBoxLayout() + self.btnremove = QPushButton('Remove Input', self) + self.btnremove.resize(self.btnremove.sizeHint()) + self.btnremove.clicked.connect(self.removeCurrentInput) + self.btnremove.setToolTip('Remove This Input') + self.bbremovebox.addWidget(self.btnremove) + self.layout.addLayout(self.bbremovebox) + + def addHideButton(self): + self.bbhidebox = QHBoxLayout() + self.btnhide = QPushButton('Hide Window', self) + self.btnhide.resize(self.btnhide.sizeHint()) + self.btnhide.clicked.connect(self.hide) + self.btnhide.setToolTip('Hide Window') + self.bbhidebox.addWidget(self.btnhide) + self.layout.addLayout(self.bbhidebox) + + def addTab(self, s): + tab = QWidget() + self.ltabs.append(tab) + self.tabs.addTab(tab, s) + tab.layout = QFormLayout() + tab.setLayout(tab.layout) + return tab + + def addFormToTab(self, d, tab): + for k, v in d.items(): + self.dqline[k] = QLineEdit(self) + self.dqline[k].setText(str(v)) + # adds label,QLineEdit to the tab + tab.layout.addRow(self.transvar(k), self.dqline[k]) + + def makePixLabel(self, fn): + pix = QPixmap(fn) + pixlbl = ClickLabel(self) + pixlbl.setPixmap(pix) + return pixlbl + + def addProx(self): + self.nprox += 1 # starts at 1 + dprox = _get_prox_dict(self.nprox) + self.ld.append(dprox) + self.addtransvarfromdict(dprox) + self.addFormToTab(dprox, self.addTab('Proximal ' + str(self.nprox))) + self.ltabs[-1].layout.addRow( + self.makePixLabel(lookupresource('proxfig'))) + # print('index to', len(self.ltabs)-1) + self.tabs.setCurrentIndex(len(self.ltabs)-1) + # print('index now', self.tabs.currentIndex(), ' of ', self.tabs.count()) + # self.addtips() + + def addDist(self): + self.ndist += 1 + ddist = _get_dist_dict(self.ndist) + self.ld.append(ddist) + self.addtransvarfromdict(ddist) + self.addFormToTab(ddist, self.addTab('Distal ' + str(self.ndist))) + self.ltabs[-1].layout.addRow( + self.makePixLabel(lookupresource('distfig'))) + # print('index to', len(self.ltabs)-1) + self.tabs.setCurrentIndex(len(self.ltabs)-1) + # print('index now', self.tabs.currentIndex(), ' of ', self.tabs.count()) + # self.addtips() + + +class OptEvokedInputParamDialog (EvokedInputBaseDialog): + def __init__(self, parent, mainwin): + super(OptEvokedInputParamDialog, self).__init__() + self.nprox = self.ndist = 0 # number of proximal,distal inputs + self.ld = [] # list of dictionaries for proximal/distal inputs + self.dqline = {} # not used, prevents failure in removeInput + + self.dtab_idx = {} # for translating input names to tab indices + self.dtab_names = {} # for translating tab indices to input names + self.dparams = {} # actual values + + # these store values used in grid + self.dqchkbox = {} # optimize + self.dqparam_name = {} # parameter name + self.dqinitial_label = {} # initial + self.dqopt_label = {} # optimtized + self.dqdiff_label = {} # delta + self.dqrange_multiplier = {} # user-defined multiplier + self.dqrange_mode = {} # range mode (stdev, %, absolute) + self.dqrange_slider = {} # slider + self.dqrange_label = {} # defined range + self.dqrange_max = {} + self.dqrange_min = {} + + self.chunk_list = [] + self.lqnumsim = [] + self.lqnumparams = [] + self.lqinputs = [] + self.opt_params = {} + self.initial_opt_ranges = [] + self.dtabdata = [] + self.simlength = 0.0 + self.sim_dt = 0.0 + self.default_num_step_sims = 30 + self.default_num_total_sims = 50 + self.mainwin = mainwin + self.optimization_running = False + self.initUI() + self.parent = parent + self.old_num_steps = 0 + + def initUI(self): + # start with a reasonable size + setscalegeom(self, 150, 150, 475, 300) + + self.ltabs = [] + self.ltabkeys = [] + self.tabs = QTabWidget() + self.din = {} + + self.grid = QGridLayout() + + row = 0 + self.sublayout = QGridLayout() + self.old_numsims = [] + self.grid.addLayout(self.sublayout, row, 0) + + row += 1 + self.grid.addWidget(self.tabs, row, 0) + + row += 1 + self.btnrunop = QPushButton('Run Optimization', self) + self.btnrunop.resize(self.btnrunop.sizeHint()) + self.btnrunop.setToolTip('Run Optimization') + self.btnrunop.clicked.connect(self.runOptimization) + self.grid.addWidget(self.btnrunop, row, 0) + + row += 1 + self.btnreset = QPushButton('Reset Ranges', self) + self.btnreset.resize(self.btnreset.sizeHint()) + self.btnreset.clicked.connect(self.updateOptRanges) + self.btnreset.setToolTip('Reset Ranges') + self.grid.addWidget(self.btnreset, row, 0) + + row += 1 + btnhide = QPushButton('Hide Window', self) + btnhide.resize(btnhide.sizeHint()) + btnhide.clicked.connect(self.hide) + btnhide.setToolTip('Hide Window') + self.grid.addWidget(btnhide, row, 0) + + self.setLayout(self.grid) + + self.setWindowTitle("Configure Optimization") + + # the largest horizontal component will be column 0 (headings) + self.resize(self.minimumSizeHint()) + + def toggle_enable_param(self, label): + + widget_dict_list = [self.dqinitial_label, self.dqopt_label, + self.dqdiff_label, self.dqparam_name, + self.dqrange_mode, self.dqrange_multiplier, + self.dqrange_label, self.dqrange_slider] + + if self.dqchkbox[label].isChecked(): + # set all other fields in the row to enabled + for widget_dict in widget_dict_list: + widget_dict[label].setEnabled(True) + toEnable = True + else: + # disable all other fields in the row + for widget_dict in widget_dict_list: + widget_dict[label].setEnabled(False) + toEnable = False + + self.changeParamEnabledStatus(label, toEnable) + + def addTab(self, id_str): + tab = QWidget() + self.ltabs.append(tab) + + name_str = _trans_input(id_str) + self.tabs.addTab(tab, name_str) + + tab_index = len(self.ltabs)-1 + self.dtab_idx[id_str] = tab_index + self.dtab_names[tab_index] = id_str + + return tab + + def cleanLabels(self): + """ + To avoid memory leaks we need to delete all widgets when we recreate grid. + Go through all tabs and check for each var name (k) + """ + for idx in range(len(self.ltabs)): + for k in self.ld[idx].keys(): + if k in self.dqinitial_label: + del self.dqinitial_label[k] + if k in self.dqopt_label: + del self.dqopt_label[k] + if k in self.dqdiff_label: + del self.dqdiff_label[k] + if k in self.dqparam_name: + del self.dqparam_name[k] + if not self.optimization_running: + if k in self.dqrange_mode: + del self.dqrange_mode[k] + if k in self.dqrange_multiplier: + del self.dqrange_multiplier[k] + if k in self.dqrange_label: + del self.dqrange_label[k] + if k in self.dqrange_slider: + del self.dqrange_slider[k] + if k in self.dqrange_min: + del self.dqrange_min[k] + if k in self.dqrange_max: + del self.dqrange_max[k] + + def addGridToTab(self, d, tab): + from functools import partial + + current_tab = len(self.ltabs)-1 + tab.layout = QGridLayout() + # tab.layout.setSpacing(10) + + self.ltabkeys.append([]) + + # The first row has column headings + row = 0 + self.ltabkeys[current_tab].append("") + for column_index, column_name in enumerate(["Optimize", "Parameter name", + "Initial", "Optimized", "Delta"]): + widget = QLabel(column_name) + widget.resize(widget.sizeHint()) + tab.layout.addWidget(widget, row, column_index) + + column_index += 1 + widget = QLabel("Range specifier") + widget.setMinimumWidth(100) + tab.layout.addWidget(widget, row, column_index, 1, 2) + + column_index += 2 + widget = QLabel("Range slider") + # widget.setMinimumWidth(160) + tab.layout.addWidget(widget, row, column_index) + + column_index += 1 + widget = QLabel("Defined range") + tab.layout.addWidget(widget, row, column_index) + + # The second row is a horizontal line + row = 1 + self.ltabkeys[current_tab].append("") + qthline = QFrame() + qthline.setFrameShape(QFrame.HLine) + qthline.setFrameShadow(QFrame.Sunken) + tab.layout.addWidget(qthline, row, 0, 1, 9) + + # The rest are the parameters + row = 2 + for k, v in d.items(): + self.ltabkeys[current_tab].append(k) + + # create and format widgets + self.dparams[k] = float(v) + self.dqchkbox[k] = QCheckBox() + self.dqchkbox[k].setStyleSheet(""" + .QCheckBox { + spacing: 20px; + } + .QCheckBox::unchecked { + color: grey; + } + .QCheckBox::checked { + color: black; + } + """) + self.dqchkbox[k].setChecked(True) + # use partial instead of lamda (so args won't be evaluated ahead of time?) + self.dqchkbox[k].clicked.connect( + partial(self.toggle_enable_param, k)) + self.dqparam_name[k] = QLabel(self) + self.dqparam_name[k].setText(self.transvar(k)) + self.dqinitial_label[k] = QLabel() + self.dqopt_label[k] = QLabel() + self.dqdiff_label[k] = QLabel() + + # add widgets to grid + tab.layout.addWidget( + self.dqchkbox[k], row, 0, alignment=Qt.AlignBaseline | Qt.AlignCenter) + tab.layout.addWidget(self.dqparam_name[k], row, 1) + tab.layout.addWidget( + self.dqinitial_label[k], row, 2) # initial value + tab.layout.addWidget( + self.dqopt_label[k], row, 3) # optimized value + tab.layout.addWidget(self.dqdiff_label[k], row, 4) # delta + + if k.startswith('t'): + range_mode = "(stdev)" + range_multiplier = "3.0" + elif k.startswith('sigma'): + range_mode = "(%)" + range_multiplier = "50.0" + else: + range_mode = "(%)" + range_multiplier = "500.0" + + if not self.optimization_running: + self.dqrange_slider[k] = QRangeSlider(k, self) + self.dqrange_slider[k].setMinimumWidth(140) + self.dqrange_label[k] = QLabel() + self.dqrange_multiplier[k] = MyLineEdit(range_multiplier, k) + self.dqrange_multiplier[k].textModified.connect( + self.updateRange) + self.dqrange_multiplier[k].setSizePolicy( + QSizePolicy.Ignored, QSizePolicy.Preferred) + self.dqrange_multiplier[k].setMinimumWidth(50) + self.dqrange_multiplier[k].setMaximumWidth(50) + self.dqrange_mode[k] = QLabel(range_mode) + tab.layout.addWidget( + self.dqrange_multiplier[k], row, 5) # range specifier + tab.layout.addWidget( + self.dqrange_mode[k], row, 6) # range mode + tab.layout.addWidget( + self.dqrange_slider[k], row, 7) # range slider + # calculated range + tab.layout.addWidget(self.dqrange_label[k], row, 8) + + row += 1 + + # A spacer in the last row stretches to fill remaining space. + # For inputs with fewer parameters than the rest, this pushes + # parameters to the top with the same spacing as the other inputs. + tab.layout.addItem(QSpacerItem(0, 0), row, 0, 1, 9) + tab.layout.setRowStretch(row, 1) + tab.setLayout(tab.layout) + + def addProx(self): + self.nprox += 1 + dprox = _get_prox_dict(self.nprox) + self.ld.append(dprox) + self.addtransvarfromdict(dprox) + tab = self.addTab('evprox_' + str(self.nprox)) + self.addGridToTab(dprox, tab) + + def addDist(self): + self.ndist += 1 + ddist = _get_dist_dict(self.ndist) + self.ld.append(ddist) + self.addtransvarfromdict(ddist) + tab = self.addTab('evdist_' + str(self.ndist)) + self.addGridToTab(ddist, tab) + + def changeParamEnabledStatus(self, label, toEnable): + import re + + label_match = re.search('(evprox|evdist)_([0-9]+)', label) + if label_match: + my_input_name = label_match.group(1) + '_' + label_match.group(2) + else: + print("ERR: can't determine input name from parameter: %s" % label) + return + + # decrease the count of num params + for chunk_index in range(self.old_num_steps): + for input_name in self.chunk_list[chunk_index]['inputs']: + if input_name == my_input_name: + try: + num_params = int(self.lqnumparams[chunk_index].text()) + except ValueError: + print( + "ERR: could not get number of params for step %d" % chunk_index) + + if toEnable: + num_params += 1 + else: + num_params -= 1 + self.lqnumparams[chunk_index].setText(str(num_params)) + self.opt_params[input_name]['ranges'][label]['enabled'] = toEnable + + def updateRange(self, label, save_slider=True): + import re + + max_width = 0 + + label_match = re.search('(evprox|evdist)_([0-9]+)', label) + if label_match: + tab_name = label_match.group(1) + '_' + label_match.group(2) + else: + print("ERR: can't determine input name from parameter: %s" % label) + return + + if self.dqchkbox[label].isChecked(): + self.opt_params[tab_name]['ranges'][label]['enabled'] = True + else: + self.opt_params[tab_name]['ranges'][label]['enabled'] = False + return + + if tab_name not in self.initial_opt_ranges or \ + label not in self.initial_opt_ranges[tab_name]: + value = self.dparams[label] + else: + value = float(self.initial_opt_ranges[tab_name][label]['initial']) + + range_type = self.dqrange_mode[label].text() + if range_type == "(%)" and value == 0.0: + # change to range from 0 to 1 + range_type = "(max)" + self.dqrange_mode[label].setText(range_type) + self.dqrange_multiplier[label].setText("1.0") + elif range_type == "(max)" and value > 0.0: + # change back to % + range_type = "(%)" + self.dqrange_mode[label].setText(range_type) + self.dqrange_multiplier[label].setText("500.0") + + try: + range_multiplier = float(self.dqrange_multiplier[label].text()) + except ValueError: + range_multiplier = 0.0 + self.dqrange_multiplier[label].setText(str(range_multiplier)) + + if range_type == "(max)": + range_min = 0 + try: + range_max = float(self.dqrange_multiplier[label].text()) + except ValueError: + range_max = 1.0 + elif range_type == "(stdev)": # timing + timing_sigma = self.get_input_timing_sigma(tab_name) + timing_bound = timing_sigma * range_multiplier + range_min = max(0, value - timing_bound) + range_max = min(self.simlength, value + timing_bound) + else: # range_type == "(%)" + range_min = max(0, value - (value * range_multiplier / 100.0)) + range_max = value + (value * range_multiplier / 100.0) + + # set up the slider + self.dqrange_slider[label].setLine(value) + self.dqrange_slider[label].setMin(range_min) + self.dqrange_slider[label].setMax(range_max) + + if not save_slider: + self.dqrange_min.pop(label, None) + self.dqrange_max.pop(label, None) + + self.opt_params[tab_name]['ranges'][label]['initial'] = value + if label in self.dqrange_min and label in self.dqrange_max: + range_min = self.dqrange_min[label] + range_max = self.dqrange_max[label] + + self.opt_params[tab_name]['ranges'][label]['minval'] = range_min + self.opt_params[tab_name]['ranges'][label]['maxval'] = range_max + self.dqrange_slider[label].setRange(range_min, range_max) + + if range_min == range_max: + self.dqrange_label[label].setText( + _format_range_str(range_min)) # use the exact value + self.dqrange_label[label].setEnabled(False) + # uncheck because invalid range + self.dqchkbox[label].setChecked(False) + # disable slider + self.dqrange_slider[label].setEnabled(False) + self.changeParamEnabledStatus(label, False) + else: + self.dqrange_label[label].setText(_format_range_str(range_min) + + " - " + + _format_range_str(range_max)) + + if self.dqrange_label[label].sizeHint().width() > max_width: + max_width = self.dqrange_label[label].sizeHint().width() + 15 + # fix the size for the defined range so that changing the slider doesn't change + # the dialog's width + self.dqrange_label[label].setMinimumWidth(max_width) + self.dqrange_label[label].setMaximumWidth(max_width) + + def prepareOptimization(self): + self.createOptParams() + self.rebuildOptStepInfo() + self.updateOptDeltas() + self.updateOptRanges(save_sliders=True) + self.btnreset.setEnabled(True) + self.btnrunop.setText('Run Optimization') + self.btnrunop.clicked.disconnect() + self.btnrunop.clicked.connect(self.runOptimization) + + def runOptimization(self): + self.current_opt_step = 0 + + # update the ranges to find which parameters have been disabled + # (unchecked) + self.updateOptRanges(save_sliders=True) + + # update the opt info dict to capture num_sims from GUI + self.rebuildOptStepInfo() + self.optimization_running = True + self.populate_initial_opt_ranges() + + # run the actual optimization + num_steps = self.get_num_chunks() + self.mainwin.startoptmodel(num_steps) + + def get_chunk_start(self, step): + return self.chunk_list[step]['opt_start'] + + def get_chunk_end(self, step): + return self.chunk_list[step]['opt_end'] + + def get_chunk_weights(self, step): + return self.chunk_list[step]['weights'] + + def get_num_chunks(self): + return len(self.chunk_list) + + def get_sims_for_chunk(self, step): + try: + num_sims = int(self.lqnumsim[step].text()) + except KeyError: + print("ERR: number of sims not found for step %d" % step) + num_sims = 0 + except ValueError: + if step == self.old_num_steps - 1: + num_sims = self.default_num_total_sims + else: + num_sims = self.default_num_step_sims + + return num_sims + + def get_chunk_ranges(self, step): + ranges = {} + for input_name in self.chunk_list[step]['inputs']: + # make sure initial value is between minval or maxval before returning + # ranges to the optimization + for label in self.opt_params[input_name]['ranges'].keys(): + if not self.opt_params[input_name]['ranges'][label]['enabled']: + continue + range_min = self.opt_params[input_name]['ranges'][label]['minval'] + range_max = self.opt_params[input_name]['ranges'][label]['maxval'] + if range_min > self.opt_params[input_name]['ranges'][label]['initial']: + self.opt_params[input_name]['ranges'][label]['initial'] = range_min + if range_max < self.opt_params[input_name]['ranges'][label]['initial']: + self.opt_params[input_name]['ranges'][label]['initial'] = range_max + + # copy the values to the ranges dict to be returned + # to optimization + ranges[label] = self.opt_params[input_name]['ranges'][label].copy() + + return ranges + + def get_initial_params(self): + initial_params = {} + for input_name in self.opt_params.keys(): + for label in self.opt_params[input_name]['ranges'].keys(): + initial_params[label] = \ + self.opt_params[input_name]['ranges'][label]['initial'] + + return initial_params + + def get_num_params(self, step): + num_params = 0 + + for input_name in self.chunk_list[step]['inputs']: + for label in self.opt_params[input_name]['ranges'].keys(): + if not self.opt_params[input_name]['ranges'][label]['enabled']: + continue + else: + num_params += 1 + + return num_params + + def push_chunk_ranges(self, ranges): + for label, value in ranges.items(): + for tab_name in self.opt_params.keys(): + if label in self.opt_params[tab_name]['ranges']: + self.opt_params[tab_name]['ranges'][label]['initial'] = float( + value) + + def clean_opt_grid(self): + # This is the top part of the Configure Optimization dialog. + + column_count = self.sublayout.columnCount() + row = 0 + while True: + try: + self.sublayout.itemAtPosition(row, 0).widget() + except AttributeError: + # no more rows + break + + for column in range(column_count): + try: + # Use deleteLater() to avoid memory leaks. + self.sublayout.itemAtPosition( + row, column).widget().deleteLater() + except AttributeError: + # if item wasn't found + pass + row += 1 + + # reset data for number of sims per chunk (step) + self.lqnumsim = [] + self.lqnumparams = [] + self.lqinputs = [] + self.old_num_steps = 0 + + def rebuildOptStepInfo(self): + # split chunks from paramter file + self.chunk_list = _chunk_evinputs( + self.opt_params, self.simlength, self.sim_dt) + + if len(self.chunk_list) == 0: + self.clean_opt_grid() + + qlabel = QLabel("No valid evoked inputs to optimize!") + qlabel.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) + qlabel.resize(qlabel.minimumSizeHint()) + self.sublayout.addWidget(qlabel, 0, 0) + self.btnrunop.setEnabled(False) + self.btnreset.setEnabled(False) + else: + self.btnrunop.setEnabled(True) + self.btnreset.setEnabled(True) + + if len(self.chunk_list) < self.old_num_steps or \ + self.old_num_steps == 0: + # clean up the old grid sublayout + self.clean_opt_grid() + + # keep track of inputs to optimize over (check against self.opt_params later) + all_inputs = [] + + # create a new grid sublayout with a row for each optimization step + for chunk_index, chunk in enumerate(self.chunk_list): + chunk['num_params'] = self.get_num_params(chunk_index) + + inputs = [] + for input_name in chunk['inputs']: + all_inputs.append(input_name) + inputs.append(_trans_input(input_name)) + + if chunk_index >= self.old_num_steps: + qlabel = QLabel("Optimization step %d:" % (chunk_index+1)) + qlabel.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) + qlabel.resize(qlabel.minimumSizeHint()) + self.sublayout.addWidget(qlabel, chunk_index, 0) + + self.lqinputs.append(QLabel("Inputs: %s" % ', '.join(inputs))) + self.lqinputs[chunk_index].setAlignment( + Qt.AlignBaseline | Qt.AlignLeft) + self.lqinputs[chunk_index].resize( + self.lqinputs[chunk_index].minimumSizeHint()) + self.sublayout.addWidget( + self.lqinputs[chunk_index], chunk_index, 1) + + # spacer here for readability of input names and reduce size + # of "Num simulations:" + self.sublayout.addItem(QSpacerItem( + 0, 0, hPolicy=QSizePolicy.MinimumExpanding), chunk_index, 2) + + qlabel_params = QLabel("Num params:") + qlabel_params.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) + qlabel_params.resize(qlabel_params.minimumSizeHint()) + self.sublayout.addWidget(qlabel_params, chunk_index, 3) + + self.lqnumparams.append(QLabel(str(chunk['num_params']))) + self.lqnumparams[chunk_index].setAlignment( + Qt.AlignBaseline | Qt.AlignLeft) + self.lqnumparams[chunk_index].resize( + self.lqnumparams[chunk_index].minimumSizeHint()) + self.sublayout.addWidget( + self.lqnumparams[chunk_index], chunk_index, 4) + + qlabel_sims = QLabel("Num simulations:") + qlabel_sims.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) + qlabel_sims.resize(qlabel_sims.minimumSizeHint()) + self.sublayout.addWidget(qlabel_sims, chunk_index, 5) + + if chunk_index == len(self.chunk_list) - 1: + chunk['num_sims'] = self.default_num_total_sims + else: + chunk['num_sims'] = self.default_num_step_sims + self.lqnumsim.append(QLineEdit(str(chunk['num_sims']))) + self.lqnumsim[chunk_index].resize( + self.lqnumsim[chunk_index].minimumSizeHint()) + self.sublayout.addWidget(self.lqnumsim[chunk_index], + chunk_index, 6) + else: + self.lqinputs[chunk_index].setText( + "Inputs: %s" % ', '.join(inputs)) + self.lqnumparams[chunk_index].setText(str(chunk['num_params'])) + + self.old_num_steps = len(self.chunk_list) + + remove_list = [] + # remove a tab if necessary + for input_name in self.opt_params.keys(): + if input_name not in all_inputs and input_name in self.dtab_idx: + remove_list.append(input_name) + + while len(remove_list) > 0: + tab_name = remove_list.pop() + tab_index = self.dtab_idx[tab_name] + + self.removeInput(tab_index) + del self.dtab_idx[tab_name] + del self.dtab_names[tab_index] + self.ltabkeys.pop(tab_index) + + # rebuild dtab_idx and dtab_names + temp_dtab_names = {} + temp_dtab_idx = {} + for new_tab_index, old_tab_index in enumerate(self.dtab_idx.values()): + # self.dtab_idx[id_str] = tab_index + id_str = self.dtab_names[old_tab_index] + temp_dtab_names[new_tab_index] = id_str + temp_dtab_idx[id_str] = new_tab_index + self.dtab_names = temp_dtab_names + self.dtab_idx = temp_dtab_idx + + def toggle_enable_user_fields(self, step, enable=True): + if not enable: + # the optimization called this to disable parameters on + # for the step passed in to this function + self.current_opt_step = step + + for input_name in self.chunk_list[step]['inputs']: + tab_index = self.dtab_idx[input_name] + tab = self.ltabs[tab_index] + + for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer + label = self.ltabkeys[tab_index][row_index] + self.dqchkbox[label].setEnabled(enable) + self.dqrange_slider[label].setEnabled(enable) + self.dqrange_multiplier[label].setEnabled(enable) + + def get_input_timing_sigma(self, tab_name): + """ get timing_sigma from already loaded values """ + + label = 'sigma_t_' + tab_name + try: + timing_sigma = self.dparams[label] + except KeyError: + timing_sigma = 3.0 + print("ERR: Couldn't fing %s. Using default %f" % + (label, timing_sigma)) + + if timing_sigma == 0.0: + # sigma of 0 will not produce a CDF + timing_sigma = 0.01 + + return timing_sigma + + def createOptParams(self): + global decay_multiplier + + self.opt_params = {} + + # iterate through tabs. data is contained in grid layout + for tab_index, tab in enumerate(self.ltabs): + tab_name = self.dtab_names[tab_index] + + # before optimization has started update 'mean', 'sigma', + # 'start', and 'user_end' + start_time_label = 't_' + tab_name + try: + try: + range_multiplier = float( + self.dqrange_multiplier[start_time_label].text()) + except ValueError: + range_multiplier = 0.0 + value = self.dparams[start_time_label] + except KeyError: + print("ERR: could not find start time parameter: %s" % + start_time_label) + continue + + timing_sigma = self.get_input_timing_sigma(tab_name) + self.opt_params[tab_name] = {'ranges': {}, + 'mean': value, + 'sigma': timing_sigma, + 'decay_multiplier': decay_multiplier} + + timing_bound = timing_sigma * range_multiplier + self.opt_params[tab_name]['user_start'] = max( + 0, value - timing_bound) + self.opt_params[tab_name]['user_end'] = min( + self.simlength, value + timing_bound) + + # add an empty dictionary so that rebuildOptStepInfo() can + # determine how many parameters + for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer + label = self.ltabkeys[tab_index][row_index] + self.opt_params[tab_name]['ranges'][label] = {'enabled': True} + + def clear_initial_opt_ranges(self): + self.initial_opt_ranges = {} + + def populate_initial_opt_ranges(self): + self.initial_opt_ranges = {} + + for input_name in self.opt_params.keys(): + self.initial_opt_ranges[input_name] = deepcopy( + self.opt_params[input_name]['ranges']) + + def updateOptDeltas(self): + # iterate through tabs. data is contained in grid layout + for tab_index, tab in enumerate(self.ltabs): + tab_name = self.dtab_names[tab_index] + + # update the initial value + for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer + label = self.ltabkeys[tab_index][row_index] + value = self.dparams[label] + + # Calculate value to put in "Delta" column. When possible, use + # percentages, but when initial value is 0, use absolute + # changes + if tab_name not in self.initial_opt_ranges or \ + not self.dqchkbox[label].isChecked(): + self.dqdiff_label[label].setEnabled(False) + self.dqinitial_label[label].setText( + ("%6f" % self.dparams[label]).rstrip('0').rstrip('.')) + text = '--' + color_fmt = "QLabel { color : black; }" + self.dqopt_label[label].setText(text) + self.dqopt_label[label].setStyleSheet(color_fmt) + self.dqopt_label[label].setAlignment(Qt.AlignHCenter) + self.dqdiff_label[label].setAlignment(Qt.AlignHCenter) + else: + initial_value = float( + self.initial_opt_ranges[tab_name][label]['initial']) + self.dqinitial_label[label].setText( + ("%6f" % initial_value).rstrip('0').rstrip('.')) + self.dqopt_label[label].setText( + ("%6f" % self.dparams[label]).rstrip('0').rstrip('.')) + self.dqopt_label[label].setAlignment( + Qt.AlignVCenter | Qt.AlignLeft) + self.dqdiff_label[label].setAlignment( + Qt.AlignVCenter | Qt.AlignLeft) + + if isclose(value, initial_value, abs_tol=1e-7): + diff = 0 + text = "0.0" + color_fmt = "QLabel { color : black; }" + else: + diff = value - initial_value + + if initial_value == 0: + # can't calculate % + if diff < 0: + text = ("%6f" % diff).rstrip('0').rstrip('.') + color_fmt = "QLabel { color : red; }" + elif diff > 0: + text = ("+%6f" % diff).rstrip('0').rstrip('.') + color_fmt = "QLabel { color : green; }" + else: + # calculate percent difference + percent_diff = 100 * diff/abs(initial_value) + if percent_diff < 0: + text = ("%2.2f %%" % percent_diff) + color_fmt = "QLabel { color : red; }" + elif percent_diff > 0: + text = ("+%2.2f %%" % percent_diff) + color_fmt = "QLabel { color : green; }" + + self.dqdiff_label[label].setStyleSheet(color_fmt) + self.dqdiff_label[label].setText(text) + + def updateRangeFromSlider(self, label, range_min, range_max): + import re + + label_match = re.search('(evprox|evdist)_([0-9]+)', label) + if label_match: + tab_name = label_match.group(1) + '_' + label_match.group(2) + else: + print("ERR: can't determine input name from parameter: %s" % label) + return + + self.dqrange_min[label] = range_min + self.dqrange_max[label] = range_max + self.dqrange_label[label].setText(_format_range_str(range_min) + " - " + + _format_range_str(range_max)) + self.opt_params[tab_name]['ranges'][label]['minval'] = range_min + self.opt_params[tab_name]['ranges'][label]['maxval'] = range_max + + def updateOptRanges(self, save_sliders=False): + # iterate through tabs. data is contained in grid layout + for tab_index, tab in enumerate(self.ltabs): + # now update the ranges + for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer + label = self.ltabkeys[tab_index][row_index] + self.updateRange(label, save_sliders) + + def setfromdin(self, din): + if not din: + return + + if 'dt' in din: + # din proivdes a complete parameter set + self.din = din + self.simlength = float(din['tstop']) + self.sim_dt = float(din['dt']) + + self.cleanLabels() + self.removeAllInputs() # turn off any previously set inputs + self.ltabkeys = [] + self.dtab_idx = {} + self.dtab_names = {} + + for evinput in _get_param_inputs(din): + if 'evprox_' in evinput: + self.addProx() + elif 'evdist_' in evinput: + self.addDist() + + for k, v in din.items(): + if k in self.dparams: + try: + new_value = float(v) + except ValueError: + print("WARN: bad value for param %s: %s. Unable to convert" + " to a floating point number" % (k, v)) + continue + self.dparams[k] = new_value + elif k.count('gbar') > 0 and \ + (k.count('evprox') > 0 or + k.count('evdist') > 0): + # NOTE: will be deprecated in future release + # for back-compat with old-style specification which didn't + # have ampa,nmda in evoked gbar + try: + new_value = float(v) + except ValueError: + print("WARN: bad value for param %s: %s. Unable to convert" + " to a floating point number" % (k, v)) + continue + lks = k.split('_') + eloc = lks[1] + enum = lks[2] + base_key_str = 'gbar_' + eloc + '_' + enum + '_' + if eloc == 'evprox': + for ct in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: + # ORIGINAL MODEL/PARAM: only ampa for prox evoked + # inputs + key_str = base_key_str + ct + '_ampa' + self.dparams[key_str] = new_value + elif eloc == 'evdist': + for ct in ['L2Pyr', 'L2Basket', 'L5Pyr']: + # ORIGINAL MODEL/PARAM: both ampa and nmda for distal + # evoked inputs + key_str = base_key_str + ct + '_ampa' + self.dparams[key_str] = new_value + key_str = base_key_str + ct + '_nmda' + self.dparams[key_str] = new_value + + if not self.optimization_running: + self.createOptParams() + self.rebuildOptStepInfo() + self.updateOptRanges(save_sliders=True) + + self.updateOptDeltas() + + def __str__(self): + # don't write any values to param file + return '' diff --git a/hnn/qt_lib.py b/hnn/qt_lib.py new file mode 100644 index 000000000..aa8ee2261 --- /dev/null +++ b/hnn/qt_lib.py @@ -0,0 +1,453 @@ +"""Miscellaneous Qt functions for HNN GUI""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import os + +from PyQt5.QtWidgets import QWidget, QGridLayout, QLineEdit, QSplitter +from PyQt5.QtWidgets import QHBoxLayout, QGroupBox, QLabel +from PyQt5.QtGui import QColor, QPainter, QFont, QPen +from PyQt5.QtCore import QCoreApplication, pyqtSignal, Qt, QSize +from PyQt5.QtCore import QMetaObject + +DEFAULT_CSS = """ +QRangeSlider * { + border: 0px; + padding: 0px; +} +QRangeSlider #Head { + background-color: rgba(157, 163, 176, 50); +} +QRangeSlider #Span { + background-color: rgba(22, 31, 50, 150); +} +QRangeSlider #Span:active { + background-color: rgba(22, 31, 50, 150); +} +QRangeSlider #Tail { + background-color: rgba(157, 163, 176, 50); +} +QRangeSlider #LineBox { + background-color: rgba(255, 255, 255, 0); +} +QRangeSlider > QSplitter::handle { + background-color: rgba(79, 91, 102, 100); +} +QRangeSlider > QSplitter::handle:vertical { + height: 4px; +} +QRangeSlider > QSplitter::handle:pressed { + background: #ca5; +} +""" + + +def getscreengeom(): + """use pyqt5 to get screen resolution""" + + width, height = 2880, 1620 # default width,height - used for development + app = QCoreApplication.instance() # can only have 1 instance of qtapp + app.setDesktopSettingsAware(True) + if len(app.screens()) > 0: + screen = app.screens()[0] + geom = screen.geometry() + return geom.width(), geom.height() + else: + return width, height + + +def lowresdisplay(): + """check if display has low resolution""" + w, h = getscreengeom() + return w < 1400 or h < 700 + + +def getmplDPI(): + """get DPI for use in matplotlib figures + + used in simulation output canvas - in simdat.py + """ + if lowresdisplay(): + return 60 + return 120 + + +def scalegeom(width, height): + """get new window width, height + + scaled by current screen resolution relative to original + development resolution + """ + devwidth, devheight = 2880.0, 1620.0 # resolution used for development + screenwidth, screenheight = getscreengeom() + widthnew = int((screenwidth / devwidth) * width) + heightnew = int((screenheight / devheight) * height) + if widthnew > 1000 or heightnew > 850: + widthnew = 1000 + heightnew = 850 + return widthnew, heightnew + + +def setscalegeom(dlg, x, y, origw, origh): + """set dialog's position (x,y) and rescale geometry + + based on original width and height and development resolution + """ + + nw, nh = scalegeom(origw, origh) + dlg.setGeometry(x, y, int(nw), int(nh)) + return int(nw), int(nh) + + +def setscalegeomcenter(dlg, origw, origh): + """set dialog in center of screen width + + rescale size based on original width and height and development resolution + """ + + nw, nh = scalegeom(origw, origh) + sw, _ = getscreengeom() + x = (sw - nw) / 2 + y = 0 + dlg.setGeometry(x, y, int(nw), int(nh)) + return int(nw), int(nh) + + +def scale(val, src, dst): + numerator = val - src[0] + denominator = float(src[1] - src[0]) * (dst[1] - dst[0]) + dst[0] + + if denominator == 0: + return 0 + + return numerator / denominator + + +def lookupresource(fn): + """look up resource adjusted for screen resolution""" + lowres = lowresdisplay() # low resolution display + if lowres: + return os.path.join('res', fn + '2.png') + else: + return os.path.join('res', fn + '.png') + + +class Ui_Form(object): + def setupUi(self, Form): + Form.setObjectName("QRangeSlider") + Form.resize(300, 30) + Form.setStyleSheet(DEFAULT_CSS) + self._linebox = QWidget(Form) + self._linebox.setObjectName("LineBox") + self.gridLayout = QGridLayout(Form) + self.gridLayout.setContentsMargins(0, 0, 0, 0) + self.gridLayout.setSpacing(0) + self.gridLayout.setObjectName("gridLayout") + self._splitter = QSplitter(Form) + self._splitter.setMinimumSize(QSize(0, 0)) + self._splitter.setMaximumSize(QSize(16777215, 16777215)) + self._splitter.setOrientation(Qt.Horizontal) + self._splitter.setObjectName("splitter") + self._head = QGroupBox(self._splitter) + self._head.setTitle("") + self._head.setObjectName("Head") + self._handle = QGroupBox(self._splitter) + self._handle.setTitle("") + self._handle.setObjectName("Span") + self._tail = QGroupBox(self._splitter) + self._tail.setTitle("") + self._tail.setObjectName("Tail") + self.gridLayout.addWidget(self._splitter, 0, 0, 1, 1) + self.retranslateUi(Form) + QMetaObject.connectSlotsByName(Form) + + def retranslateUi(self, Form): + _translate = QCoreApplication.translate + Form.setWindowTitle(_translate("QRangeSlider", "QRangeSlider")) + + +class Element(QGroupBox): + def __init__(self, parent, main): + super(Element, self).__init__(parent) + self.main = main + + def setStyleSheet(self, style): + self.parent().setStyleSheet(style) + + def textColor(self): + return getattr(self, '__textColor', QColor(125, 125, 125)) + + def setTextColor(self, color): + if type(color) == tuple and len(color) == 3: + color = QColor(color[0], color[1], color[2]) + elif type(color) == int: + color = QColor(color, color, color) + setattr(self, '__textColor', color) + + def paintEvent(self, event): + qp = QPainter() + qp.begin(self) + if self.main.drawValues(): + self.drawText(event, qp) + qp.end() + + +class Head(Element): + def __init__(self, parent, main): + super(Head, self).__init__(parent, main) + + def drawText(self, event, qp): + qp.setPen(self.textColor()) + qp.setFont(QFont('Arial', 10)) + qp.drawText(event.rect(), Qt.AlignLeft, ("%.3f" % self.main.min())) + + +class Tail(Element): + def __init__(self, parent, main): + super(Tail, self).__init__(parent, main) + + def drawText(self, event, qp): + qp.setPen(self.textColor()) + qp.setFont(QFont('Arial', 10)) + qp.drawText(event.rect(), Qt.AlignRight, ("%.3f" % self.main.max())) + + +class LineBox(Element): + def __init__(self, parent, main): + super(LineBox, self).__init__(parent, main) + + def drawText(self, event, qp): + qp.setPen(QPen(Qt.red, 2, Qt.SolidLine, Qt.SquareCap, Qt.MiterJoin)) + pos = self.main.valueToPos(self.main.line_value) + if (pos == 0): + pos += 1 + qp.drawLine(pos, 0, pos, 50) + + +class Handle(Element): + def __init__(self, parent, main): + super(Handle, self).__init__(parent, main) + + def drawText(self, event, qp): + pass + # qp.setPen(self.textColor()) + # qp.setFont(QFont('Arial', 10)) + # qp.drawText(event.rect(), Qt.AlignLeft, str(self.main.start())) + # qp.drawText(event.rect(), Qt.AlignRight, str(self.main.end())) + + def mouseMoveEvent(self, event): + event.accept() + mx = event.globalX() + _mx = getattr(self, '__mx', None) + if not _mx: + setattr(self, '__mx', mx) + dx = 0 + else: + dx = mx - _mx + setattr(self, '__mx', mx) + if dx == 0: + event.ignore() + return + elif dx > 0: + dx = 1 + elif dx < 0: + dx = -1 + s = self.main.start() + dx + e = self.main.end() + dx + if s >= self.main.min() and e <= self.main.max(): + self.main.setRange(s, e) + + +class QRangeSlider(QWidget, Ui_Form): + endValueChanged = pyqtSignal(int) + maxValueChanged = pyqtSignal(int) + minValueChanged = pyqtSignal(int) + startValueChanged = pyqtSignal(int) + rangeValuesChanged = pyqtSignal(str, float, float) + + _SPLIT_START = 1 + _SPLIT_END = 2 + + def __init__(self, label, parent): + super(QRangeSlider, self).__init__(parent) + self.label = label + self.rangeValuesChanged.connect(parent.updateRangeFromSlider) + self.setupUi(self) + self.setMouseTracking(False) + self._splitter.splitterMoved.connect(self._handleMoveSplitter) + + self._linebox_layout = QHBoxLayout() + self._linebox_layout.setSpacing(0) + self._linebox_layout.setContentsMargins(0, 0, 0, 0) + self._linebox.setLayout(self._linebox_layout) + self.linebox = LineBox(self._linebox, main=self) + self._linebox_layout.addWidget(self.linebox) + self._head_layout = QHBoxLayout() + self._head_layout.setSpacing(0) + self._head_layout.setContentsMargins(0, 0, 0, 0) + self._head.setLayout(self._head_layout) + self.head = Head(self._head, main=self) + self._head_layout.addWidget(self.head) + self._handle_layout = QHBoxLayout() + self._handle_layout.setSpacing(0) + self._handle_layout.setContentsMargins(0, 0, 0, 0) + self._handle.setLayout(self._handle_layout) + self.handle = Handle(self._handle, main=self) + self.handle.setTextColor((150, 255, 150)) + self._handle_layout.addWidget(self.handle) + self._tail_layout = QHBoxLayout() + self._tail_layout.setSpacing(0) + self._tail_layout.setContentsMargins(0, 0, 0, 0) + self._tail.setLayout(self._tail_layout) + self.tail = Tail(self._tail, main=self) + self._tail_layout.addWidget(self.tail) + self.setDrawValues(True) + + def min(self): + return getattr(self, '__min', None) + + def max(self): + return getattr(self, '__max', None) + + def setMin(self, value): + setattr(self, '__min', value) + self.minValueChanged.emit(value) + + def setMax(self, value): + setattr(self, '__max', value) + self.maxValueChanged.emit(value) + + def start(self): + return getattr(self, '__start', None) + + def end(self): + return getattr(self, '__end', None) + + def _setStart(self, value): + setattr(self, '__start', value) + self.startValueChanged.emit(value) + + def setStart(self, value): + v = self.valueToPos(value) + self._splitter.splitterMoved.disconnect() + self._splitter.moveSplitter(v, self._SPLIT_START) + self._splitter.splitterMoved.connect(self._handleMoveSplitter) + self._setStart(value) + + def _setEnd(self, value): + setattr(self, '__end', value) + self.endValueChanged.emit(value) + + def setEnd(self, value): + v = self.valueToPos(value) + self._splitter.splitterMoved.disconnect() + self._splitter.moveSplitter(v, self._SPLIT_END) + self._splitter.splitterMoved.connect(self._handleMoveSplitter) + self._setEnd(value) + + def drawValues(self): + return getattr(self, '__drawValues', None) + + def setLine(self, value): + self.line_value = value + + def setDrawValues(self, draw): + setattr(self, '__drawValues', draw) + + def getRange(self): + return (self.start(), self.end()) + + def setRange(self, start, end): + self.setStart(start) + self.setEnd(end) + + def keyPressEvent(self, event): + key = event.key() + if key == Qt.Key_Left: + s = self.start() - 1 + e = self.end() - 1 + elif key == Qt.Key_Right: + s = self.start() + 1 + e = self.end() + 1 + else: + event.ignore() + return + event.accept() + if s >= self.min() and e <= self.max(): + self.setRange(s, e) + + def setBackgroundStyle(self, style): + self._tail.setStyleSheet(style) + self._head.setStyleSheet(style) + + def setSpanStyle(self, style): + self._handle.setStyleSheet(style) + + def valueToPos(self, value): + return int(scale(value, (self.min(), self.max()), (0, self.width()))) + + def _posToValue(self, xpos): + return scale(xpos, (0, self.width()), (self.min(), self.max())) + + def _handleMoveSplitter(self, xpos, index): + self._splitter.handleWidth() + + def _lockWidth(widget): + width = widget.size().width() + widget.setMinimumWidth(width) + widget.setMaximumWidth(width) + + def _unlockWidth(widget): + widget.setMinimumWidth(0) + widget.setMaximumWidth(16777215) + + if index == self._SPLIT_START: + v = self._posToValue(xpos) + _lockWidth(self._tail) + if v >= self.end(): + return + self._setStart(v) + self.rangeValuesChanged.emit(self.label, v, self.end()) + elif index == self._SPLIT_END: + # account for width of head + xpos += 4 + v = self._posToValue(xpos) + _lockWidth(self._head) + if v <= self.start(): + return + self._setEnd(v) + self.rangeValuesChanged.emit(self.label, self.start(), v) + _unlockWidth(self._tail) + _unlockWidth(self._head) + _unlockWidth(self._handle) + + +class MyLineEdit(QLineEdit): + textModified = pyqtSignal(str) # (label) + + def __init__(self, contents, label, parent=None): + super(MyLineEdit, self).__init__(contents, parent) + self.editingFinished.connect(self.__handleEditingFinished) + self.textChanged.connect(self.__handleTextChanged) + self._before = contents + self._label = label + + def __handleTextChanged(self, text): + if not self.hasFocus(): + self._before = text + + def __handleEditingFinished(self): + before, after = self._before, self.text() + if before != after: + self._before = after + self.textModified.emit(self._label) + + +class ClickLabel(QLabel): + """clickable label""" + + clicked = pyqtSignal() + + def mousePressEvent(self, event): + self.clicked.emit() diff --git a/hnn/qt_main.py b/hnn/qt_main.py new file mode 100644 index 000000000..16505acad --- /dev/null +++ b/hnn/qt_main.py @@ -0,0 +1,1188 @@ +"""Classes for creating the main HNN GUI""" + +# Authors: Sam Neymotin +# Blake Caldwell +# Shane Lee + +# Python builtins +import sys +import os +import multiprocessing +import numpy as np +import traceback +from collections import namedtuple +from copy import deepcopy +from psutil import cpu_count + +# External libraries +from PyQt5.QtWidgets import (QMainWindow, QAction, qApp, QApplication, + QFileDialog, QComboBox, QToolTip, QPushButton, + QGridLayout, QInputDialog, QMenu, QMessageBox, + QWidget, QLayout) +from PyQt5.QtGui import QIcon, QFont +from PyQt5.QtCore import Qt + +from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT +import matplotlib.pyplot as plt + +from hnn_core import read_params +from hnn_core.dipole import average_dipoles + +# HNN modules +from .qt_dialog import (BaseParamDialog, EvokedOrRhythmicDialog, + WaitSimDialog, HelpDialog, SchematicDialog, + bringwintotop) +from .paramrw import (usingOngoingInputs, get_output_dir, + write_gids_param, get_fname) +from .simdata import SimData +from .qt_sim import SIMCanvas +from .qt_thread import SimThread, OptThread +from .qt_lib import (getmplDPI, getscreengeom, lookupresource, + setscalegeomcenter) +from .specfn import spec_dpl_kernel, save_spec_data +from .DataViewGUI import DataViewGUI +from .qt_dipole import DipoleCanvas +from .qt_vsoma import VSomaViewGUI, VSomaCanvas +from .qt_spec import SpecViewGUI, SpecCanvas +from .qt_spike import SpikeViewGUI, SpikeCanvas +from .qt_psd import PSDViewGUI, PSDCanvas + +# TODO: These globals should be made configurable via the GUI +drawavgdpl = 0 +fontsize = plt.rcParams['font.size'] = 10 + + +def _get_defncore(): + """get default number of cores """ + + try: + defncore = len(os.sched_getaffinity(0)) + except AttributeError: + defncore = cpu_count(logical=False) + + if defncore is None or defncore == 0: + # in case psutil is not supported (e.g. BSD) + defncore = multiprocessing.cpu_count() + + return defncore + + +def isWindows(): + # are we on windows? or linux/mac ? + return sys.platform.startswith('win') + + +def getPyComm(): + """get the python command""" + + # check python command interpreter path - if available + if sys.executable is not None: + pyc = sys.executable + if pyc.count('python') > 0 and len(pyc) > 0: + return pyc # full path to python + if isWindows(): + return 'python' + return 'python3' + + +def _add_missing_frames(tb): + fake_tb = namedtuple( + 'fake_tb', ('tb_frame', 'tb_lasti', 'tb_lineno', 'tb_next') + ) + result = fake_tb(tb.tb_frame, tb.tb_lasti, tb.tb_lineno, tb.tb_next) + frame = tb.tb_frame.f_back + while frame: + result = fake_tb(frame, frame.f_lasti, frame.f_lineno, result) + frame = frame.f_back + return result + + +def bringwintobot(win): + # win.show() + # win.lower() + win.hide() + + +class HNNGUI(QMainWindow): + """main HNN GUI class""" + + def __init__(self): + """initialize the main HNN GUI""" + + super().__init__() + sys.excepthook = self.excepthook + + global fontsize + + relative_root_path = os.path.join(os.path.dirname(__file__), '..') + hnn_root_dir = os.path.realpath(relative_root_path) + + self.defncore = _get_defncore() + self.runningsim = False + self.runthread = None + self.fontsize = fontsize + self.linewidth = plt.rcParams['lines.linewidth'] = 1 + self.markersize = plt.rcParams['lines.markersize'] = 5 + self.schemwin = SchematicDialog(self) + self.sim_canvas = self.toolbar = None + paramfn = os.path.join(hnn_root_dir, 'param', 'default.param') + self.prng_seedcore_opt = 0 + self.baseparamwin = BaseParamDialog(self, paramfn) + self.is_optimization = False + self.sim_data = SimData() + self.initUI() + self.helpwin = HelpDialog(self) + self.erselectdistal = \ + EvokedOrRhythmicDialog(self, True, self.baseparamwin.evparamwin, + self.baseparamwin.distparamwin) + self.erselectprox = \ + EvokedOrRhythmicDialog(self, False, self.baseparamwin.evparamwin, + self.baseparamwin.proxparamwin) + self.waitsimwin = WaitSimDialog(self) + + default_param = os.path.join(get_output_dir(), 'data', 'default') + first_load = not (os.path.exists(default_param)) + + if first_load: + QMessageBox.information(self, "HNN", "Welcome to HNN! Default" + " parameter file loaded. Press 'Run" + " Simulation' to display simulation" + " output") + else: + self.statusBar().showMessage("Loaded %s" % default_param) + + def excepthook(self, exc_type, exc_value, exc_tb): + enriched_tb = _add_missing_frames(exc_tb) if exc_tb else exc_tb + # Note: sys.__excepthook__(...) would not work here. + # We need to use print_exception(...): + traceback.print_exception(exc_type, exc_value, enriched_tb) + msgBox = QMessageBox(self) + msgBox.information( + self, "Exception", "WARNING: an exception occurred" + "! Details can be found in the console output. Please " + "include this output when opening an issue on GitHub: " + "" + "https://github.com/jonescompneurolab/hnn/issues") + self.done('Exception') + + def redraw(self): + """redraw simulation and external data""" + self.sim_canvas.plot() + self.sim_canvas.draw() + + def changeFontSize(self): + """bring up window to change font sizes""" + global fontsize + + i, ok = QInputDialog.getInt(self, "Set Font Size", "Font Size:", + plt.rcParams['font.size'], 1, 100, 1) + if ok: + self.fontsize = plt.rcParams['font.size'] = fontsize = i + self.redraw() + + def changeLineWidth(self): + """bring up window to change line width(s)""" + i, ok = QInputDialog.getInt(self, "Set Line Width", "Line Width:", + plt.rcParams['lines.linewidth'], 1, 20, 1) + if ok: + self.linewidth = plt.rcParams['lines.linewidth'] = i + self.redraw() + + def changeMarkerSize(self): + """bring up window to change marker size""" + i, ok = QInputDialog.getInt(self, "Set Marker Size", "Font Size:", + self.markersize, 1, 100, 1) + if ok: + self.markersize = plt.rcParams['lines.markersize'] = i + self.redraw() + + def selParamFileDialog(self): + """bring up window to select simulation parameter file""" + + relative_root_path = os.path.join(os.path.dirname(__file__), '..') + hnn_root_dir = os.path.realpath(relative_root_path) + + qfd = QFileDialog() + qfd.setHistory([os.path.join(get_output_dir(), 'param'), + os.path.join(hnn_root_dir, 'param')]) + fn = qfd.getOpenFileName(self, 'Open param file', + os.path.join(hnn_root_dir, 'param'), + "Param files (*.param)") + if len(fn) > 0 and fn[0] == '': + # no file selected in dialog + return + + tmpfn = os.path.abspath(fn[0]) + + try: + params = read_params(tmpfn) + except ValueError: + QMessageBox.information(self, "HNN", "WARNING: could not" + "retrieve parameters from %s" % + tmpfn) + return + + # check that valid number of trials was given + if 'N_trials' not in params or params['N_trials'] == 0: + print("Warning: invalid configured number of trials." + " Setting 'N_trials' to 1.") + params['N_trials'] = 1 + + # Now update GUI components + self.baseparamwin.paramfn = tmpfn + + # now update the GUI components to reflect the param file selected + self.baseparamwin.updateDispParam(params) + self.setWindowTitle(self.baseparamwin.paramfn) + + self.initSimCanvas() # recreate canvas + + # check if param file exists in combo box already + cb_index = self.cbsim.findText(self.baseparamwin.paramfn) + self.populateSimCB(cb_index) # populate the combobox + + if self.sim_data.get_exp_data_size() > 0: + self.toggleEnableOptimization(True) + + def loadDataFile(self, fn): + """load a dipole data file""" + + extdata = None + try: + extdata = np.loadtxt(fn) + except ValueError: + # possible that data file is comma delimited instead of whitespace + # delimited + try: + extdata = np.loadtxt(fn, delimiter=',') + except ValueError: + QMessageBox.information(self, "HNN", "WARNING: could not load" + " data file %s" % fn) + return False + except IsADirectoryError: + QMessageBox.information(self, "HNN", "WARNING: could not load data" + " file %s" % fn) + return False + + self.sim_data.update_exp_data(fn, extdata) + print('Loaded data in ', fn) + + self.sim_canvas.plot() + self.sim_canvas.draw() # make sure new lines show up in plot + + if self.baseparamwin.paramfn: + self.toggleEnableOptimization(True) + return True + + def loadDataFileDialog(self): + """bring up window to select/load external dipole data""" + hnn_root_dir = \ + os.path.realpath(os.path.join(os.path.dirname(__file__), '..')) + + qfd = QFileDialog() + qfd.setHistory([os.path.join(get_output_dir(), 'data'), + os.path.join(hnn_root_dir, 'data')]) + fn = qfd.getOpenFileName(self, 'Open data file', + os.path.join(hnn_root_dir, 'data'), + "Data files (*.txt)") + if len(fn) > 0 and fn[0] == '': + # no file selected in dialog + return + + # use abspath to make sure have right path separators + self.loadDataFile(os.path.abspath(fn[0])) + + def clearDataFile(self): + """clear external dipole data""" + self.sim_canvas.clearlextdatobj() + self.sim_data.clear_exp_data() + self.toggleEnableOptimization(False) + self.sim_canvas.plot() # recreate canvas + self.sim_canvas.draw() + + def setparams(self): + """show set parameters dialog window""" + if self.baseparamwin: + for win in self.baseparamwin.lsubwin: + bringwintobot(win) + bringwintotop(self.baseparamwin) + + def showAboutDialog(self): + """show HNN's about dialog box""" + from hnn import __version__ + msgBox = QMessageBox(self) + msgBox.setTextFormat(Qt.RichText) + msgBox.setWindowTitle('About') + msgBox.setText("Human Neocortical Neurosolver (HNN) v" + __version__ + + "
" + + "https://hnn.brown.edu" + + "
" + + "" + + "HNN On Github
© 2017-2019 " + "Brown University" + + ", Providence, RI
Software License") + msgBox.setStandardButtons(QMessageBox.Ok) + msgBox.exec_() + + def showOptWarnDialog(self): + # TODO : not implemented yet + msgBox = QMessageBox(self) + msgBox.setTextFormat(Qt.RichText) + msgBox.setWindowTitle('Warning') + msgBox.setText("") + msgBox.setStandardButtons(QMessageBox.Ok) + msgBox.exec_() + + def showHelpDialog(self): + # show the help dialog box + bringwintotop(self.helpwin) + + def show_plot(self, plot_type): + paramfn = self.baseparamwin.paramfn + if paramfn is None: + return + if paramfn in self.sim_data._sim_data: + sim_data = self.sim_data._sim_data[paramfn]['data'] + else: + sim_data = None + + if plot_type == 'dipole': + DataViewGUI(DipoleCanvas, self.baseparamwin.params, sim_data, + 'Dipole Viewer') + elif plot_type == 'vsoma': + VSomaViewGUI(VSomaCanvas, self.baseparamwin.params, sim_data, + 'Somatic Voltages Viewer') + elif plot_type == 'PSD': + PSDViewGUI(PSDCanvas, self.baseparamwin.params, sim_data, + 'PSD Viewer') + elif plot_type == 'spec': + SpecViewGUI(SpecCanvas, self.baseparamwin.params, sim_data, + 'Spectrogram Viewer') + elif plot_type == 'spike': + SpikeViewGUI(SpikeCanvas, self.baseparamwin.params, sim_data, + 'Spike Viewer') + else: + raise ValueError("Unknown plot type") + + def showSomaVPlot(self): + # start the somatic voltage visualization process (separate window) + if not float(self.baseparamwin.params['record_vsoma']): + smsg = 'In order to view somatic voltages you must first rerun' + \ + ' the simulation with saving somatic voltages. To do so' + \ + ' from the main GUI, click on Set Parameters -> Run ->' + \ + ' Analysis -> Save Somatic Voltages, enter a 1 and then' + \ + ' rerun the simulation.' + msg = QMessageBox() + msg.setIcon(QMessageBox.Information) + msg.setText(smsg) + msg.setWindowTitle('Rerun simulation') + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + else: + self.show_plot('vsoma') + + def showPSDPlot(self): + self.show_plot('PSD') + + def showSpecPlot(self): + self.show_plot('spec') + + def showRasterPlot(self): + self.show_plot('spike') + + def showDipolePlot(self): + self.show_plot('dipole') + + def showwaitsimwin(self): + """show the wait sim window (has simulation log)""" + bringwintotop(self.waitsimwin) + + def togAvgDpl(self): + """toggle drawing of the average (across trials) dipole""" + global drawavgdpl + + drawavgdpl = not drawavgdpl + self.sim_canvas.plot() + self.sim_canvas.draw() + + def hidesubwin(self): + """hide GUI's sub windows""" + self.baseparamwin.hide() + self.schemwin.hide() + self.baseparamwin.syngainparamwin.hide() + for win in self.baseparamwin.lsubwin: + win.hide() + self.activateWindow() + + def distribsubwin(self): + """distribute GUI's sub-windows on screen""" + sw, sh = getscreengeom() + lwin = [win for win in self.baseparamwin.lsubwin if win.isVisible()] + if self.baseparamwin.isVisible(): + lwin.insert(0, self.baseparamwin) + if self.schemwin.isVisible(): + lwin.insert(0, self.schemwin) + if self.baseparamwin.syngainparamwin.isVisible(): + lwin.append(self.baseparamwin.syngainparamwin) + curx, cury, maxh = 0, 0, 0 + for win in lwin: + win.move(curx, cury) + curx += win.width() + maxh = max(maxh, win.height()) + if curx >= sw: + curx = 0 + cury += maxh + maxh = win.height() + if cury >= sh: + cury = cury = 0 + + def updateDatCanv(self, params): + """update GUI to reflect param file selected""" + self.baseparamwin.updateDispParam(params) + self.initSimCanvas() # recreate canvas + self.setWindowTitle(self.baseparamwin.paramfn) + + def updateSelectedSim(self, sim_idx): + """Update the sim shown in the ComboBox""" + + paramfn = self.cbsim.itemText(sim_idx) + try: + params = read_params(paramfn) + except ValueError: + QMessageBox.information(self, "HNN", "WARNING: could not" + "retrieve parameters from %s" % + paramfn) + return + self.baseparamwin.paramfn = paramfn + + # update GUI + self.updateDatCanv(params) + self.cbsim.setCurrentIndex(sim_idx) + + def removeSim(self): + """Remove the currently selected simulation""" + + sim_idx = self.cbsim.currentIndex() + paramfn = self.cbsim.itemText(sim_idx) + if not paramfn == '': + self.sim_data.remove_sim_by_fn(paramfn) + + self.cbsim.removeItem(sim_idx) + + # go to last entry + new_simidx = self.cbsim.count() - 1 + if new_simidx < 0: + self.clearSimulations() + else: + self.updateSelectedSim(new_simidx) + + def prevSim(self): + """Go to previous simulation""" + + new_simidx = self.cbsim.currentIndex() - 1 + if new_simidx < 0: + print("There is no previous simulation") + return + else: + self.updateSelectedSim(new_simidx) + + def nextSim(self): + """go to next simulation""" + + if self.cbsim.currentIndex() + 2 > self.cbsim.count(): + print("There is no next simulation") + return + else: + new_simidx = self.cbsim.currentIndex() + 1 + self.updateSelectedSim(new_simidx) + + def clearSimulationData(self): + """clear the simulation data""" + self.baseparamwin.params = None + self.baseparamwin.paramfn = None + + self.sim_data.clear_sim_data() + self.cbsim.clear() # un-populate the combobox + self.toggleEnableOptimization(False) + + def clearSimulations(self): + """clear all simulation data + + erase simulations from canvas (does not clear external data) + """ + self.clearSimulationData() + self.initSimCanvas() # recreate canvas + self.sim_canvas.draw() + self.setWindowTitle('') + + def clearCanvas(self): + # clear all simulation & external data and erase everything from the + # canvas + self.sim_canvas.clearlextdatobj() # clear the external data + self.clearSimulationData() + self.sim_data.clear_exp_data() + self.initSimCanvas() # recreate canvas + self.sim_canvas.draw() + self.setWindowTitle('') + + def initMenu(self): + """initialize the GUI's menu""" + exitAction = QAction(QIcon.fromTheme('exit'), 'Exit', self) + exitAction.setShortcut('Ctrl+Q') + exitAction.setStatusTip('Exit HNN application') + exitAction.triggered.connect(qApp.quit) + + selParamFile = QAction(QIcon.fromTheme('open'), 'Load parameter file', + self) + selParamFile.setShortcut('Ctrl+P') + selParamFile.setStatusTip('Load simulation parameter (.param) file') + selParamFile.triggered.connect(self.selParamFileDialog) + + clearCanv = QAction('Clear canvas', self) + clearCanv.setShortcut('Ctrl+X') + clearCanv.setStatusTip('Clear canvas (simulation+data)') + clearCanv.triggered.connect(self.clearCanvas) + + clearSims = QAction('Clear simulation(s)', self) + # clearSims.setShortcut('Ctrl+X') + clearSims.setStatusTip('Clear simulation(s)') + clearSims.triggered.connect(self.clearSimulations) + + loadDataFile = QAction(QIcon.fromTheme('open'), 'Load data file', + self) + loadDataFile.setShortcut('Ctrl+D') + loadDataFile.setStatusTip('Load (dipole) data file') + loadDataFile.triggered.connect(self.loadDataFileDialog) + + clearDataFileAct = QAction(QIcon.fromTheme('close'), + 'Clear data file(s)', self) + clearDataFileAct.setShortcut('Ctrl+C') + clearDataFileAct.setStatusTip('Clear (dipole) data file(s)') + clearDataFileAct.triggered.connect(self.clearDataFile) + + runSimAct = QAction('Run simulation', self) + runSimAct.setShortcut('Ctrl+S') + runSimAct.setStatusTip('Run simulation') + runSimAct.triggered.connect(self.controlsim) + + self.menubar = self.menuBar() + fileMenu = self.menubar.addMenu('&File') + self.menubar.setNativeMenuBar(False) + fileMenu.addAction(selParamFile) + fileMenu.addSeparator() + fileMenu.addAction(loadDataFile) + fileMenu.addAction(clearDataFileAct) + fileMenu.addSeparator() + fileMenu.addAction(exitAction) + + # part of edit menu for changing drawing properties (line thickness, + # font size, toggle avg dipole drawing) + editMenu = self.menubar.addMenu('&Edit') + viewAvgDplAction = QAction('Toggle Average Dipole Drawing', self) + viewAvgDplAction.setStatusTip('Toggle Average Dipole Drawing') + viewAvgDplAction.triggered.connect(self.togAvgDpl) + editMenu.addAction(viewAvgDplAction) + changeFontSizeAction = QAction('Change Font Size', self) + changeFontSizeAction.setStatusTip('Change Font Size.') + changeFontSizeAction.triggered.connect(self.changeFontSize) + editMenu.addAction(changeFontSizeAction) + changeLineWidthAction = QAction('Change Line Width', self) + changeLineWidthAction.setStatusTip('Change Line Width.') + changeLineWidthAction.triggered.connect(self.changeLineWidth) + editMenu.addAction(changeLineWidthAction) + changeMarkerSizeAction = QAction('Change Marker Size', self) + changeMarkerSizeAction.setStatusTip('Change Marker Size.') + changeMarkerSizeAction.triggered.connect(self.changeMarkerSize) + editMenu.addAction(changeMarkerSizeAction) + editMenu.addSeparator() + editMenu.addAction(clearSims) + # need new act to avoid DBus warning + clearDataFileAct2 = QAction(QIcon.fromTheme('close'), + 'Clear data file(s)', self) + clearDataFileAct2.setStatusTip('Clear (dipole) data file(s)') + clearDataFileAct2.triggered.connect(self.clearDataFile) + editMenu.addAction(clearDataFileAct2) + editMenu.addAction(clearCanv) + + # view menu - to view drawing/visualizations + viewMenu = self.menubar.addMenu('&View') + self.viewDipoleAction = QAction('View Simulation Dipoles', + self) + self.viewDipoleAction.setStatusTip('View Simulation Dipoles') + self.viewDipoleAction.triggered.connect(self.showDipolePlot) + viewMenu.addAction(self.viewDipoleAction) + self.viewRasterAction = QAction('View Simulation Spiking Activity', + self) + self.viewRasterAction.setStatusTip('View Simulation Raster Plot') + self.viewRasterAction.triggered.connect(self.showRasterPlot) + viewMenu.addAction(self.viewRasterAction) + self.viewPSDAction = QAction('View PSD', self) + self.viewPSDAction.setStatusTip('View PSD') + self.viewPSDAction.triggered.connect(self.showPSDPlot) + viewMenu.addAction(self.viewPSDAction) + + self.viewSomaVAction = QAction('View Somatic Voltage', self) + self.viewSomaVAction.setStatusTip('View Somatic Voltage') + self.viewSomaVAction.triggered.connect(self.showSomaVPlot) + viewMenu.addAction(self.viewSomaVAction) + + self.viewSpecAction = QAction('View Spectrograms', self) + self.viewSpecAction.setStatusTip('View Spectrograms/Dipoles' + ' from Experimental Data') + self.viewSpecAction.triggered.connect(self.showSpecPlot) + viewMenu.addAction(self.viewSpecAction) + + viewMenu.addSeparator() + viewSchemAction = QAction('View Model Schematics', self) + viewSchemAction.setStatusTip('View Model Schematics') + viewSchemAction.triggered.connect(self.showschematics) + viewMenu.addAction(viewSchemAction) + viewSimLogAction = QAction('View Simulation Log', self) + viewSimLogAction.setStatusTip('View Detailed Simulation Log') + viewSimLogAction.triggered.connect(self.showwaitsimwin) + viewMenu.addAction(viewSimLogAction) + viewMenu.addSeparator() + distributeWindowsAction = QAction('Distribute Windows', self) + distributeWindowsAction.setStatusTip('Distribute Parameter Windows' + ' Across Screen.') + distributeWindowsAction.triggered.connect(self.distribsubwin) + viewMenu.addAction(distributeWindowsAction) + hideWindowsAction = QAction('Hide Windows', self) + hideWindowsAction.setStatusTip('Hide Parameter Windows.') + hideWindowsAction.triggered.connect(self.hidesubwin) + hideWindowsAction.setShortcut('Ctrl+H') + viewMenu.addAction(hideWindowsAction) + + simMenu = self.menubar.addMenu('&Simulation') + setParmAct = QAction('Set Parameters', self) + setParmAct.setStatusTip('Set Simulation Parameters') + setParmAct.triggered.connect(self.setparams) + simMenu.addAction(setParmAct) + simMenu.addAction(runSimAct) + setOptParamAct = QAction('Configure Optimization', self) + setOptParamAct.setShortcut('Ctrl+O') + setOptParamAct.setStatusTip('Set parameters for evoked input' + ' optimization') + setOptParamAct.triggered.connect(self.showoptparamwin) + simMenu.addAction(setOptParamAct) + self.toggleEnableOptimization(False) + prevSimAct = QAction('Go to Previous Simulation', self) + prevSimAct.setShortcut('Ctrl+Z') + prevSimAct.setStatusTip('Go Back to Previous Simulation') + prevSimAct.triggered.connect(self.prevSim) + simMenu.addAction(prevSimAct) + nextSimAct = QAction('Go to Next Simulation', self) + nextSimAct.setShortcut('Ctrl+Y') + nextSimAct.setStatusTip('Go Forward to Next Simulation') + nextSimAct.triggered.connect(self.nextSim) + simMenu.addAction(nextSimAct) + # need another QAction to avoid DBus warning + clearSims2 = QAction('Clear simulation(s)', self) + clearSims2.setStatusTip('Clear simulation(s)') + clearSims2.triggered.connect(self.clearSimulations) + simMenu.addAction(clearSims2) + + aboutMenu = self.menubar.addMenu('&About') + aboutAction = QAction('About HNN', self) + aboutAction.setStatusTip('About HNN') + aboutAction.triggered.connect(self.showAboutDialog) + aboutMenu.addAction(aboutAction) + helpAction = QAction('Help', self) + helpAction.setStatusTip('Help on how to use HNN (parameters).') + helpAction.triggered.connect(self.showHelpDialog) + # aboutMenu.addAction(helpAction) + + def toggleEnableOptimization(self, toEnable): + for menu in self.menubar.findChildren(QMenu): + if menu.title() == '&Simulation': + for item in menu.actions(): + if item.text() == 'Configure Optimization': + item.setEnabled(toEnable) + break + break + + def addButtons(self, gRow): + self.pbtn = pbtn = QPushButton('Set Parameters', self) + pbtn.setToolTip('Set Parameters') + pbtn.resize(pbtn.sizeHint()) + pbtn.clicked.connect(self.setparams) + self.grid.addWidget(self.pbtn, gRow, 0, 1, 3) + + self.pfbtn = pfbtn = QPushButton('Set Parameters From File', self) + pfbtn.setToolTip('Set Parameters From File') + pfbtn.resize(pfbtn.sizeHint()) + pfbtn.clicked.connect(self.selParamFileDialog) + self.grid.addWidget(self.pfbtn, gRow, 3, 1, 3) + + self.btnsim = btn = QPushButton('Run Simulation', self) + btn.setToolTip('Run Simulation') + btn.resize(btn.sizeHint()) + btn.clicked.connect(self.controlsim) + self.grid.addWidget(self.btnsim, gRow, 6, 1, 3) + + self.qbtn = qbtn = QPushButton('Quit', self) + qbtn.clicked.connect(QApplication.exit) + qbtn.resize(qbtn.sizeHint()) + self.grid.addWidget(self.qbtn, gRow, 9, 1, 3) + + def shownetparamwin(self): + bringwintotop(self.baseparamwin.netparamwin) + + def showoptparamwin(self): + bringwintotop(self.baseparamwin.optparamwin) + + def showdistparamwin(self): + bringwintotop(self.erselectdistal) + + def showproxparamwin(self): + bringwintotop(self.erselectprox) + + def showschematics(self): + bringwintotop(self.schemwin) + + def addParamImageButtons(self, gRow): + """add parameter image buttons to the GUI""" + + self.locbtn = QPushButton('Local Network' + os.linesep + + 'Connections', self) + self.locbtn.setIcon(QIcon(lookupresource('connfig'))) + self.locbtn.clicked.connect(self.shownetparamwin) + self.grid.addWidget(self.locbtn, gRow, 0, 1, 4) + + self.proxbtn = QPushButton('Proximal Drive' + os.linesep + + 'Thalamus', + self) + self.proxbtn.setIcon(QIcon(lookupresource('proxfig'))) + self.proxbtn.clicked.connect(self.showproxparamwin) + self.grid.addWidget(self.proxbtn, gRow, 4, 1, 4) + + self.distbtn = QPushButton('Distal Drive NonLemniscal' + + os.linesep + + 'Thal./Cortical Feedback', self) + self.distbtn.setIcon(QIcon(lookupresource('distfig'))) + self.distbtn.clicked.connect(self.showdistparamwin) + self.grid.addWidget(self.distbtn, gRow, 8, 1, 4) + + gRow += 1 + + def initUI(self): + """initialize the user interface (UI)""" + + self.initMenu() + self.statusBar() + + # start GUI in center of screenm, scale based on screen w x h + setscalegeomcenter(self, 1500, 1300) + + # move param windows to be offset from main GUI + new_x = max(0, self.x() - 300) + new_y = max(0, self.y() + 100) + self.baseparamwin.move(new_x, new_y) + self.baseparamwin.evparamwin.move(new_x+50, new_y+50) + self.baseparamwin.optparamwin.move(new_x+100, new_y+100) + self.setWindowTitle(self.baseparamwin.paramfn) + QToolTip.setFont(QFont('SansSerif', 10)) + + self.grid = grid = QGridLayout() + # grid.setSpacing(10) + + gRow = 0 + + self.addButtons(gRow) + + gRow += 1 + + self.initSimCanvas(gRow=gRow, reInit=False) + gRow += 2 + + self.cbsim = QComboBox(self) + try: + self.populateSimCB() + except ValueError: + # If no simulations could be loaded into combobox + # don't crash the initialization process + print("Warning: no simulations to load") + pass + self.cbsim.activated[str].connect(self.onActivateSimCB) + self.grid.addWidget(self.cbsim, gRow, 0, 1, 8) + self.btnrmsim = QPushButton('Remove Simulation', self) + self.btnrmsim.resize(self.btnrmsim.sizeHint()) + self.btnrmsim.clicked.connect(self.removeSim) + self.btnrmsim.setToolTip('Remove Currently Selected Simulation') + self.grid.addWidget(self.btnrmsim, gRow, 8, 1, 4) + + gRow += 1 + self.addParamImageButtons(gRow) + + # need a separate widget to put grid on + widget = QWidget(self) + widget.setLayout(grid) + self.setCentralWidget(widget) + + self.setWindowIcon(QIcon(os.path.join('res', 'icon.png'))) + + self.schemwin.show() # so it's underneath main window + + self.show() + + def onActivateSimCB(self, paramfn): + """load simulation when activating simulation combobox""" + + if paramfn != self.baseparamwin.paramfn: + try: + params = read_params(paramfn) + except ValueError: + QMessageBox.information(self, "HNN", "WARNING: could not" + "retrieve parameters from %s" % + paramfn) + return + self.baseparamwin.paramfn = paramfn + + self.updateDatCanv(params) + + def populateSimCB(self, index=None): + """populate the simulation combobox""" + + self.cbsim.clear() + for paramfn in self.sim_data._sim_data.keys(): + self.cbsim.addItem(paramfn) + + if self.cbsim.count() == 0: + raise ValueError("No simulations to add to combo box") + + if index is None or index < 0: + # set to last entry + self.cbsim.setCurrentIndex(self.cbsim.count() - 1) + else: + self.cbsim.setCurrentIndex(index) + + def initSimCanvas(self, gRow=1, reInit=True): + """initialize the simulation canvas, loading any required data""" + gCol = 0 + + if reInit: + self.grid.itemAtPosition(gRow, gCol).widget().deleteLater() + self.grid.itemAtPosition(gRow + 1, gCol).widget().deleteLater() + + # if just initialized or after clearSimulationData + if self.baseparamwin.paramfn and self.baseparamwin.params is None: + try: + self.baseparamwin.params = read_params( + self.baseparamwin.paramfn) + except ValueError: + QMessageBox.information(self, "HNN", "WARNING: could not" + "retrieve parameters from %s" % + self.baseparamwin.paramfn) + return + + self.sim_canvas = SIMCanvas(self.baseparamwin.paramfn, + self.baseparamwin.params, + parent=self, width=10, height=1, + dpi=getmplDPI(), + is_optimization=self.is_optimization) + + # this is the Navigation widget + # it takes the Canvas widget and a parent + self.toolbar = NavigationToolbar2QT(self.sim_canvas, self) + gWidth = 12 + self.grid.addWidget(self.toolbar, gRow, gCol, 1, gWidth) + self.grid.addWidget(self.sim_canvas, gRow + 1, gCol, 1, gWidth) + + if self.sim_canvas.saved_exception is not None: + raise self.sim_canvas.saved_exception + + def setcursors(self, cursor): + """set cursors of self and children""" + + self.setCursor(cursor) + self.update() + kids = self.children() + kids.append(self.sim_canvas) # matplotlib simcanvas + for k in kids: + if type(k) == QLayout or type(k) == QAction: + # These types don't have setCursor() + continue + k.setCursor(cursor) + k.update() + + def startoptmodel(self, num_steps): + """start model optimization""" + if self.runningsim: + raise ValueError("Optimization already running") + + self.is_optimization = True + if not self.baseparamwin.saveparams(): + # user may have pressed 'cancel' + return + + # optimize the model + print('Starting model optimization. . .') + # save initial parameters file + # data_dir = op.join(get_output_dir(), 'data') + # sim_dir = op.join(data_dir, self.params['sim_prefix']) + # param_out = os.path.join(sim_dir, 'before_opt.param') + # write_legacy_paramf(param_out, self.params) + seed = self.baseparamwin.runparamwin.get_prng_seedcore_opt() + ncore = self.baseparamwin.runparamwin.getncore() + self.runthread = OptThread(ncore, self.baseparamwin.params, + num_steps, + seed, self.sim_data, + self.sim_result_callback, + self.opt_callback, mainwin=self) + self.runningsim = True + self.runthread.start() + + # update optimization dialog + self.baseparamwin.optparamwin.btnreset.setEnabled(False) + self.baseparamwin.optparamwin.btnrunop.setText('Stop Optimization') + self.baseparamwin.optparamwin.btnrunop.clicked.disconnect() + self.baseparamwin.optparamwin.btnrunop.clicked.connect( + self.stopsim) + + # update GUI + self.statusBar().showMessage("Optimizing model. . .") + self.setcursors(Qt.WaitCursor) + self.btnsim.setText("Stop Optimization") + self.qbtn.setEnabled(False) + self.waitsimwin.updatetxt('Optimizing model. . .') + bringwintotop(self.waitsimwin) + + def controlsim(self): + """control the simulation""" + if self.runningsim: + # stop sim works but leaves subproc as zombie until this main + # GUI + self.stopsim() + else: + self.is_optimization = False + self.startsim(self.baseparamwin.runparamwin.getncore()) + + def stopsim(self): + """stop the simulation""" + if self.runningsim: + self.waitsimwin.hide() + print('Terminating simulation. . .') + self.statusBar().showMessage('Terminating sim. . .') + self.runningsim = False + self.runthread.stop() # killed = True # terminate() + self.runthread.wait(1000) + self.runthread.terminate() + self.btnsim.setText("Run Simulation") + self.qbtn.setEnabled(True) + self.statusBar().showMessage('') + self.setcursors(Qt.ArrowCursor) + + if self.is_optimization: + self.baseparamwin.optparamwin.btnrunop.setText( + 'Run Optimization') + self.baseparamwin.optparamwin.btnrunop.clicked.disconnect() + self.baseparamwin.optparamwin.btnrunop.clicked.connect( + self.baseparamwin.optparamwin.runOptimization) + self.is_optimization = False + + def startsim(self, ncore): + """start the simulation""" + # update self.self.baseparamwin.params with values from GUI + # and save to file + if not self.baseparamwin.saveparams(): + return # make sure params saved and ok to run + + self.setcursors(Qt.WaitCursor) + + print('Starting simulation (%d cores). . .' % ncore) + self.runningsim = True + + self.statusBar().showMessage("Running simulation. . .") + + # check that valid number of trials was given + if 'N_trials' not in self.baseparamwin.params or \ + self.baseparamwin.params['N_trials'] == 0: + print("Warning: invalid configured number of trials." + " Setting to 1.") + self.baseparamwin.params['N_trials'] = 1 + + if self.baseparamwin.params['record_vsoma'] and ncore > 1: + txt = 'A bug currently prevents recording somatic voltages' + \ + ' for simulatioins run on more than one core. This' + \ + ' simulation will proceed using only a single core.' + msg = QMessageBox() + msg.setIcon(QMessageBox.Information) + msg.setText(txt) + msg.setWindowTitle('Rerun simulation') + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + ncore = 1 + + self.runthread = SimThread(ncore, self.baseparamwin.params, + self.sim_result_callback, mainwin=self) + + # We have all the events we need connected we can start the thread + self.runthread.start() + # At this point we want to allow user to stop/terminate the thread + # so we enable that button + self.btnsim.setText("Stop Simulation") # setEnabled(False) + self.qbtn.setEnabled(False) + + bringwintotop(self.waitsimwin) + + def sim_result_callback(self, result): + sim_data = result.data + sim_data['spec'] = [] + params = result.params + + sim_data['dpls'] = deepcopy(sim_data['raw_dpls']) + ntrial = len(sim_data['raw_dpls']) + for trial_idx in range(ntrial): + window_len = params['dipole_smooth_win'] # specified in ms + fctr = params['dipole_scalefctr'] + if window_len > 0: # param files set this to zero for no smoothing + sim_data['dpls'][trial_idx].smooth(window_len=window_len) + if fctr > 0: + sim_data['dpls'][trial_idx].scale(fctr) + + # save average dipole from individual trials in a single file + if ntrial > 1: + sim_data['avg_dpl'] = average_dipoles(sim_data['dpls']) + elif ntrial == 1: + sim_data['avg_dpl'] = sim_data['dpls'][0] + else: + raise ValueError("No dipole(s) returned from simulation") + + # make sure the directory for saving data has been created + data_dir = os.path.join(get_output_dir(), 'data') + sim_dir = os.path.join(data_dir, params['sim_prefix']) + try: + os.mkdir(sim_dir) + except FileExistsError: + pass + + # TODO: Can below be removed if spk.txt is new hnn-core format with 3 + # columns (including spike type)? + # Follow https://github.com/jonescompneurolab/hnn-core/issues/219 + write_gids_param(get_fname(sim_dir, 'param'), sim_data['gid_ranges']) + + # save spikes by trial + glob = os.path.join(sim_dir, 'spk_%d.txt') + sim_data['spikes'].write(glob) + + # save dipole for each trial and perform spectral analysis + for trial_idx, dpl in enumerate(sim_data['dpls']): + dipole_fn = get_fname(sim_dir, 'normdpl', trial_idx) + dpl.write(dipole_fn) + + if params['save_dpl']: + raw_dipole_fn = get_fname(sim_dir, 'rawdpl', trial_idx) + sim_data['raw_dpls'][trial_idx].write(raw_dipole_fn) + + if params['save_spec_data'] or \ + usingOngoingInputs(params): + spec_results = spec_dpl_kernel(dpl, params['f_max_spec'], + params['dt'], params['tstop']) + sim_data['spec'].append(spec_results) + + if params['save_spec_data']: + spec_fn = get_fname(sim_dir, 'rawspec', trial_idx) + save_spec_data(spec_fn, spec_results) + + paramfn = os.path.join(get_output_dir(), 'param', + params['sim_prefix'] + '.param') + + self.sim_data.update_sim_data(paramfn, params, sim_data['dpls'], + sim_data['avg_dpl'], sim_data['spikes'], + sim_data['gid_ranges'], + sim_data['spec'], sim_data['vsoma']) + + def opt_callback(self): + # re-enable all the range sliders (last step) + self.baseparamwin.optparamwin.toggle_enable_user_fields( + self.baseparamwin.optparamwin.get_num_chunks() - 1, + enable=True) + + self.baseparamwin.optparamwin.clear_initial_opt_ranges() + self.baseparamwin.optparamwin.optimization_running = False + # self.done() + + def done(self, except_msg=''): + """called when the simulation completes running""" + self.runningsim = False + self.waitsimwin.hide() + self.statusBar().showMessage("") + self.btnsim.setText("Run Simulation") + self.qbtn.setEnabled(True) + + if len(except_msg) > 0: + failed = True + else: + failed = False + + if failed: + msg = "%s: Failed " % except_msg + + if self.is_optimization: + msg += "running optimization " + self.baseparamwin.optparamwin.btnrunop.setText( + 'Run Optimization') + self.baseparamwin.optparamwin.btnrunop.clicked.disconnect() + self.baseparamwin.optparamwin.btnrunop.clicked.connect( + self.baseparamwin.optparamwin.runOptimization) + else: + msg += "running sim " + + QMessageBox.critical(self, "Failed!", msg + "using " + + self.baseparamwin.paramfn + + '. Check simulation log or console for error ' + 'messages') + else: + # save params to file after successful completion + self.baseparamwin.saveparams(checkok=False) + + msg = "Finished " + + if self.is_optimization: + msg += "running optimization " + self.baseparamwin.optparamwin.btnrunop.setText( + 'Prepare for Another Optimization') + self.baseparamwin.optparamwin.btnrunop.clicked.disconnect() + self.baseparamwin.optparamwin.btnrunop.clicked.connect( + self.baseparamwin.optparamwin.prepareOptimization) + else: + msg += "running sim " + + if self.baseparamwin.params['save_figs']: + self.sim_data.save_dipole_with_hist(self.baseparamwin.paramfn, + self.baseparamwin.params) + self.sim_data.save_spec_with_hist(self.baseparamwin.paramfn, + self.baseparamwin.params) + + if self.baseparamwin.params['record_vsoma']: + self.sim_data.save_vsoma(self.baseparamwin.paramfn, + self.baseparamwin.params) + + data_dir = os.path.join(get_output_dir(), 'data') + sim_dir = os.path.join(data_dir, + self.baseparamwin.params['sim_prefix']) + QMessageBox.information(self, "Done!", msg + "using " + + self.baseparamwin.paramfn + + '. Saved data/figures in: ' + sim_dir) + + # recreate canvas (plots too) to avoid incorrect axes + self.initSimCanvas() + # self.sim_canvas.plot() + self.setcursors(Qt.ArrowCursor) + + self.setWindowTitle(self.baseparamwin.paramfn) + + self.populateSimCB() # populate the combobox + cb_index = self.cbsim.findText(self.baseparamwin.paramfn) + if cb_index >= 0: + self.cbsim.setCurrentIndex(cb_index) + + self.is_optimization = False + + +if __name__ == '__main__': + app = QApplication(sys.argv) + HNNGUI() + sys.exit(app.exec_()) diff --git a/hnn/qt_psd.py b/hnn/qt_psd.py new file mode 100644 index 000000000..71359f597 --- /dev/null +++ b/hnn/qt_psd.py @@ -0,0 +1,346 @@ +import os +from math import sqrt +from copy import deepcopy + +from PyQt5.QtWidgets import QSizePolicy, QAction, QFileDialog +from PyQt5.QtGui import QIcon + +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure +import matplotlib.gridspec as gridspec + +from hnn_core.dipole import average_dipoles, Dipole + +from .DataViewGUI import DataViewGUI +from .specfn import spec_dpl_kernel, extract_spec + +fontsize = plt.rcParams['font.size'] = 10 +random_label = np.random.rand(100) + + +def extract_psd(dpl, f_max_spec): + """Extract PSDs for layers using Morlet method + + Parameters + ---------- + dpls: Dipole object + Dipole for a single trial + f_max_spec: float + Maximum frequency of analysis + + Returns + ---------- + F: array + Frequencies associated with Morlet spectral analysis + psds: list of MortletSpec objects + List containing results of spectral analysis for each layer + + """ + + psds = [] + dt = dpl.times[1] - dpl.times[0] + tstop = dpl.times[-1] + + spec_results = spec_dpl_kernel(dpl, f_max_spec, dt, tstop) + + for col in ['TFR', 'TFR_L2', 'TFR_L5']: + psds.append(np.mean(spec_results[col], axis=1)) + + return spec_results['freq'], np.array(psds) + + +class PSDCanvas(FigureCanvasQTAgg): + def __init__(self, params, sim_data, index, parent=None, width=12, + height=10, dpi=120, title='PSD Viewer'): + FigureCanvasQTAgg.__init__(self, Figure(figsize=(width, height), + dpi=dpi)) + self.title = title + self.setParent(parent) + self.gui = parent + self.index = index + FigureCanvasQTAgg.setSizePolicy(self, QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + self.params = params + self.invertedhistax = False + self.G = gridspec.GridSpec(10, 1) + self.dpls = self.gui.dpls + self.specs = self.gui.specs + self.avg_spec = self.gui.avg_spec + self.avg_dpl = self.gui.avg_dpl + self.lextdatobj = [] + + self.plot() + + def drawpsd(self, dspec, fig, G, ltextra=''): + global random_label + + lax = [] + avgs = [] + stds = [] + + lkF = ['f_L2', 'f_L5', 'f_L2'] + lkS = ['TFR_L2', 'TFR_L5', 'TFR'] + + plt.ion() + + gdx = 311 + + ltitle = ['Layer 2/3', 'Layer 5', 'Aggregate'] + + yl = [1e9, -1e9] + + for _, kS in enumerate(lkS): + avg = np.mean(dspec[kS], axis=1) + std = np.std(dspec[kS], axis=1) / sqrt(dspec[kS].shape[1]) + yl[0] = min(yl[0], np.amin(avg - std)) + yl[1] = max(yl[1], np.amax(avg + std)) + avgs.append(avg) + stds.append(std) + + yl = tuple(yl) + xl = (dspec['f_L2'][0], dspec['f_L2'][-1]) + + for i, kS in enumerate(lkS): + ax = fig.add_subplot(gdx, label=random_label) + random_label += 1 + lax.append(ax) + + if i == 2: + ax.set_xlabel('Frequency (Hz)') + + ax.plot(dspec[lkF[i]], np.mean(dspec[lkS[i]], axis=1), color='w', + linewidth=self.gui.linewidth + 2) + ax.plot(dspec[lkF[i]], avgs[i] - stds[i], color='gray', + linewidth=self.gui.linewidth) + ax.plot(dspec[lkF[i]], avgs[i] + stds[i], color='gray', + linewidth=self.gui.linewidth) + + ax.set_ylim(yl) + ax.set_xlim(xl) + + ax.set_facecolor('k') + ax.grid(True) + ax.set_title(ltitle[i]) + ax.set_ylabel(r'$nAm^2$') + + gdx += 1 + return lax + + def clearaxes(self): + for ax in self.lax: + ax.set_yticks([]) + ax.cla() + + def clearlextdatobj(self): + # clear list of external data objects + for o in self.lextdatobj: + if isinstance(o, list): + # this is the plot. clear the line + o[0].set_visible(False) + else: + # this is the legend entry + o.set_visible(False) + del self.lextdatobj + self.lextdatobj = [] # reset list of external data objects + + def plotextdat(self, lF, lextpsd, lextfiles): + """plot 'external' data (e.g. from experiment/other simulation)""" + + white_patch = mpatches.Patch(color='white', label='Simulation') + self.lpatch = [white_patch] + + ax = self.lax[2] # plot on agg + + yl = ax.get_ylim() + + cmap = plt.get_cmap('nipy_spectral') + csm = plt.cm.ScalarMappable(cmap=cmap) + csm.set_clim((0, 100)) + + for f, lpsd, fname in zip(lF, lextpsd, lextfiles): + clr = csm.to_rgba(int(np.random.RandomState().uniform(5, 101, 1))) + avg = np.mean(lpsd, axis=0) + std = np.std(lpsd, axis=0) / sqrt(lpsd.shape[1]) + self.lextdatobj.append(ax.plot(f, avg, color=clr, + linewidth=self.gui.linewidth + 2)) + self.lextdatobj.append(ax.plot(f, avg - std, '--', color=clr, + linewidth=self.gui.linewidth)) + self.lextdatobj.append(ax.plot(f, avg + std, '--', color=clr, + linewidth=self.gui.linewidth)) + yl = ((min(yl[0], min(avg))), (max(yl[1], max(avg)))) + label_str = fname.split(os.path.sep)[-1].split('.txt')[0] + new_patch = mpatches.Patch(color=clr, label=label_str) + self.lpatch.append(new_patch) + + ax.set_ylim(yl) + self.lextdatobj.append(ax.legend(handles=self.lpatch)) + + def plot(self): + if len(self.specs) == 0: + # data hasn't been loaded yet + return + + if self.index == 0: + ltextra = 'All Trials' + self.lax = self.drawpsd(self.avg_spec, self.figure, self.G, + ltextra=ltextra) + else: + ltextra = 'Trial ' + str(self.index) + self.lax = self.drawpsd(self.specs[self.index - 1], self.figure, + self.G, ltextra=ltextra) + + self.figure.subplots_adjust(bottom=0.06, left=0.06, right=0.98, + top=0.97, wspace=0.1, hspace=0.09) + + self.draw() + + +class PSDViewGUI(DataViewGUI): + """Class for displaying spectrogram viewer + + Required parameters: N_trials, f_max_spec, sim_prefix + """ + def __init__(self, CanvasType, params, sim_data, title): + self.specs = [] # used by drawspec + self.psds = [] # used by plotextdat + self.lextfiles = [] # external data files + self.lF = [] # frequencies associated with external data psd + self.dpls = None + self.avg_dpl = [] + self.avg_spec = {} + self.params = params + + # used by loadSimData + self.sim_data = sim_data + super(PSDViewGUI, self).__init__(CanvasType, params, sim_data, title) + self.addLoadDataActions() + self.loadSimData() + + def addLoadDataActions(self): + loadDataFile = QAction(QIcon.fromTheme('open'), 'Load data file.', + self) + loadDataFile.setShortcut('Ctrl+D') + loadDataFile.setStatusTip('Load experimental (.txt) data.') + loadDataFile.triggered.connect(self.loadDisplayData) + + clearDataFileAct = QAction(QIcon.fromTheme('close'), + 'Clear data.', self) + clearDataFileAct.setShortcut('Ctrl+C') + clearDataFileAct.setStatusTip('Clear data.') + clearDataFileAct.triggered.connect(self.clearDataFile) + + self.fileMenu.addAction(loadDataFile) + self.fileMenu.addAction(clearDataFileAct) + + def loadSimData(self): + """Load and plot from SimData""" + + # store copy of data in this object, that can be reused by + # canvas (self.m) on re-instantiation + if self.sim_data is not None: + self.avg_dpl = self.sim_data['avg_dpl'] + self.dpls = self.sim_data['dpls'] + self.specs = self.sim_data['spec'] + if self.specs is None or len(self.specs) == 0: + self.specs = extract_spec(self.dpls, self.params['f_max_spec']) + + # calculate TFR from spec trial data + self.avg_spec = deepcopy(self.specs[0]) + ntrials = self.params['N_trials'] + TFR_list = [self.specs[i]['TFR'] for i in range(ntrials)] + TFR_L2_list = [self.specs[i]['TFR_L2'] for i in range(ntrials)] + TFR_L5_list = [self.specs[i]['TFR_L5'] for i in range(ntrials)] + self.avg_spec['TFR'] = np.mean(np.array(TFR_list), axis=0) + self.avg_spec['TFR_L2'] = np.mean(np.array(TFR_L2_list), axis=0) + self.avg_spec['TFR_L5'] = np.mean(np.array(TFR_L5_list), axis=0) + + # populate the data inside canvas object before calling + # self.m.plot() + self.m.avg_dpl = self.avg_dpl + self.m.dpls = self.dpls + self.m.specs = self.specs + self.m.avg_spec = self.avg_spec + + if len(self.specs) > 0: + self.printStat('Plotting simulation PSDs.') + self.m.lF = self.lF + self.m.dpls = self.dpls + self.m.avg_dpl = self.avg_dpl + self.m.plot() + self.m.draw() # make sure new lines show up in plot + self.printStat('') + + def loadDisplayData(self): + """Load dipole(s) from .txt file and plot PSD""" + fname = QFileDialog.getOpenFileName(self, 'Open .txt file', 'data') + fname = os.path.abspath(fname[0]) + + if not os.path.isfile(fname): + return + + self.m.index = 0 + file_data = np.loadtxt(fname, dtype=float) + if file_data.shape[1] > 2: + # Multiple trials contained in this file. Only 'agg' dipole is + # present for each trial + dpls = [] + ntrials = file_data.shape[1] + for trial in range(1, ntrials): + dpl_data = np.c_[file_data[:, trial], + np.zeros(len(file_data[:, trial])), + np.zeros(len(file_data[:, trial]))] + dpl = Dipole(file_data[:, 0], dpl_data) + dpls.append(dpl) + self.dpls = dpls + self.avg_dpl = average_dipoles(dpls) + else: + # Normal dipole file saved by HNN. There is a single trial with + # column 0: times, column 1: 'agg' dipole, column 2: 'L2' dipole + # and column 3: 'L5' dipole + + ntrials = 1 + dpl_data = np.c_[file_data[:, 1], + file_data[:, 1], + file_data[:, 1]] + dpl = Dipole(file_data[:, 0], dpl_data) + + self.avg_dpl = dpl + self.dpls = [self.avg_dpl] + + print('Loaded data from %s: %d trials.' % (fname, ntrials)) + print('Extracting Spectrograms...') + f_max_spec = 120.0 # use 120 Hz as maximum for PSD plots + + # a progress bar would be helpful right here! + f, psd = extract_psd(self.avg_dpl, f_max_spec) + self.psds.append(psd) + self.lF.append(f) + + # updateCB depends on ntrial being set + # self.ntrial = len(self.specs) + # self.updateCB() + self.printStat('Extracted ' + str(len(self.psds)) + ' PSDs from ' + + fname) + self.lextfiles.append(fname) + + if len(self.psds) > 0: + self.printStat('Plotting ext data PSDs.') + self.m.lF = self.lF + self.m.psds = self.psds + self.m.dpls = self.dpls + self.m.avg_dpl = self.avg_dpl + self.m.plotextdat(self.lF, self.psds, self.lextfiles) + self.m.draw() # make sure new lines show up in plot + self.printStat('') + + def clearDataFile(self): + self.m.clearlextdatobj() + self.lextpsd = [] + self.lextfiles = [] + self.lF = [] + self.m.draw() diff --git a/hnn/qt_sim.py b/hnn/qt_sim.py new file mode 100644 index 000000000..c7bf1fb60 --- /dev/null +++ b/hnn/qt_sim.py @@ -0,0 +1,447 @@ +import os +import numpy as np +from math import ceil + +from PyQt5 import QtWidgets +import matplotlib.pyplot as plt +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure +import matplotlib.patches as mpatches +import matplotlib.gridspec as gridspec + +from .paramrw import countEvokedInputs +from .qt_lib import getscreengeom +from .paramrw import get_output_dir, get_inputs +from .simdata import check_feeds_to_plot, plot_hists_on_gridspec +from .specfn import plot_spec +from .spikefn import ExtInputs + +fontsize = plt.rcParams['font.size'] = 10 + + +class SIMCanvas(FigureCanvasQTAgg): + # matplotlib/pyqt-compatible canvas for drawing simulation & external data + # based on https://pythonspot.com/en/pyqt5-matplotlib/ + + def __init__(self, paramfn, params, parent=None, width=5, height=4, + dpi=40, is_optimization=False, title='Simulation Viewer'): + FigureCanvasQTAgg.__init__(self, Figure(figsize=(width, height), + dpi=dpi)) + + self.title = title + self.sim_data = parent.sim_data + self.lextdatobj = [] # external data object + self.clridx = 5 # index for next color for drawing external data + self.errtot = None + self.lerr = [] + + # legend for dipole signals + self.lpatch = [mpatches.Patch(color='black', label='Sim.')] + self.setParent(parent) + self.linewidth = parent.linewidth + FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, + QtWidgets.QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + self.params = params + self.paramfn = paramfn + self.axdipole = self.axspec = None + self.G = gridspec.GridSpec(10, 1) + self._data_dir = os.path.join(get_output_dir(), 'data') + + self.is_optimization = is_optimization + if not is_optimization: + self.sim_data.clear_opt_data() + + self.saved_exception = None + try: + self.plot() + except Exception as err: + self.saved_exception = err + + def plotinputhist(self, extinputs=None, feeds_to_plot=None): + """ plot input histograms""" + + xmin = 0. + xmax = self.params['tstop'] + xlim = (xmin, xmax) + axes = [] + + sim_dt = self.params['dt'] + num_step = ceil(xmax / sim_dt) + 1 + times = np.linspace(xmin, xmax, num_step) + + plot_distribs = True + if extinputs is not None and feeds_to_plot is not None: + if feeds_to_plot is None: + feeds_to_plot = check_feeds_to_plot(extinputs.inputs, + self.params) + + if feeds_to_plot['ongoing'] or feeds_to_plot['evoked'] or \ + feeds_to_plot['pois']: + # hist gridspec + axes = plot_hists_on_gridspec(self.figure, self.G, + feeds_to_plot, extinputs, times, + xlim, self.linewidth) + + plot_distribs = False + + if plot_distribs: + dinput = self.getInputDistrib() + feeds_to_plot = check_feeds_to_plot(dinput, self.params) + if not (feeds_to_plot['ongoing'] or feeds_to_plot['evoked'] or + feeds_to_plot['pois']): + # no plots to create + return axes + + n_hists = 0 + + if feeds_to_plot['evdist']: + dist_tot = np.zeros(len(dinput['evdist'][0][0])) + for dist in dinput['evdist']: + dist_tot += dist[1] + + axdist = self.figure.add_subplot(self.G[n_hists, :]) + n_hists += 1 + + axdist.plot(dinput['evdist'][0][0], dist_tot, color='g', + lw=self.linewidth, + label='evdist distribution') + axdist.set_xlim(dinput['evdist'][0][0][0], + dinput['evdist'][0][0][-1]) + axdist.invert_yaxis() # invert the distal input axes + axes.append(axdist) + + if feeds_to_plot['evprox']: + prox_tot = np.zeros(len(dinput['evprox'][0][0])) + for prox in dinput['evprox']: + prox_tot += prox[1] + + axprox = self.figure.add_subplot(self.G[n_hists, :]) + n_hists += 1 + + axprox.plot(dinput['evprox'][0][0], prox_tot, color='r', + lw=self.linewidth, + label='evprox distribution') + axprox.set_xlim(dinput['evprox'][0][0][0], + dinput['evprox'][0][0][-1]) + axes.append(axprox) + + return axes + + def clearaxes(self): + # clear the figures axes + for ax in self.figure.get_axes(): + if ax: + ax.cla() + + def getInputDistrib(self): + import scipy.stats as stats + + dinput = {'evprox': [], 'evdist': [], 'prox': [], 'dist': [], + 'pois': []} + try: + sim_tstop = self.params['tstop'] + sim_dt = self.params['dt'] + except KeyError: + return dinput + + num_step = ceil(sim_tstop / sim_dt) + 1 + times = np.linspace(0, sim_tstop, num_step) + ltprox, ltdist = self.getEVInputTimes() + for prox in ltprox: + pdf = stats.norm.pdf(times, prox[0], prox[1]) + dinput['evprox'].append((times, pdf)) + for dist in ltdist: + pdf = stats.norm.pdf(times, dist[0], dist[1]) + dinput['evdist'].append((times, pdf)) + return dinput + + def getEVInputTimes(self): + # get the evoked input times + + if self.params is None: + raise ValueError("No valid params found") + + nprox, ndist = countEvokedInputs(self.params) + ltprox, ltdist = [], [] + for i in range(nprox): + input_mu = self.params['t_evprox_' + str(i + 1)] + input_sigma = self.params['sigma_t_evprox_' + str(i + 1)] + ltprox.append((input_mu, input_sigma)) + for i in range(ndist): + input_mu = self.params['t_evdist_' + str(i + 1)] + input_sigma = self.params['sigma_t_evdist_' + str(i + 1)] + ltdist.append((input_mu, input_sigma)) + return ltprox, ltdist + + def drawEVInputTimes(self, ax, yl, h=0.1, hw=15, hl=15): + # draw the evoked input times using arrows + ltprox, ltdist = self.getEVInputTimes() + yrange = abs(yl[1] - yl[0]) + + for tt in ltprox: + ax.arrow(tt[0], yl[0], 0, h * yrange, fc='r', ec='r', + head_width=hw, head_length=hl) + for tt in ltdist: + ax.arrow(tt[0], yl[1], 0, -h * yrange, fc='g', ec='g', + head_width=hw, head_length=hl) + + def getnextcolor(self): + # get next color for external data (colors selected in order) + self.clridx += 5 + if self.clridx > 100: + self.clridx = 5 + return self.clridx + + def _has_simdata(self): + """check if any simulation data available""" + if self.paramfn in self.sim_data._sim_data: + avg_dpl = self.sim_data._sim_data[self.paramfn]['data']['avg_dpl'] + if avg_dpl is not None: + return True + + return False + + def plotextdat(self): + global fontsize + + if self.sim_data._exp_data is None or \ + len(self.sim_data._exp_data) == 0: + return + + # plot 'external' data (e.g. from experiment/other simulation) + if self._has_simdata(): # has the simulation been run yet? + tstop = self.params['tstop'] + # recalculate/save the error + self.lerr, self.errtot = self.sim_data.calcerr(self.paramfn, + tstop) + + if self.axdipole is None: + self.axdipole = self.figure.add_subplot(self.G[0:-1, 0]) + xl = (0.0, 1.0) + yl = (-0.001, 0.001) + else: + xl = self.axdipole.get_xlim() + yl = self.axdipole.get_ylim() + + cmap = plt.get_cmap('nipy_spectral') + csm = plt.cm.ScalarMappable(cmap=cmap) + csm.set_clim((0, 100)) + + self.clearlextdatobj() # clear annotation objects + + ddx = 0 + for fn, dat in self.sim_data._exp_data.items(): + shp = dat.shape + clr = csm.to_rgba(self.getnextcolor()) + c = min(shp[1], 1) + self.lextdatobj.append(self.axdipole.plot(dat[:, 0], dat[:, c], + color=clr, linewidth=self.linewidth + 1)) + xl = ((min(xl[0], min(dat[:, 0]))), (max(xl[1], max(dat[:, 0])))) + yl = ((min(yl[0], min(dat[:, c]))), (max(yl[1], max(dat[:, c])))) + fx = int(shp[0] * float(c) / shp[1]) + if self.lerr: + tx, ty = dat[fx, 0], dat[fx, c] + txt = 'RMSE: %.2f' % round(self.lerr[ddx], 2) + if not self.is_optimization: + self.axdipole.annotate(txt, xy=(dat[0, 0], dat[0, c]), + xytext=(tx, ty), color=clr, + fontweight='bold') + label = fn.split(os.path.sep)[-1].split('.txt')[0] + self.lpatch.append(mpatches.Patch(color=clr, label=label)) + ddx += 1 + + self.axdipole.set_xlim(xl) + self.axdipole.set_ylim(yl) + + if len(self.lpatch) > 0: + self.axdipole.legend(handles=self.lpatch, loc=2) + + if self.errtot is not None: + textcoords = 'axes fraction' + clr = 'black' + txt = 'Avg. RMSE: %.2f' % round(self.errtot, 2) + if self.is_optimization: + if 'initial_error' in self.sim_data._opt_data: + initial_error = self.sim_data._opt_data['initial_error'] + txt = 'Initial RMSE: %.2f' % round(initial_error, 2) + annot_initial = \ + self.axdipole.annotate(txt, xy=(0, 0), + xytext=(0.86, 0.005), + textcoords=textcoords, + color=clr, + fontweight='bold') + self.lextdatobj.append(annot_initial) + txt = 'Opt RMSE: %.2f' % round(self.errtot, 2) + clr = 'gray' + + annot_avg = self.axdipole.annotate(txt, xy=(0, 0), + xytext=(0.005, 0.005), + textcoords=textcoords, + color=clr, + fontweight='bold') + self.lextdatobj.append(annot_avg) + + if not self._has_simdata(): # need axis labels + self.axdipole.set_xlabel('Time (ms)', fontsize=fontsize) + self.axdipole.set_ylabel('Dipole (nAm)', fontsize=fontsize) + myxl = self.axdipole.get_xlim() + if myxl[0] < 0.0: + self.axdipole.set_xlim((0.0, myxl[1] + myxl[0])) + + def clearlextdatobj(self): + # clear list of external data objects + for o in self.lextdatobj: + if isinstance(o, list): + # this is the plot. clear the line + o[0].set_visible(False) + del self.lextdatobj + self.lextdatobj = [] # reset list of external data objects + self.lpatch = [] # reset legend + self.clridx = 5 # reset index for next color for drawing ext data + + if self.is_optimization: + self.lpatch.append(mpatches.Patch(color='grey', + label='Optimization')) + self.lpatch.append(mpatches.Patch(color='black', label='Initial')) + elif self._has_simdata(): + self.lpatch.append(mpatches.Patch(color='black', + label='Simulation')) + if hasattr(self, 'annot_avg'): + self.annot_avg.set_visible(False) + del self.annot_avg + + def plotsimdat(self): + """plot the simulation data""" + + global fontsize + + DrawSpec = False + xlim = (0.0, 1.0) + ylim = (-0.001, 0.001) + + if self.params is None: + data_to_plot = False + gRow = 0 + else: + # for later + ntrial = self.params['N_trials'] + tstop = self.params['tstop'] + dipole_scalefctr = self.params['dipole_scalefctr'] + N_pyr_x = self.params['N_pyr_x'] + N_pyr_y = self.params['N_pyr_y'] + + # update xlim to tstop + xlim = (0.0, tstop) + + # for trying to plot a simulation read from disk (e.g. default) + if self.paramfn not in self.sim_data._sim_data: + # load simulation data from disk + data_to_plot = self.sim_data.update_sim_data_from_disk( + self.paramfn, self.params) + else: + data_to_plot = True + + if data_to_plot: + sim_data = self.sim_data._sim_data[self.paramfn]['data'] + trials = [trial_idx for trial_idx in range(ntrial)] + extinputs = ExtInputs(sim_data['spikes'], + sim_data['gid_ranges'], + trials, self.params) + + feeds_to_plot = check_feeds_to_plot(extinputs.inputs, + self.params) + else: + # best we can do is plot the distributions of the inputs + extinputs = feeds_to_plot = None + + hist_axes = self.plotinputhist(extinputs, feeds_to_plot) + gRow = len(hist_axes) + + if data_to_plot: + # check that dipole data is present + single_sim = self.sim_data._sim_data[self.paramfn]['data'] + if single_sim['avg_dpl'] is None: + data_to_plot = False + + # whether to draw the specgram - should draw if user saved + # it or have ongoing, poisson, or tonic inputs + if single_sim['spec'] is not None \ + and len(single_sim['spec']) > 0 \ + and (self.params['save_spec_data'] or + feeds_to_plot['ongoing'] or + feeds_to_plot['pois'] or feeds_to_plot['tonic']): + DrawSpec = True + + first_spec_trial = single_sim['spec'][0] + + # adjust dipole to match spectogram limits (e.g. missing + # first 50 ms b/c edge effects) + xlim = (first_spec_trial['time'][0], + first_spec_trial['time'][-1]) + + if DrawSpec: # dipole axis takes fewer rows if also drawing specgram + self.axdipole = self.figure.add_subplot(self.G[gRow:5, 0]) + bottom = 0.08 + + # set the axes of input histograms to match dipole and spec plots + for ax in hist_axes: + ax.set_xlim(xlim) + + else: + self.axdipole = self.figure.add_subplot(self.G[gRow:-1, 0]) + # there is no spec plot below, so label dipole with time on x-axis + self.axdipole.set_xlabel('Time (ms)', fontsize=fontsize) + bottom = 0.0 + + self.axdipole.set_ylim(ylim) + self.axdipole.set_xlim(xlim) + + left = 0.08 + w, _ = getscreengeom() + if w < 2800: + left = 0.1 + # reduce padding + self.figure.subplots_adjust(left=left, right=0.99, bottom=bottom, + top=0.99, hspace=0.1, wspace=0.1) + + if not data_to_plot: + # no dipole or spec data to plot + return + + self.sim_data.plot_dipole(self.paramfn, self.axdipole, self.linewidth, + dipole_scalefctr, N_pyr_x, N_pyr_y, + self.is_optimization) + + if DrawSpec: + self.axspec = self.figure.add_subplot(self.G[6:10, 0]) + cax = plot_spec(self.axspec, sim_data['spec'], ntrial, + self.params['spec_cmap'], xlim, + fontsize) + + # plot colorbar horizontally to save space + cbaxes = self.figure.add_axes([0.6, 0.49, 0.3, 0.005]) + plt.colorbar(cax, cax=cbaxes, orientation='horizontal') + + def plotarrows(self): + # run after scales have been updated + xl = self.axdipole.get_xlim() + yl = self.axdipole.get_ylim() + + using_feeds = get_inputs(self.params) + if using_feeds['evoked']: + self.drawEVInputTimes(self.axdipole, yl, 0.1, + (xl[1] - xl[0]) * .02, + (yl[1] - yl[0]) * .02) + + def plot(self): + self.clearaxes() + plt.close(self.figure) + self.figure.clf() + self.axdipole = None + + self.plotsimdat() # creates self.axdipole + self.plotextdat() + self.plotarrows() + + self.draw() diff --git a/hnn/qt_spec.py b/hnn/qt_spec.py new file mode 100644 index 000000000..d127fc2b0 --- /dev/null +++ b/hnn/qt_spec.py @@ -0,0 +1,252 @@ +"""Create the Spectrogram viewing window""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import numpy as np +import os + +from PyQt5.QtWidgets import QSizePolicy, QAction, QFileDialog +from PyQt5.QtGui import QIcon + +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure +from hnn_core.dipole import average_dipoles, Dipole + +from .DataViewGUI import DataViewGUI +from .specfn import plot_spec, extract_spec + +fontsize = plt.rcParams['font.size'] = 10 +random_label = np.random.rand(100) + + +class SpecCanvas(FigureCanvasQTAgg): + """Class for the Spectrogram viewer + + This is designed to be called from SpecViewGUI class to add functionality + for loading and clearing data + """ + def __init__(self, params, sim_data, index, parent=None, width=12, + height=10, dpi=120, title='Spectrogram Viewer'): + FigureCanvasQTAgg.__init__(self, Figure(figsize=(width, height), + dpi=dpi)) + self.title = title + self.setParent(parent) + self.gui = parent + self.index = index + FigureCanvasQTAgg.setSizePolicy(self, QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + self.params = params + self.invertedhistax = False + self.G = gridspec.GridSpec(10, 1) + self.dpls = self.gui.dpls + self.specs = self.gui.specs + self.avg_dpl = self.gui.avg_dpl + self.lax = [] + + if 'spec_cmap' in self.params: + self.spec_cmap = self.params['spec_cmap'] + else: + # default to jet, but allow user to change in param file + self.spec_cmap = 'jet' + + self.plot() + + def clearaxes(self): + for ax in self.lax: + ax.set_yticks([]) + ax.cla() + + def drawspec(self, dpls, avgdipole, spec_data, fig, G, + ltextra=''): + global random_label + + ntrial = len(spec_data) + if ntrial == 0: + return + + if self.index == 0: + ntrial = 1 + + plt.ion() + + gdx = 211 + + ax = fig.add_subplot(gdx, label=random_label) + random_label += 1 + lax = [ax] + + # use spectogram limits (missing first 50 ms b/c edge effects) + xlim = (spec_data[0]['time'][0], + spec_data[0]['time'][-1]) + + if self.index == 0: + for dpltrial in dpls: + ax.plot(dpltrial.times, dpltrial.data['agg'], + linewidth=self.gui.linewidth, color='gray') + ax.plot(avgdipole.times, avgdipole.data['agg'], + linewidth=self.gui.linewidth + 1, color='black') + else: + ax.plot(dpls[self.index - 1].times, + dpls[self.index - 1].data['agg'], + linewidth=self.gui.linewidth + 1, + color='gray') + + ax.set_xlim(xlim) + ax.set_ylabel('Dipole (nAm)') + + gdx = 212 + + ax = fig.add_subplot(gdx, label=random_label) + random_label += 1 + ntrial = len(dpls) + + plot_spec(ax, spec_data, ntrial, self.spec_cmap, xlim) + + lax.append(ax) + + return lax + + def plot(self): + ltextra = 'Trial ' + str(self.index) + if self.index == 0: + ltextra = 'All Trials' + self.lax = self.drawspec(self.dpls, self.avg_dpl, self.specs, + self.figure, self.G, ltextra=ltextra) + self.figure.subplots_adjust(bottom=0.06, left=0.06, right=0.98, + top=0.97, wspace=0.1, hspace=0.09) + self.draw() + + +class SpecViewGUI(DataViewGUI): + """Class for displaying spectrogram viewer + + Required parameters: N_trials, f_max_spec, sim_prefix, + spec_cmap + """ + def __init__(self, CanvasType, params, sim_data, title): + self.specs = [] + self.lextfiles = [] # external data files + self.dpls = None + self.avg_dpl = [] + self.params = params + + # used by loadSimData + self.sim_data = sim_data + super(SpecViewGUI, self).__init__(CanvasType, self.params, sim_data, + title) + self._addLoadDataActions() + self.loadSimData(self.params['sim_prefix'], self.params['f_max_spec']) + + def _addLoadDataActions(self): + loadDataFile = QAction(QIcon.fromTheme('open'), 'Load data.', self) + loadDataFile.setShortcut('Ctrl+D') + loadDataFile.setStatusTip('Load experimental (.txt) data.') + loadDataFile.triggered.connect(self.loadDisplayData) + + clearDataFileAct = QAction(QIcon.fromTheme('close'), 'Clear data.', + self) + clearDataFileAct.setShortcut('Ctrl+C') + clearDataFileAct.setStatusTip('Clear data.') + clearDataFileAct.triggered.connect(self.clearDataFile) + + self.fileMenu.addAction(loadDataFile) + self.fileMenu.addAction(clearDataFileAct) + + def loadSimData(self, sim_prefix, f_max_spec): + """Load and plot from SimData""" + + # store copy of data in this object, that can be reused by + # canvas (self.m) on re-instantiation + if self.sim_data is not None: + self.avg_dpl = self.sim_data['avg_dpl'] + self.dpls = self.sim_data['dpls'] + self.specs = self.sim_data['spec'] + if self.specs is None or len(self.specs) == 0: + self.specs = extract_spec(self.dpls, f_max_spec) + + # populate the data inside canvas object before calling + # self.m.plot() + self.m.avg_dpl = self.avg_dpl + self.m.dpls = self.dpls + self.m.specs = self.specs + + self.ntrial = len(self.specs) + + self.updateCB() + self.printStat('Extracted ' + str(len(self.m.specs)) + + ' spectrograms for ' + sim_prefix) + + if len(self.m.specs) > 0: + self.printStat('Plotting Spectrograms.') + self.m.plot() + self.m.draw() # make sure new lines show up in plot + self.printStat('') + + def loadDisplayData(self): + """Load dipole(s) from .txt file and plot spectrograms""" + fname = QFileDialog.getOpenFileName(self, 'Open .txt file', 'data') + fname = os.path.abspath(fname[0]) + + if not os.path.isfile(fname): + return + + self.m.index = 0 + file_data = np.loadtxt(fname, dtype=float) + if file_data.shape[1] > 2: + # Multiple trials contained in this file. Only 'agg' dipole is + # present for each trial + dpls = [] + ntrials = file_data.shape[1] + for trial in range(1, ntrials): + dpl_data = np.c_[file_data[:, trial], + np.zeros(len(file_data[:, trial])), + np.zeros(len(file_data[:, trial]))] + dpl = Dipole(file_data[:, 0], dpl_data) + dpls.append(dpl) + self.dpls = dpls + self.avg_dpl = average_dipoles(dpls) + else: + # Normal dipole file saved by HNN. There is a single trial with + # column 0: times, column 1: 'agg' dipole, column 2: 'L2' dipole + # and column 3: 'L5' dipole + + ntrials = 1 + dpl_data = np.c_[file_data[:, 1], + file_data[:, 1], + file_data[:, 1]] + dpl = Dipole(file_data[:, 0], dpl_data) + + self.avg_dpl = dpl + self.dpls = [self.avg_dpl] + + print('Loaded data from %s: %d trials.' % (fname, ntrials)) + print('Extracting Spectrograms...') + # a progress bar would be helpful right here! + self.specs = extract_spec(self.dpls, self.params['f_max_spec']) + + # updateCB depends on ntrial being set + self.ntrial = len(self.specs) + self.updateCB() + self.printStat('Extracted ' + str(len(self.specs)) + + ' spectrograms from ' + fname) + self.lextfiles.append(fname) + + if len(self.specs) > 0: + self.printStat('Plotting Spectrograms.') + self.m.specs = self.specs + self.m.dpls = self.dpls + self.m.avg_dpl = self.avg_dpl + self.m.plot() + self.m.draw() # make sure new lines show up in plot + self.printStat('') + + def clearDataFile(self): + """Clear data from file and revert to SimData""" + self.specs = [] + self.lextfiles = [] + self.m.index = 0 + self.loadSimData(self.params['sim_prefix'], self.params['f_max_spec']) diff --git a/hnn/qt_spike.py b/hnn/qt_spike.py new file mode 100644 index 000000000..e0f0de441 --- /dev/null +++ b/hnn/qt_spike.py @@ -0,0 +1,319 @@ +"""Create the spike raster plot viewing window""" + +# Authors: Sam Neymotin +# Blake Caldwell + +import numpy as np +from numpy import hamming +from math import ceil + +from PyQt5.QtWidgets import QAction, QSizePolicy + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure +import matplotlib.gridspec as gridspec + +from pylab import convolve + +from .DataViewGUI import DataViewGUI +from .paramrw import usingEvokedInputs, usingOngoingInputs, usingPoissonInputs +from .spikefn import ExtInputs + +plt.rcParams['lines.linewidth'] = 1 +fontsize = plt.rcParams['font.size'] = 10 +rastmarksz = 5 # raster dot size +binsz = 5.0 +smoothsz = 0 # no smoothing +random_label = np.random.rand(100) + +# colors for the different cell types +dclr = {'L2_pyramidal': 'g', + 'L5_pyramidal': 'r', + 'L2_basket': 'w', + 'L5_basket': 'b'} + + +# convolve with a hamming window +def hammfilt(x, winsz): + win = hamming(winsz) + win /= sum(win) + return convolve(x, win, 'same') + + +# adjust input gids for display purposes +def adjustinputgid(extinputs, gid): + if gid == extinputs.gid_prox: + return 0 + elif gid == extinputs.gid_dist: + return 1 + elif extinputs.is_prox_gid(gid): + return 2 + elif extinputs.is_dist_gid(gid): + return 3 + return gid + + +def gid_to_type(extinputs, gid): + for gidtype, gids in extinputs.gid_ranges.items(): + if gid in gids: + return gidtype + + +def getdspk(spikes, extinputs, tstop): + ddat = {} + ddat['spk'] = spikes + + dspk = {'Cell': ([], [], []), + 'Input': ([], [], [])} + dhist = {} + for ty in dclr.keys(): + dhist[ty] = [] + haveinputs = False + for (t, gid) in ddat['spk']: + ty = gid_to_type(extinputs, gid) + if ty in dclr: + dspk['Cell'][0].append(t) + dspk['Cell'][1].append(gid) + dspk['Cell'][2].append(dclr[ty]) + dhist[ty].append(t) + else: + dspk['Input'][0].append(t) + dspk['Input'][1].append(adjustinputgid(extinputs, gid)) + if extinputs.is_prox_gid(gid): + dspk['Input'][2].append('r') + elif extinputs.is_dist_gid(gid): + dspk['Input'][2].append('g') + else: + dspk['Input'][2].append('orange') + haveinputs = True + for ty in dhist.keys(): + dhist[ty] = np.histogram(dhist[ty], range=(0, tstop), + bins=int(tstop / binsz)) + if smoothsz > 0: + dhist[ty] = hammfilt(dhist[ty][0], smoothsz) + else: + dhist[ty] = dhist[ty][0] + return dspk, haveinputs, dhist + + +class SpikeCanvas(FigureCanvasQTAgg): + def __init__(self, params, sim_data, index, parent=None, width=12, + height=10, dpi=120, title='Spike Viewer'): + FigureCanvasQTAgg.__init__(self, Figure(figsize=(width, height), + dpi=dpi)) + self.title = title + self.setParent(parent) + self.gui = parent + self.index = index + FigureCanvasQTAgg.setSizePolicy(self, QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + self.params = params + self.invertedhistax = False + self.G = gridspec.GridSpec(16, 1) + + self.sim_data = sim_data + self.alldat = {} + + # whether to draw histograms (spike counts per time) + self.bDrawHist = True + + self.plot() + + def clearaxes(self): + for ax in self.lax: + ax.set_yticks([]) + ax.cla() + + def drawhist(self, dhist, ntrial, tstop): + global random_label + ax = self.figure.add_subplot(self.G[-4:-1, :], label=random_label) + random_label += 1 + fctr = 1.0 + if ntrial > 0: + fctr = 1.0 / ntrial + for ty in dhist.keys(): + ax.plot(np.arange(binsz / 2, tstop + binsz / 2, binsz), + dhist[ty] * fctr, dclr[ty], linestyle='--') + ax.set_xlim((0, tstop)) + ax.set_ylabel('Cell Spikes') + return ax + + def drawrast(self, dspk, extinputs, haveinputs, fig, G, sz=8): + global random_label + + lax = [] + lk = ['Cell'] + row = 0 + tstop = self.params['tstop'] + + if haveinputs: + lk.append('Input') + lk.reverse() + + dinput = extinputs.inputs + + for _, k in enumerate(lk): + if k == 'Input': # input spiking + bins = ceil(150. * tstop / 1000.) # bins needs to be an int + + EvokedInputs = usingEvokedInputs(self.params) + OngoingInputs = usingOngoingInputs(self.params) + PoissonInputs = usingPoissonInputs(self.params) + haveEvokedDist = (EvokedInputs and len(dinput['evdist']) > 0) + haveOngoingDist = (OngoingInputs and len(dinput['dist']) > 0) + haveEvokedProx = (EvokedInputs and len(dinput['evprox']) > 0) + haveOngoingProx = (OngoingInputs and len(dinput['prox']) > 0) + + if haveEvokedDist or haveOngoingDist: + ax = fig.add_subplot(G[row:row + 2, :], label=random_label) + random_label += 1 + row += 2 + lax.append(ax) + if haveEvokedDist: + extinputs.plot_hist(ax, 'evdist', 0, bins, (0, tstop), + color='g', hty='step') + if haveOngoingDist: + extinputs.plot_hist(ax, 'dist', 0, bins, (0, tstop), + color='g') + ax.invert_yaxis() + ax.set_ylabel('Distal Input') + + if haveEvokedProx or haveOngoingProx: + ax2 = fig.add_subplot(G[row:row + 2, :], + label=random_label) + random_label += 1 + row += 2 + lax.append(ax2) + if haveEvokedProx: + extinputs.plot_hist(ax2, 'evprox', 0, bins, (0, tstop), + color='r', hty='step') + if haveOngoingProx: + extinputs.plot_hist(ax2, 'prox', 0, bins, (0, tstop), + color='r') + ax2.set_ylabel('Proximal Input') + + if PoissonInputs and len(dinput['pois']): + axp = fig.add_subplot(G[row:row + 2, :], + label=random_label) + random_label += 1 + row += 2 + lax.append(axp) + extinputs.plot_hist(axp, 'pois', 0, bins, (0, tstop), + color='orange') + axp.set_ylabel('Poisson Input') + + else: # local circuit neuron spiking + ncell = len(extinputs.gid_ranges['L2_pyramidal']) + \ + len(extinputs.gid_ranges['L2_basket']) + \ + len(extinputs.gid_ranges['L5_pyramidal']) + \ + len(extinputs.gid_ranges['L5_basket']) + + endrow = -1 + if self.bDrawHist: + endrow = -4 + + ax = fig.add_subplot(G[row:endrow, :], label=random_label) + random_label += 1 + lax.append(ax) + + ax.scatter(dspk[k][0], dspk[k][1], c=dspk[k][2], s=sz**2) + ax.set_ylabel(k + ' ID') + white_patch = mpatches.Patch(color='white', + label='L2/3 Basket') + green_patch = mpatches.Patch(color='green', label='L2/3 Pyr') + red_patch = mpatches.Patch(color='red', label='L5 Pyr') + blue_patch = mpatches.Patch(color='blue', label='L5 Basket') + ax.legend(handles=[white_patch, green_patch, blue_patch, + red_patch], loc='best') + ax.set_ylim((-1, ncell + 1)) + ax.invert_yaxis() + + return lax + + def loadspk(self, idx): + if idx in self.alldat: + return + self.alldat[idx] = {} + + trials = [trial_idx for trial_idx in range(self.params['N_trials'])] + self.extinputs = ExtInputs(self.sim_data['spikes'], + self.sim_data['gid_ranges'], + trials, self.params) + + if idx == 0 and self.params['N_trials'] > 1: + # combine spikes into a single list for all trials + spike_times = np.array(sum(self.sim_data['spikes'].spike_times, + [])) + spike_gids = np.array(sum(self.sim_data['spikes'].spike_gids, [])) + else: + spike_times = self.sim_data['spikes'].spike_times[idx - 1] + spike_gids = self.sim_data['spikes'].spike_gids[idx - 1] + + empty_array = np.empty((len(spike_times), 0), np.float64) + spike_arr = np.array([spike_times, spike_gids]) + spike_arr = np.append(empty_array, spike_arr.transpose(), axis=1) + + dspk, haveinputs, dhist = getdspk(spike_arr, self.extinputs, + self.params['tstop']) + self.alldat[idx]['dspk'] = dspk + self.alldat[idx]['haveinputs'] = haveinputs + self.alldat[idx]['dhist'] = dhist + self.alldat[idx]['extinputs'] = self.extinputs + + def plot(self): + self.loadspk(self.index) + + idx = self.index + dspk = self.alldat[idx]['dspk'] + haveinputs = self.alldat[idx]['haveinputs'] + dhist = self.alldat[idx]['dhist'] + extinputs = self.alldat[idx]['extinputs'] + + self.lax = self.drawrast(dspk, extinputs, haveinputs, self.figure, + self.G, rastmarksz) + if self.bDrawHist: + self.lax.append(self.drawhist(dhist, self.params['N_trials'], + self.params['tstop'])) + + for ax in self.lax: + ax.set_facecolor('k') + ax.grid(True) + if self.params['tstop'] != -1: + ax.set_xlim((0, self.params['tstop'])) + + if idx == 0: + self.lax[0].set_title('All Trials') + else: + self.lax[0].set_title('Trial ' + str(self.index)) + + self.lax[-1].set_xlabel('Time (ms)') + + self.figure.subplots_adjust(bottom=0.0, left=0.06, right=1.0, + top=0.97, wspace=0.1, hspace=0.09) + + self.draw() + + +class SpikeViewGUI(DataViewGUI): + """Class for displaying spiking raster plot viewer + + Required parameters: tstop, N_trials + """ + def __init__(self, CanvasType, params, sim_data, title): + super(SpikeViewGUI, self).__init__(CanvasType, params, sim_data, title) + self.addViewHistAction() + + def addViewHistAction(self): + """Add 'Toggle Histograms' to view menu""" + drawHistAction = QAction('Toggle Histograms', self) + drawHistAction.setStatusTip('Toggle Histogram Drawing.') + drawHistAction.triggered.connect(self.toggleHist) + self.viewMenu.addAction(drawHistAction) + + def toggleHist(self): + self.m.bDrawHist = not self.m.bDrawHist + self.initCanvas() + self.m.plot() diff --git a/hnn/qt_thread.py b/hnn/qt_thread.py new file mode 100755 index 000000000..3c579d796 --- /dev/null +++ b/hnn/qt_thread.py @@ -0,0 +1,657 @@ +"""File with functions and classes for running the NEURON """ + +# Authors: Blake Caldwell +# Sam Neymotin +# Shane Lee + +import os +import sys +from math import ceil, isclose +from contextlib import redirect_stdout +from psutil import wait_procs, process_iter, NoSuchProcess +import traceback +from queue import Queue +from threading import Event +import numpy as np + +import nlopt +from PyQt5 import QtCore +from hnn_core import simulate_dipole, Network, MPIBackend + +from .paramrw import get_output_dir, hnn_core_compat_params + + +class BasicSignal(QtCore.QObject): + """for signaling""" + sig = QtCore.pyqtSignal() + + +class ObjectSignal(QtCore.QObject): + """for returning an object""" + sig = QtCore.pyqtSignal(object) + + +class QueueSignal(QtCore.QObject): + """for returning data""" + qsig = QtCore.pyqtSignal(Queue, str, float) + + +class QueueDataSignal(QtCore.QObject): + """for returning data""" + qsig = QtCore.pyqtSignal(Queue, str, np.ndarray, float, float) + + +class EventSignal(QtCore.QObject): + """for synchronization""" + esig = QtCore.pyqtSignal(Event, str) + + +class TextSignal(QtCore.QObject): + """for passing text""" + tsig = QtCore.pyqtSignal(str) + + +class DataSignal(QtCore.QObject): + """for signalling data read""" + dsig = QtCore.pyqtSignal(str, dict) + + +class ParamSignal(QtCore.QObject): + """for updating GUI & param file during optimization""" + psig = QtCore.pyqtSignal(dict) + + +class CanvSignal(QtCore.QObject): + """for updating main GUI canvas""" + csig = QtCore.pyqtSignal(bool) + + +class ResultObj(QtCore.QObject): + def __init__(self, data, params): + self.data = data + self.params = params + + +def _kill_list_of_procs(procs): + """tries to terminate processes in a list before sending kill signal""" + # try terminate first + for p in procs: + try: + p.terminate() + except NoSuchProcess: + pass + _, alive = wait_procs(procs, timeout=3) + + # now try kill + for p in alive: + p.kill() + _, alive = wait_procs(procs, timeout=3) + + return alive + + +def _get_nrniv_procs_running(): + """return a list of nrniv processes running""" + ls = [] + name = 'nrniv' + for p in process_iter(attrs=["name", "exe", "cmdline"]): + if name == p.info['name'] or \ + p.info['exe'] and os.path.basename(p.info['exe']) == name or \ + p.info['cmdline'] and p.info['cmdline'][0] == name: + ls.append(p) + return ls + + +def _kill_and_check_nrniv_procs(): + """handle killing any stale nrniv processess""" + procs = _get_nrniv_procs_running() + if len(procs) > 0: + running = _kill_list_of_procs(procs) + if len(running) > 0: + pids = [str(proc.pid) for proc in running] + print("ERROR: failed to kill nrniv process(es) %s" % + ','.join(pids)) + + +def simulate(params, n_procs=None): + """Start the simulation with hnn_core.simulate + + Parameters + ---------- + params : dict + The parameters + + n_procs : int | None + The number of MPI processes requested by the user. If None, then will + attempt to detect number of cores (including hyperthreads) and start + parallel simulation over all of them. + """ + + # create the network from the parameter file. note, NEURON objects haven't + # been created yet + net = Network(params, add_drives_from_params=True) + + sim_data = {} + # run the simulation with MPIBackend for faster completion time + with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'): + record_vsoma = bool(params['record_vsoma']) + sim_data['raw_dpls'] = simulate_dipole(net, params['N_trials'], + postproc=False, + record_vsoma=record_vsoma) + + # hnn-core changes this to bool, change back to int + if isinstance(params['record_vsoma'], bool): + params['record_vsoma'] = int(params['record_vsoma']) + sim_data['gid_ranges'] = net.gid_ranges + sim_data['spikes'] = net.cell_response + sim_data['vsoma'] = net.cell_response.vsoma + + return sim_data + + +# based on https://nikolak.com/pyqt-threading-tutorial/ +class SimThread(QtCore.QThread): + """The SimThread class. + + Parameters + ---------- + + ncore : int + Number of cores to run this simulation over + params : dict + Dictionary of params describing simulation config + result_callback: function + Handle to for callback to call after every sim completion + waitsimwin : WaitSimDialog + Handle to the Qt dialog during a simulation + mainwin : HNNGUI + Handle to the main application window + + Attributes + ---------- + ncore : int + Number of cores to run this simulation over + params : dict + Dictionary of params describing simulation config + mainwin : HNNGUI + Handle to the main application window + opt : bool + Whether this simulation thread is running an optimization + killed : bool + Whether this simulation was forcefully terminated + """ + + def __init__(self, ncore, params, result_callback, mainwin): + QtCore.QThread.__init__(self) + self.ncore = ncore + self.params = params + self.mainwin = mainwin + self.is_optimization = self.mainwin.is_optimization + self.baseparamwin = self.mainwin.baseparamwin + self.result_signal = ObjectSignal() + self.result_signal.sig.connect(result_callback) + self.killed = False + + self.paramfn = os.path.join(get_output_dir(), 'param', + self.params['sim_prefix'] + '.param') + + self.txtComm = TextSignal() + self.txtComm.tsig.connect(self.mainwin.waitsimwin.updatetxt) + + self.param_signal = ParamSignal() + self.param_signal.psig.connect(self.baseparamwin.updateDispParam) + + self.done_signal = TextSignal() + self.done_signal.tsig.connect(self.mainwin.done) + + def _updatewaitsimwin(self, txt): + """Used to write messages to simulation window""" + self.txtComm.tsig.emit(txt) + + class _log_sim_status(object): + """Replaces sys.stdout.write() to write message to simulation window""" + def __init__(self, parent): + self.out = sys.stdout + self.parent = parent + + def write(self, message): + self.out.write(message) + stripped_message = message.strip() + if not stripped_message == '': + self.parent._updatewaitsimwin(stripped_message) + + def flush(self): + self.out.flush() + + def stop(self): + """Terminate running simulation""" + _kill_and_check_nrniv_procs() + self.killed = True + + def run(self, sim_length=None): + """Start simulation""" + + msg = '' + banner = not self.is_optimization + try: + self._run(banner=banner, sim_length=sim_length) # run simulation + # update params in all windows (optimization) + except RuntimeError as e: + msg = str(e) + self.done_signal.tsig.emit(msg) + return + + if not self.is_optimization: + self.param_signal.psig.emit(self.params) + self.done_signal.tsig.emit(msg) + + # gracefully stop this thread + self.quit() + + def _run(self, banner=True, sim_length=None): + self.killed = False + + sim_params = hnn_core_compat_params(self.params) + if sim_length is not None: + sim_params['tstop'] = sim_length + + while True: + if self.ncore == 0: + raise RuntimeError("No cores available for simulation") + + try: + sim_log = self._log_sim_status(parent=self) + with redirect_stdout(sim_log): + sim_data = simulate(sim_params, self.ncore) + break + except RuntimeError as e: + if self.ncore == 1: + # can't reduce ncore any more + print(str(e)) + self._updatewaitsimwin(str(e)) + _kill_and_check_nrniv_procs() + raise RuntimeError("Simulation failed to start") + + # check if proc was killed before retrying with fewer cores + if self.killed: + # exit using RuntimeError + raise RuntimeError("Terminated") + + self.ncore = ceil(self.ncore / 2) + txt = "INFO: Failed starting simulation, retrying with %d cores" \ + % self.ncore + print(txt) + self._updatewaitsimwin(txt) + _kill_and_check_nrniv_procs() + + # put sim_data into the val attribute of a ResultObj + self.result_signal.sig.emit(ResultObj(sim_data, self.params)) + + +class OptThread(SimThread): + """The OptThread class. + + Parameters + ---------- + + ncore : int + Number of cores to run this simulation over + params : dict + Dictionary of params describing simulation config + waitsimwin : WaitSimDialog + Handle to the Qt dialog during a simulation + result_callback: function + Handle to for callback to call after every sim completion + mainwin : HNNGUI + Handle to the main application window + + Attributes + ---------- + ncore : int + Number of cores to run this simulation over + params : dict + Dictionary of params describing simulation config + mainwin : HNNGUI instance + Handle to the main application window + baseparamwin: BaseParamDialog instance + Handle to base parameters dialog + paramfn : str + Full pathname of the written parameter file name + """ + def __init__(self, ncore, params, num_steps, seed, sim_data, + result_callback, opt_callback, mainwin): + super().__init__(ncore, params, result_callback, mainwin) + self.waitsimwin = self.mainwin.waitsimwin + self.optparamwin = self.baseparamwin.optparamwin + self.cur_itr = 0 + self.num_steps = num_steps + self.sim_data = sim_data + self.result_callback = result_callback + self.seed = seed + self.best_step_werr = 1e9 + self.sim_running = False + self.killed = False + + self.done_signal.tsig.connect(opt_callback) + + self.refresh_signal = BasicSignal() + self.refresh_signal.sig.connect(self.mainwin.initSimCanvas) + + self.update_sim_data_from_opt_data = EventSignal() + self.update_sim_data_from_opt_data.esig.connect( + sim_data.update_sim_data_from_opt_data) + + self.update_opt_data_from_sim_data = EventSignal() + self.update_opt_data_from_sim_data.esig.connect( + sim_data.update_opt_data_from_sim_data) + + self.update_initial_opt_data_from_sim_data = EventSignal() + self.update_initial_opt_data_from_sim_data.esig.connect( + sim_data.update_initial_opt_data_from_sim_data) + + self.get_err_from_sim_data = QueueSignal() + self.get_err_from_sim_data.qsig.connect(sim_data.get_err_wrapper) + + self.get_werr_from_sim_data = QueueDataSignal() + self.get_werr_from_sim_data.qsig.connect(sim_data.get_werr_wrapper) + + def run(self): + msg = '' + try: + self._run() # run optimization + except RuntimeError as e: + msg = str(e) + + self.done_signal.tsig.emit(msg) + + def stop(self): + """Terminate running simulation""" + self.sim_thread.stop() + self.sim_thread.terminate() + self.sim_thread.wait() + self.killed = True + self.done_signal.tsig.emit("Optimization terminated") + + def _run(self): + # initialize RNG with seed from config + nlopt.srand(self.seed) + self.get_initial_data() + + for step in range(self.num_steps): + self.cur_step = step + + # disable range sliders for each step once that step has begun + self.optparamwin.toggle_enable_user_fields(step, enable=False) + + self.step_ranges = self.optparamwin.get_chunk_ranges(step) + self.step_sims = self.optparamwin.get_sims_for_chunk(step) + + if self.step_sims == 0: + txt = "Skipping optimization step %d (0 simulations)" % \ + (step + 1) + self._updatewaitsimwin(txt) + continue + + if len(self.step_ranges) == 0: + txt = "Skipping optimization step %d (0 parameters)" % \ + (step + 1) + self._updatewaitsimwin(txt) + continue + + txt = "Starting optimization step %d/%d" % (step + 1, + self.num_steps) + self._updatewaitsimwin(txt) + print(txt) + + opt_results = self.run_opt_step() + + # update with optimzed params for the next round + for var_name, new_value in zip(self.step_ranges, opt_results): + old_value = self.step_ranges[var_name]['initial'] + + # only change the parameter value if it changed significantly + if not isclose(old_value, new_value, abs_tol=1e-9): + self.step_ranges[var_name]['final'] = new_value + else: + self.step_ranges[var_name]['final'] = \ + self.step_ranges[var_name]['initial'] + + # push into GUI and save to param file so that next simulation + # starts from there. + push_values = {} + for param_name in self.step_ranges.keys(): + push_values[param_name] = self.step_ranges[param_name]['final'] + self.baseparamwin.update_gui_params(push_values) + + # update optimization dialog window + self.optparamwin.push_chunk_ranges(push_values) + + # update opt_data with the final best + update_event = Event() + self.update_sim_data_from_opt_data.esig.emit(update_event, + self.paramfn) + update_event.wait() + + # check that optimization improved RMSE + err_queue = Queue() + self.get_err_from_sim_data.qsig.emit(err_queue, self.paramfn, + self.params['tstop']) + final_err = err_queue.get() + if final_err > self.initial_err: + txt = "Warning: optimization failed to improve RMSE below" + \ + " %.2f. Reverting to old parameters." % \ + round(self.initial_err, 2) + self._updatewaitsimwin(txt) + print(txt) + + initial_params = self.optparamwin.get_initial_params() + # populate param values into GUI and save params to file + self.baseparamwin.update_gui_params(initial_params) + + # update optimization dialog window + self.optparamwin.push_chunk_ranges(initial_params) + + # run a full length simulation + self.sim_thread = SimThread(self.ncore, self.params, + self.result_callback, + mainwin=self.mainwin) + self.sim_running = True + try: + self.sim_thread.run() + self.sim_thread.wait() + if self.killed: + self.quit() + self.sim_running = False + except Exception: + traceback.print_exc() + raise RuntimeError("Failed to run final simulation. " + "See previous traceback.") + + def run_opt_step(self): + self.cur_itr = 0 + self.opt_start = self.optparamwin.get_chunk_start(self.cur_step) + self.opt_end = self.optparamwin.get_chunk_end(self.cur_step) + txt = 'Optimizing from [%3.3f-%3.3f] ms' % (self.opt_start, + self.opt_end) + self._updatewaitsimwin(txt) + print(txt) + + # weights calculated once per step + self.opt_weights = \ + self.optparamwin.get_chunk_weights(self.cur_step) + + # run an opt step + algorithm = nlopt.LN_COBYLA + self.num_params = len(self.step_ranges) + self.opt = nlopt.opt(algorithm, self.num_params) + opt_results = self.optimize(self.step_ranges, self.step_sims, + algorithm) + + return opt_results + + def get_initial_data(self): + # Has this simulation been run before (is there data?) + if not self.sim_data.in_sim_data(self.paramfn): + # run a full length simulation + txt = "Running a simulation with initial parameter set before" + \ + " beginning optimization." + self._updatewaitsimwin(txt) + print(txt) + + self.sim_thread = SimThread(self.ncore, self.params, + self.result_callback, + mainwin=self.mainwin) + self.sim_running = True + try: + self.sim_thread.run() + self.sim_thread.wait() + if self.killed: + self.quit() + self.sim_running = False + except Exception: + traceback.print_exc() + raise RuntimeError("Failed to run initial simulation. " + "See previous traceback.") + + # results are in self.sim_data now + + # store the initial fit for display in final dipole plot as + # black dashed line. + update_event = Event() + self.update_opt_data_from_sim_data.esig.emit(update_event, + self.paramfn) + update_event.wait() + update_event.clear() + self.update_initial_opt_data_from_sim_data.esig.emit(update_event, + self.paramfn) + update_event.wait() + + err_queue = Queue() + self.get_err_from_sim_data.qsig.emit(err_queue, self.paramfn, + self.params['tstop']) + self.initial_err = err_queue.get() + + def opt_sim(self, new_params, grad=0): + txt = "Optimization step %d, simulation %d" % (self.cur_step + 1, + self.cur_itr + 1) + self._updatewaitsimwin(txt) + print(txt) + + # Prepare a dict of parameters for this simulation to populate in GUI + opt_params = {} + for param_name, param_value in zip(self.step_ranges.keys(), + new_params): + if param_value >= self.step_ranges[param_name]['minval'] and \ + param_value <= self.step_ranges[param_name]['maxval']: + opt_params[param_name] = param_value + else: + # This test is not strictly necessary with COBYLA, but in + # case the algorithm is changed at some point in the future + print('INFO: optimization chose ' + '%.3f for %s outside of [%.3f-%.3f].' + % (param_value, param_name, + self.step_ranges[param_name]['minval'], + self.step_ranges[param_name]['maxval'])) + return 1e9 # invalid param value -> large error + + # populate param values into GUI + self.baseparamwin.update_gui_params(opt_params) + + sim_params = hnn_core_compat_params(self.params) + for param_name, param_value in opt_params.items(): + sim_params[param_name] = param_value + + # run the simulation, but stop at self.opt_end + self.sim_thread = SimThread(self.ncore, sim_params, + self.result_callback, + mainwin=self.mainwin) + + self.sim_running = True + try: + # may not need to run the entire simulation + self.sim_thread.run(sim_length=self.opt_end) + self.sim_thread.wait() + if self.killed: + self.quit() + self.sim_running = False + except Exception: + traceback.print_exc() + raise RuntimeError("Failed to run simulation. " + "See previous traceback.") + + # calculate wRMSE for all steps + err_queue = Queue() + self.get_werr_from_sim_data.qsig.emit(err_queue, self.paramfn, + self.opt_weights, self.opt_end, + self.opt_start) + werr = err_queue.get() + + txt = "Weighted RMSE = %f" % werr + print(txt) + self._updatewaitsimwin(os.linesep + 'Simulation finished: ' + txt + + os.linesep) + + # save params numbered by cur_itr + # data_dir = op.join(get_output_dir(), 'data') + # sim_dir = op.join(data_dir, self.params['sim_prefix']) + # param_out = os.path.join(sim_dir, 'step_%d_sim_%d.param' % + # (self.cur_step, self.cur_itr)) + # write_legacy_paramf(param_out, self.params) + + if werr < self.best_step_werr: + self._updatewaitsimwin("new best with RMSE %f" % werr) + + update_event = Event() + self.update_opt_data_from_sim_data.esig.emit(update_event, + self.paramfn) + update_event.wait() + + self.best_step_werr = werr + # save best param file + # param_out = os.path.join(sim_dir, 'step_%d_best.param' % + # self.cur_step) + # write_legacy_paramf(param_out, self.params) + + if self.cur_itr == 0 and self.cur_step > 0: + # Update plots for the first simulation only of this step + # (best results from last round). Skip the first step because + # there are no optimization results to show yet. + self.refresh_signal.sig.emit() # redraw with updated RMSE + + self.cur_itr += 1 + + return werr + + def optimize(self, params_input, num_sims, algorithm): + opt_params = [] + lb = [] + ub = [] + + for param_name in params_input.keys(): + upper = params_input[param_name]['maxval'] + lower = params_input[param_name]['minval'] + if upper == lower: + continue + + ub.append(upper) + lb.append(lower) + opt_params.append(params_input[param_name]['initial']) + + if algorithm == nlopt.G_MLSL_LDS or algorithm == nlopt.G_MLSL: + # In case these mixed mode (global + local) algorithms are + # used in the future + local_opt = nlopt.opt(nlopt.LN_COBYLA, self.num_params) + self.opt.set_local_optimizer(local_opt) + + self.opt.set_lower_bounds(lb) + self.opt.set_upper_bounds(ub) + + # minimize the wRMSE returned by self.opt_sim + self.opt.set_min_objective(self.opt_sim) + self.opt.set_xtol_rel(1e-4) + self.opt.set_maxeval(num_sims) + + # start the optimization: run self.runsim for # iterations in num_sims + opt_results = self.opt.optimize(opt_params) + + return opt_results diff --git a/hnn/qt_vsoma.py b/hnn/qt_vsoma.py new file mode 100644 index 000000000..86867dfcd --- /dev/null +++ b/hnn/qt_vsoma.py @@ -0,0 +1,152 @@ +"""Create the somatic voltage viewing window""" + +# Authors: Sam Neymotin +# Blake Caldwell + +from PyQt5.QtWidgets import QSizePolicy + +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure +import matplotlib.gridspec as gridspec + +from .DataViewGUI import DataViewGUI + +fontsize = plt.rcParams['font.size'] = 10 +random_label = np.random.rand(100) + + +class VSomaCanvas(FigureCanvasQTAgg): + """Class for the somatic voltages viewer + + This is designed to be called from VSomaViewGUI class to add functionality + for loading and clearing data + """ + + def __init__(self, params, sim_data, index, parent=None, width=12, + height=10, dpi=120, title='Somatic Voltage Viewer'): + FigureCanvasQTAgg.__init__(self, Figure(figsize=(width, height), + dpi=dpi)) + self.title = title + self.setParent(parent) + self.gui = parent + self.index = index + FigureCanvasQTAgg.setSizePolicy(self, QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + self.params = params + self.invertedax = False + self.G = gridspec.GridSpec(10, 1) + + self.sim_data = sim_data + # colors for the different cell types + self.dclr = {'L2_pyramidal': 'g', + 'L5_pyramidal': 'r', + 'L2_basket': 'w', + 'L5_basket': 'b'} + + self.plot() + + def drawvolt(self, dvolt, times, fig, G, maxperty=10, ltextra=''): + """Create voltage plots + + + Parameters + ---------- + + dvolt: dict + Dictionary with somatic voltages for a single trial. Keys are + gids and value is array of somatic voltages at each timestep + times: array + 1-D array containing the corresponding times to each somatic + voltage + fig: Figure object + The figure to plot voltages + G: GridSpec object + Grid on which to place axes + maxperty: int + How many cells of a type to draw. If None, 10 cells will be shown + ltextra: str + String containing title of window + """ + + global random_label + + ax = fig.add_subplot(G[0:-1, :], label=random_label) + random_label += 1 + + dcnt = {} # counts number of times cell of a type drawn + yoff = 0 + + for gid_type in self.dclr.keys(): + for gid in self.sim_data['gid_ranges'][gid_type]: + if gid_type not in dcnt: + dcnt[gid_type] = 0 + elif dcnt[gid_type] > maxperty: + continue + vsoma = np.array(dvolt[gid]) + ax.plot(times, -vsoma + yoff, self.dclr[gid_type], + linewidth=self.gui.linewidth) + yoff += max(vsoma) - min(vsoma) + dcnt[gid_type] += 1 + + white_patch = mpatches.Patch(color='white', label='L2/3 Basket') + green_patch = mpatches.Patch(color='green', label='L2/3 Pyr') + red_patch = mpatches.Patch(color='red', label='L5 Pyr') + blue_patch = mpatches.Patch(color='blue', label='L5 Basket') + ax.legend(handles=[white_patch, green_patch, blue_patch, red_patch]) + + if not self.invertedax: + ax.set_ylim(ax.get_ylim()[::-1]) + self.invertedax = True + + ax.set_yticks([]) + + ax.set_facecolor('k') + ax.grid(True) + if self.params['tstop'] > 0: + ax.set_xlim((0, self.params['tstop'])) + + ax.set_title(ltextra) + ax.set_xlabel('Time (ms)') + + def plot(self): + if len(self.sim_data['vsoma']) == 0: + # data hasn't been loaded yet + return + + ltextra = 'Trial ' + str(self.index) + volt_data = self.sim_data['vsoma'][self.index] + times = self.sim_data['dpls'][0].times + + self.drawvolt(volt_data, times, self.figure, self.G, + ltextra=ltextra) + self.figure.subplots_adjust(bottom=0.01, left=0.01, right=0.99, + top=0.99, wspace=0.1, hspace=0.09) + + self.draw() + + +class VSomaViewGUI(DataViewGUI): + """Class for displaying somatic voltages viewer + + Required parameters in params dict: N_trials, tstop + """ + def updateCB(self): + self.cb.clear() + for i in range(self.ntrial): + self.cb.addItem('Show Trial ' + str(i + 1)) + self.cb.activated[int].connect(self.onActivated) + + def onActivated(self, idx): + if idx != self.index: + self.index = idx + self.statusBar().showMessage('Loading data from trial ' + + str(self.index) + '.') + self.m.index = self.index + self.initCanvas() + self.m.plot() + self.statusBar().showMessage('') diff --git a/hnn/simdata.py b/hnn/simdata.py new file mode 100644 index 000000000..966c5e960 --- /dev/null +++ b/hnn/simdata.py @@ -0,0 +1,779 @@ +import os +import numpy as np +from math import ceil +from glob import glob +from pickle import dump, load +from copy import deepcopy + +from scipy import signal +import matplotlib as mpl +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +from hnn_core import read_spikes +from hnn_core.dipole import read_dipole, average_dipoles + +from .spikefn import ExtInputs +from .specfn import plot_spec +from .paramrw import get_output_dir, get_fname, get_inputs +from .paramrw import read_gids_param + +drawindivdpl = 1 +drawavgdpl = 1 +fontsize = plt.rcParams['font.size'] = 10 + + +def get_dipoles_from_disk(sim_dir, ntrials): + """Read dipole trial data from disk + + Parameters + ---------- + sim_dir : str + Path of simulation data directory + ntrials : int + Number of trials expected to be read from disk + + Returns + ---------- + dpls: list of Dipole objects + List containing Dipoles of each trial + """ + + dpls = [] + dpl_fname_pattern = os.path.join(sim_dir, 'dpl_*.txt') + glob_list = sorted(glob(str(dpl_fname_pattern))) + if len(glob_list) == 0: + # get the old style filename + glob_list = [get_fname(sim_dir, 'normdpl')] + + for dipole_fn in glob_list: + dpl_trial = None + try: + dpl_trial = read_dipole(dipole_fn) + except OSError: + if os.path.exists(sim_dir): + print('Warning: could not read file:', dipole_fn) + except ValueError: + if os.path.exists(sim_dir): + print('Warning: could not read file:', dipole_fn) + + dpls.append(dpl_trial) + + if len(dpls) == 0: + print("Warning: no dipole(s) read from %s" % sim_dir) + + if len(dpls) < ntrials: + print("Warning: only read %d of %d dipole files in %s" % + (len(dpls), ntrials, sim_dir)) + + return dpls + + +def read_spectrials(sim_dir): + """read spectrogram data files for individual trials""" + spec_list = [] + + spec_fname_pattern = os.path.join(sim_dir, 'rawspec_*.npz') + glob_list = sorted(glob(str(spec_fname_pattern))) + if len(glob_list) == 0: + # get the old style filename + glob_list = [get_fname(sim_dir, 'rawspec')] + + for spec_fn in glob_list: + spec_trial = None + try: + with np.load(spec_fn, allow_pickle=True) as spec_data: + # need to make a copy of data so we can close NpzFile + spec_trial = dict(spec_data) + except OSError: + if os.path.exists(sim_dir): + print('Warning: could not read file:', spec_fn) + except ValueError: + if os.path.exists(sim_dir): + print('Warning: could not read file:', spec_fn) + + spec_list.append(spec_trial) + + return spec_list + + +def read_vsomatrials(sim_dir): + """read somatic voltage data files for individual trials""" + vsoma_list = [] + + vsoma_fname_pattern = os.path.join(sim_dir, 'vsoma_*.pkl') + glob_list = sorted(glob(str(vsoma_fname_pattern))) + if len(glob_list) == 0: + # get the old style filename + glob_list = [get_fname(sim_dir, 'vsoma')] + + for vsoma_fn in glob_list: + vsoma_trial = None + try: + with open(vsoma_fn, 'rb') as f: + vsoma_trial = load(f) + except OSError: + if os.path.exists(sim_dir): + print('Warning: could not read file:', vsoma_fn) + except ValueError: + if os.path.exists(sim_dir): + print('Warning: could not read file:', vsoma_fn) + + vsoma_list.append(vsoma_trial) + + return vsoma_list + + +def read_spktrials(sim_dir, gid_ranges): + spk_fname_pattern = os.path.join(sim_dir, 'spk_*.txt') + if len(glob(str(spk_fname_pattern))) == 0: + # if legacy HNN only ran one trial, then no spk_0.txt gets written + spk_fname_pattern = get_fname(sim_dir, 'rawspk') + + try: + spikes = read_spikes(spk_fname_pattern, gid_ranges) + except FileNotFoundError: + print('Warning: could not read file:', spk_fname_pattern) + + return spikes + + +def check_feeds_to_plot(feeds_from_spikes, params): + # ensures synaptic weight > 0 + using_feeds = get_inputs(params) + + feed_counts = {'pois': len(feeds_from_spikes['pois']) > 0, + 'evdist': len(feeds_from_spikes['evdist']) > 0, + 'evprox': len(feeds_from_spikes['evprox']) > 0, + 'dist': len(feeds_from_spikes['dist']) > 0, + 'prox': len(feeds_from_spikes['prox']) > 0} + feed_counts['evoked'] = feed_counts['evdist'] or feed_counts['evprox'] + feed_counts['ongoing'] = feed_counts['dist'] or feed_counts['prox'] + + feeds_to_plot = {} + for key in feed_counts.keys(): + feeds_to_plot[key] = feed_counts[key] and using_feeds[key] + + return feeds_to_plot + + +def plot_hists_on_gridspec(figure, gridspec, feeds_to_plot, extinputs, times, + xlim, linewidth): + + axdist = axpois = axprox = None + axes = [] + n_hists = 0 + + # check poisson inputs, create subplot + if feeds_to_plot['pois']: + axpois = figure.add_subplot(gridspec[n_hists, :]) + n_hists += 1 + + # check distal inputs, create subplot + if feeds_to_plot['evdist'] or feeds_to_plot['dist']: + axdist = figure.add_subplot(gridspec[n_hists, :]) + n_hists += 1 + + # check proximal inputs, create subplot + if feeds_to_plot['evprox'] or feeds_to_plot['prox']: + axprox = figure.add_subplot(gridspec[n_hists, :]) + n_hists += 1 + + plot_linewidth = linewidth + 1 + # check input types provided in simulation + if feeds_to_plot['pois']: + extinputs.plot_hist(axpois, 'pois', times, 'auto', xlim, + color='k', hty='step', + lw=plot_linewidth) + axes.append(axpois) + if feeds_to_plot['dist']: + extinputs.plot_hist(axdist, 'dist', times, 'auto', xlim, + color='g', lw=plot_linewidth) + axes.append(axdist) + if feeds_to_plot['prox']: + extinputs.plot_hist(axprox, 'prox', times, 'auto', xlim, + color='r', lw=plot_linewidth) + axes.append(axprox) + if feeds_to_plot['evdist']: + extinputs.plot_hist(axdist, 'evdist', times, 'auto', xlim, + color='g', hty='step', + lw=plot_linewidth) + axes.append(axdist) + if feeds_to_plot['evprox']: + extinputs.plot_hist(axprox, 'evprox', times, 'auto', xlim, + color='r', hty='step', + lw=plot_linewidth) + axes.append(axprox) + + # get the ymax for the two histograms + ymax = 0 + for ax in [axpois, axdist, axprox]: + if ax is not None: + if ax.get_ylim()[1] > ymax: + ymax = ax.get_ylim()[1] + + # set ymax for both to be the same + for ax in [axpois, axdist, axprox]: + if ax is not None: + ax.set_ylim(0, ymax) + ax.set_xlim(xlim) + ax.legend(loc=1) # legend in upper right + + # invert the distal input axes + if axdist is not None: + axdist.invert_yaxis() + + return axes + + +class SimData(object): + """The SimData class""" + + def __init__(self): + self._sim_data = {} + self._opt_data = {} + self._exp_data = {} + self._data_dir = os.path.join(get_output_dir(), 'data') + + def remove_sim_by_fn(self, paramfn): + """Deletes sim from SimData + + Parameters + ---------- + paramfn : str + Filename of parameter file to remove + """ + del self._sim_data[paramfn] + + def update_sim_data(self, paramfn, params, dpls, avg_dpl, spikes, + gid_ranges, spec=None, vsoma=None): + self._sim_data[paramfn] = {'params': params, + 'data': {'dpls': dpls, + 'avg_dpl': avg_dpl, + 'spikes': spikes, + 'gid_ranges': gid_ranges, + 'spec': spec, 'vsoma': vsoma}} + + def clear_exp_data(self): + """Clear all experimental data from SimData""" + + self._exp_data = {} + + def clear_sim_data(self): + """Clear all simulation data from SimData""" + + self._sim_data.clear() + + def update_exp_data(self, exp_fn, exp_data): + """Adds experimental data to SimData + + Parameters + ---------- + exp_fn : str + Filename of experimental data + exp_data : array + Data from np.loadtxt() on experimental data file + """ + self._exp_data[exp_fn] = exp_data + + def get_exp_data_size(self): + """Adds experimental data to SimData + + Returns + ---------- + length: int + The number of experimental data files in SimData + """ + + return len(self._exp_data) + + def update_sim_data_from_disk(self, paramfn, params): + """Adds simulation data to SimData + + Parameters + ---------- + paramfn : str + Simulation parameter filename + params : dict + Dictionary containing parameters + + Returns + ---------- + success: bool + Whether simulation data could be read + """ + + sim_dir = os.path.join(self._data_dir, params['sim_prefix']) + if not os.path.exists(sim_dir): + self.update_sim_data(paramfn, params, None, None, None, None) + return False + + dpls = get_dipoles_from_disk(sim_dir, params['N_trials']) + if len(dpls) == 0: + self.update_sim_data(paramfn, params, None, None, None, None) + return False + elif len(dpls) == 1: + avg_dpl = dpls[0] + else: + avg_dpl = average_dipoles(dpls) + + warning_message = 'Warning: could not read file:' + # gid_ranges + paramtxt_fn = get_fname(sim_dir, 'param') + try: + gid_ranges = read_gids_param(paramtxt_fn) + except FileNotFoundError: + print(warning_message, paramtxt_fn) + return False + + # spikes + spikes = read_spktrials(sim_dir, gid_ranges) + if len(spikes.spike_times) == 0: + print("Warning: no spikes read from %s" % sim_dir) + elif len(spikes.spike_times) < params['N_trials']: + print("Warning: only read %d of %d spike files in %s" % + (len(spikes.spike_times), params['N_trials'], sim_dir)) + + # spec data + spec = None + using_feeds = get_inputs(params) + if params['save_spec_data'] or using_feeds['ongoing'] or \ + using_feeds['pois'] or using_feeds['tonic']: + spec = read_spectrials(sim_dir) + if len(spec) == 0: + print("Warning: no spec data read from %s" % sim_dir) + elif len(spec) < params['N_trials']: + print("Warning: only read %d of %d spec files in %s" % + (len(spec), params['N_trials'], sim_dir)) + + # somatic voltages + vsoma = None + if params['record_vsoma']: + vsoma = read_vsomatrials(sim_dir) + if len(vsoma) == 0: + print("Warning: no somatic voltages read from %s" % sim_dir) + elif len(vsoma) < params['N_trials']: + print("Warning: only read %d of %d voltage files in %s" % + (len(vsoma), params['N_trials'], sim_dir)) + + self.update_sim_data(paramfn, params, dpls, avg_dpl, spikes, + gid_ranges, spec, vsoma) + + return True + + def calcerr(self, paramfn, tstop, tstart=0.0, weights=None): + """Calculate root mean squared error using SimData + + Parameters + ---------- + paramfn : str + Simulation parameter filename to calculate RMSE for against + experimental data previously loaded in SimData + tstop : float + Time in ms defining the end of the region to calculate RMSE + tstart : float | None + Time in ms defining the start of the region to calculate RMSE. + If None is provided, this defaults to 0.0 ms. + weights : array | None + An array containing weights for each data point of the simulation. + If weights is provided, then the weighted root mean square error + will be returned. If None is provided, then standard RMSE will be + returned. + + Returns + ---------- + lerr : list of floats + A list of RMSE values between the simulation and each experimental + data files stored in SimData + errtot : float + Average RMSE over all experimental data files + """ + + NSig = errtot = 0.0 + lerr = [] + for _, dat in self._exp_data.items(): + shp = dat.shape + + exp_times = dat[:, 0] + sim_times = self._sim_data[paramfn]['data']['avg_dpl'].times + + # do tstart and tstop fall within both datasets? + # if not, use the closest data point as the new tstop/tstart + for tseries in [exp_times, sim_times]: + if tstart < tseries[0]: + tstart = tseries[0] + if tstop > tseries[-1]: + tstop = tseries[-1] + + # make sure start and end times are valid for both dipoles + exp_start_index = (np.abs(exp_times - tstart)).argmin() + exp_end_index = (np.abs(exp_times - tstop)).argmin() + exp_length = exp_end_index - exp_start_index + + sim_start_index = (np.abs(sim_times - tstart)).argmin() + sim_end_index = (np.abs(sim_times - tstop)).argmin() + sim_length = sim_end_index - sim_start_index + + if weights is not None: + weight = weights[sim_start_index:sim_end_index] + + for c in range(1, shp[1], 1): + sim_dpl = self._sim_data[paramfn]['data']['avg_dpl'] + dpl1 = sim_dpl.data['agg'][sim_start_index:sim_end_index] + dpl2 = dat[exp_start_index:exp_end_index, c] + + if (sim_length > exp_length): + # downsample simulation timeseries to match exp data + dpl1 = signal.resample(dpl1, exp_length) + if weights is not None: + weight = signal.resample(weight, exp_length) + indices = np.where(weight < 1e-4) + weight[indices] = 0 + + elif (sim_length < exp_length): + # downsample exp timeseries to match simulation data + dpl2 = signal.resample(dpl2, sim_length) + + if weights is not None: + err0 = np.sqrt((weight * ((dpl1 - dpl2) ** 2)).sum() / + weight.sum()) + else: + err0 = np.sqrt(((dpl1 - dpl2) ** 2).mean()) + lerr.append(err0) + errtot += err0 + NSig += 1 + + if not NSig == 0.0: + errtot /= NSig + return lerr, errtot + + def clear_opt_data(self): + self._initial_opt = {} + self._opt_data = {} + + def in_sim_data(self, paramfn): + if paramfn in self._sim_data: + return True + return False + + def update_opt_data(self, paramfn, params, avg_dpl, dpls=None, + spikes=None, gid_ranges=None, spec=None, + vsoma=None): + self._opt_data = {'paramfn': paramfn, + 'params': params, + 'data': {'dpls': None, + 'avg_dpl': avg_dpl, + 'spikes': None, + 'gid_ranges': None, + 'spec': None, + 'vsoma': None}} + + def update_initial_opt_data_from_sim_data(self, event, paramfn): + if paramfn not in self._sim_data: + raise ValueError("Simulation not in sim_data: %s" % paramfn) + + single_sim_data = self._sim_data[paramfn]['data'] + self._opt_data['initial_dpl'] = \ + deepcopy(single_sim_data['avg_dpl']) + self._opt_data['initial_error'] = self.get_err(paramfn) + + event.set() + + def get_err(self, paramfn, tstop=None): + if paramfn not in self._sim_data: + raise ValueError("Simulation not in sim_data: %s" % paramfn) + + if tstop is None: + tstop = self._sim_data[paramfn]['params']['tstop'] + _, err = self.calcerr(paramfn, tstop) + return err + + def get_err_wrapper(self, queue, paramfn, tstop=None): + err = self.get_err(paramfn, tstop) + queue.put(err) + + def get_werr(self, paramfn, weights, tstop=None, tstart=None): + if paramfn not in self._sim_data: + raise ValueError("Simulation not in sim_data: %s" % paramfn) + + if tstop is None: + tstop = self._sim_data[paramfn]['params']['tstop'] + _, werr = self.calcerr(paramfn, tstop, tstart, weights) + return werr + + def get_werr_wrapper(self, queue, paramfn, weights, tstop=None, + tstart=None): + err = self.get_werr(paramfn, weights, tstop, tstart) + queue.put(err) + + def update_opt_data_from_sim_data(self, event, paramfn): + if paramfn not in self._sim_data: + raise ValueError("Simulation not in sim_data: %s" % paramfn) + + sim_params = self._sim_data[paramfn]['params'] + single_sim = self._sim_data[paramfn]['data'] + self._opt_data = {'paramfn': paramfn, + 'params': deepcopy(sim_params), + 'data': {'dpls': deepcopy(single_sim['dpls']), + 'avg_dpl': deepcopy(single_sim['avg_dpl']), + 'spikes': deepcopy(single_sim['spikes']), + 'gid_ranges': + deepcopy(single_sim['gid_ranges']), + 'spec': deepcopy(single_sim['spec']), + 'vsoma': deepcopy(single_sim['vsoma'])}} + + event.set() + + def update_sim_data_from_opt_data(self, event, paramfn): + opt_data = self._opt_data['data'] + single_sim = {'paramfn': paramfn, + 'params': deepcopy(self._opt_data['params']), + 'data': {'dpls': deepcopy(opt_data['dpls']), + 'avg_dpl': deepcopy(opt_data['avg_dpl']), + 'spikes': deepcopy(opt_data['spikes']), + 'gid_ranges': + deepcopy(opt_data['gid_ranges']), + 'spec': deepcopy(opt_data['spec']), + 'vsoma': deepcopy(opt_data['vsoma'])}} + self._sim_data[paramfn] = single_sim + + event.set() + + def _read_dpl(self, paramfn, trial_idx, ntrial): + if ntrial == 1: + dpltrial = self._sim_data[paramfn]['data']['avg_dpl'] + + else: + trial_data = self._sim_data[paramfn]['data']['dpls'] + if trial_idx > len(trial_data): + print("Warning: data not available for trials above index", + (len(trial_data) - 1)) + return None + + dpltrial = self._sim_data[paramfn]['data']['dpls'][trial_idx] + + return dpltrial + + def save_spec_with_hist(self, paramfn, params): + sim_dir = os.path.join(self._data_dir, params['sim_prefix']) + ntrial = params['N_trials'] + xmin = 0. + xmax = params['tstop'] + xlim = (xmin, xmax) + linewidth = 1 + + num_step = ceil(xmax / params['dt']) + 1 + times = np.linspace(xmin, xmax, num_step) + + for trial_idx in range(ntrial): + f = plt.figure(figsize=(8, 8)) + font_prop = {'size': 8} + mpl.rc('font', **font_prop) + + # get inputs from spike file + gid_ranges = self._sim_data[paramfn]['data']['gid_ranges'] + spikes = self._sim_data[paramfn]['data']['spikes'][trial_idx] + extinputs = ExtInputs(spikes, gid_ranges, [trial_idx], params) + feeds_to_plot = check_feeds_to_plot(extinputs.inputs, params) + + if feeds_to_plot['ongoing'] or feeds_to_plot['evoked'] or \ + feeds_to_plot['pois']: + # hist gridspec + gs2 = gridspec.GridSpec(2, 1, hspace=0.14, bottom=0.75, + top=0.95, left=0.1, right=0.82) + + plot_hists_on_gridspec(f, gs2, feeds_to_plot, extinputs, + times, xlim, linewidth) + + # the right margin is a hack and NOT guaranteed! + # it's making space for the stupid colorbar that creates a new + # grid to replace gs1 when called, and it doesn't update the + # params of gs1 + gs0 = gridspec.GridSpec(1, 4, wspace=0.05, hspace=0., bottom=0.05, + top=0.45, left=0.1, right=1.) + gs1 = gridspec.GridSpec(2, 1, height_ratios=[1, 3], bottom=0.50, + top=0.70, left=0.1, right=0.82) + + axspec = f.add_subplot(gs0[:, :]) + axdipole = f.add_subplot(gs1[:, :]) + + spec_data = self._sim_data[paramfn]['data']['spec'] + cax = plot_spec(axspec, spec_data, ntrial, + params['spec_cmap'], xlim, fontsize) + f.colorbar(cax, ax=axspec) + + # set xlim based on TFR plot + # xlim_new = axspec.get_xlim() + + # dipole + dpl = self._read_dpl(paramfn, trial_idx, ntrial) + if dpl is None: + break + dpl.plot(layer='agg', ax=axdipole, show=False) + axdipole.set_xlim(xlim) + + spec_fig_fn = get_fname(sim_dir, 'figspec', trial_idx) + f.savefig(spec_fig_fn, dpi=300) + plt.close(f) + + def save_dipole_with_hist(self, paramfn, params): + sim_dir = os.path.join(self._data_dir, params['sim_prefix']) + ntrial = params['N_trials'] + xmin = 0. + xmax = params['tstop'] + xlim = (xmin, xmax) + linewidth = 1 + + num_step = ceil(xmax / params['dt']) + 1 + times = np.linspace(xmin, xmax, num_step) + + for trial_idx in range(ntrial): + f = plt.figure(figsize=(12, 6)) + font_prop = {'size': 8} + mpl.rc('font', **font_prop) + + gid_ranges = self._sim_data[paramfn]['data']['gid_ranges'] + spikes = self._sim_data[paramfn]['data']['spikes'][trial_idx] + extinputs = ExtInputs(spikes, gid_ranges, [trial_idx], params) + feeds_to_plot = check_feeds_to_plot(extinputs.inputs, params) + + if feeds_to_plot['ongoing'] or feeds_to_plot['evoked'] or \ + feeds_to_plot['pois']: + # hist gridspec + gs1 = gridspec.GridSpec(2, 1, hspace=0.14, bottom=0.60, + top=0.95, left=0.1, right=0.90) + + plot_hists_on_gridspec(f, gs1, feeds_to_plot, extinputs, + times, xlim, linewidth) + + # dipole gridpec + gs0 = gridspec.GridSpec(1, 1, wspace=0.05, hspace=0, bottom=0.10, + top=0.55, left=0.1, right=0.90) + axdipole = f.add_subplot(gs0[:, :]) + + # dipole + dpl = self._read_dpl(paramfn, trial_idx, ntrial) + if dpl is None: + break + dpl.plot(layer='agg', ax=axdipole, show=False) + axdipole.set_xlim(xlim) + + dipole_fig_fn = get_fname(sim_dir, 'figdpl', trial_idx) + f.savefig(dipole_fig_fn, dpi=300) + plt.close(f) + + def save_vsoma(self, paramfn, params): + ntrial = params['N_trials'] + sim_dir = os.path.join(self._data_dir, params['sim_prefix']) + current_sim_data = self._sim_data[paramfn]['data'] + + for trial_idx in range(ntrial): + vsoma_outfn = get_fname(sim_dir, 'vsoma', trial_idx) + + if trial_idx + 1 > len(current_sim_data['vsoma']): + raise ValueError("No vsoma data for trial %d" % trial_idx) + + vsoma = current_sim_data['vsoma'][trial_idx] + + # store tvec with voltages. it will be the same for + # all trials + vsoma['vtime'] = current_sim_data['dpls'][0].times + with open(str(vsoma_outfn), 'wb') as f: + dump(vsoma, f) + + def plot_dipole(self, paramfn, ax, linewidth, dipole_scalefctr, N_pyr_x=0, + N_pyr_y=0, is_optimization=False): + """Plot the dipole(s) HNN style + + Parameters + ---------- + paramfn : str + Simulation parameter filename to lookup data from prior simulation + ax : axis object + Axis on which to plot dipoles(s) + linewidth : int + Base width for dipole lines. Averages will be one size larger + dipole_scalefctr : float + Scaling factor applied to dipole data + N_pyr_x : int + Nr of cells (x) + N_pyr_y : int + Nr of cells (y) + is_optimization : bool + True if plots should be specific for optimization results + """ + + yl = [0, 0] + dpl = self._sim_data[paramfn]['data']['avg_dpl'] + yl[0] = min(yl[0], np.amin(dpl.data['agg'])) + yl[1] = max(yl[1], np.amax(dpl.data['agg'])) + + if not is_optimization: + # plot average dipoles from prior simulations + old_dpl = self._sim_data[paramfn]['data']['avg_dpl'] + ax.plot(old_dpl.times, old_dpl.data['agg'], '--', color='black', + linewidth=linewidth) + + sim_data = self._sim_data[paramfn]['data'] + ntrial = len(sim_data['dpls']) + # plot dipoles from individual trials + if ntrial > 1 and drawindivdpl: + for dpltrial in sim_data['dpls']: + ax.plot(dpltrial.times, dpltrial.data['agg'], + color='gray', + linewidth=linewidth) + yl[0] = min(yl[0], dpltrial.data['agg'].min()) + yl[1] = max(yl[1], dpltrial.data['agg'].max()) + + if drawavgdpl or ntrial == 1: + # this is the average dipole (across trials) + # it's also the ONLY dipole when running a single trial + ax.plot(dpl.times, dpl.data['agg'], 'k', + linewidth=linewidth + 1) + yl[0] = min(yl[0], dpl.data['agg'].min()) + yl[1] = max(yl[1], dpl.data['agg'].max()) + else: + if 'avg_dpl' not in self._opt_data or \ + 'initial_dpl' not in self._opt_data: + # if there was an exception running optimization + # still plot average dipole from sim + ax.plot(dpl.times, dpl.data['agg'], 'k', + linewidth=linewidth + 1) + yl[0] = min(yl[0], dpl.data['agg'].min()) + yl[1] = max(yl[1], dpl.data['agg'].max()) + else: + if self._opt_data['avg_dpl'] is not None: + # show optimized dipole as gray line + optdpl = self._opt_data['avg_dpl'] + ax.plot(optdpl.times, optdpl.data['agg'], 'k', + color='gray', linewidth=linewidth + 1) + yl[0] = min(yl[0], optdpl.data['agg'].min()) + yl[1] = max(yl[1], optdpl.data['agg'].max()) + + if self._opt_data['initial_dpl'] is not None: + # show initial dipole in dotted black line + plot_data = self._opt_data['initial_dpl'] + times = plot_data.times + plot_dpl = plot_data.data['agg'] + ax.plot(times, plot_dpl, '--', color='black', + linewidth=linewidth) + dpl = self._opt_data['initial_dpl'].data['agg'] + yl[0] = min(yl[0], dpl.min()) + yl[1] = max(yl[1], dpl.max()) + + # get the number of pyramidal neurons used in the simulation and + # multiply by scale factor to get estimated number of pyramidal + # neurons for y-axis label + num_pyr = int(N_pyr_x * N_pyr_y * 2) + NEstPyr = int(num_pyr * float(dipole_scalefctr)) + if NEstPyr > 0: + ax.set_ylabel(r'Dipole (nAm $\times$ ' + + str(dipole_scalefctr) + + ')\nFrom Estimated ' + + str(NEstPyr) + ' Cells', fontsize=fontsize) + else: + # is this handling overflow? + ax.set_ylabel(r'Dipole (nAm $\times$ ' + + str(dipole_scalefctr) + + ')\n', fontsize=fontsize) + ax.set_ylim(yl) diff --git a/hnn/specfn.py b/hnn/specfn.py new file mode 100644 index 000000000..3ec004e37 --- /dev/null +++ b/hnn/specfn.py @@ -0,0 +1,264 @@ +# specfn.py - Average time-frequency energy representation using Morlet +# wavelet method +# +# v 1.10.2-py35 +# rev 2017-02-21 (SL: fixed an issue with indexing) +# last major: (SL: more comments on the units of Morlet Spec) +# 11-29-2020: BC removed code that no longer uses in preparation for +# hnn-core integration + +import numpy as np +import scipy.signal as sps +from copy import deepcopy +import matplotlib.pyplot as plt + +fontsize = plt.rcParams['font.size'] = 10 + + +# MorletSpec class based on a time vec tvec and a time series vec tsvec +class MorletSpec(): + def __init__(self, tvec, tsvec, f_max, dt, tstop, tmin=50.0, + f_min=1.): + # Save variable portion of fdata_spec as identifying attribute + # self.name = fdata_spec + + # Import dipole data and remove extra dimensions from signal array. + self.tvec = tvec + self.tsvec = tsvec + + self.f_min = f_min + self.tstop = tstop + self.dt = dt + + # maximum frequency of analysis + # Add 1 to ensure analysis is inclusive of maximum frequency + self.f_max = f_max + 1 + + # cutoff time in ms + self.tmin = tmin + + # truncate these vectors appropriately based on tmin + if self.tstop > self.tmin: + # must be done in this order! timeseries first! + self.tsvec = self.tsvec[self.tvec >= self.tmin] + self.tvec = self.tvec[self.tvec >= self.tmin] + + # Check that tstop is greater than tmin + if self.tstop > self.tmin: + # Array of frequencies over which to sort + self.f = np.arange(self.f_min, self.f_max) + + # Number of cycles in wavelet (>5 advisable) + self.width = 7. + + # Calculate sampling frequency + self.fs = 1000. / self.dt + + # Generate Spec data + self.TFR = self.__traces2TFR() + else: + print("tstop not greater than %4.2f ms. " % self.tmin + + "Skipping wavelet analysis.") + + # also creates self.timevec + def __traces2TFR(self): + self.S_trans = self.tsvec.transpose() + # self.S_trans = self.S.transpose() + + # range should probably be 0 to len(self.S_trans) + # shift tvec to reflect change + # this is in ms + self.t = 1000. * np.arange(1, len(self.S_trans) + 1) / self.fs + \ + self.tmin - self.dt + + # preallocation + B = np.zeros((len(self.f), len(self.S_trans))) + + if self.S_trans.ndim == 1: + for j in range(0, len(self.f)): + s = sps.detrend(self.S_trans[:]) + + # += is used here because these were zeros and now it's adding + # the solution + B[j, :] += self.__energyvec(self.f[j], s) + + return B + + # this code doesn't return anything presently ... + else: + for i in range(0, self.S_trans.shape[0]): + for j in range(0, len(self.f)): + s = sps.detrend(self.S_trans[i, :]) + B[j, :] += self.__energyvec(self.f[j], s) + + # calculate the morlet wavelet for central frequency f + def __morlet(self, f, t): + """ Morlet's wavelet for frequency f and time t + Wavelet normalized so total energy is 1 + f: specific frequency + y: final units are 1/s + """ + # sf in Hz + sf = f / self.width + + # st in s + st = 1. / (2. * np.pi * sf) + + # A in 1 / s + A = 1. / (st * np.sqrt(2. * np.pi)) + + # units: 1/s * (exp (s**2 / s**2)) * exp( 1/ s * s) + y = A * np.exp(-t**2. / (2. * st**2.)) * np.exp(1.j * 2. * np.pi * f * + t) + + return y + + # Return an array containing the energy as function of time for freq f + def __energyvec(self, f, s): + """ Final units of y: signal units squared. + + For instance, a signal of Am would have Am^2 + The energy is calculated using Morlet's wavelets + f: frequency + s: signal + """ + dt = 1. / self.fs + sf = f / self.width + st = 1. / (2. * np.pi * sf) + + t = np.arange(-3.5 * st, 3.5 * st, dt) + + # calculate the morlet wavelet for this frequency + # units of m are 1/s + m = self.__morlet(f, t) + + # convolve wavelet with signal + y = sps.fftconvolve(s, m) + + # take the power ... + y = (2. * abs(y) / self.fs)**2. + i_lower = int(np.ceil(len(m) / 2.)) + i_upper = int(len(y) - np.floor(len(m) / 2.) + 1) + y = y[i_lower:i_upper] + + return y + + +# core class for frequency analysis assuming stationary time series +class Welch(): + def __init__(self, t_vec, ts_vec, dt): + # assign data internally + self.t_vec = t_vec + self.ts_vec = ts_vec + self.dt = dt + self.units = 'tsunits^2' + + # only assign length if same + if len(self.t_vec) == len(self.ts_vec): + self.N = len(ts_vec) + + else: + # raise an exception for real sometime in the future, for now + # just say something + print("in specfn.Welch(), your lengths don't match! Something" + " will fail!") + + # grab the dt (in ms) and calc sampling frequency + self.fs = 1000. / self.dt + + # calculate the actual Welch + self.f, self.P = sps.welch(self.ts_vec, self.fs, window='hanning', + nperseg=self.N, noverlap=0, nfft=self.N, + return_onesided=True, scaling='spectrum') + + +def spec_dpl_kernel(dpl, f_max, dt, tstop): + # Do the conversion prior to generating these spec + # dpl.convert_fAm_to_nAm() + + print("Extracting spectrogram from dipole") + # Generate various spec results + spec_agg = MorletSpec(dpl.times, dpl.data['agg'], f_max, dt, tstop) + spec_L2 = MorletSpec(dpl.times, dpl.data['L2'], f_max, dt, tstop) + spec_L5 = MorletSpec(dpl.times, dpl.data['L5'], f_max, dt, tstop) + + # Get max spectral power data + # BC (11/29/2020): no longer calculating this + max_agg = [] + + # Generate periodogram resutls + pgram = Welch(dpl.times, dpl.data['agg'], dt) + + spec_results = {'time': spec_agg.t, 'freq': spec_agg.f, + 'TFR': spec_agg.TFR, 'max_agg': max_agg, + 't_L2': spec_L2.t, 'f_L2': spec_L2.f, + 'TFR_L2': spec_L2.TFR, 't_L5': spec_L5.t, + 'f_L5': spec_L5.f, 'TFR_L5': spec_L5.TFR, + 'pgram_p': pgram.P, 'pgram_f': pgram.f} + + return spec_results + + +def save_spec_data(fspec, spec): + # Save spec results + print("Saving %s" % fspec) + np.savez_compressed(fspec, time=spec['time'], freq=spec['freq'], + TFR=spec['TFR'], max_agg=spec['max_agg'], + t_L2=spec['t_L2'], f_L2=spec['f_L2'], + TFR_L2=spec['TFR_L2'], t_L5=spec['t_L5'], + f_L5=spec['f_L5'], TFR_L5=spec['TFR_L5'], + pgram_p=spec['pgram_p'], pgram_f=spec['pgram_f']) + + +def plot_spec(ax, spec_data, ntrial, spec_cmap, xlim, fontsize=fontsize): + """Plot spectrogram""" + + # calculate TFR from spec trial data + # start with data from the first trial, but make deepcopy + # we will be modifying it in place + spec_TFR = deepcopy(spec_data[0]) + spec_list = [spec_data[i]['TFR'] for i in range(ntrial)] + spec_TFR['TFR'] = np.mean(np.array(spec_list), axis=0) + + # Plot TFR data and add colorbar + plot = ax.imshow(spec_TFR['TFR'], + extent=(spec_TFR['time'][0], + spec_TFR['time'][-1], + spec_TFR['freq'][-1], + spec_TFR['freq'][0]), + aspect='auto', origin='upper', + cmap=plt.get_cmap(spec_cmap)) + ax.set_ylabel('Frequency (Hz)', fontsize=fontsize) + ax.set_xlabel('Time (ms)', fontsize=fontsize) + ax.set_xlim(xlim) + ax.set_ylim(spec_TFR['freq'][-1], spec_TFR['freq'][0]) + + return plot + + +def extract_spec(dpls, f_max_spec): + """Extract Mortlet spectrograms from dipoles + + Parameters + ---------- + dpls: list of Dipole objects + List containing Dipoles of each trial + f_max_spec: float + Maximum frequency of analysis + + Returns + ---------- + specs: list of MortletSpec objects + List containing spectrograms of each trial + + """ + + specs = [] + for dpltrial in dpls: + dt = dpltrial.times[1] - dpltrial.times[0] + tstop = dpltrial.times[-1] + + spec_results = spec_dpl_kernel(dpltrial, f_max_spec, dt, tstop) + specs.append(spec_results) + + return specs diff --git a/hnn/spikefn.py b/hnn/spikefn.py new file mode 100644 index 000000000..84ce37d52 --- /dev/null +++ b/hnn/spikefn.py @@ -0,0 +1,242 @@ +# spikefn.py - dealing with spikes +# +# v 1.10.0-py35 +# rev 2016-05-01 (SL: minor) +# last major: (SL: toward python3) +# 2020-12-1 BC: use hnn-core and remove old code + +import numpy as np + + +# histogram bin optimization +def _hist_bin_opt(x, N_trials): + """ Shimazaki and Shinomoto, Neural Comput, 2007 """ + + bin_checks = np.arange(80, 300, 10) + # bin_checks = np.linspace(150, 300, 16) + costs = np.zeros(len(bin_checks)) + i = 0 + # this might be vectorizable in np + for n_bins in bin_checks: + # use np.histogram to do the numerical minimization + pdf, bin_edges = np.histogram(x, n_bins) + # calculate bin width + # some discrepancy here but should be fine + w_bin = np.unique(np.diff(bin_edges)) + if len(w_bin) > 1: + w_bin = w_bin[0] + # calc mean and var + kbar = np.mean(pdf) + kvar = np.var(pdf) + # calc cost + costs[i] = (2. * kbar - kvar) / (N_trials * w_bin)**2. + i += 1 + # find the bin size corresponding to a minimization of the costs + bin_opt_list = bin_checks[costs.min() == costs] + bin_opt = bin_opt_list[0] + return bin_opt + + +class ExtInputs(object): + """Class for extracting gids and times from external inputs""" + + def __init__(self, spikes, gid_ranges, trials, params): + self.p_dict = params + self.gid_ranges = gid_ranges + + if 'common' in self.gid_ranges: + # hnn-core + extinput_key = 'common' + elif 'extinput' in self.gid_ranges: + # hnn legacy + extinput_key = 'extinput' + elif 'bursty1' in self.gid_ranges or 'bursty2' in self.gid_ranges: + extinput_key = ['bursty1', 'bursty2'] + else: + print(self.gid_ranges) + raise ValueError("Unable to find key for external inputs") + + # parse evoked prox and dist input gids from gid_ranges + self.gid_evprox, self.gid_evdist = self._get_evokedinput_gids() + + # parse ongoing prox and dist input gids from gid_ranges + self.gid_prox, self.gid_dist = self._get_extinput_gids(extinput_key) + + # poisson input gids + self.gid_pois = self._get_poisinput_gids() + + # self.inputs is dict of input times with keys 'prox' and 'dist' + self.inputs = self._get_extinput_times(spikes, trials) + + self._add_delay_times() + + def _get_extinput_gids(self, extinput_key): + """Determine if both feeds exist in this sim + + If they do, self.gid_ranges[extinput_key] has length 2 + If so, first gid is guaraneteed to be prox feed, second to be dist + feed + """ + + if isinstance(extinput_key, list): + gids = list() + for val in extinput_key: + gids.extend(self.gid_ranges[val]) + # the order here is defined by create_pext to be prox, dist + return gids + elif len(self.gid_ranges[extinput_key]) == 2: + return self.gid_ranges[extinput_key] + elif len(self.gid_ranges[extinput_key]) > 0: + # Otherwise, only one feed exists in this sim + # Must use param file to figure out which one... + if self.p_dict['t0_input_prox'] < self.p_dict['tstop']: + return self.gid_ranges[extinput_key][0], None + elif self.p_dict['t0_input_dist'] < self.p_dict['tstop']: + return None, self.gid_ranges[extinput_key][0] + else: + return None, None + + def _get_poisinput_gids(self): + """get Poisson input gids""" + + gids = [] + if len(self.gid_ranges['extpois']) > 0: + if self.p_dict['t0_pois'] < self.p_dict['tstop']: + gids = np.array(self.gid_ranges['extpois']) + self.pois_gid_range = (min(gids), max(gids)) + return gids + + def countevinputs(self, ty): + # count number of evoked inputs + num_inputs = 0 + for key in self.gid_ranges.keys(): + if key.startswith(ty) and len(self.gid_ranges[key]) > 0: + num_inputs += 1 + return num_inputs + + def countevprox(self): + return self.countevinputs('evprox') + + def countevdist(self): + return self.countevinputs('evdist') + + def _get_evokedinput_gids(self): + gid_prox, gid_dist = None, None + nprox, ndist = self.countevprox(), self.countevdist() + + if nprox > 0: + gid_prox = [] + for i in range(nprox): + if len(self.gid_ranges['evprox' + str(i + 1)]) > 0: + gid_prox += list(self.gid_ranges['evprox' + str(i + 1)]) + gid_prox = np.array(gid_prox) + self.evprox_gid_range = (min(gid_prox), max(gid_prox)) + if ndist > 0: + gid_dist = [] + for i in range(ndist): + if len(self.gid_ranges['evdist' + str(i + 1)]) > 0: + gid_dist += list(self.gid_ranges['evdist' + str(i + 1)]) + gid_dist = np.array(gid_dist) + self.evdist_gid_range = (min(gid_dist), max(gid_dist)) + + return gid_prox, gid_dist + + def _filter(self, spikes, trials, filter_range): + """returns spike_list, a list of lists of spikes. + + Each list corresponds to a cell, counted by range + """ + + filtered_spike_times = [] + for trial_idx in trials: + indices = np.where(np.in1d(spikes.spike_gids[trial_idx], + filter_range))[0] + matches = np.array(spikes.spike_times[trial_idx])[indices] + filtered_spike_times += list(matches) + + return np.array(filtered_spike_times) + + def _get_times(self, spikes, trials, filter_range): + return self._filter(spikes, trials, filter_range) + + def _unique_times(self, spikes, trials, filter_range): + filtered_spike_times = self._get_times(spikes, trials, filter_range) + + return np.unique(filtered_spike_times) + + def _get_extinput_times(self, spikes, trials): + """load all spike times from file""" + + inputs = {k: np.array([]) for k in ['prox', 'dist', 'evprox', 'evdist', + 'pois']} + if self.gid_prox is not None: + inputs['prox'] = self._get_times(spikes, trials, [self.gid_prox]) + if self.gid_dist is not None: + inputs['dist'] = self._get_times(spikes, trials, [self.gid_dist]) + if self.gid_evprox is not None: + inputs['evprox'] = self._unique_times(spikes, trials, + self.gid_evprox) + if self.gid_evdist is not None: + inputs['evdist'] = self._unique_times(spikes, trials, + self.gid_evdist) + if self.gid_pois is not None: + inputs['pois'] = self._unique_times(spikes, trials, self.gid_pois) + + return inputs + + def is_prox_gid(self, gid): + """check if gid is associated with a proximal input""" + + if gid == self.gid_prox: + return True + if len(self.inputs['evprox']) > 0: + return self.evprox_gid_range[0] <= gid <= self.evprox_gid_range[1] + + return False + + def is_dist_gid(self, gid): + """check if gid is associated with a distal input""" + + if gid == self.gid_dist: + return True + if len(self.inputs['evdist']) > 0: + return self.evdist_gid_range[0] <= gid <= self.evdist_gid_range[1] + + return False + + def is_pois_gid(self, gid): + """check if gid is associated with a Poisson input""" + if len(self.inputs['pois']) > 0: + return self.pois_gid_range[0] <= gid <= self.pois_gid_range[1] + + return False + + def _add_delay_times(self): + # if same prox delay to both layers, add it to the prox input times + if self.p_dict['input_prox_A_delay_L2'] == \ + self.p_dict['input_prox_A_delay_L5']: + self.inputs['prox'] += self.p_dict['input_prox_A_delay_L2'] + + # if same dist delay to both layers, add it to the dist input times + if self.p_dict['input_dist_A_delay_L2'] == \ + self.p_dict['input_dist_A_delay_L5']: + self.inputs['dist'] += self.p_dict['input_dist_A_delay_L2'] + + def plot_hist(self, ax, extinput, tvec, bins='auto', xlim=None, + color='green', hty='bar', lw=4): + # extinput is either 'dist' or 'prox' + + if bins == 'auto': + bins = _hist_bin_opt(self.inputs[extinput], 1) + if not xlim: + xlim = (0., self.p_dict['tstop']) + if len(self.inputs[extinput]): + hist = ax.hist(self.inputs[extinput], bins, range=xlim, + color=color, label=extinput, histtype=hty, + linewidth=lw) + ax.set_xticklabels([]) + ax.tick_params(bottom=False, left=False) + else: + hist = None + + return hist diff --git a/hnn/tests/test_compare_hnn.py b/hnn/tests/test_compare_hnn.py new file mode 100644 index 000000000..59418b38f --- /dev/null +++ b/hnn/tests/test_compare_hnn.py @@ -0,0 +1,80 @@ +import os.path as op +import sys +from numpy import loadtxt +from numpy.testing import assert_allclose + +from mne.utils import _fetch_file +from PyQt5 import QtWidgets, QtCore +import pytest + +from hnn import HNNGUI +from hnn.paramrw import get_output_dir, get_fname + + +def run_hnn(qtbot, monkeypatch): + # for pressing exit button + exit_calls = [] + monkeypatch.setattr(QtWidgets.QApplication, "exit", + lambda: exit_calls.append(1)) + + # skip in warning messages + monkeypatch.setattr(QtWidgets.QMessageBox, "warning", + lambda *args: QtWidgets.QMessageBox.Ok) + monkeypatch.setattr(QtWidgets.QMessageBox, "information", + lambda *args: QtWidgets.QMessageBox.Ok) + + main = HNNGUI() + qtbot.addWidget(main) + + # start the simulation by pressing the button + qtbot.mouseClick(main.btnsim, QtCore.Qt.LeftButton) + qtbot.waitUntil(lambda: main.runningsim, 10000) + + # wait up to 300 seconds for simulation to finish + qtbot.waitUntil(lambda: not main.runningsim, 300000) + qtbot.mouseClick(main.qbtn, QtCore.Qt.LeftButton) + assert exit_calls == [1] + + +@pytest.mark.skipif(sys.platform == 'win32', + reason="does not run on windows") +def test_hnn(qtbot, monkeypatch): + """Test HNN can run a simulation""" + + run_hnn(qtbot, monkeypatch) + dirname = op.join(get_output_dir(), 'data', 'default') + dipole_fn = get_fname(dirname, 'normdpl', 0) + pr = loadtxt(op.join(dirname, dipole_fn)) + assert len(pr) > 0 + + +@pytest.mark.skip(reason="Skipping until #232 verification is complete") +def test_compare_hnn(qtbot, monkeypatch): + """Test simulation data are consistent with master""" + + # do we need to run a simulation? + run_sim = False + dirname = op.join(get_output_dir(), 'data', 'default') + for data_type in ['normdpl', 'rawspk']: + fname = get_fname(dirname, data_type, 0) + if not op.exists(fname): + run_sim = True + break + + if run_sim: + run_hnn(qtbot, monkeypatch) + + data_dir = ('https://raw.githubusercontent.com/jonescompneurolab/' + 'hnn/test_data/') + for data_type in ['normdpl', 'rawspk']: + fname = get_fname(dirname, data_type, 0) + data_url = op.join(data_dir, fname) + if not op.exists(fname): + _fetch_file(data_url, fname) + + pr = loadtxt(op.join(dirname, fname)) + master = loadtxt(fname) + + assert_allclose(pr[:, 1], master[:, 1], rtol=1e-4, atol=0) + assert_allclose(pr[:, 2], master[:, 2], rtol=1e-4, atol=0) + assert_allclose(pr[:, 3], master[:, 3], rtol=1e-4, atol=0) diff --git a/hnn/tests/test_gui.py b/hnn/tests/test_gui.py new file mode 100644 index 000000000..11e201f0a --- /dev/null +++ b/hnn/tests/test_gui.py @@ -0,0 +1,18 @@ +import sys + +from PyQt5 import QtWidgets, QtCore +import pytest + +from hnn import HNNGUI + + +@pytest.mark.skipif(sys.platform == 'win32', + reason="does not run on windows") +def test_exit_button(qtbot, monkeypatch): + exit_calls = [] + monkeypatch.setattr(QtWidgets.QApplication, "exit", + lambda: exit_calls.append(1)) + main = HNNGUI() + qtbot.addWidget(main) + qtbot.mouseClick(main.qbtn, QtCore.Qt.LeftButton) + assert exit_calls == [1] diff --git a/hnn/tests/test_view_windows.py b/hnn/tests/test_view_windows.py new file mode 100644 index 000000000..d9bfb80ee --- /dev/null +++ b/hnn/tests/test_view_windows.py @@ -0,0 +1,79 @@ +import os.path as op + +import pytest +from mne.utils import _fetch_file + +from hnn import HNNGUI + + +def fetch_file(fname): + data_dir = ('https://raw.githubusercontent.com/jonescompneurolab/' + 'hnn/test_data/') + + data_url = op.join(data_dir, fname) + if not op.exists(fname): + _fetch_file(data_url, fname) + + +@pytest.mark.skip(reason="Skipping until #232 improves launching view windows") +def test_view_rast(qtbot): + """Show the spiking activity window""" + fname = 'spk.txt' + fetch_file(fname) + + # start the GUI + main = HNNGUI() + qtbot.addWidget(main) + + main.viewRasterAction.trigger() + + +@pytest.mark.skip(reason="Skipping until #232 improves launching view windows") +def test_view_dipole(qtbot): + """Show the dipole window""" + fname = 'dpl.txt' + fetch_file(fname) + + # start the GUI + main = HNNGUI() + qtbot.addWidget(main) + + main.viewDipoleAction.trigger() + + +@pytest.mark.skip(reason="Skipping until #232 improves launching view windows") +def test_view_psd(qtbot): + """Show the PSD window""" + fname = 'dpl.txt' + fetch_file(fname) + + # start the GUI + main = HNNGUI() + qtbot.addWidget(main) + + main.viewPSDAction.trigger() + + +@pytest.mark.skip(reason="Skipping until #232 improves launching view windows") +def test_view_spec(qtbot): + """Show the pectrogram window""" + fname = 'dpl.txt' + fetch_file(fname) + + # start the GUI + main = HNNGUI() + qtbot.addWidget(main) + + main.viewSpecAction.trigger() + + +@pytest.mark.skip(reason="Skipping until #232 improves launching view windows") +def test_view_soma(qtbot): + fname = 'spike.txt' + fetch_file(fname) + + # start the GUI + main = HNNGUI() + qtbot.addWidget(main) + + main.viewSomaVAction.trigger() diff --git a/hnn_nrnui.py b/hnn_nrnui.py deleted file mode 100644 index d3ed708c6..000000000 --- a/hnn_nrnui.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging -from neuron import h - -from jupyter_geppetto.geppetto_comm import GeppettoCoreAPI as G -from neuron_ui import neuron_utils -from neuron_ui import neuron_geometries_utils - -import params_default -from paramrw import quickreadprm -from conf import fcfg,dconf -from simdat import * - -class HNN: - def __init__ (self): - logging.debug('Loading HNN') - neuron_utils.createProject(name='HNN') - import run - - self.t_vec = h.Vector() - self.t_vec.record(h._ref_t) - neuron_utils.createStateVariable(id='time', name='time', - units='ms', python_variable={"record_variable": self.t_vec, - "segment": None}) - - neuron_geometries_utils.extractGeometries() - - logging.debug('HNN loaded') - - self.RunSimButton = neuron_utils.add_button('Run', extraData={'commands':[]}) - self.StopSimButton = neuron_utils.add_button('Stop',extraData={'commands':[]}) - self.SetParams = neuron_utils.add_button('Set Parameters') - self.BasePanel = neuron_utils.add_panel('HNN Control', items=[self.RunSimButton,self.StopSimButton,self.SetParams],widget_id='hnnBasePanel') - self.BasePanel.on_close(self.close) - self.BasePanel.display() - - def close (self): - self.BasePanel.close() - del self.BasePanel - -""" -to start (first 2 commands from console) - -./NEURON-UI/NEURON-UI & -jupyter console --existing -import os -os.chdir('/u/samn/hnn/NEURON-UI/neuron_ui/models/hnn') -import hnn_nrnui -net=hnn_nrnui.HNN() - -""" - diff --git a/hnn_qt5.py b/hnn_qt5.py deleted file mode 100644 index d3c801c91..000000000 --- a/hnn_qt5.py +++ /dev/null @@ -1,4258 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox, QTextEdit, QInputDialog, QSpacerItem, QFrame, QSplitter -from PyQt5.QtGui import QIcon, QFont, QPixmap, QColor, QPainter, QFont, QPen -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot, Qt, QSize -from PyQt5.QtCore import QMetaObject, QUrl -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -import matplotlib.pyplot as plt -import multiprocessing -from subprocess import Popen, PIPE -import shlex, shutil -from collections import OrderedDict -from copy import deepcopy -from time import time, sleep -from conf import dconf -import conf -import numpy as np -from math import ceil, isclose -import spikefn -import params_default -from paramrw import quickreadprm, usingOngoingInputs, countEvokedInputs, usingEvokedInputs, ExpParams -from paramrw import chunk_evinputs, get_inputs, trans_input, find_param, validate_param_file -from simdat import SIMCanvas, getinputfiles, updatedat -from gutils import setscalegeom, lowresdisplay, setscalegeomcenter, getmplDPI, getscreengeom -import nlopt -from psutil import cpu_count, wait_procs, process_iter, NoSuchProcess -from threading import Lock -import traceback -from collections import namedtuple - -prtime = False - -def isWindows (): - # are we on windows? or linux/mac ? - return sys.platform.startswith('win') - -def getPyComm (): - # get the python command - Windows only has python linux/mac have python3 - if sys.executable is not None: # check python command interpreter path - if available - pyc = sys.executable - if pyc.count('python') > 0 and len(pyc) > 0: - return pyc # full path to python - if isWindows(): - return 'python' - return 'python3' - -def parseargs (): - for i in range(len(sys.argv)): - if sys.argv[i] == '-dataf' and i + 1 < len(sys.argv): - print('-dataf is ', sys.argv[i+1]) - conf.dconf['dataf'] = dconf['dataf'] = sys.argv[i+1] - i += 1 - elif sys.argv[i] == '-paramf' and i + 1 < len(sys.argv): - print('-paramf is ', sys.argv[i+1]) - conf.dconf['paramf'] = dconf['paramf'] = sys.argv[i+1] - i += 1 - -parseargs() - -simf = dconf['simf'] -paramf = dconf['paramf'] -debug = dconf['debug'] -testLFP = dconf['testlfp'] or dconf['testlaminarlfp'] -param_fname = os.path.splitext(os.path.basename(paramf)) -basedir = os.path.join(dconf['datdir'], param_fname[0]) - -# get default number of cores -defncore = 0 -hyperthreading=False - -try: - defncore = len(os.sched_getaffinity(0)) -except AttributeError: - physical_cores = cpu_count(logical=False) - logical_cores = multiprocessing.cpu_count() - - if logical_cores is not None and logical_cores > physical_cores: - hyperthreading=True - defncore = logical_cores - else: - defncore = physical_cores - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: plt.rcParams['font.size'] = dconf['fontsize'] = 10 - -if debug: print('getPyComm:',getPyComm()) - -hnn_root_dir = os.path.dirname(os.path.realpath(__file__)) - -DEFAULT_CSS = """ -QRangeSlider * { - border: 0px; - padding: 0px; -} -QRangeSlider #Head { - background-color: rgba(157, 163, 176, 50); -} -QRangeSlider #Span { - background-color: rgba(22, 31, 50, 150); -} -QRangeSlider #Span:active { - background-color: rgba(22, 31, 50, 150); -} -QRangeSlider #Tail { - background-color: rgba(157, 163, 176, 50); -} -QRangeSlider #LineBox { - background-color: rgba(255, 255, 255, 0); -} -QRangeSlider > QSplitter::handle { - background-color: rgba(79, 91, 102, 100); -} -QRangeSlider > QSplitter::handle:vertical { - height: 4px; -} -QRangeSlider > QSplitter::handle:pressed { - background: #ca5; -} -""" - -def scale(val, src, dst): - try: - return ((val - src[0]) / float(src[1]-src[0]) * (dst[1]-dst[0]) + dst[0]) - except ZeroDivisionError: - return 0 - -class Ui_Form(object): - def setupUi(self, Form): - Form.setObjectName("QRangeSlider") - Form.resize(300, 30) - Form.setStyleSheet(DEFAULT_CSS) - self._linebox = QWidget(Form) - self._linebox.setObjectName("LineBox") - self.gridLayout = QGridLayout(Form) - self.gridLayout.setContentsMargins(0, 0, 0, 0) - self.gridLayout.setSpacing(0) - self.gridLayout.setObjectName("gridLayout") - self._splitter = QSplitter(Form) - self._splitter.setMinimumSize(QSize(0, 0)) - self._splitter.setMaximumSize(QSize(16777215, 16777215)) - self._splitter.setOrientation(Qt.Horizontal) - self._splitter.setObjectName("splitter") - self._head = QGroupBox(self._splitter) - self._head.setTitle("") - self._head.setObjectName("Head") - self._handle = QGroupBox(self._splitter) - self._handle.setTitle("") - self._handle.setObjectName("Span") - self._tail = QGroupBox(self._splitter) - self._tail.setTitle("") - self._tail.setObjectName("Tail") - self.gridLayout.addWidget(self._splitter, 0, 0, 1, 1) - self.retranslateUi(Form) - QMetaObject.connectSlotsByName(Form) - - def retranslateUi(self, Form): - _translate = QCoreApplication.translate - Form.setWindowTitle(_translate("QRangeSlider", "QRangeSlider")) - - -class Element(QGroupBox): - def __init__(self, parent, main): - super(Element, self).__init__(parent) - self.main = main - - def setStyleSheet(self, style): - self.parent().setStyleSheet(style) - - def textColor(self): - return getattr(self, '__textColor', QColor(125, 125, 125)) - - def setTextColor(self, color): - if type(color) == tuple and len(color) == 3: - color = QColor(color[0], color[1], color[2]) - elif type(color) == int: - color = QColor(color, color, color) - setattr(self, '__textColor', color) - - def paintEvent(self, event): - qp = QPainter() - qp.begin(self) - if self.main.drawValues(): - self.drawText(event, qp) - qp.end() - - -class Head(Element): - def __init__(self, parent, main): - super(Head, self).__init__(parent, main) - - def drawText(self, event, qp): - qp.setPen(self.textColor()) - qp.setFont(QFont('Arial', 10)) - qp.drawText(event.rect(), Qt.AlignLeft, ("%.3f"%self.main.min())) - - -class Tail(Element): - def __init__(self, parent, main): - super(Tail, self).__init__(parent, main) - - def drawText(self, event, qp): - qp.setPen(self.textColor()) - qp.setFont(QFont('Arial', 10)) - qp.drawText(event.rect(), Qt.AlignRight, ("%.3f"%self.main.max())) - - -class LineBox(Element): - def __init__(self, parent, main): - super(LineBox, self).__init__(parent, main) - - def drawText(self, event, qp): - qp.setPen(QPen(Qt.red, 2, Qt.SolidLine, Qt.SquareCap, Qt.MiterJoin)) - pos = self.main.valueToPos(self.main.line_value) - if (pos == 0): - pos += 1 - qp.drawLine(pos, 0, pos, 50) - - -class Handle(Element): - def __init__(self, parent, main): - super(Handle, self).__init__(parent, main) - - def drawText(self, event, qp): - pass - # qp.setPen(self.textColor()) - # qp.setFont(QFont('Arial', 10)) - # qp.drawText(event.rect(), Qt.AlignLeft, str(self.main.start())) - # qp.drawText(event.rect(), Qt.AlignRight, str(self.main.end())) - - def mouseMoveEvent(self, event): - event.accept() - mx = event.globalX() - _mx = getattr(self, '__mx', None) - if not _mx: - setattr(self, '__mx', mx) - dx = 0 - else: - dx = mx - _mx - setattr(self, '__mx', mx) - if dx == 0: - event.ignore() - return - elif dx > 0: - dx = 1 - elif dx < 0: - dx = -1 - s = self.main.start() + dx - e = self.main.end() + dx - if s >= self.main.min() and e <= self.main.max(): - self.main.setRange(s, e) - - -class QRangeSlider(QWidget, Ui_Form): - endValueChanged = pyqtSignal(int) - maxValueChanged = pyqtSignal(int) - minValueChanged = pyqtSignal(int) - startValueChanged = pyqtSignal(int) - rangeValuesChanged = pyqtSignal(str, float, float) - - _SPLIT_START = 1 - _SPLIT_END = 2 - - def __init__(self, label, parent): - super(QRangeSlider, self).__init__(parent) - self.label = label - self.rangeValuesChanged.connect(parent.updateRangeFromSlider) - self.setupUi(self) - self.setMouseTracking(False) - self._splitter.splitterMoved.connect(self._handleMoveSplitter) - - self._linebox_layout = QHBoxLayout() - self._linebox_layout.setSpacing(0) - self._linebox_layout.setContentsMargins(0, 0, 0, 0) - self._linebox.setLayout(self._linebox_layout) - self.linebox = LineBox(self._linebox, main=self) - self._linebox_layout.addWidget(self.linebox) - self._head_layout = QHBoxLayout() - self._head_layout.setSpacing(0) - self._head_layout.setContentsMargins(0, 0, 0, 0) - self._head.setLayout(self._head_layout) - self.head = Head(self._head, main=self) - self._head_layout.addWidget(self.head) - self._handle_layout = QHBoxLayout() - self._handle_layout.setSpacing(0) - self._handle_layout.setContentsMargins(0, 0, 0, 0) - self._handle.setLayout(self._handle_layout) - self.handle = Handle(self._handle, main=self) - self.handle.setTextColor((150, 255, 150)) - self._handle_layout.addWidget(self.handle) - self._tail_layout = QHBoxLayout() - self._tail_layout.setSpacing(0) - self._tail_layout.setContentsMargins(0, 0, 0, 0) - self._tail.setLayout(self._tail_layout) - self.tail = Tail(self._tail, main=self) - self._tail_layout.addWidget(self.tail) - self.setDrawValues(True) - - def min(self): - return getattr(self, '__min', None) - - def max(self): - return getattr(self, '__max', None) - - def setMin(self, value): - setattr(self, '__min', value) - self.minValueChanged.emit(value) - - def setMax(self, value): - setattr(self, '__max', value) - self.maxValueChanged.emit(value) - - def start(self): - return getattr(self, '__start', None) - - def end(self): - return getattr(self, '__end', None) - - def _setStart(self, value): - setattr(self, '__start', value) - self.startValueChanged.emit(value) - - def setStart(self, value): - v = self.valueToPos(value) - self._splitter.splitterMoved.disconnect() - self._splitter.moveSplitter(v, self._SPLIT_START) - self._splitter.splitterMoved.connect(self._handleMoveSplitter) - self._setStart(value) - - def _setEnd(self, value): - setattr(self, '__end', value) - self.endValueChanged.emit(value) - - def setEnd(self, value): - v = self.valueToPos(value) - self._splitter.splitterMoved.disconnect() - self._splitter.moveSplitter(v, self._SPLIT_END) - self._splitter.splitterMoved.connect(self._handleMoveSplitter) - self._setEnd(value) - - def drawValues(self): - return getattr(self, '__drawValues', None) - - def setLine(self, value): - self.line_value = value - - def setDrawValues(self, draw): - setattr(self, '__drawValues', draw) - - def getRange(self): - return (self.start(), self.end()) - - def setRange(self, start, end): - self.setStart(start) - self.setEnd(end) - - def keyPressEvent(self, event): - key = event.key() - if key == Qt.Key_Left: - s = self.start()-1 - e = self.end()-1 - elif key == Qt.Key_Right: - s = self.start()+1 - e = self.end()+1 - else: - event.ignore() - return - event.accept() - if s >= self.min() and e <= self.max(): - self.setRange(s, e) - - def setBackgroundStyle(self, style): - self._tail.setStyleSheet(style) - self._head.setStyleSheet(style) - - def setSpanStyle(self, style): - self._handle.setStyleSheet(style) - - def valueToPos(self, value): - return int(scale(value, (self.min(), self.max()), (0, self.width()))) - - def _posToValue(self, xpos): - return scale(xpos, (0, self.width()), (self.min(), self.max())) - - def _handleMoveSplitter(self, xpos, index): - hw = self._splitter.handleWidth() - def _lockWidth(widget): - width = widget.size().width() - widget.setMinimumWidth(width) - widget.setMaximumWidth(width) - def _unlockWidth(widget): - widget.setMinimumWidth(0) - widget.setMaximumWidth(16777215) - if index == self._SPLIT_START: - v = self._posToValue(xpos) - _lockWidth(self._tail) - if v >= self.end(): - return - offset = -20 - w = xpos + offset - self._setStart(v) - self.rangeValuesChanged.emit(self.label, v, self.end()) - elif index == self._SPLIT_END: - # account for width of head - xpos += 4 - v = self._posToValue(xpos) - _lockWidth(self._head) - if v <= self.start(): - return - offset = -40 - w = self.width() - xpos + offset - self._setEnd(v) - self.rangeValuesChanged.emit(self.label, self.start(), v) - _unlockWidth(self._tail) - _unlockWidth(self._head) - _unlockWidth(self._handle) - -# see https://stackoverflow.com/questions/12182133/pyqt4-combine-textchanged-and-editingfinished-for-qlineedit -class MyLineEdit(QLineEdit): - textModified = pyqtSignal(str) # (label) - - def __init__(self, contents, label, parent=None): - super(MyLineEdit, self).__init__(contents, parent) - self.editingFinished.connect(self.__handleEditingFinished) - self.textChanged.connect(self.__handleTextChanged) - self._before = contents - self._label = label - - def __handleTextChanged(self, text): - if not self.hasFocus(): - self._before = text - - def __handleEditingFinished(self): - before, after = self._before, self.text() - if before != after: - self._before = after - self.textModified.emit(self._label) - -# for signaling -class Communicate (QObject): - commsig = pyqtSignal() - -class DoneSignal (QObject): - finishSim = pyqtSignal(bool, bool) - -# for signaling - passing text -class TextSignal (QObject): - tsig = pyqtSignal(str) - -# for signaling - updating GUI & param file during optimization -class ParamSignal (QObject): - psig = pyqtSignal(OrderedDict) - -class CanvSignal (QObject): - csig = pyqtSignal(bool, bool) - -def bringwintobot (win): - #win.show() - #win.lower() - win.hide() - -def kill_list_of_procs(procs): - # try terminate first - for p in procs: - try: - p.terminate() - except NoSuchProcess: - pass - gone, alive = wait_procs(procs, timeout=3) - - # now try kill - for p in alive: - p.kill() - gone, alive = wait_procs(procs, timeout=3) - - return alive - - -def get_nrniv_procs_running(): - ls = [] - name = 'nrniv' - for p in process_iter(attrs=["name", "exe", "cmdline"]): - if name == p.info['name'] or \ - p.info['exe'] and os.path.basename(p.info['exe']) == name or \ - p.info['cmdline'] and p.info['cmdline'][0] == name: - ls.append(p) - return ls - -def kill_and_check_nrniv_procs(): - procs = get_nrniv_procs_running() - if len(procs) > 0: - running = kill_list_of_procs(procs) - if len(running) > 0: - pids = [ str(proc.pid) for proc in running ] - print("ERROR: failed to kill nrniv process(es) %s" % ','.join(pids)) - -def bringwintotop (win): - # bring a pyqt5 window to the top (parents still stay behind children) - # based on examples from https://www.programcreek.com/python/example/101663/PyQt5.QtCore.Qt.WindowActive - #win.show() - #win.setWindowState(win.windowState() & ~Qt.WindowMinimized | Qt.WindowActive) - #win.raise_() - win.showNormal() - win.activateWindow() - #win.setWindowState((win.windowState() & ~Qt.WindowMinimized) | Qt.WindowActive) - #win.activateWindow() - #win.raise_() - #win.show() - -# based on https://nikolak.com/pyqt-threading-tutorial/ -class RunSimThread (QThread): - def __init__ (self,c,d,ntrial,ncore,waitsimwin,opt=False,baseparamwin=None,mainwin=None,onNSG=False): - QThread.__init__(self) - self.c = c - self.d = d - self.killed = False - self.proc = None - self.ntrial = ntrial - self.ncore = ncore - self.waitsimwin = waitsimwin - self.opt = opt - self.baseparamwin = baseparamwin - self.mainwin = mainwin - self.onNSG = onNSG - - self.txtComm = TextSignal() - self.txtComm.tsig.connect(self.waitsimwin.updatetxt) - - self.prmComm = ParamSignal() - if self.baseparamwin is not None: - self.prmComm.psig.connect(self.baseparamwin.updatesaveparams) - - self.canvComm = CanvSignal() - if self.mainwin is not None: - self.canvComm.csig.connect(self.mainwin.initSimCanvas) - - self.lock = Lock() - - def updatewaitsimwin (self, txt): - # print('RunSimThread updatewaitsimwin, txt=',txt) - self.txtComm.tsig.emit(txt) - - def updatebaseparamwin (self, d): - self.prmComm.psig.emit(d) - - def updatedispparam (self): - self.c.commsig.emit() - - def updatedrawerr (self): - self.canvComm.csig.emit(False, self.opt) # False means do not recalculate error - - def stop (self): - self.killproc() - - def __del__ (self): - self.quit() - self.wait() - - def run (self): - failed=False - - if self.opt and self.baseparamwin is not None: - try: - self.optmodel() # run optimization - except RuntimeError: - failed = True - self.baseparamwin.optparamwin.toggleEnableUserFields(self.cur_step, enable=True) - self.baseparamwin.optparamwin.clear_initial_opt_ranges() - self.baseparamwin.optparamwin.optimization_running = False - else: - try: - self.runsim() # run simulation - self.updatedispparam() # update params in all windows (optimization) - except RuntimeError: - failed = True - - self.d.finishSim.emit(self.opt, failed) # send the finish signal - - - def killproc (self): - if self.proc is None: - # any nrniv processes found are not part of current sim - return - - if debug: print('Thread killing sim. . .') - - # try the nice way to stop the mpiexec proc - self.proc.terminate() - - retries = 0 - while self.proc.poll() is None and retries < 5: - # mpiexec still running - self.proc.kill() - if self.proc.poll() is None: - sleep(1) - retries += 1 - - # make absolute sure all nrniv procs have been killed - kill_and_check_nrniv_procs() - - self.lock.acquire() - self.killed = True - self.lock.release() - - def spawn_sim (self, simlength, banner=False): - global paramf, hyperthreading - import simdat - - - if isWindows() or not hyperthreading: - mpicmd = 'mpiexec -np ' - else: - mpicmd = 'mpiexec --use-hwthread-cpus -np ' - - if banner: - nrniv_cmd = ' nrniv -python -mpi ' - else: - nrniv_cmd = ' nrniv -python -mpi -nobanner ' - - if self.onNSG: - cmd = 'python nsgr.py ' + paramf + ' ' + str(self.ntrial) + ' 710.0' - elif not simlength is None: - cmd = mpicmd + str(self.ncore) + nrniv_cmd + simf + ' ' + paramf + ' ntrial ' + str(self.ntrial) + ' simlength ' + str(simlength) - else: - cmd = mpicmd + str(self.ncore) + nrniv_cmd + simf + ' ' + paramf + ' ntrial ' + str(self.ntrial) - cmdargs = shlex.split(cmd,posix="win" not in sys.platform) # https://github.com/maebert/jrnl/issues/348 - if debug: print("cmd:",cmd,"cmdargs:",cmdargs) - if prtime: - self.proc = Popen(cmdargs,cwd=os.getcwd()) - else: - self.proc = Popen(cmdargs,stdout=PIPE,stderr=PIPE,cwd=os.getcwd(),universal_newlines=True) - - def get_proc_stream (self, stream, print_to_console=False): - try: - for line in iter(stream.readline, ""): - if print_to_console: - print(line.strip()) - try: # see https://stackoverflow.com/questions/2104779/qobject-qplaintextedit-multithreading-issues - self.updatewaitsimwin(line.strip()) # sends a pyqtsignal to waitsimwin, which updates its textedit - except: - if debug: print('RunSimThread updatewaitsimwin exception...') - pass # catch exception in case anything else goes wrong - except ValueError: - # if process is killed and stream.readline() gives I/O error - pass - stream.close() - - # run sim command via mpi, then delete the temp file. - def runsim (self, is_opt=False, banner=True, simlength=None): - import simdat - - global defncore, paramf, hyperthreading - self.lock.acquire() - self.killed = False - self.lock.release() - - self.spawn_sim(simlength, banner=banner) - retried = False - - #cstart = time() - while True: - status = self.proc.poll() - if not status is None: - if status == 0: - # success, use same number of cores next time - defncore = self.ncore - break - elif status == 1 and not retried: - self.ncore = ceil(self.ncore/2) - txt = "INFO: Failed starting mpiexec, retrying with %d cores" % self.ncore - print(txt) - self.updatewaitsimwin(txt) - self.spawn_sim(simlength, banner=banner) - retried = True - else: - txt = "Simulation exited with return code %d. Stderr from console:"%status - print(txt) - self.updatewaitsimwin(txt) - self.get_proc_stream(self.proc.stderr, print_to_console=True) - kill_and_check_nrniv_procs() - raise RuntimeError - - self.get_proc_stream(self.proc.stdout, print_to_console=False) - - # check if proc was killed - self.lock.acquire() - if self.killed: - self.lock.release() - # exit using RuntimeError - raise RuntimeError - else: - self.lock.release() - - sleep(1) - - #cend = time() - #rtime = cend - cstart - #if debug: print('sim finished in %.3f s'%rtime) - - try: - simdat.updatedat(paramf) - except ValueError: - print("Warning: failed to load simulation results for %s" % paramf) - - if not is_opt and 'dpl' in simdat.ddat: - simdat.updatelsimdat(paramf,simdat.ddat['dpl']) # update lsimdat and its current sim index - - - def optmodel (self): - import simdat - - global basedir - need_initial_ddat = False - - # initialize RNG with seed from config - seed = int(find_param(paramf,'prng_seedcore_opt')) - nlopt.srand(seed) - - # initial_ddat stores the initial fit (from "Run Simulation"). - # To be displayed in final dipole plot as black dashed line. - if len(simdat.ddat) > 0: - simdat.initial_ddat['dpl'] = deepcopy(simdat.ddat['dpl']) - simdat.initial_ddat['errtot'] = deepcopy(simdat.ddat['errtot']) - else: - need_initial_ddat = True - - self.baseparamwin.optparamwin.populate_initial_opt_ranges() - - # save initial parameters file - param_out = os.path.join(basedir,'before_opt.param') - shutil.copyfile(paramf, param_out) - - self.updatewaitsimwin('Optimizing model. . .') - - self.last_step = False - self.first_step = True - num_steps = self.baseparamwin.optparamwin.get_num_chunks() - for step in range(num_steps): - self.cur_step = step - if step == num_steps - 1: - self.last_step = True - - # disable range sliders for each step once that step has begun - self.baseparamwin.optparamwin.toggleEnableUserFields(step, enable=False) - - self.step_ranges = self.baseparamwin.optparamwin.get_chunk_ranges(step) - self.step_sims = self.baseparamwin.optparamwin.get_sims_for_chunk(step) - - if self.step_sims == 0: - txt = "Skipping optimization step %d (0 simulations)"%(step+1) - self.updatewaitsimwin(txt) - continue - - if len(self.step_ranges) == 0: - txt = "Skipping optimization step %d (0 parameters)"%(step+1) - self.updatewaitsimwin(txt) - continue - - txt = "Starting optimization step %d/%d" % (step + 1, num_steps) - self.updatewaitsimwin(txt) - self.runOptStep(step) - - if 'dpl' in self.best_ddat: - simdat.ddat['dpl'] = deepcopy(self.best_ddat['dpl']) - if 'errtot' in self.best_ddat: - simdat.ddat['errtot'] = deepcopy(self.best_ddat['errtot']) - - if need_initial_ddat: - simdat.initial_ddat = deepcopy(simdat.ddat) - - simdat.updateoptdat(paramf,simdat.ddat['dpl']) # update optdat with best from this step - - # put best opt results into GUI and save to param file - push_values = OrderedDict() - for param_name in self.step_ranges.keys(): - push_values[param_name] = self.step_ranges[param_name]['final'] - self.updatebaseparamwin(push_values) - self.baseparamwin.optparamwin.push_chunk_ranges(step,push_values) - - sleep(1) - - self.first_step = False - - # one final sim with the best parameters to update display - self.runsim(is_opt=True, banner=False) - simdat.updatelsimdat(paramf,simdat.ddat['dpl']) # update lsimdat and its current sim index - simdat.updateoptdat(paramf,simdat.ddat['dpl']) # update optdat with the final best - - # re-enable all the range sliders - self.baseparamwin.optparamwin.toggleEnableUserFields(step, enable=True) - - self.baseparamwin.optparamwin.clear_initial_opt_ranges() - self.baseparamwin.optparamwin.optimization_running = False - - - def runOptStep (self, step): - import simdat - global basedir, paramf - - self.optsim = 0 - self.minopterr = 1e9 - self.stepminopterr = self.minopterr - self.best_ddat = {} - self.opt_start = self.baseparamwin.optparamwin.get_chunk_start(step) - self.opt_end = self.baseparamwin.optparamwin.get_chunk_end(step) - self.opt_weights = self.baseparamwin.optparamwin.get_chunk_weights(step) - def optrun (new_params, grad=0): - txt = "Optimization step %d, simulation %d" % (step + 1, - self.optsim + 1) - self.updatewaitsimwin(txt) - print(txt) - - dtest = OrderedDict() # parameter values to test - for param_name, test_value in zip(self.step_ranges.keys(), new_params): # set parameters - if test_value >= self.step_ranges[param_name]['minval'] and \ - test_value <= self.step_ranges[param_name]['maxval']: - if debug: - print('optrun prm:', self.step_ranges[param_name]['initial'], - self.step_ranges[param_name]['minval'], - self.step_ranges[param_name]['maxval'], - test_value) - dtest[param_name] = test_value - else: - # This test is not strictly necessary with COBYLA, but in case the algorithm - # is changed at some point in the future - print('INFO: optimization chose %.3f for %s outside of [%.3f-%.3f].' - % (test_value, param_name, - self.step_ranges[param_name]['minval'], - self.step_ranges[param_name]['maxval'])) - return 1e9 # invalid param value -> large error - - # put new param values into GUI and save params to file - self.updatebaseparamwin(dtest) - sleep(1) - - # run the simulation, but stop early if possible - self.runsim(is_opt=True, banner=False, simlength=self.opt_end) - - # calculate wRMSE for all steps - simdat.weighted_rmse(simdat.ddat, - self.opt_end, - self.opt_weights, - tstart=self.opt_start) - err = simdat.ddat['werrtot'] - - if self.last_step: - # weighted RMSE with weights of all 1's is the same as - # regular RMSE - simdat.ddat['errtot'] = simdat.ddat['werrtot'] - txt = "RMSE = %f"%err - else: - # calculate regular RMSE for displaying on plot - simdat.calcerr(simdat.ddat, - self.opt_end, - tstart=self.opt_start) - - txt = "weighted RMSE = %f, RMSE = %f"% (err,simdat.ddat['errtot']) - - print(txt) - self.updatewaitsimwin(os.linesep+'Simulation finished: ' + txt + os.linesep) # print error - - fnoptinf = os.path.join(basedir,'optinf.txt') - with open(fnoptinf,'a') as fpopt: - fpopt.write(str(simdat.ddat['errtot'])+os.linesep) # write error - - # save copy param file - param_fname = os.path.basename(paramf) - curr_paramf = os.path.join(basedir, param_fname) - - # save params numbered by optsim - param_out = os.path.join(basedir,'step_%d_sim_%d.param'%(self.cur_step,self.optsim)) - shutil.copyfile(curr_paramf, param_out) - - if err < self.stepminopterr: - self.updatewaitsimwin("new best with RMSE %f"%err) - - self.stepminopterr = err - # save best param file - shutil.copyfile(curr_paramf, os.path.join(basedir,'step_%d_best.param'%self.cur_step)) # convenience, save best here - if 'dpl' in simdat.ddat: - self.best_ddat['dpl'] = simdat.ddat['dpl'] - if 'errtot' in simdat.ddat: - self.best_ddat['errtot'] = simdat.ddat['errtot'] - - if self.optsim == 0 and not self.first_step: - # Update plots for the first simulation only of this step (best results from last round) - # Skip the first step because there are no optimization results to show yet. - self.updatedrawerr() # send event to draw updated error (asynchronously) - - self.optsim += 1 - - return err # return error - - def optimize(params_input, evals, algorithm): - opt_params = [] - lb = [] - ub = [] - - for param_name in params_input.keys(): - upper = params_input[param_name]['maxval'] - lower = params_input[param_name]['minval'] - if upper == lower: - continue - - ub.append(upper) - lb.append(lower) - opt_params.append(params_input[param_name]['initial']) - - if algorithm == nlopt.G_MLSL_LDS or algorithm == nlopt.G_MLSL: - # In case these mixed mode (global + local) algorithms are used in the future - local_opt = nlopt.opt(nlopt.LN_COBYLA, num_params) - opt.set_local_optimizer(local_opt) - - opt.set_lower_bounds(lb) - opt.set_upper_bounds(ub) - opt.set_min_objective(optrun) - opt.set_xtol_rel(1e-4) - opt.set_maxeval(evals) - opt_results = opt.optimize(opt_params) - - return opt_results - - txt = 'Optimizing from [%3.3f-%3.3f] ms' % (self.opt_start, - self.opt_end) - self.updatewaitsimwin(txt) - - num_params = len(self.step_ranges) - algorithm = nlopt.LN_COBYLA - opt = nlopt.opt(algorithm, num_params) - opt_results = optimize(self.step_ranges, self.step_sims, algorithm) - - # update opt params for the next round - for var_name, new_value in zip(self.step_ranges, opt_results): - old_value = self.step_ranges[var_name]['initial'] - - # only change the parameter value if it changed significantly - if not isclose(old_value, new_value, abs_tol=1e-9): - self.step_ranges[var_name]['final'] = new_value - else: - self.step_ranges[var_name]['final'] = \ - self.step_ranges[var_name]['initial'] - -# look up resource adjusted for screen resolution -def lookupresource (fn): - lowres = lowresdisplay() # low resolution display - if lowres: - return os.path.join('res',fn+'2.png') - else: - return os.path.join('res',fn+'.png') - -def format_range_str(value): - if value == 0: - value_str = "0.000" - elif value < 0.1 : - value_str = ("%6f" % value) - else: - value_str = ("%.3f" % value) - - return value_str - -# DictDialog - dictionary-based dialog with tabs - should make all dialogs -# specifiable via cfg file format - then can customize gui without changing py code -# and can reduce code explosion / overlap between dialogs -class DictDialog (QDialog): - - def __init__ (self, parent, din): - super(DictDialog, self).__init__(parent) - self.ldict = [] # subclasses should override - self.ltitle = [] - self.dtransvar = {} # for translating model variable name to more human-readable form - self.stitle = '' - self.initd() - self.initUI() - self.initExtra() - self.setfromdin(din) # set values from input dictionary - self.addtips() - - def addtips (self): - for ktip in dconf.keys(): - if ktip in self.dqline: - self.dqline[ktip].setToolTip(dconf[ktip]) - elif ktip in self.dqextra: - self.dqextra[ktip].setToolTip(dconf[ktip]) - - def __str__ (self): - s = '' - for k,v in self.dqline.items(): s += k + ': ' + v.text().strip() + os.linesep - return s - - def saveparams (self): self.hide() - - def initd (self): pass # implemented in subclass - - def getval (self,k): - if k in self.dqline.keys(): - return self.dqline[k].text().strip() - - def lines2val (self,ksearch,val): - for k in self.dqline.keys(): - if k.count(ksearch) > 0: - self.dqline[k].setText(str(val)) - - def setfromdin (self,din): - if not din: return - for k,v in din.items(): - if k in self.dqline: - self.dqline[k].setText(str(v).strip()) - - def transvar (self,k): - if k in self.dtransvar: return self.dtransvar[k] - return k - - def addtransvar (self,k,strans): - self.dtransvar[k] = strans - self.dtransvar[strans] = k - - def initExtra (self): self.dqextra = OrderedDict() # extra items not written to param file - - def initUI (self): - self.layout = QVBoxLayout(self) - - # Add stretch to separate the form layout from the button - self.layout.addStretch(1) - - # Initialize tab screen - self.ltabs = [] - self.tabs = QTabWidget(); self.layout.addWidget(self.tabs) - - for i in range(len(self.ldict)): self.ltabs.append(QWidget()) - - self.tabs.resize(575,200) - - # create tabs and their layouts - for tab,s in zip(self.ltabs,self.ltitle): - self.tabs.addTab(tab,s) - tab.layout = QFormLayout() - tab.setLayout(tab.layout) - - self.dqline = OrderedDict() # QLineEdits dict; key is model variable - for d,tab in zip(self.ldict, self.ltabs): - for k,v in d.items(): - self.dqline[k] = QLineEdit(self) - self.dqline[k].setText(str(v)) - tab.layout.addRow(self.transvar(k),self.dqline[k]) # adds label,QLineEdit to the tab - - # Add tabs to widget - self.layout.addWidget(self.tabs) - self.setLayout(self.layout) - self.setWindowTitle(self.stitle) - #nw, nh = setscalegeom(self, 150, 150, 625, 300) - #nx = parent.rect().x()+parent.rect().width()/2-nw/2 - #ny = parent.rect().y()+parent.rect().height()/2-nh/2 - #print(parent.rect(),nx,ny) - #self.move(nx, ny) - #self.move(self.parent. - #self.move(self.parent.widget.rect().x+self.parent.widget.rect().width()/2-nw, - # self.parent.widget.rect().y+self.parent.widget.rect().height()/2-nh) - - def TurnOff (self): pass - - def addOffButton (self): - # Create a horizontal box layout to hold the button - self.button_box = QHBoxLayout() - self.btnoff = QPushButton('Turn Off Inputs',self) - self.btnoff.resize(self.btnoff.sizeHint()) - self.btnoff.clicked.connect(self.TurnOff) - self.btnoff.setToolTip('Turn Off Inputs') - self.button_box.addWidget(self.btnoff) - self.layout.addLayout(self.button_box) - - def addHideButton (self): - self.bbhidebox = QHBoxLayout() - self.btnhide = QPushButton('Hide Window',self) - self.btnhide.resize(self.btnhide.sizeHint()) - self.btnhide.clicked.connect(self.hide) - self.btnhide.setToolTip('Hide Window') - self.bbhidebox.addWidget(self.btnhide) - self.layout.addLayout(self.bbhidebox) - -# widget to specify ongoing input params (proximal, distal) -class OngoingInputParamDialog (DictDialog): - def __init__ (self, parent, inty, din=None): - self.inty = inty - if self.inty.startswith('Proximal'): - self.prefix = 'input_prox_A_' - self.postfix = '_prox' - self.isprox = True - else: - self.prefix = 'input_dist_A_' - self.postfix = '_dist' - self.isprox = False - super(OngoingInputParamDialog, self).__init__(parent,din) - self.addOffButton() - self.addImages() - self.addHideButton() - - # add png cartoons to tabs - def addImages (self): - if self.isprox: self.pix = QPixmap(lookupresource('proxfig')) - else: self.pix = QPixmap(lookupresource('distfig')) - for tab in self.ltabs: - pixlbl = ClickLabel(self) - pixlbl.setPixmap(self.pix) - tab.layout.addRow(pixlbl) - - - # turn off by setting all weights to 0.0 - def TurnOff (self): self.lines2val('weight',0.0) - - def initd (self): - self.dtiming = OrderedDict([#('distribution' + self.postfix, 'normal'), - ('t0_input' + self.postfix, 1000.), - ('t0_input_stdev' + self.postfix, 0.), - ('tstop_input' + self.postfix, 250.), - ('f_input' + self.postfix, 10.), - ('f_stdev' + self.postfix, 20.), - ('events_per_cycle' + self.postfix, 2), - ('repeats' + self.postfix, 10)]) - - self.dL2 = OrderedDict([(self.prefix + 'weight_L2Pyr_ampa', 0.), - (self.prefix + 'weight_L2Pyr_nmda', 0.), - (self.prefix + 'weight_L2Basket_ampa', 0.), - (self.prefix + 'weight_L2Basket_nmda',0.), - (self.prefix + 'delay_L2', 0.1),]) - - self.dL5 = OrderedDict([(self.prefix + 'weight_L5Pyr_ampa', 0.), - (self.prefix + 'weight_L5Pyr_nmda', 0.)]) - - if self.isprox: - self.dL5[self.prefix + 'weight_L5Basket_ampa'] = 0.0 - self.dL5[self.prefix + 'weight_L5Basket_nmda'] = 0.0 - self.dL5[self.prefix + 'delay_L5'] = 0.1 - - self.ldict = [self.dtiming, self.dL2, self.dL5] - self.ltitle = ['Timing', 'Layer 2/3', 'Layer 5'] - self.stitle = 'Set Rhythmic '+self.inty+' Inputs' - - dtmp = {'L2':'L2/3 ','L5':'L5 '} - for d in [self.dL2, self.dL5]: - for k in d.keys(): - lk = k.split('_') - if k.count('weight') > 0: - self.addtransvar(k, dtmp[lk[-2][0:2]] + lk[-2][2:]+' '+lk[-1].upper()+u' weight (µS)') - else: - self.addtransvar(k, 'Delay (ms)') - - #self.addtransvar('distribution'+self.postfix,'Distribution') - self.addtransvar('t0_input'+self.postfix,'Start time mean (ms)') - self.addtransvar('t0_input_stdev'+self.postfix,'Start time stdev (ms)') - self.addtransvar('tstop_input'+self.postfix,'Stop time (ms)') - self.addtransvar('f_input'+self.postfix,'Burst frequency (Hz)') - self.addtransvar('f_stdev'+self.postfix,'Burst stdev (ms)') - self.addtransvar('events_per_cycle'+self.postfix,'Spikes/burst') - self.addtransvar('repeats'+self.postfix,'Number bursts') - -class EvokedOrRhythmicDialog (QDialog): - def __init__ (self, parent, distal, evwin, rhythwin): - super(EvokedOrRhythmicDialog, self).__init__(parent) - if distal: self.prefix = 'Distal' - else: self.prefix = 'Proximal' - self.evwin = evwin - self.rhythwin = rhythwin - self.initUI() - - def initUI (self): - self.layout = QVBoxLayout(self) - # Add stretch to separate the form layout from the button - self.layout.addStretch(1) - - self.btnrhythmic = QPushButton('Rhythmic ' + self.prefix + ' Inputs',self) - self.btnrhythmic.resize(self.btnrhythmic.sizeHint()) - self.btnrhythmic.clicked.connect(self.showrhythmicwin) - self.layout.addWidget(self.btnrhythmic) - - self.btnevoked = QPushButton('Evoked Inputs',self) - self.btnevoked.resize(self.btnevoked.sizeHint()) - self.btnevoked.clicked.connect(self.showevokedwin) - self.layout.addWidget(self.btnevoked) - - self.addHideButton() - - setscalegeom(self, 150, 150, 270, 120) - self.setWindowTitle("Pick Input Type") - - def showevokedwin (self): - bringwintotop(self.evwin) - self.hide() - - def showrhythmicwin (self): - bringwintotop(self.rhythwin) - self.hide() - - def addHideButton (self): - self.bbhidebox = QHBoxLayout() - self.btnhide = QPushButton('Hide Window',self) - self.btnhide.resize(self.btnhide.sizeHint()) - self.btnhide.clicked.connect(self.hide) - self.btnhide.setToolTip('Hide Window') - self.bbhidebox.addWidget(self.btnhide) - self.layout.addLayout(self.bbhidebox) - - -class SynGainParamDialog (QDialog): - def __init__ (self, parent, netparamwin): - super(SynGainParamDialog, self).__init__(parent) - self.netparamwin = netparamwin - self.initUI() - - def scalegain (self, k, fctr): - oldval = float(self.netparamwin.dqline[k].text().strip()) - newval = oldval * fctr - self.netparamwin.dqline[k].setText(str(newval)) - if debug: print('scaling ',k,' by', fctr, 'from ',oldval,'to ',newval,'=',oldval*fctr) - return newval - - def isE (self,ty): return ty.count('Pyr') > 0 - def isI (self,ty): return ty.count('Basket') > 0 - - def tounity (self): - for k in self.dqle.keys(): self.dqle[k].setText('1.0') - - def scalegains (self): - if debug: print('scaling synaptic gains') - for i,k in enumerate(self.dqle.keys()): - fctr = float(self.dqle[k].text().strip()) - if fctr < 0.: - fctr = 0. - self.dqle[k].setText(str(fctr)) - elif fctr == 1.0: - continue - if debug: print(k,fctr) - for k2 in self.netparamwin.dqline.keys(): - l = k2.split('_') - ty1,ty2 = l[1],l[2] - if self.isE(ty1) and self.isE(ty2) and k == 'E -> E': - self.scalegain(k2,fctr) - elif self.isE(ty1) and self.isI(ty2) and k == 'E -> I': - self.scalegain(k2,fctr) - elif self.isI(ty1) and self.isE(ty2) and k == 'I -> E': - self.scalegain(k2,fctr) - elif self.isI(ty1) and self.isI(ty2) and k == 'I -> I': - self.scalegain(k2,fctr) - self.tounity() # go back to unity since pressed OK - next call to this dialog will reset new values - self.hide() - - def initUI (self): - grid = QGridLayout() - grid.setSpacing(10) - - self.dqle = OrderedDict() - for row,k in enumerate(['E -> E', 'E -> I', 'I -> E', 'I -> I']): - lbl = QLabel(self) - lbl.setText(k) - lbl.adjustSize() - grid.addWidget(lbl,row, 0) - qle = QLineEdit(self) - qle.setText('1.0') - grid.addWidget(qle,row, 1) - self.dqle[k] = qle - - row += 1 - self.btnok = QPushButton('OK',self) - self.btnok.resize(self.btnok.sizeHint()) - self.btnok.clicked.connect(self.scalegains) - grid.addWidget(self.btnok, row, 0, 1, 1) - self.btncancel = QPushButton('Cancel',self) - self.btncancel.resize(self.btncancel.sizeHint()) - self.btncancel.clicked.connect(self.hide) - grid.addWidget(self.btncancel, row, 1, 1, 1) - - self.setLayout(grid) - setscalegeom(self, 150, 150, 270, 180) - self.setWindowTitle("Synaptic Gains") - -# widget to specify tonic inputs -class TonicInputParamDialog (DictDialog): - def __init__ (self, parent, din): - super(TonicInputParamDialog, self).__init__(parent,din) - self.addOffButton() - self.addHideButton() - - # turn off by setting all weights to 0.0 - def TurnOff (self): self.lines2val('A',0.0) - - def initd (self): - - self.dL2 = OrderedDict([ - # IClamp params for L2Pyr - ('Itonic_A_L2Pyr_soma', 0.), - ('Itonic_t0_L2Pyr_soma', 0.), - ('Itonic_T_L2Pyr_soma', -1.), - # IClamp param for L2Basket - ('Itonic_A_L2Basket', 0.), - ('Itonic_t0_L2Basket', 0.), - ('Itonic_T_L2Basket', -1.)]) - - self.dL5 = OrderedDict([ - # IClamp params for L5Pyr - ('Itonic_A_L5Pyr_soma', 0.), - ('Itonic_t0_L5Pyr_soma', 0.), - ('Itonic_T_L5Pyr_soma', -1.), - # IClamp param for L5Basket - ('Itonic_A_L5Basket', 0.), - ('Itonic_t0_L5Basket', 0.), - ('Itonic_T_L5Basket', -1.)]) - - dtmp = {'L2':'L2/3 ','L5':'L5 '} # temporary dictionary for string translation - for d in [self.dL2, self.dL5]: - for k in d.keys(): - cty = k.split('_')[2] # cell type - tcty = dtmp[cty[0:2]] + cty[2:] # translated cell type - if k.count('A') > 0: - self.addtransvar(k, tcty + ' amplitude (nA)') - elif k.count('t0') > 0: - self.addtransvar(k, tcty + ' start time (ms)') - elif k.count('T') > 0: - self.addtransvar(k, tcty + ' stop time (ms)') - - self.ldict = [self.dL2, self.dL5] - self.ltitle = ['Layer 2/3', 'Layer 5'] - self.stitle = 'Set Tonic Inputs' - -# widget to specify ongoing poisson inputs -class PoissonInputParamDialog (DictDialog): - def __init__ (self, parent, din): - super(PoissonInputParamDialog, self).__init__(parent,din) - self.addOffButton() - self.addHideButton() - - # turn off by setting all weights to 0.0 - def TurnOff (self): self.lines2val('weight',0.0) - - def initd (self): - - self.dL2,self.dL5 = OrderedDict(),OrderedDict() - ld = [self.dL2,self.dL5] - - for i,lyr in enumerate(['L2','L5']): - d = ld[i] - for ty in ['Pyr', 'Basket']: - for sy in ['ampa','nmda']: d[lyr+ty+'_Pois_A_weight'+'_'+sy]=0. - d[lyr+ty+'_Pois_lamtha']=0. - - self.dtiming = OrderedDict([('t0_pois', 0.), - ('T_pois', -1)]) - - self.addtransvar('t0_pois','Start time (ms)') - self.addtransvar('T_pois','Stop time (ms)') - - dtmp = {'L2':'L2/3 ','L5':'L5 '} # temporary dictionary for string translation - for d in [self.dL2, self.dL5]: - for k in d.keys(): - ks = k.split('_') - cty = ks[0] # cell type - tcty = dtmp[cty[0:2]] + cty[2:] # translated cell type - if k.count('weight'): - self.addtransvar(k, tcty+ ' ' + ks[-1].upper() + u' weight (µS)') - elif k.endswith('lamtha'): - self.addtransvar(k, tcty+ ' freq (Hz)') - - self.ldict = [self.dL2, self.dL5, self.dtiming] - self.ltitle = ['Layer 2/3', 'Layer 5', 'Timing'] - self.stitle = 'Set Poisson Inputs' - -# evoked input param dialog (allows adding/removing arbitrary number of evoked inputs) -class EvokedInputParamDialog (QDialog): - def __init__ (self, parent, din): - super(EvokedInputParamDialog, self).__init__(parent) - self.nprox = self.ndist = 0 # number of proximal,distal inputs - self.ld = [] # list of dictionaries for proximal/distal inputs - self.dqline = OrderedDict() - self.dtransvar = {} # for translating model variable name to more human-readable form - self.initUI() - self.setfromdin(din) - - def addtips (self): - for ktip in dconf.keys(): - if ktip in self.dqline: - self.dqline[ktip].setToolTip(dconf[ktip]) - - def transvar (self,k): - if k in self.dtransvar: return self.dtransvar[k] - return k - - def addtransvar (self,k,strans): - self.dtransvar[k] = strans - self.dtransvar[strans] = k - - def set_qline_float (self, key_str, value): - try: - new_value = float(value) - except ValueError: - print("WARN: bad value for param %s: %s. Unable to convert" - " to a floating point number" % (key_str, value)) - return - - # Enforce no sci. not. + limit field len + remove trailing 0's - self.dqline[key_str].setText(("%7f" % new_value).rstrip('0').rstrip('.')) - - def setfromdin (self,din): - if not din: return - - if 'dt' in din: - - # Optimization feature introduces the case where din just contains optimization - # relevant parameters. In that case, we don't want to remove all inputs, just - # modify existing inputs. - self.removeAllInputs() # turn off any previously set inputs - - nprox, ndist = countEvokedInputs(din) - for i in range(nprox+ndist): - if i % 2 == 0: - if self.nprox < nprox: - self.addProx() - elif self.ndist < ndist: - self.addDist() - else: - if self.ndist < ndist: - self.addDist() - elif self.nprox < nprox: - self.addProx() - - for k,v in din.items(): - if k == 'sync_evinput': - try: - new_value = bool(int(v)) - except ValueError: - print("WARN: bad value for param %s: %s. Unable to convert" - " to a boolean value" % (k,v)) - continue - if new_value: - self.chksync.setChecked(True) - else: - self.chksync.setChecked(False) - elif k == 'inc_evinput': - try: - new_value = float(v) - except ValueError: - print("WARN: bad value for param %s: %s. Unable to convert" - " to a floating point number" % (k,v)) - continue - self.incedit.setText(str(new_value).strip()) - elif k in self.dqline: - if k.startswith('numspikes'): - try: - new_value = int(v) - except ValueError: - print("WARN: bad value for param %s: %s. Unable to convert" - " to a integer" % (k, v)) - continue - self.dqline[k].setText(str(new_value)) - else: - self.set_qline_float(k, v) - elif k.count('gbar') > 0 and \ - (k.count('evprox') > 0 or \ - k.count('evdist') > 0): - # NOTE: will be deprecated in future release - # for back-compat with old-style specification which didn't have ampa,nmda in evoked gbar - lks = k.split('_') - eloc = lks[1] - enum = lks[2] - base_key_str = 'gbar_' + eloc + '_' + enum + '_' - if eloc == 'evprox': - for ct in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: - # ORIGINAL MODEL/PARAM: only ampa for prox evoked inputs - key_str = base_key_str + ct + '_ampa' - self.set_qline_float(key_str, v) - elif eloc == 'evdist': - for ct in ['L2Pyr', 'L2Basket', 'L5Pyr']: - # ORIGINAL MODEL/PARAM: both ampa and nmda for distal evoked inputs - key_str = base_key_str + ct + '_ampa' - self.set_qline_float(key_str, v) - key_str = base_key_str + ct + '_nmda' - self.set_qline_float(key_str, v) - - def initUI (self): - self.layout = QVBoxLayout(self) - - # Add stretch to separate the form layout from the button - self.layout.addStretch(1) - - self.ltabs = [] - self.tabs = QTabWidget() - self.layout.addWidget(self.tabs) - - self.button_box = QVBoxLayout() - self.btnprox = QPushButton('Add Proximal Input',self) - self.btnprox.resize(self.btnprox.sizeHint()) - self.btnprox.clicked.connect(self.addProx) - self.btnprox.setToolTip('Add Proximal Input') - self.button_box.addWidget(self.btnprox) - - self.btndist = QPushButton('Add Distal Input',self) - self.btndist.resize(self.btndist.sizeHint()) - self.btndist.clicked.connect(self.addDist) - self.btndist.setToolTip('Add Distal Input') - self.button_box.addWidget(self.btndist) - - self.chksync = QCheckBox('Synchronous Inputs',self) - self.chksync.resize(self.chksync.sizeHint()) - self.chksync.setChecked(True) - self.button_box.addWidget(self.chksync) - - self.incbox = QHBoxLayout() - self.inclabel = QLabel(self) - self.inclabel.setText('Increment start time (ms)') - self.inclabel.adjustSize() - self.inclabel.setToolTip('Increment mean evoked input start time(s) by this amount on each trial.') - self.incedit = QLineEdit(self) - self.incedit.setText('0.0') - self.incbox.addWidget(self.inclabel) - self.incbox.addWidget(self.incedit) - - self.layout.addLayout(self.button_box) - self.layout.addLayout(self.incbox) - - self.tabs.resize(425,200) - - # Add tabs to widget - self.layout.addWidget(self.tabs) - self.setLayout(self.layout) - - self.setWindowTitle('Evoked Inputs') - - self.addRemoveInputButton() - self.addHideButton() - self.addtips() - - def lines2val (self,ksearch,val): - for k in self.dqline.keys(): - if k.count(ksearch) > 0: - self.dqline[k].setText(str(val)) - - def allOff (self): self.lines2val('gbar',0.0) - - def removeAllInputs (self): - for i in range(len(self.ltabs)): self.removeCurrentInput() - self.nprox = self.ndist = 0 - - def IsProx (self,idx): - # is this evoked input proximal (True) or distal (False) ? - try: - d = self.ld[idx] - for k in d.keys(): - if k.count('evprox'): - return True - except: - pass - return False - - def getInputID (self,idx): - # get evoked input number of the evoked input associated with idx - try: - d = self.ld[idx] - for k in d.keys(): - lk = k.split('_') - if len(lk) >= 3: - return int(lk[2]) - except: - pass - return -1 - - def downShift (self,idx): - # downshift the evoked input ID, keys, values - d = self.ld[idx] - dnew = {} # new dictionary - newidx = 0 # new evoked input ID - for k,v in d.items(): - lk = k.split('_') - if len(lk) >= 3: - if lk[0]=='sigma': - newidx = int(lk[3])-1 - lk[3] = str(newidx) - else: - newidx = int(lk[2])-1 - lk[2] = str(newidx) - newkey = '_'.join(lk) - dnew[newkey] = v - if k in self.dqline: - self.dqline[newkey] = self.dqline[k] - del self.dqline[k] - self.ld[idx] = dnew - currtxt = self.tabs.tabText(idx) - newtxt = currtxt.split(' ')[0] + ' ' + str(newidx) - self.tabs.setTabText(idx,newtxt) - # print('d original:',d, 'd new:',dnew) - - def removeInput (self,idx): - # remove the evoked input specified by idx - if idx < 0 or idx > len(self.ltabs): return - # print('removing input at index', idx) - self.tabs.removeTab(idx) - tab = self.ltabs[idx] - self.ltabs.remove(tab) - d = self.ld[idx] - - isprox = self.IsProx(idx) # is it a proximal input? - isdist = not isprox # is it a distal input? - inputID = self.getInputID(idx) # wht's the proximal/distal input number? - - # print('isprox,isdist,inputid',isprox,isdist,inputID) - - for k in d.keys(): - if k in self.dqline: - del self.dqline[k] - self.ld.remove(d) - tab.setParent(None) - - # now downshift the evoked inputs (only proximal or only distal) that came after this one - # first get the IDs of the evoked inputs to downshift - lds = [] # list of inputs to downshift - for jdx in range(len(self.ltabs)): - if isprox and self.IsProx(jdx) and self.getInputID(jdx) > inputID: - #print('downshift prox',self.getInputID(jdx)) - lds.append(jdx) - elif isdist and not self.IsProx(jdx) and self.getInputID(jdx) > inputID: - #print('downshift dist',self.getInputID(jdx)) - lds.append(jdx) - for jdx in lds: self.downShift(jdx) # then do the downshifting - - # print(self) # for testing - - def removeCurrentInput (self): # removes currently selected input - idx = self.tabs.currentIndex() - if idx < 0: return - self.removeInput(idx) - - def __str__ (self): - s = '' - for k,v in self.dqline.items(): s += k + ': ' + v.text().strip() + os.linesep - if self.chksync.isChecked(): s += 'sync_evinput: 1'+os.linesep - else: s += 'sync_evinput: 0'+os.linesep - s += 'inc_evinput: ' + self.incedit.text().strip() + os.linesep - return s - - def addRemoveInputButton (self): - self.bbremovebox = QHBoxLayout() - self.btnremove = QPushButton('Remove Input',self) - self.btnremove.resize(self.btnremove.sizeHint()) - self.btnremove.clicked.connect(self.removeCurrentInput) - self.btnremove.setToolTip('Remove This Input') - self.bbremovebox.addWidget(self.btnremove) - self.layout.addLayout(self.bbremovebox) - - def addHideButton (self): - self.bbhidebox = QHBoxLayout() - self.btnhide = QPushButton('Hide Window',self) - self.btnhide.resize(self.btnhide.sizeHint()) - self.btnhide.clicked.connect(self.hide) - self.btnhide.setToolTip('Hide Window') - self.bbhidebox.addWidget(self.btnhide) - self.layout.addLayout(self.bbhidebox) - - def addTab (self,s): - tab = QWidget() - self.ltabs.append(tab) - self.tabs.addTab(tab,s) - tab.layout = QFormLayout() - tab.setLayout(tab.layout) - return tab - - def addFormToTab (self,d,tab): - for k,v in d.items(): - self.dqline[k] = QLineEdit(self) - self.dqline[k].setText(str(v)) - tab.layout.addRow(self.transvar(k),self.dqline[k]) # adds label,QLineEdit to the tab - - def makePixLabel (self,fn): - pix = QPixmap(fn) - pixlbl = ClickLabel(self) - pixlbl.setPixmap(pix) - return pixlbl - - def addtransvarfromdict (self,d): - dtmp = {'L2':'L2/3 ','L5':'L5 '} - for k in d.keys(): - if k.startswith('gbar'): - ks = k.split('_') - stmp = ks[-2] - self.addtransvar(k,dtmp[stmp[0:2]] + stmp[2:] + ' ' + ks[-1].upper() + u' weight (µS)') - elif k.startswith('t'): - self.addtransvar(k,'Start time mean (ms)') - elif k.startswith('sigma'): - self.addtransvar(k,'Start time stdev (ms)') - elif k.startswith('numspikes'): - self.addtransvar(k,'Number spikes') - - def addProx (self): - self.nprox += 1 # starts at 1 - # evprox feed strength - dprox = OrderedDict([('t_evprox_' + str(self.nprox), 0.), # times and stdevs for evoked responses - ('sigma_t_evprox_' + str(self.nprox), 2.5), - ('numspikes_evprox_' + str(self.nprox), 1), - ('gbar_evprox_' + str(self.nprox) + '_L2Pyr_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L2Pyr_nmda', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L2Basket_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L2Basket_nmda', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Pyr_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Pyr_nmda', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Basket_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Basket_nmda', 0.)]) - self.ld.append(dprox) - self.addtransvarfromdict(dprox) - self.addFormToTab(dprox, self.addTab('Proximal ' + str(self.nprox))) - self.ltabs[-1].layout.addRow(self.makePixLabel(lookupresource('proxfig'))) - #print('index to', len(self.ltabs)-1) - self.tabs.setCurrentIndex(len(self.ltabs)-1) - #print('index now', self.tabs.currentIndex(), ' of ', self.tabs.count()) - self.addtips() - - def addDist (self): - self.ndist += 1 - # evdist feed strengths - ddist = OrderedDict([('t_evdist_' + str(self.ndist), 0.), - ('sigma_t_evdist_' + str(self.ndist), 6.), - ('numspikes_evdist_' + str(self.ndist), 1), - ('gbar_evdist_' + str(self.ndist) + '_L2Pyr_ampa', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L2Pyr_nmda', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L2Basket_ampa', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L2Basket_nmda', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L5Pyr_ampa', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L5Pyr_nmda', 0.)]) - self.ld.append(ddist) - self.addtransvarfromdict(ddist) - self.addFormToTab(ddist,self.addTab('Distal ' + str(self.ndist))) - self.ltabs[-1].layout.addRow(self.makePixLabel(lookupresource('distfig'))) - #print('index to', len(self.ltabs)-1) - self.tabs.setCurrentIndex(len(self.ltabs)-1) - #print('index now', self.tabs.currentIndex(), ' of ', self.tabs.count()) - self.addtips() - -class OptEvokedInputParamDialog (EvokedInputParamDialog): - - def __init__ (self, parent, optrun_func): - super(EvokedInputParamDialog, self).__init__(None) - self.nprox = self.ndist = 0 # number of proximal,distal inputs - self.ld = [] # list of dictionaries for proximal/distal inputs - self.dtab_idx = {} # for translating input names to tab indices - self.dtab_names = {} # for translating tab indices to input names - self.dparams = OrderedDict() # actual values - self.dqline = OrderedDict() # not used, prevents failure in removeInput - - # these store values used in grid - self.dqchkbox = OrderedDict() # optimize - self.dqparam_name = OrderedDict() # parameter name - self.dqinitial_label = OrderedDict() # initial - self.dqopt_label = OrderedDict() # optimtized - self.dqdiff_label = OrderedDict() # delta - self.dqrange_multiplier = OrderedDict() # user-defined multiplier - self.dqrange_mode = OrderedDict() # range mode (stdev, %, absolute) - self.dqrange_slider = OrderedDict() # slider - self.dqrange_label = OrderedDict() # defined range - self.dqrange_max = OrderedDict() - self.dqrange_min = OrderedDict() - - self.chunk_list = [] - self.lqnumsim = [] - self.lqnumparams = [] - self.lqinputs = [] - self.opt_params = {} - self.initial_opt_ranges = [] - self.dtabdata = [] - self.dtransvar = {} # for translating model variable name to more human-readable form - self.simlength = 0.0 - self.sim_dt = 0.0 - self.default_num_step_sims = 30 - self.default_num_total_sims = 50 - self.optrun_func = optrun_func - self.optimization_running = False - self.initUI() - self.parent = parent - self.old_num_steps = 0 - - def initUI (self): - # start with a reasonable size - setscalegeom(self, 150, 150, 475, 300) - - self.ltabs = [] - self.ltabkeys = [] - self.tabs = QTabWidget() - self.din = {} - - self.grid = QGridLayout() - - row = 0 - self.sublayout = QGridLayout() - self.old_numsims = [] - self.grid.addLayout(self.sublayout, row, 0) - - row += 1 - self.grid.addWidget(self.tabs, row, 0) - - row += 1 - self.btnrunop = QPushButton('Run Optimization', self) - self.btnrunop.resize(self.btnrunop.sizeHint()) - self.btnrunop.setToolTip('Run Optimization') - self.btnrunop.clicked.connect(self.runOptimization) - self.grid.addWidget(self.btnrunop, row, 0) - - row += 1 - self.btnreset = QPushButton('Reset Ranges',self) - self.btnreset.resize(self.btnreset.sizeHint()) - self.btnreset.clicked.connect(self.updateOptRanges) - self.btnreset.setToolTip('Reset Ranges') - self.grid.addWidget(self.btnreset, row, 0) - - row += 1 - btnhide = QPushButton('Hide Window',self) - btnhide.resize(btnhide.sizeHint()) - btnhide.clicked.connect(self.hide) - btnhide.setToolTip('Hide Window') - self.grid.addWidget(btnhide, row, 0) - - self.setLayout(self.grid) - - self.setWindowTitle("Configure Optimization") - - # the largest horizontal component will be column 0 (headings) - self.resize(self.minimumSizeHint()) - - def toggle_enable_param(self, label): - import re - - widget_dict_list = [self.dqinitial_label, self.dqopt_label, - self.dqdiff_label, self.dqparam_name, - self.dqrange_mode, self.dqrange_multiplier, - self.dqrange_label, self.dqrange_slider] - - if self.dqchkbox[label].isChecked(): - # set all other fields in the row to enabled - for widget_dict in widget_dict_list: - widget_dict[label].setEnabled(True) - toEnable = True - else: - # disable all other fields in the row - for widget_dict in widget_dict_list: - widget_dict[label].setEnabled(False) - toEnable = False - - self.changeParamEnabledStatus(label, toEnable) - - def addTab (self,id_str): - tab = QWidget() - self.ltabs.append(tab) - - name_str = trans_input(id_str) - self.tabs.addTab(tab, name_str) - - tab_index = len(self.ltabs)-1 - self.dtab_idx[id_str] = tab_index - self.dtab_names[tab_index] = id_str - - return tab - - def cleanLabels(self): - """ - To avoid memory leaks we need to delete all widgets when we recreate grid. - Go through all tabs and check for each var name (k) - """ - for idx in range(len(self.ltabs)): - for k in self.ld[idx].keys(): - if k in self.dqinitial_label: - del self.dqinitial_label[k] - if k in self.dqopt_label: - del self.dqopt_label[k] - if k in self.dqdiff_label: - del self.dqdiff_label[k] - if k in self.dqparam_name: - del self.dqparam_name[k] - if not self.optimization_running: - if k in self.dqrange_mode: - del self.dqrange_mode[k] - if k in self.dqrange_multiplier: - del self.dqrange_multiplier[k] - if k in self.dqrange_label: - del self.dqrange_label[k] - if k in self.dqrange_slider: - del self.dqrange_slider[k] - if k in self.dqrange_min: - del self.dqrange_min[k] - if k in self.dqrange_max: - del self.dqrange_max[k] - - def addGridToTab (self, d, tab): - from functools import partial - import re - - current_tab = len(self.ltabs)-1 - tab.layout = QGridLayout() - #tab.layout.setSpacing(10) - - self.ltabkeys.append([]) - - # The first row has column headings - row = 0 - self.ltabkeys[current_tab].append("") - for column_index, column_name in enumerate(["Optimize", "Parameter name", - "Initial", "Optimized", "Delta"]): - widget = QLabel(column_name) - widget.resize(widget.sizeHint()) - tab.layout.addWidget(widget, row, column_index) - - column_index += 1 - widget = QLabel("Range specifier") - widget.setMinimumWidth(100) - tab.layout.addWidget(widget, row, column_index, 1, 2) - - column_index += 2 - widget = QLabel("Range slider") - # widget.setMinimumWidth(160) - tab.layout.addWidget(widget, row, column_index) - - column_index += 1 - widget = QLabel("Defined range") - tab.layout.addWidget(widget, row, column_index) - - # The second row is a horizontal line - row = 1 - self.ltabkeys[current_tab].append("") - qthline = QFrame() - qthline.setFrameShape(QFrame.HLine) - qthline.setFrameShadow(QFrame.Sunken) - tab.layout.addWidget(qthline, row, 0, 1, 9) - - # The rest are the parameters - row = 2 - for k,v in d.items(): - self.ltabkeys[current_tab].append(k) - - # create and format widgets - self.dparams[k] = float(v) - self.dqchkbox[k] = QCheckBox() - self.dqchkbox[k].setStyleSheet(""" - .QCheckBox { - spacing: 20px; - } - .QCheckBox::unchecked { - color: grey; - } - .QCheckBox::checked { - color: black; - } - """) - self.dqchkbox[k].setChecked(True) - # use partial instead of lamda (so args won't be evaluated ahead of time?) - self.dqchkbox[k].clicked.connect(partial(self.toggle_enable_param, k)) - self.dqparam_name[k] = QLabel(self) - self.dqparam_name[k].setText(self.transvar(k)) - self.dqinitial_label[k] = QLabel() - self.dqopt_label[k] = QLabel() - self.dqdiff_label[k] = QLabel() - - # add widgets to grid - tab.layout.addWidget(self.dqchkbox[k], row, 0, alignment = Qt.AlignBaseline | Qt.AlignCenter) - tab.layout.addWidget(self.dqparam_name[k], row, 1) - tab.layout.addWidget(self.dqinitial_label[k], row, 2) # initial value - tab.layout.addWidget(self.dqopt_label[k], row, 3) # optimized value - tab.layout.addWidget(self.dqdiff_label[k], row, 4) # delta - - if k.startswith('t'): - range_mode = "(stdev)" - range_multiplier = "3.0" - elif k.startswith('sigma'): - range_mode = "(%)" - range_multiplier = "50.0" - else: - range_mode = "(%)" - range_multiplier = "500.0" - - if not self.optimization_running: - self.dqrange_slider[k] = QRangeSlider(k,self) - self.dqrange_slider[k].setMinimumWidth(140) - self.dqrange_label[k] = QLabel() - self.dqrange_multiplier[k] = MyLineEdit(range_multiplier, k) - self.dqrange_multiplier[k].textModified.connect(self.updateRange) - self.dqrange_multiplier[k].setSizePolicy(QSizePolicy.Ignored, QSizePolicy.Preferred) - self.dqrange_multiplier[k].setMinimumWidth(50) - self.dqrange_multiplier[k].setMaximumWidth(50) - self.dqrange_mode[k] = QLabel(range_mode) - tab.layout.addWidget(self.dqrange_multiplier[k], row, 5) # range specifier - tab.layout.addWidget(self.dqrange_mode[k], row, 6) # range mode - tab.layout.addWidget(self.dqrange_slider[k], row, 7) # range slider - tab.layout.addWidget(self.dqrange_label[k], row, 8) # calculated range - - row += 1 - - # A spacer in the last row stretches to fill remaining space. - # For inputs with fewer parameters than the rest, this pushes parameters - # to the top with the same spacing as the other inputs. - tab.layout.addItem(QSpacerItem(0, 0), row, 0, 1, 9) - tab.layout.setRowStretch(row,1) - tab.setLayout(tab.layout) - - def addProx (self): - self.nprox += 1 # starts at 1 - # evprox feed strength - dprox = OrderedDict([('t_evprox_' + str(self.nprox), 0.), # times and stdevs for evoked responses - ('sigma_t_evprox_' + str(self.nprox), 2.5), - #('numspikes_evprox_' + str(self.nprox), 1), - ('gbar_evprox_' + str(self.nprox) + '_L2Pyr_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L2Pyr_nmda', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L2Basket_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L2Basket_nmda', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Pyr_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Pyr_nmda', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Basket_ampa', 0.), - ('gbar_evprox_' + str(self.nprox) + '_L5Basket_nmda', 0.)]) - self.ld.append(dprox) - self.addtransvarfromdict(dprox) - tab = self.addTab('evprox_' + str(self.nprox)) - self.addGridToTab(dprox, tab) - - def addDist (self): - self.ndist += 1 - # evdist feed strengths - ddist = OrderedDict([('t_evdist_' + str(self.ndist), 0.), - ('sigma_t_evdist_' + str(self.ndist), 6.), - #('numspikes_evdist_' + str(self.ndist), 1), - ('gbar_evdist_' + str(self.ndist) + '_L2Pyr_ampa', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L2Pyr_nmda', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L2Basket_ampa', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L2Basket_nmda', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L5Pyr_ampa', 0.), - ('gbar_evdist_' + str(self.ndist) + '_L5Pyr_nmda', 0.)]) - self.ld.append(ddist) - self.addtransvarfromdict(ddist) - tab = self.addTab('evdist_' + str(self.ndist)) - self.addGridToTab(ddist, tab) - - def changeParamEnabledStatus(self, label, toEnable): - import re - - label_match = re.search('(evprox|evdist)_([0-9]+)', label) - if label_match: - my_input_name = label_match.group(1) + '_' + label_match.group(2) - else: - print("ERR: can't determine input name from parameter: %s" % label) - return - - # decrease the count of num params - for chunk_index in range(self.old_num_steps): - for input_name in self.chunk_list[chunk_index]['inputs']: - if input_name == my_input_name: - try: - num_params = int(self.lqnumparams[chunk_index].text()) - except ValueError: - print("ERR: could not get number of params for step %d"%chunk_index) - - if toEnable: - num_params += 1 - else: - num_params -= 1 - self.lqnumparams[chunk_index].setText(str(num_params)) - self.opt_params[input_name]['ranges'][label]['enabled'] = toEnable - - def updateRange(self, label, save_slider=True): - import re - - max_width = 0 - - label_match = re.search('(evprox|evdist)_([0-9]+)', label) - if label_match: - tab_name = label_match.group(1) + '_' + label_match.group(2) - else: - print("ERR: can't determine input name from parameter: %s" % label) - return - - if self.dqchkbox[label].isChecked(): - self.opt_params[tab_name]['ranges'][label]['enabled'] = True - else: - self.opt_params[tab_name]['ranges'][label]['enabled'] = False - return - - if tab_name not in self.initial_opt_ranges or \ - label not in self.initial_opt_ranges[tab_name]: - value = self.dparams[label] - else: - value = float(self.initial_opt_ranges[tab_name][label]['initial']) - - range_type = self.dqrange_mode[label].text() - if range_type == "(%)" and value == 0.0: - # change to range from 0 to 1 - range_type = "(max)" - self.dqrange_mode[label].setText(range_type) - self.dqrange_multiplier[label].setText("1.0") - elif range_type == "(max)" and value > 0.0: - # change back to % - range_type = "(%)" - self.dqrange_mode[label].setText(range_type) - self.dqrange_multiplier[label].setText("500.0") - - try: - range_multiplier = float(self.dqrange_multiplier[label].text()) - except ValueError: - range_multiplier = 0.0 - self.dqrange_multiplier[label].setText(str(range_multiplier)) - - if range_type == "(max)": - range_min = 0 - try: - range_max = float(self.dqrange_multiplier[label].text()) - except ValueError: - range_max = 1.0 - elif range_type == "(stdev)": # timing - timing_sigma = self.get_input_timing_sigma(tab_name) - timing_bound = timing_sigma * range_multiplier - range_min = max(0, value - timing_bound) - range_max = min(self.simlength, value + timing_bound) - else: # range_type == "(%)" - range_min = max(0, value - (value * range_multiplier / 100.0)) - range_max = value + (value * range_multiplier / 100.0) - - # set up the slider - self.dqrange_slider[label].setLine(value) - self.dqrange_slider[label].setMin(range_min) - self.dqrange_slider[label].setMax(range_max) - - if not save_slider: - self.dqrange_min.pop(label, None) - self.dqrange_max.pop(label, None) - - self.opt_params[tab_name]['ranges'][label]['initial'] = value - if label in self.dqrange_min and label in self.dqrange_max: - range_min = self.dqrange_min[label] - range_max = self.dqrange_max[label] - - self.opt_params[tab_name]['ranges'][label]['minval'] = range_min - self.opt_params[tab_name]['ranges'][label]['maxval'] = range_max - self.dqrange_slider[label].setRange(range_min, range_max) - - if range_min == range_max: - self.dqrange_label[label].setText(format_range_str(range_min)) # use the exact value - self.dqrange_label[label].setEnabled(False) - # uncheck because invalid range - self.dqchkbox[label].setChecked(False) - # disable slider - self.dqrange_slider[label].setEnabled(False) - self.changeParamEnabledStatus(label, False) - else: - self.dqrange_label[label].setText(format_range_str(range_min) + - " - " + - format_range_str(range_max)) - - if self.dqrange_label[label].sizeHint().width() > max_width: - max_width = self.dqrange_label[label].sizeHint().width() + 15 - # fix the size for the defined range so that changing the slider doesn't change - # the dialog's width - self.dqrange_label[label].setMinimumWidth(max_width) - self.dqrange_label[label].setMaximumWidth(max_width) - - def prepareOptimization(self): - self.createOptParams() - self.rebuildOptStepInfo() - self.updateOptDeltas() - self.updateOptRanges(save_sliders=True) - self.btnreset.setEnabled(True) - self.btnrunop.setText('Run Optimization') - self.btnrunop.clicked.disconnect() - self.btnrunop.clicked.connect(self.runOptimization) - - def runOptimization(self): - self.current_opt_step = 0 - - # update the ranges to find which parameters have been disabled (unchecked) - self.updateOptRanges(save_sliders=True) - - # update the opt info dict to capture num_sims from GUI - self.rebuildOptStepInfo() - self.optimization_running = True - - # run the actual optimization. optrun_func comes from HNNGUI.startoptmodel(): - # passed to BaseParamDialog then finally OptEvokedInputParamDialog - self.optrun_func() - - def get_chunk_start(self, step): - return self.chunk_list[step]['opt_start'] - - def get_chunk_end(self, step): - return self.chunk_list[step]['opt_end'] - - def get_chunk_weights(self, step): - return self.chunk_list[step]['weights'] - - def get_num_chunks(self): - return len(self.chunk_list) - - def get_sims_for_chunk(self, step): - try: - num_sims = int(self.lqnumsim[step].text()) - except KeyError: - print("ERR: number of sims not found for step %d"%step) - num_sims = 0 - except ValueError: - if step == self.old_num_steps - 1: - num_sims = self.default_num_total_sims - else: - num_sims = self.default_num_step_sims - - return num_sims - - def get_chunk_ranges(self, step): - ranges = {} - for input_name in self.chunk_list[step]['inputs']: - # make sure initial value is between minval or maxval before returning - # ranges to the optimization - for label in self.opt_params[input_name]['ranges'].keys(): - if not self.opt_params[input_name]['ranges'][label]['enabled']: - continue - range_min = self.opt_params[input_name]['ranges'][label]['minval'] - range_max = self.opt_params[input_name]['ranges'][label]['maxval'] - if range_min > self.opt_params[input_name]['ranges'][label]['initial']: - self.opt_params[input_name]['ranges'][label]['initial'] = range_min - if range_max < self.opt_params[input_name]['ranges'][label]['initial']: - self.opt_params[input_name]['ranges'][label]['initial'] = range_max - - # copy the values to the ranges dict to be returned - # to optimization - ranges[label] = self.opt_params[input_name]['ranges'][label].copy() - - return ranges - - def get_num_params(self, step): - num_params = 0 - - for input_name in self.chunk_list[step]['inputs']: - for label in self.opt_params[input_name]['ranges'].keys(): - if not self.opt_params[input_name]['ranges'][label]['enabled']: - continue - else: - num_params += 1 - - return num_params - - def push_chunk_ranges(self, step, ranges): - import re - - for label, value in ranges.items(): - for tab_name in self.opt_params.keys(): - if label in self.opt_params[tab_name]['ranges']: - self.opt_params[tab_name]['ranges'][label]['initial'] = float(value) - - def clean_opt_grid(self): - # This is the top part of the Configure Optimization dialog. - - column_count = self.sublayout.columnCount() - row = 0 - while True: - try: - self.sublayout.itemAtPosition(row,0).widget() - except AttributeError: - # no more rows - break - - for column in range(column_count): - try: - # Use deleteLater() to avoid memory leaks. - self.sublayout.itemAtPosition(row, column).widget().deleteLater() - except AttributeError: - # if item wasn't found - pass - row += 1 - - # reset data for number of sims per chunk (step) - self.lqnumsim = [] - self.lqnumparams = [] - self.lqinputs = [] - self.old_num_steps = 0 - - def rebuildOptStepInfo(self): - # split chunks from paramter file - self.chunk_list = chunk_evinputs(self.opt_params, self.simlength, self.sim_dt) - - if len(self.chunk_list) == 0: - self.clean_opt_grid() - - qlabel = QLabel("No valid evoked inputs to optimize!") - qlabel.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) - qlabel.resize(qlabel.minimumSizeHint()) - self.sublayout.addWidget(qlabel, 0, 0) - self.btnrunop.setEnabled(False) - self.btnreset.setEnabled(False) - else: - self.btnrunop.setEnabled(True) - self.btnreset.setEnabled(True) - - if len(self.chunk_list) < self.old_num_steps or \ - self.old_num_steps == 0: - # clean up the old grid sublayout - self.clean_opt_grid() - - # keep track of inputs to optimize over (check against self.opt_params later) - all_inputs = [] - - # create a new grid sublayout with a row for each optimization step - for chunk_index, chunk in enumerate(self.chunk_list): - chunk['num_params'] = self.get_num_params(chunk_index) - - inputs = [] - for input_name in chunk['inputs']: - all_inputs.append(input_name) - inputs.append(trans_input(input_name)) - - if chunk_index >= self.old_num_steps: - qlabel = QLabel("Optimization step %d:"%(chunk_index+1)) - qlabel.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) - qlabel.resize(qlabel.minimumSizeHint()) - self.sublayout.addWidget(qlabel,chunk_index, 0) - - self.lqinputs.append(QLabel("Inputs: %s"%', '.join(inputs))) - self.lqinputs[chunk_index].setAlignment(Qt.AlignBaseline | Qt.AlignLeft) - self.lqinputs[chunk_index].resize(self.lqinputs[chunk_index].minimumSizeHint()) - self.sublayout.addWidget(self.lqinputs[chunk_index], chunk_index, 1) - - # spacer here for readability of input names and reduce size - # of "Num simulations:" - self.sublayout.addItem(QSpacerItem(0, 0, hPolicy = QSizePolicy.MinimumExpanding), chunk_index, 2) - - qlabel_params = QLabel("Num params:") - qlabel_params.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) - qlabel_params.resize(qlabel_params.minimumSizeHint()) - self.sublayout.addWidget(qlabel_params,chunk_index, 3) - - self.lqnumparams.append(QLabel(str(chunk['num_params']))) - self.lqnumparams[chunk_index].setAlignment(Qt.AlignBaseline | Qt.AlignLeft) - self.lqnumparams[chunk_index].resize(self.lqnumparams[chunk_index].minimumSizeHint()) - self.sublayout.addWidget(self.lqnumparams[chunk_index],chunk_index, 4) - - qlabel_sims = QLabel("Num simulations:") - qlabel_sims.setAlignment(Qt.AlignBaseline | Qt.AlignLeft) - qlabel_sims.resize(qlabel_sims.minimumSizeHint()) - self.sublayout.addWidget(qlabel_sims,chunk_index, 5) - - if chunk_index == len(self.chunk_list) - 1: - chunk['num_sims'] = self.default_num_total_sims - else: - chunk['num_sims'] = self.default_num_step_sims - self.lqnumsim.append(QLineEdit(str(chunk['num_sims']))) - self.lqnumsim[chunk_index].resize( - self.lqnumsim[chunk_index].minimumSizeHint()) - self.sublayout.addWidget(self.lqnumsim[chunk_index], - chunk_index, 6) - else: - self.lqinputs[chunk_index].setText("Inputs: %s"%', '.join(inputs)) - self.lqnumparams[chunk_index].setText(str(chunk['num_params'])) - - self.old_num_steps = len(self.chunk_list) - - remove_list = [] - # remove a tab if necessary - for input_name in self.opt_params.keys(): - if input_name not in all_inputs and input_name in self.dtab_idx: - remove_list.append(input_name) - - while len(remove_list) > 0: - tab_name = remove_list.pop() - tab_index = self.dtab_idx[tab_name] - - self.removeInput(tab_index) - del self.dtab_idx[tab_name] - del self.dtab_names[tab_index] - self.ltabkeys.pop(tab_index) - - # rebuild dtab_idx and dtab_names - temp_dtab_names = {} - temp_dtab_idx = {} - for new_tab_index, old_tab_index in enumerate(self.dtab_idx.values()): - # self.dtab_idx[id_str] = tab_index - id_str = self.dtab_names[old_tab_index] - temp_dtab_names[new_tab_index] = id_str - temp_dtab_idx[id_str] = new_tab_index - self.dtab_names = temp_dtab_names - self.dtab_idx = temp_dtab_idx - - def toggleEnableUserFields(self, step, enable=True): - if not enable: - # the optimization called this to disable parameters on - # for the step passed in to this function - self.current_opt_step = step - - for input_name in self.chunk_list[step]['inputs']: - tab_index = self.dtab_idx[input_name] - tab = self.ltabs[tab_index] - - for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer - label = self.ltabkeys[tab_index][row_index] - self.dqchkbox[label].setEnabled(enable) - self.dqrange_slider[label].setEnabled(enable) - self.dqrange_multiplier[label].setEnabled(enable) - - def get_input_timing_sigma(self, tab_name): - """ get timing_sigma from already loaded values """ - - label = 'sigma_t_' + tab_name - try: - timing_sigma = self.dparams[label] - except KeyError: - timing_sigma = 3.0 - print("ERR: Couldn't fing %s. Using default %f" % - (label,timing_sigma)) - - if timing_sigma == 0.0: - # sigma of 0 will not produce a CDF - timing_sigma = 0.01 - - return timing_sigma - - def createOptParams(self): - self.opt_params = {} - - # iterate through tabs. data is contained in grid layout - for tab_index, tab in enumerate(self.ltabs): - tab_name = self.dtab_names[tab_index] - - # before optimization has started update 'mean', 'sigma', - # 'start', and 'user_end' - start_time_label = 't_' + tab_name - try: - try: - range_multiplier = float(self.dqrange_multiplier[start_time_label].text()) - except ValueError: - range_multiplier = 0.0 - value = self.dparams[start_time_label] - except KeyError: - print("ERR: could not find start time parameter: %s" % start_time_label) - continue - - timing_sigma = self.get_input_timing_sigma(tab_name) - self.opt_params[tab_name] = {'ranges': {}, - 'mean' : value, - 'sigma': timing_sigma, - 'decay_multiplier': dconf['decay_multiplier']} - - timing_bound = timing_sigma * range_multiplier - self.opt_params[tab_name]['user_start'] = max(0, value - timing_bound) - self.opt_params[tab_name]['user_end'] = min(self.simlength, value + timing_bound) - - # add an empty dictionary so that rebuildOptStepInfo() can determine - # how many parameters - for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer - label = self.ltabkeys[tab_index][row_index] - self.opt_params[tab_name]['ranges'][label] = {'enabled': True} - - def clear_initial_opt_ranges(self): - self.initial_opt_ranges = {} - - def populate_initial_opt_ranges(self): - self.initial_opt_ranges = {} - - for input_name in self.opt_params.keys(): - self.initial_opt_ranges[input_name] = deepcopy(self.opt_params[input_name]['ranges']) - - def updateOptDeltas(self): - # iterate through tabs. data is contained in grid layout - for tab_index, tab in enumerate(self.ltabs): - tab_name = self.dtab_names[tab_index] - - # update the initial value - for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer - label = self.ltabkeys[tab_index][row_index] - value = self.dparams[label] - - # Calculate value to put in "Delta" column. When possible, use - # percentages, but when initial value is 0, use absolute changes - if tab_name not in self.initial_opt_ranges or \ - not self.dqchkbox[label].isChecked(): - self.dqdiff_label[label].setEnabled(False) - self.dqinitial_label[label].setText(("%6f"%self.dparams[label]).rstrip('0').rstrip('.')) - text = '--' - color_fmt = "QLabel { color : black; }" - self.dqopt_label[label].setText(text) - self.dqopt_label[label].setStyleSheet(color_fmt) - self.dqopt_label[label].setAlignment(Qt.AlignHCenter) - self.dqdiff_label[label].setAlignment(Qt.AlignHCenter) - else: - initial_value = float(self.initial_opt_ranges[tab_name][label]['initial']) - self.dqinitial_label[label].setText(("%6f"%initial_value).rstrip('0').rstrip('.')) - self.dqopt_label[label].setText(("%6f"%self.dparams[label]).rstrip('0').rstrip('.')) - self.dqopt_label[label].setAlignment(Qt.AlignVCenter|Qt.AlignLeft) - self.dqdiff_label[label].setAlignment(Qt.AlignVCenter|Qt.AlignLeft) - - if isclose(value, initial_value, abs_tol=1e-7): - diff = 0 - text = "0.0" - color_fmt = "QLabel { color : black; }" - else: - diff = value - initial_value - - if initial_value == 0: - # can't calculate % - if diff < 0: - text = ("%6f"%diff).rstrip('0').rstrip('.') - color_fmt = "QLabel { color : red; }" - elif diff > 0: - text = ("+%6f"%diff).rstrip('0').rstrip('.') - color_fmt = "QLabel { color : green; }" - else: - # calculate percent difference - percent_diff = 100 * diff/abs(initial_value) - if percent_diff < 0: - text = ("%2.2f %%"%percent_diff) - color_fmt = "QLabel { color : red; }" - elif percent_diff > 0: - text = ("+%2.2f %%"%percent_diff) - color_fmt = "QLabel { color : green; }" - - self.dqdiff_label[label].setStyleSheet(color_fmt) - self.dqdiff_label[label].setText(text) - - def updateRangeFromSlider(self, label, range_min, range_max): - import re - - label_match = re.search('(evprox|evdist)_([0-9]+)', label) - if label_match: - tab_name = label_match.group(1) + '_' + label_match.group(2) - else: - print("ERR: can't determine input name from parameter: %s" % label) - return - - self.dqrange_min[label] = range_min - self.dqrange_max[label] = range_max - self.dqrange_label[label].setText(format_range_str(range_min) + " - " + - format_range_str(range_max)) - self.opt_params[tab_name]['ranges'][label]['minval'] = range_min - self.opt_params[tab_name]['ranges'][label]['maxval'] = range_max - - def updateOptRanges(self, save_sliders=False): - # iterate through tabs. data is contained in grid layout - for tab_index, tab in enumerate(self.ltabs): - # now update the ranges - for row_index in range(2, tab.layout.rowCount()-1): # last row is a spacer - label = self.ltabkeys[tab_index][row_index] - self.updateRange(label, save_sliders) - - def setfromdin (self,din): - if not din: - return - - if 'dt' in din: - # din proivdes a complete parameter set - self.din = din - self.simlength = float(din['tstop']) - self.sim_dt = float(din['dt']) - - self.cleanLabels() - self.removeAllInputs() # turn off any previously set inputs - self.ltabkeys = [] - self.dtab_idx = {} - self.dtab_names = {} - - for evinput in get_inputs(din): - if 'evprox_' in evinput: - self.addProx() - elif 'evdist_' in evinput: - self.addDist() - - for k,v in din.items(): - if k in self.dparams: - try: - new_value = float(v) - except ValueError: - print("WARN: bad value for param %s: %s. Unable to convert" - " to a floating point number" % (k,v)) - continue - self.dparams[k] = new_value - elif k.count('gbar') > 0 and \ - (k.count('evprox') > 0 or \ - k.count('evdist') > 0): - # NOTE: will be deprecated in future release - # for back-compat with old-style specification which didn't have ampa,nmda in evoked gbar - try: - new_value = float(v) - except ValueError: - print("WARN: bad value for param %s: %s. Unable to convert" - " to a floating point number" % (k,v)) - continue - lks = k.split('_') - eloc = lks[1] - enum = lks[2] - base_key_str = 'gbar_' + eloc + '_' + enum + '_' - if eloc == 'evprox': - for ct in ['L2Pyr','L2Basket','L5Pyr','L5Basket']: - # ORIGINAL MODEL/PARAM: only ampa for prox evoked inputs - key_str = base_key_str + ct + '_ampa' - self.dparams[key_str] = new_value - elif eloc == 'evdist': - for ct in ['L2Pyr','L2Basket','L5Pyr']: - # ORIGINAL MODEL/PARAM: both ampa and nmda for distal evoked inputs - key_str = base_key_str + ct + '_ampa' - self.dparams[key_str] = new_value - key_str = base_key_str + ct + '_nmda' - self.dparams[key_str] = new_value - - if not self.optimization_running: - self.createOptParams() - self.rebuildOptStepInfo() - self.updateOptRanges(save_sliders=True) - - self.updateOptDeltas() - - def __str__ (self): - # don't write any values to param file - return '' - -# widget to specify run params (tstop, dt, etc.) -- not many params here -class RunParamDialog (DictDialog): - def __init__ (self, parent, din = None): - super(RunParamDialog, self).__init__(parent,din) - self.addHideButton() - self.parent = parent - - def initd (self): - - self.drun = OrderedDict([('tstop', 250.), # simulation end time (ms) - ('dt', 0.025), # timestep - ('celsius',37.0), # temperature - ('N_trials',1), # number of trials - ('threshold',0.0)]) # firing threshold - # cvode - not currently used by simulation - - # analysis - self.danalysis = OrderedDict([('save_figs',0), - ('save_spec_data', 0), - ('f_max_spec', 40), - ('dipole_scalefctr',30e3), - ('dipole_smooth_win',15.0), - ('save_vsoma',0)]) - - self.drand = OrderedDict([('prng_seedcore_opt', 0), - ('prng_seedcore_input_prox', 0), - ('prng_seedcore_input_dist', 0), - ('prng_seedcore_extpois', 0), - ('prng_seedcore_extgauss', 0), - ('prng_seedcore_evprox_1', 0), - ('prng_seedcore_evdist_1', 0), - ('prng_seedcore_evprox_2', 0), - ('prng_seedcore_evdist_2', 0)]) - - self.ldict = [self.drun, self.danalysis, self.drand] - self.ltitle = ['Run', 'Analysis', 'Randomization Seeds'] - self.stitle = 'Run Parameters' - - self.addtransvar('tstop','Duration (ms)') - self.addtransvar('dt','Integration Timestep (ms)') - self.addtransvar('celsius','Temperature (C)') - self.addtransvar('threshold','Firing Threshold (mV)') - self.addtransvar('N_trials','Trials') - self.addtransvar('save_spec_data','Save Spectral Data') - self.addtransvar('save_figs','Save Figures') - self.addtransvar('f_max_spec', 'Max Spectral Frequency (Hz)') - self.addtransvar('spec_cmap', 'Spectrogram Colormap') - self.addtransvar('dipole_scalefctr','Dipole Scaling') - self.addtransvar('dipole_smooth_win','Dipole Smooth Window (ms)') - self.addtransvar('save_vsoma','Save Somatic Voltages') - self.addtransvar('prng_seedcore_opt','Parameter Optimization') - self.addtransvar('prng_seedcore_input_prox','Ongoing Proximal Input') - self.addtransvar('prng_seedcore_input_dist','Ongoing Distal Input') - self.addtransvar('prng_seedcore_extpois','External Poisson') - self.addtransvar('prng_seedcore_extgauss','External Gaussian') - self.addtransvar('prng_seedcore_evprox_1','Evoked Proximal 1') - self.addtransvar('prng_seedcore_evdist_1','Evoked Distal 1 ') - self.addtransvar('prng_seedcore_evprox_2','Evoked Proximal 2') - self.addtransvar('prng_seedcore_evdist_2','Evoked Distal 2') - - def selectionchange(self,i): - self.spec_cmap = self.cmaps[i] - self.parent.updatesaveparams({}) - - def initExtra (self): - global defncore, paramf - - DictDialog.initExtra(self) - self.dqextra['NumCores'] = QLineEdit(self) - self.dqextra['NumCores'].setText(str(defncore)) - self.addtransvar('NumCores','Number Cores') - self.ltabs[0].layout.addRow('NumCores',self.dqextra['NumCores']) - - self.spec_map_cb = None - - self.cmaps = ['jet', - 'viridis', - 'plasma', - 'inferno', - 'magma', - 'cividis'] - - # get default spec_cmap - p_exp = ExpParams(paramf, 0) - if len(p_exp.expmt_groups) > 0: - expmt_group = p_exp.expmt_groups[0] - else: - expmt_group = None - p = p_exp.return_pdict(expmt_group, 0) - self.spec_cmap = p['spec_cmap'] - - self.spec_map_cb = QComboBox() - for cmap in self.cmaps: - self.spec_map_cb.addItem(cmap) - self.spec_map_cb.currentIndexChanged.connect(self.selectionchange) - self.ltabs[1].layout.addRow(self.transvar('spec_cmap'),self.spec_map_cb) - - def getntrial (self): return int(self.dqline['N_trials'].text().strip()) - - def getncore (self): return int(self.dqextra['NumCores'].text().strip()) - - def setfromdin (self,din): - global defncore - - if not din: return - - # number of cores may have changed if the configured number failed - self.dqextra['NumCores'].setText(str(defncore)) - for k,v in din.items(): - if k in self.dqline: - self.dqline[k].setText(str(v).strip()) - elif k == 'spec_cmap': - self.spec_cmap = v - self.spec_map_cb.setCurrentIndex(self.cmaps.index(self.spec_cmap)) - - def __str__ (self): - s = '' - for k,v in self.dqline.items(): s += k + ': ' + v.text().strip() + os.linesep - s += 'spec_cmap: ' + self.spec_cmap + os.linesep - return s - -# widget to specify (pyramidal) cell parameters (geometry, synapses, biophysics) -class CellParamDialog (DictDialog): - def __init__ (self, parent = None, din = None): - super(CellParamDialog, self).__init__(parent,din) - self.addHideButton() - - def initd (self): - - self.dL2PyrGeom = OrderedDict([('L2Pyr_soma_L', 22.1), # Soma - ('L2Pyr_soma_diam', 23.4), - ('L2Pyr_soma_cm', 0.6195), - ('L2Pyr_soma_Ra', 200.), - # Dendrites - ('L2Pyr_dend_cm', 0.6195), - ('L2Pyr_dend_Ra', 200.), - ('L2Pyr_apicaltrunk_L', 59.5), - ('L2Pyr_apicaltrunk_diam', 4.25), - ('L2Pyr_apical1_L', 306.), - ('L2Pyr_apical1_diam', 4.08), - ('L2Pyr_apicaltuft_L', 238.), - ('L2Pyr_apicaltuft_diam', 3.4), - ('L2Pyr_apicaloblique_L', 340.), - ('L2Pyr_apicaloblique_diam', 3.91), - ('L2Pyr_basal1_L', 85.), - ('L2Pyr_basal1_diam', 4.25), - ('L2Pyr_basal2_L', 255.), - ('L2Pyr_basal2_diam', 2.72), - ('L2Pyr_basal3_L', 255.), - ('L2Pyr_basal3_diam', 2.72)]) - - self.dL2PyrSyn = OrderedDict([('L2Pyr_ampa_e', 0.), # Synapses - ('L2Pyr_ampa_tau1', 0.5), - ('L2Pyr_ampa_tau2', 5.), - ('L2Pyr_nmda_e', 0.), - ('L2Pyr_nmda_tau1', 1.), - ('L2Pyr_nmda_tau2', 20.), - ('L2Pyr_gabaa_e', -80.), - ('L2Pyr_gabaa_tau1', 0.5), - ('L2Pyr_gabaa_tau2', 5.), - ('L2Pyr_gabab_e', -80.), - ('L2Pyr_gabab_tau1', 1.), - ('L2Pyr_gabab_tau2', 20.)]) - - self.dL2PyrBiophys = OrderedDict([('L2Pyr_soma_gkbar_hh2', 0.01), # Biophysics soma - ('L2Pyr_soma_gnabar_hh2', 0.18), - ('L2Pyr_soma_el_hh2', -65.), - ('L2Pyr_soma_gl_hh2', 4.26e-5), - ('L2Pyr_soma_gbar_km', 250.), - # Biophysics dends - ('L2Pyr_dend_gkbar_hh2', 0.01), - ('L2Pyr_dend_gnabar_hh2', 0.15), - ('L2Pyr_dend_el_hh2', -65.), - ('L2Pyr_dend_gl_hh2', 4.26e-5), - ('L2Pyr_dend_gbar_km', 250.)]) - - - self.dL5PyrGeom = OrderedDict([('L5Pyr_soma_L', 39.), # Soma - ('L5Pyr_soma_diam', 28.9), - ('L5Pyr_soma_cm', 0.85), - ('L5Pyr_soma_Ra', 200.), - # Dendrites - ('L5Pyr_dend_cm', 0.85), - ('L5Pyr_dend_Ra', 200.), - ('L5Pyr_apicaltrunk_L', 102.), - ('L5Pyr_apicaltrunk_diam', 10.2), - ('L5Pyr_apical1_L', 680.), - ('L5Pyr_apical1_diam', 7.48), - ('L5Pyr_apical2_L', 680.), - ('L5Pyr_apical2_diam', 4.93), - ('L5Pyr_apicaltuft_L', 425.), - ('L5Pyr_apicaltuft_diam', 3.4), - ('L5Pyr_apicaloblique_L', 255.), - ('L5Pyr_apicaloblique_diam', 5.1), - ('L5Pyr_basal1_L', 85.), - ('L5Pyr_basal1_diam', 6.8), - ('L5Pyr_basal2_L', 255.), - ('L5Pyr_basal2_diam', 8.5), - ('L5Pyr_basal3_L', 255.), - ('L5Pyr_basal3_diam', 8.5)]) - - self.dL5PyrSyn = OrderedDict([('L5Pyr_ampa_e', 0.), # Synapses - ('L5Pyr_ampa_tau1', 0.5), - ('L5Pyr_ampa_tau2', 5.), - ('L5Pyr_nmda_e', 0.), - ('L5Pyr_nmda_tau1', 1.), - ('L5Pyr_nmda_tau2', 20.), - ('L5Pyr_gabaa_e', -80.), - ('L5Pyr_gabaa_tau1', 0.5), - ('L5Pyr_gabaa_tau2', 5.), - ('L5Pyr_gabab_e', -80.), - ('L5Pyr_gabab_tau1', 1.), - ('L5Pyr_gabab_tau2', 20.)]) - - self.dL5PyrBiophys = OrderedDict([('L5Pyr_soma_gkbar_hh2', 0.01), # Biophysics soma - ('L5Pyr_soma_gnabar_hh2', 0.16), - ('L5Pyr_soma_el_hh2', -65.), - ('L5Pyr_soma_gl_hh2', 4.26e-5), - ('L5Pyr_soma_gbar_ca', 60.), - ('L5Pyr_soma_taur_cad', 20.), - ('L5Pyr_soma_gbar_kca', 2e-4), - ('L5Pyr_soma_gbar_km', 200.), - ('L5Pyr_soma_gbar_cat', 2e-4), - ('L5Pyr_soma_gbar_ar', 1e-6), - # Biophysics dends - ('L5Pyr_dend_gkbar_hh2', 0.01), - ('L5Pyr_dend_gnabar_hh2', 0.14), - ('L5Pyr_dend_el_hh2', -71.), - ('L5Pyr_dend_gl_hh2', 4.26e-5), - ('L5Pyr_dend_gbar_ca', 60.), - ('L5Pyr_dend_taur_cad', 20.), - ('L5Pyr_dend_gbar_kca', 2e-4), - ('L5Pyr_dend_gbar_km', 200.), - ('L5Pyr_dend_gbar_cat', 2e-4), - ('L5Pyr_dend_gbar_ar', 1e-6)]) - - dtrans = {'gkbar':'Kv', 'gnabar':'Na', 'km':'Km', 'gl':'leak',\ - 'ca':'Ca', 'kca':'KCa','cat':'CaT','ar':'HCN','cad':'Ca decay time',\ - 'dend':'Dendrite','soma':'Soma','apicaltrunk':'Apical Dendrite Trunk',\ - 'apical1':'Apical Dendrite 1','apical2':'Apical Dendrite 2',\ - 'apical3':'Apical Dendrite 3','apicaltuft':'Apical Dendrite Tuft',\ - 'apicaloblique':'Oblique Apical Dendrite','basal1':'Basal Dendrite 1',\ - 'basal2':'Basal Dendrite 2','basal3':'Basal Dendrite 3'} - - for d in [self.dL2PyrGeom, self.dL5PyrGeom]: - for k in d.keys(): - lk = k.split('_') - if lk[-1] == 'L': - self.addtransvar(k,dtrans[lk[1]] + ' ' + r'length (micron)') - elif lk[-1] == 'diam': - self.addtransvar(k,dtrans[lk[1]] + ' ' + r'diameter (micron)') - elif lk[-1] == 'cm': - self.addtransvar(k,dtrans[lk[1]] + ' ' + r'capacitive density (F/cm2)') - elif lk[-1] == 'Ra': - self.addtransvar(k,dtrans[lk[1]] + ' ' + r'resistivity (ohm-cm)') - - for d in [self.dL2PyrSyn, self.dL5PyrSyn]: - for k in d.keys(): - lk = k.split('_') - if k.endswith('e'): - self.addtransvar(k,lk[1].upper() + ' ' + ' reversal (mV)') - elif k.endswith('tau1'): - self.addtransvar(k,lk[1].upper() + ' ' + ' rise time (ms)') - elif k.endswith('tau2'): - self.addtransvar(k,lk[1].upper() + ' ' + ' decay time (ms)') - - for d in [self.dL2PyrBiophys, self.dL5PyrBiophys]: - for k in d.keys(): - lk = k.split('_') - if lk[2].count('g') > 0: - if lk[3]=='km' or lk[3]=='ca' or lk[3]=='kca' or lk[3]=='cat' or lk[3]=='ar': - nv = dtrans[lk[1]] + ' ' + dtrans[lk[3]] + ' ' + ' channel density ' - else: - nv = dtrans[lk[1]] + ' ' + dtrans[lk[2]] + ' ' + ' channel density ' - if lk[3] == 'hh2' or lk[3] == 'cat' or lk[3] == 'ar' : nv += '(S/cm2)' - else: nv += '(pS/micron2)' - elif lk[2].count('el') > 0: - nv = dtrans[lk[1]] + ' leak reversal (mV)' - elif lk[2].count('taur') > 0: - nv = dtrans[lk[1]] + ' ' + dtrans[lk[3]] + ' (ms)' - self.addtransvar(k,nv) - - self.ldict = [self.dL2PyrGeom, self.dL2PyrSyn, self.dL2PyrBiophys,\ - self.dL5PyrGeom, self.dL5PyrSyn, self.dL5PyrBiophys] - self.ltitle = [ 'L2/3 Pyr Geometry', 'L2/3 Pyr Synapses', 'L2/3 Pyr Biophysics',\ - 'L5 Pyr Geometry', 'L5 Pyr Synapses', 'L5 Pyr Biophysics'] - self.stitle = 'Cell Parameters' - - -# widget to specify network parameters (number cells, weights, etc.) -class NetworkParamDialog (DictDialog): - def __init__ (self, parent = None, din = None): - super(NetworkParamDialog, self).__init__(parent,din) - self.addHideButton() - - def initd (self): - # number of cells - self.dcells = OrderedDict([('N_pyr_x', 10), - ('N_pyr_y', 10)]) - - # max conductances TO L2Pyr - self.dL2Pyr = OrderedDict([('gbar_L2Pyr_L2Pyr_ampa', 0.), - ('gbar_L2Pyr_L2Pyr_nmda', 0.), - ('gbar_L2Basket_L2Pyr_gabaa', 0.), - ('gbar_L2Basket_L2Pyr_gabab', 0.)]) - - # max conductances TO L2Baskets - self.dL2Bas = OrderedDict([('gbar_L2Pyr_L2Basket', 0.), - ('gbar_L2Basket_L2Basket', 0.)]) - - # max conductances TO L5Pyr - self.dL5Pyr = OrderedDict([('gbar_L2Pyr_L5Pyr', 0.), - ('gbar_L2Basket_L5Pyr', 0.), - ('gbar_L5Pyr_L5Pyr_ampa', 0.), - ('gbar_L5Pyr_L5Pyr_nmda', 0.), - ('gbar_L5Basket_L5Pyr_gabaa', 0.), - ('gbar_L5Basket_L5Pyr_gabab', 0.)]) - - # max conductances TO L5Baskets - self.dL5Bas = OrderedDict([('gbar_L2Pyr_L5Basket', 0.), - ('gbar_L5Pyr_L5Basket', 0.), - ('gbar_L5Basket_L5Basket', 0.)]) - - self.ldict = [self.dcells, self.dL2Pyr, self.dL5Pyr, self.dL2Bas, self.dL5Bas] - self.ltitle = ['Cells', 'Layer 2/3 Pyr', 'Layer 5 Pyr', 'Layer 2/3 Bas', 'Layer 5 Bas'] - self.stitle = 'Local Network Parameters' - - self.addtransvar('N_pyr_x', 'Num Pyr Cells (X direction)') - self.addtransvar('N_pyr_y', 'Num Pyr Cells (Y direction)') - - dtmp = {'L2':'L2/3 ','L5':'L5 '} - - for d in [self.dL2Pyr, self.dL5Pyr, self.dL2Bas, self.dL5Bas]: - for k in d.keys(): - lk = k.split('_') - sty1 = dtmp[lk[1][0:2]] + lk[1][2:] - sty2 = dtmp[lk[2][0:2]] + lk[2][2:] - if len(lk) == 3: - self.addtransvar(k,sty1+' -> '+sty2+u' weight (µS)') - else: - self.addtransvar(k,sty1+' -> '+sty2+' '+lk[3].upper()+u' weight (µS)') - -class HelpDialog (QDialog): - def __init__ (self, parent): - super(HelpDialog, self).__init__(parent) - self.initUI() - - def initUI (self): - self.layout = QVBoxLayout(self) - # Add stretch to separate the form layout from the button - self.layout.addStretch(1) - - setscalegeom(self, 100, 100, 300, 100) - self.setWindowTitle('Help') - -# dialog for visualizing model -class VisnetDialog (QDialog): - def __init__ (self, parent): - super(VisnetDialog, self).__init__(parent) - self.initUI() - - def showcells3D (self): Popen([getPyComm(), 'visnet.py', 'cells', paramf]) # nonblocking - def showEconn (self): Popen([getPyComm(), 'visnet.py', 'Econn', paramf]) # nonblocking - def showIconn (self): Popen([getPyComm(), 'visnet.py', 'Iconn', paramf]) # nonblocking - - def runvisnet (self): - lcmd = [getPyComm(), 'visnet.py', 'cells'] - #if self.chkcells.isChecked(): lcmd.append('cells') - #if self.chkE.isChecked(): lcmd.append('Econn') - #if self.chkI.isChecked(): lcmd.append('Iconn') - lcmd.append(paramf) - Popen(lcmd) # nonblocking - - def initUI (self): - - self.layout = QVBoxLayout(self) - - # Add stretch to separate the form layout from the button - # self.layout.addStretch(1) - - """ - self.chkcells = QCheckBox('Cells in 3D',self) - self.chkcells.resize(self.chkcells.sizeHint()) - self.chkcells.setChecked(True) - self.layout.addWidget(self.chkcells) - self.chkE = QCheckBox('Excitatory Connections',self) - self.chkE.resize(self.chkE.sizeHint()) - self.layout.addWidget(self.chkE) - - self.chkI = QCheckBox('Inhibitory Connections',self) - self.chkI.resize(self.chkI.sizeHint()) - self.layout.addWidget(self.chkI) - """ - - # Create a horizontal box layout to hold the buttons - self.button_box = QHBoxLayout() - - self.btnok = QPushButton('Visualize',self) - self.btnok.resize(self.btnok.sizeHint()) - self.btnok.clicked.connect(self.runvisnet) - self.button_box.addWidget(self.btnok) - - self.btncancel = QPushButton('Cancel',self) - self.btncancel.resize(self.btncancel.sizeHint()) - self.btncancel.clicked.connect(self.hide) - self.button_box.addWidget(self.btncancel) - - self.layout.addLayout(self.button_box) - - setscalegeom(self, 100, 100, 300, 100) - - self.setWindowTitle('Visualize Model') - -class SchematicDialog (QDialog): - # class for holding model schematics (and parameter shortcuts) - def __init__ (self, parent): - super(SchematicDialog, self).__init__(parent) - self.initUI() - - def initUI (self): - - self.setWindowTitle('Model Schematics') - QToolTip.setFont(QFont('SansSerif', 10)) - - self.grid = grid = QGridLayout() - grid.setSpacing(10) - - gRow = 0 - - self.locbtn = QPushButton('Local Network'+os.linesep+'Connections',self) - self.locbtn.setIcon(QIcon(lookupresource('connfig'))) - self.locbtn.clicked.connect(self.parent().shownetparamwin) - self.grid.addWidget(self.locbtn,gRow,0,1,1) - - self.proxbtn = QPushButton('Proximal Drive'+os.linesep+'Thalamus',self) - self.proxbtn.setIcon(QIcon(lookupresource('proxfig'))) - self.proxbtn.clicked.connect(self.parent().showproxparamwin) - self.grid.addWidget(self.proxbtn,gRow,1,1,1) - - self.distbtn = QPushButton('Distal Drive NonLemniscal'+os.linesep+'Thal./Cortical Feedback',self) - self.distbtn.setIcon(QIcon(lookupresource('distfig'))) - self.distbtn.clicked.connect(self.parent().showdistparamwin) - self.grid.addWidget(self.distbtn,gRow,2,1,1) - - self.netbtn = QPushButton('Model'+os.linesep+'Visualization',self) - self.netbtn.setIcon(QIcon(lookupresource('netfig'))) - self.netbtn.clicked.connect(self.parent().showvisnet) - self.grid.addWidget(self.netbtn,gRow,3,1,1) - - gRow = 1 - - # for schematic dialog box - self.pixConn = QPixmap(lookupresource('connfig')) - self.pixConnlbl = ClickLabel(self) - self.pixConnlbl.setScaledContents(True) - #self.pixConnlbl.resize(self.pixConnlbl.size()) - self.pixConnlbl.setPixmap(self.pixConn) - # self.pixConnlbl.clicked.connect(self.shownetparamwin) - self.grid.addWidget(self.pixConnlbl,gRow,0,1,1) - - self.pixProx = QPixmap(lookupresource('proxfig')) - self.pixProxlbl = ClickLabel(self) - self.pixProxlbl.setScaledContents(True) - self.pixProxlbl.setPixmap(self.pixProx) - # self.pixProxlbl.clicked.connect(self.showproxparamwin) - self.grid.addWidget(self.pixProxlbl,gRow,1,1,1) - - self.pixDist = QPixmap(lookupresource('distfig')) - self.pixDistlbl = ClickLabel(self) - self.pixDistlbl.setScaledContents(True) - self.pixDistlbl.setPixmap(self.pixDist) - # self.pixDistlbl.clicked.connect(self.showdistparamwin) - self.grid.addWidget(self.pixDistlbl,gRow,2,1,1) - - self.pixNet = QPixmap(lookupresource('netfig')) - self.pixNetlbl = ClickLabel(self) - self.pixNetlbl.setScaledContents(True) - self.pixNetlbl.setPixmap(self.pixNet) - # self.pixNetlbl.clicked.connect(self.showvisnet) - self.grid.addWidget(self.pixNetlbl,gRow,3,1,1) - - self.setLayout(grid) - -class BaseParamDialog (QDialog): - # base widget for specifying params (contains buttons to create other widgets - def __init__ (self, parent, optrun_func): - super(BaseParamDialog, self).__init__(parent) - self.proxparamwin = self.distparamwin = self.netparamwin = self.syngainparamwin = None - self.initUI() - self.runparamwin = RunParamDialog(self) - self.cellparamwin = CellParamDialog(self) - self.netparamwin = NetworkParamDialog(self) - self.syngainparamwin = SynGainParamDialog(self,self.netparamwin) - self.proxparamwin = OngoingInputParamDialog(self,'Proximal') - self.distparamwin = OngoingInputParamDialog(self,'Distal') - self.evparamwin = EvokedInputParamDialog(self,None) - self.optparamwin = OptEvokedInputParamDialog(self,optrun_func) - self.poisparamwin = PoissonInputParamDialog(self,None) - self.tonicparamwin = TonicInputParamDialog(self,None) - self.lsubwin = [self.runparamwin, self.cellparamwin, self.netparamwin, - self.proxparamwin, self.distparamwin, self.evparamwin, - self.poisparamwin, self.tonicparamwin, self.optparamwin] - self.updateDispParam() - self.parent = parent - - def updateDispParam (self): - # now update the GUI components to reflect the param file selected - try: - validate_param_file(paramf) - except ValueError: - QMessageBox.information(self, "HNN", "WARNING: could not retrieve parameters from %s" % paramf) - return - - din = quickreadprm(paramf) - - if usingEvokedInputs(din): # default for evoked is to show average dipole - conf.dconf['drawavgdpl'] = True - elif usingOngoingInputs(din): # default for ongoing is NOT to show average dipole - conf.dconf['drawavgdpl'] = False - - for dlg in self.lsubwin: dlg.setfromdin(din) # update to values from file - self.qle.setText(paramf.split(os.path.sep)[-1].split('.param')[0]) # update simulation name - - def setrunparam (self): bringwintotop(self.runparamwin) - def setcellparam (self): bringwintotop(self.cellparamwin) - def setnetparam (self): bringwintotop(self.netparamwin) - def setsyngainparam (self): bringwintotop(self.syngainparamwin) - def setproxparam (self): bringwintotop(self.proxparamwin) - def setdistparam (self): bringwintotop(self.distparamwin) - def setevparam (self): bringwintotop(self.evparamwin) - def setpoisparam (self): bringwintotop(self.poisparamwin) - def settonicparam (self): bringwintotop(self.tonicparamwin) - - def initUI (self): - - grid = QGridLayout() - grid.setSpacing(10) - - row = 1 - - self.lbl = QLabel(self) - self.lbl.setText('Simulation Name:') - self.lbl.adjustSize() - self.lbl.setToolTip('Simulation Name used to save parameter file and simulation data') - grid.addWidget(self.lbl, row, 0) - self.qle = QLineEdit(self) - self.qle.setText(paramf.split(os.path.sep)[-1].split('.param')[0]) - grid.addWidget(self.qle, row, 1) - row+=1 - - self.btnrun = QPushButton('Run',self) - self.btnrun.resize(self.btnrun.sizeHint()) - self.btnrun.setToolTip('Set Run Parameters') - self.btnrun.clicked.connect(self.setrunparam) - grid.addWidget(self.btnrun, row, 0, 1, 1) - - self.btncell = QPushButton('Cell',self) - self.btncell.resize(self.btncell.sizeHint()) - self.btncell.setToolTip('Set Cell (Geometry, Synapses, Biophysics) Parameters') - self.btncell.clicked.connect(self.setcellparam) - grid.addWidget(self.btncell, row, 1, 1, 1) - row+=1 - - self.btnnet = QPushButton('Local Network',self) - self.btnnet.resize(self.btnnet.sizeHint()) - self.btnnet.setToolTip('Set Local Network Parameters') - self.btnnet.clicked.connect(self.setnetparam) - grid.addWidget(self.btnnet, row, 0, 1, 1) - - self.btnsyngain = QPushButton('Synaptic Gains',self) - self.btnsyngain.resize(self.btnsyngain.sizeHint()) - self.btnsyngain.setToolTip('Set Local Network Synaptic Gains') - self.btnsyngain.clicked.connect(self.setsyngainparam) - grid.addWidget(self.btnsyngain, row, 1, 1, 1) - - row+=1 - - self.btnprox = QPushButton('Rhythmic Proximal Inputs',self) - self.btnprox.resize(self.btnprox.sizeHint()) - self.btnprox.setToolTip('Set Rhythmic Proximal Inputs') - self.btnprox.clicked.connect(self.setproxparam) - grid.addWidget(self.btnprox, row, 0, 1, 2); row+=1 - - self.btndist = QPushButton('Rhythmic Distal Inputs',self) - self.btndist.resize(self.btndist.sizeHint()) - self.btndist.setToolTip('Set Rhythmic Distal Inputs') - self.btndist.clicked.connect(self.setdistparam) - grid.addWidget(self.btndist, row, 0, 1, 2) - row+=1 - - self.btnev = QPushButton('Evoked Inputs',self) - self.btnev.resize(self.btnev.sizeHint()) - self.btnev.setToolTip('Set Evoked Inputs') - self.btnev.clicked.connect(self.setevparam) - grid.addWidget(self.btnev, row, 0, 1, 2) - row+=1 - - self.btnpois = QPushButton('Poisson Inputs',self) - self.btnpois.resize(self.btnpois.sizeHint()) - self.btnpois.setToolTip('Set Poisson Inputs') - self.btnpois.clicked.connect(self.setpoisparam) - grid.addWidget(self.btnpois, row, 0, 1, 2) - row+=1 - - self.btntonic = QPushButton('Tonic Inputs',self) - self.btntonic.resize(self.btntonic.sizeHint()) - self.btntonic.setToolTip('Set Tonic (Current Clamp) Inputs') - self.btntonic.clicked.connect(self.settonicparam) - grid.addWidget(self.btntonic, row, 0, 1, 2) - row+=1 - - self.btnsave = QPushButton('Save Parameters To File',self) - self.btnsave.resize(self.btnsave.sizeHint()) - self.btnsave.setToolTip('Save All Parameters to File (Specified by Simulation Name)') - self.btnsave.clicked.connect(self.saveparams) - grid.addWidget(self.btnsave, row, 0, 1, 2) - row+=1 - - self.btnhide = QPushButton('Hide Window',self) - self.btnhide.resize(self.btnhide.sizeHint()) - self.btnhide.clicked.connect(self.hide) - self.btnhide.setToolTip('Hide Window') - grid.addWidget(self.btnhide, row, 0, 1, 2) - - self.setLayout(grid) - - self.setWindowTitle('Set Parameters') - - def saveparams (self, checkok = True): - global paramf,basedir - tmpf = os.path.join(dconf['paramoutdir'],self.qle.text() + '.param') - oktosave = True - if os.path.isfile(tmpf) and checkok: - self.show() - oktosave = False - msg = QMessageBox() - msg.setIcon(QMessageBox.Warning) - msg.setText(tmpf + ' already exists. Over-write?') - msg.setWindowTitle('Over-write file(s)?') - msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) - if msg.exec_() == QMessageBox.Ok: - oktosave = True - - if oktosave: - with open(tmpf,'w') as fp: - fp.write(str(self)) - - paramf = dconf['paramf'] = tmpf # update paramf - basedir = os.path.join(dconf['datdir'], self.qle.text()) - os.makedirs(basedir, exist_ok=True) - - return oktosave - - def updatesaveparams (self, dtest): - if debug: print('BaseParamDialog updatesaveparams: dtest=',dtest) - # update parameter values in GUI (so user can see and so GUI will save these param values) - for win in self.lsubwin: win.setfromdin(dtest) - # save parameters - do not ask if can over-write the param file - self.saveparams(checkok = False) - - def __str__ (self): - s = 'sim_prefix: ' + self.qle.text() + os.linesep - s += 'expmt_groups: {' + self.qle.text() + '}' + os.linesep - for win in self.lsubwin: s += str(win) - return s - -# clickable label -class ClickLabel (QLabel): - """ - def __init__(self, *args, **kwargs): - QLabel.__init__(self) - # self._pixmap = QPixmap(self.pixmap()) - # spolicy = QSizePolicy(QSizePolicy.MinimumExpanding,QSizePolicy.MinimumExpanding) - spolicy = QSizePolicy(QSizePolicy.Fixed,QSizePolicy.Fixed) - # spolicy = QSizePolicy(QSizePolicy.Preferred,QSizePolicy.Preferred) - # spolicy.setHorizontalStretch(0) - # spolicy.setVerticalStretch(0) - self.setSizePolicy(spolicy) - self.setMinimumWidth(150) - self.setMinimumHeight(150) - def setPixmap (self, pm): - QLabel.setPixmap(self,pm) - self._pixmap = pm - """ - clicked = pyqtSignal() - def mousePressEvent(self, event): - self.clicked.emit() - """ - def resizeEvent(self, event): - self.setPixmap(self._pixmap.scaled( - self.width(), self.height(), - QtCore.Qt.KeepAspectRatio)) - """ - -class WaitSimDialog (QDialog): - def __init__ (self, parent): - super(WaitSimDialog, self).__init__(parent) - self.initUI() - self.txt = '' # text for display - - def updatetxt (self,txt): - self.qtxt.append(txt) - - def initUI (self): - self.layout = QVBoxLayout(self) - self.layout.addStretch(1) - - self.qtxt = QTextEdit(self) - self.layout.addWidget(self.qtxt) - - self.stopbtn = stopbtn = QPushButton('Stop All Simulations', self) - stopbtn.setToolTip('Stop All Simulations') - stopbtn.resize(stopbtn.sizeHint()) - stopbtn.clicked.connect(self.stopsim) - self.layout.addWidget(stopbtn) - - setscalegeomcenter(self, 500, 250) - self.setWindowTitle("Simulation Log") - - def stopsim (self): - self.parent().stopsim() - self.hide() - - -class HNNGUI (QMainWindow): - # main HNN GUI class - def __init__ (self): - # initialize the main HNN GUI - global paramf, basedir - super().__init__() - self.runningsim = False - self.runthread = None - self.fontsize = dconf['fontsize'] - self.linewidth = plt.rcParams['lines.linewidth'] = 1 - self.markersize = plt.rcParams['lines.markersize'] = 5 - self.dextdata = OrderedDict() # external data - self.schemwin = SchematicDialog(self) - self.m = self.toolbar = None - self.baseparamwin = BaseParamDialog(self, self.startoptmodel) - self.optMode = False - self.initUI() - self.visnetwin = VisnetDialog(self) - self.helpwin = HelpDialog(self) - self.erselectdistal = EvokedOrRhythmicDialog(self, True, self.baseparamwin.evparamwin, self.baseparamwin.distparamwin) - self.erselectprox = EvokedOrRhythmicDialog(self, False, self.baseparamwin.evparamwin, self.baseparamwin.proxparamwin) - self.waitsimwin = WaitSimDialog(self) - default_param = os.path.join(dconf['dbase'],'data','default') - first_load = not (os.path.exists(default_param)) - - if "TRAVIS_TESTING" in os.environ and os.environ["TRAVIS_TESTING"] == "1": - print("Exiting because HNN was started with TRAVIS_TESTING=1") - qApp.quit() - exit(0) - - if first_load: - QMessageBox.information(self, "HNN", "Welcome to HNN! Default parameter file loaded. " - "Press 'Run Simulation' to display simulation output") - else: - self.statusBar().showMessage("Loaded %s"%default_param) - # successful initialization, catch all further exceptions - sys.excepthook = self.excepthook - - def _add_missing_frames(self, tb): - fake_tb = namedtuple( - 'fake_tb', ('tb_frame', 'tb_lasti', 'tb_lineno', 'tb_next') - ) - result = fake_tb(tb.tb_frame, tb.tb_lasti, tb.tb_lineno, tb.tb_next) - frame = tb.tb_frame.f_back - while frame: - result = fake_tb(frame, frame.f_lasti, frame.f_lineno, result) - frame = frame.f_back - return result - - def excepthook(self, exc_type, exc_value, exc_tb): - enriched_tb = self._add_missing_frames(exc_tb) if exc_tb else exc_tb - # Note: sys.__excepthook__(...) would not work here. - # We need to use print_exception(...): - traceback.print_exception(exc_type, exc_value, enriched_tb) - msgBox = QMessageBox(self) - msgBox.information(self, "Exception", "WARNING: an exception occurred! " - "Details can be found in the console output. Please " - "include this output when opening an issue og GitHub: " - "" - "https://github.com/jonescompneurolab/hnn/issues") - - - def redraw (self): - # redraw simulation & external data - self.m.plot() - self.m.draw() - - def changeFontSize (self): - # bring up window to change font sizes - i, ok = QInputDialog.getInt(self, "Set Font Size","Font Size:", plt.rcParams['font.size'], 1, 100, 1) - if ok: - self.fontsize = plt.rcParams['font.size'] = dconf['fontsize'] = i - self.redraw() - - def changeLineWidth (self): - # bring up window to change line width(s) - i, ok = QInputDialog.getInt(self, "Set Line Width","Line Width:", plt.rcParams['lines.linewidth'], 1, 20, 1) - if ok: - self.linewidth = plt.rcParams['lines.linewidth'] = i - self.redraw() - - def changeMarkerSize (self): - # bring up window to change marker size - i, ok = QInputDialog.getInt(self, "Set Marker Size","Font Size:", self.markersize, 1, 100, 1) - if ok: - self.markersize = plt.rcParams['lines.markersize'] = i - self.redraw() - - def selParamFileDialog (self): - # bring up window to select simulation parameter file - global paramf,basedir - qfd = QFileDialog() - qfd.setHistory([os.path.join(dconf['dbase'],'param'), os.path.join(hnn_root_dir,'param')]) - fn = qfd.getOpenFileName(self, 'Open param file', - os.path.join(hnn_root_dir,'param'), - "Param files (*.param)") - if len(fn) > 0 and fn[0] == '': - # no file selected in dialog - return - - paramf = os.path.abspath(fn[0]) # to make sure have right path separators on Windows OS - param_fname = os.path.splitext(os.path.basename(paramf)) - basedir = os.path.join(dconf['datdir'], param_fname[0]) - - try: - validate_param_file(paramf) - except ValueError: - QMessageBox.information(self, "HNN", "WARNING: could not retrieve parameters from %s" % fn[0]) - return - - # now update the GUI components to reflect the param file selected - self.baseparamwin.updateDispParam() - self.initSimCanvas() # recreate canvas - # self.m.plot() # replot data - self.setWindowTitle(paramf) - # store the sim just loaded in simdat's list - is this the desired behavior? or should we first erase prev sims? - import simdat - if 'dpl' in simdat.ddat: - simdat.updatelsimdat(paramf,simdat.ddat['dpl']) # update lsimdat and its current sim index - self.populateSimCB() # populate the combobox - - if len(self.dextdata) > 0: - self.toggleEnableOptimization(True) - - def loadDataFile (self, fn): - # load a dipole data file - global paramf - - import simdat - try: - self.dextdata[fn] = np.loadtxt(fn) - except ValueError: - # possible that data file is comma delimted instead of whitespace delimted - try: - self.dextdata[fn] = np.loadtxt(fn, delimiter=',') - except ValueError: - QMessageBox.information(self, "HNN", "WARNING: could not load data file %s" % fn) - return False - except IsADirectoryError: - QMessageBox.information(self, "HNN", "WARNING: could not load data file %s" % fn) - return False - - simdat.ddat['dextdata'] = self.dextdata - print('Loaded data in ', fn) - - self.m.plot() - self.m.draw() # make sure new lines show up in plot - - if paramf: - self.toggleEnableOptimization(True) - return True - - def loadDataFileDialog (self): - # bring up window to select/load external dipole data file - qfd = QFileDialog() - qfd.setHistory([os.path.join(dconf['dbase'],'data'), os.path.join(hnn_root_dir,'data')]) - fn = qfd.getOpenFileName(self, 'Open data file', - os.path.join(hnn_root_dir,'data'), - "Data files (*.txt)") - if len(fn) > 0 and fn[0] == '': - # no file selected in dialog - return - - self.loadDataFile(os.path.abspath(fn[0])) # use abspath to make sure have right path separators - - def clearDataFile (self): - # clear external dipole data - import simdat - self.m.clearlextdatobj() - self.dextdata = simdat.ddat['dextdata'] = OrderedDict() - self.toggleEnableOptimization(False) - self.m.plot() # recreate canvas - self.m.draw() - - def setparams (self): - # show set parameters dialog window - if self.baseparamwin: - for win in self.baseparamwin.lsubwin: bringwintobot(win) - bringwintotop(self.baseparamwin) - - def showAboutDialog (self): - # show HNN's about dialog box - from __init__ import __version__ - msgBox = QMessageBox(self) - msgBox.setTextFormat(Qt.RichText) - msgBox.setWindowTitle('About') - msgBox.setText("Human Neocortical Neurosolver (HNN) v" + __version__ + "
"+\ - "https://hnn.brown.edu
"+\ - "HNN On Github
"+\ - "© 2017-2019 Brown University, Providence, RI
"+\ - "Software License") - msgBox.setStandardButtons(QMessageBox.Ok) - msgBox.exec_() - - def showOptWarnDialog (self): - # TODO : not implemented yet - msgBox = QMessageBox(self) - msgBox.setTextFormat(Qt.RichText) - msgBox.setWindowTitle('Warning') - msgBox.setText("") - msgBox.setStandardButtons(QMessageBox.Ok) - msgBox.exec_() - - def showHelpDialog (self): - # show the help dialog box - bringwintotop(self.helpwin) - - def showSomaVPlot (self): - # start the somatic voltage visualization process (separate window) - global basedir - if not float(self.baseparamwin.runparamwin.getval('save_vsoma')): - smsg='In order to view somatic voltages you must first rerun the simulation with saving somatic voltages. To do so from the main GUI, click on Set Parameters -> Run -> Analysis -> Save Somatic Voltages, enter a 1 and then rerun the simulation.' - msg = QMessageBox() - msg.setIcon(QMessageBox.Information) - msg.setText(smsg) - msg.setWindowTitle('Rerun simulation') - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - else: - lcmd = [getPyComm(), 'visvolt.py',paramf] - if debug: print('visvolt cmd:',lcmd) - Popen(lcmd) # nonblocking - - def showPSDPlot (self): - # start the PSD visualization process (separate window) - global basedir - lcmd = [getPyComm(), 'vispsd.py',paramf] - if debug: print('vispsd cmd:',lcmd) - Popen(lcmd) # nonblocking - - def showLFPPlot (self): - # start the LFP visualization process (separate window) - global basedir - lcmd = [getPyComm(), 'vislfp.py',paramf] - if debug: print('vislfp cmd:',lcmd) - Popen(lcmd) # nonblocking - - def showSpecPlot (self): - # start the spectrogram visualization process (separate window) - global basedir - lcmd = [getPyComm(), 'visspec.py',paramf] - if debug: print('visspec cmd:',lcmd) - Popen(lcmd) # nonblocking - - def showRasterPlot (self): - # start the raster plot visualization process (separate window) - global basedir - - spikefile = os.path.join(basedir,'spk.txt') - if os.path.isfile(spikefile): - lcmd = [getPyComm(), 'visrast.py',paramf,spikefile] - else: - QMessageBox.information(self, "HNN", "WARNING: no spiking data at %s" % spikefile) - return - - if dconf['drawindivrast']: lcmd.append('indiv') - if debug: print('visrast cmd:',lcmd) - Popen(lcmd) # nonblocking - - def showDipolePlot (self): - # start the dipole visualization process (separate window) - global basedir - - dipole_file = os.path.join(basedir,'dpl.txt') - if os.path.isfile(dipole_file): - lcmd = [getPyComm(), 'visdipole.py',paramf,dipole_file] - else: - QMessageBox.information(self, "HNN", "WARNING: no dipole data at %s" % dipole_file) - return - - if debug: print('visdipole cmd:',lcmd) - Popen(lcmd) # nonblocking - - def showwaitsimwin (self): - # show the wait sim window (has simulation log) - bringwintotop(self.waitsimwin) - - def togAvgDpl (self): - # toggle drawing of the average (across trials) dipole - conf.dconf['drawavgdpl'] = not conf.dconf['drawavgdpl'] - self.m.plot() - self.m.draw() - - def hidesubwin (self): - # hide GUI's sub windows - self.baseparamwin.hide() - self.schemwin.hide() - self.baseparamwin.syngainparamwin.hide() - for win in self.baseparamwin.lsubwin: win.hide() - self.activateWindow() - - def distribsubwin (self): - # distribute GUI's sub-windows on screen - sw,sh = getscreengeom() - lwin = [win for win in self.baseparamwin.lsubwin if win.isVisible()] - if self.baseparamwin.isVisible(): lwin.insert(0,self.baseparamwin) - if self.schemwin.isVisible(): lwin.insert(0,self.schemwin) - if self.baseparamwin.syngainparamwin.isVisible(): lwin.append(self.baseparamwin.syngainparamwin) - curx,cury,maxh=0,0,0 - for win in lwin: - win.move(curx, cury) - curx += win.width() - maxh = max(maxh,win.height()) - if curx >= sw: - curx = 0 - cury += maxh - maxh = win.height() - if cury >= sh: cury = cury = 0 - - def updateDatCanv (self,fn): - # update the simulation data and canvas - try: - getinputfiles(fn) # reset input data - if already exists - except: - pass - # now update the GUI components to reflect the param file selected - self.baseparamwin.updateDispParam() - self.initSimCanvas() # recreate canvas - self.setWindowTitle(fn) - - def removeSim (self): - # remove the currently selected simulation - global paramf,basedir - import simdat - if debug: print('removeSim',paramf,simdat.lsimidx) - if len(simdat.lsimdat) > 0 and simdat.lsimidx >= 0: - cidx = self.cbsim.currentIndex() # - a = simdat.lsimdat[:cidx] - b = simdat.lsimdat[cidx+1:] - c = [x for x in a] - for x in b: c.append(x) - simdat.lsimdat = c - self.cbsim.removeItem(cidx) - simdat.lsimidx = max(0,len(simdat.lsimdat) - 1) - if len(simdat.lsimdat) > 0: - paramf = simdat.lsimdat[simdat.lsimidx][0] - param_fname = os.path.splitext(os.path.basename(paramf)) - basedir = os.path.join(dconf['datdir'], param_fname[0]) - if debug: print('new paramf:',paramf,simdat.lsimidx) - self.updateDatCanv(paramf) - self.cbsim.setCurrentIndex(simdat.lsimidx) - else: - self.clearSimulations() - - def prevSim (self): - # go to previous simulation - global paramf,basedir - import simdat - if debug: print('prevSim',paramf,simdat.lsimidx) - if len(simdat.lsimdat) > 0 and simdat.lsimidx > 0: - simdat.lsimidx -= 1 - paramf = simdat.lsimdat[simdat.lsimidx][0] - param_fname = os.path.splitext(os.path.basename(paramf)) - basedir = os.path.join(dconf['datdir'], param_fname[0]) - if debug: print('new paramf:',paramf,simdat.lsimidx) - self.updateDatCanv(paramf) - self.cbsim.setCurrentIndex(simdat.lsimidx) - - def nextSim (self): - # go to next simulation - global paramf,basedir - import simdat - if debug: print('nextSim',paramf,simdat.lsimidx) - if len(simdat.lsimdat) > 0 and simdat.lsimidx + 1 < len(simdat.lsimdat): - simdat.lsimidx += 1 - paramf = simdat.lsimdat[simdat.lsimidx][0] - param_fname = os.path.splitext(os.path.basename(paramf)) - basedir = os.path.join(dconf['datdir'], param_fname[0]) - if debug: print('new paramf:',paramf,simdat.lsimidx) - self.updateDatCanv(paramf) - self.cbsim.setCurrentIndex(simdat.lsimidx) - - def clearSimulationData (self): - # clear the simulation data - global paramf - import simdat - paramf = '' # set paramf to empty so no data gets loaded - basedir = None - simdat.ddat = {} # clear data in simdat.ddat - simdat.lsimdat = [] - simdat.lsimidx = 0 - self.populateSimCB() # un-populate the combobox - self.toggleEnableOptimization(False) - - - def clearSimulations (self): - # clear all simulation data and erase simulations from canvas (does not clear external data) - self.clearSimulationData() - self.initSimCanvas() # recreate canvas - self.m.draw() - self.setWindowTitle('') - - def clearCanvas (self): - # clear all simulation & external data and erase everything from the canvas - import simdat - self.clearSimulationData() - self.m.clearlextdatobj() # clear the external data - self.dextdata = simdat.ddat['dextdata'] = OrderedDict() - self.initSimCanvas() # recreate canvas - self.m.draw() - self.setWindowTitle('') - - def initMenu (self): - # initialize the GUI's menu - exitAction = QAction(QIcon.fromTheme('exit'), 'Exit', self) - exitAction.setShortcut('Ctrl+Q') - exitAction.setStatusTip('Exit HNN application') - exitAction.triggered.connect(qApp.quit) - - selParamFile = QAction(QIcon.fromTheme('open'), 'Load parameter file', self) - selParamFile.setShortcut('Ctrl+P') - selParamFile.setStatusTip('Load simulation parameter (.param) file') - selParamFile.triggered.connect(self.selParamFileDialog) - - clearCanv = QAction('Clear canvas', self) - clearCanv.setShortcut('Ctrl+X') - clearCanv.setStatusTip('Clear canvas (simulation+data)') - clearCanv.triggered.connect(self.clearCanvas) - - clearSims = QAction('Clear simulation(s)', self) - #clearSims.setShortcut('Ctrl+X') - clearSims.setStatusTip('Clear simulation(s)') - clearSims.triggered.connect(self.clearSimulations) - - loadDataFile = QAction(QIcon.fromTheme('open'), 'Load data file', self) - loadDataFile.setShortcut('Ctrl+D') - loadDataFile.setStatusTip('Load (dipole) data file') - loadDataFile.triggered.connect(self.loadDataFileDialog) - - clearDataFileAct = QAction(QIcon.fromTheme('close'), 'Clear data file(s)', self) - clearDataFileAct.setShortcut('Ctrl+C') - clearDataFileAct.setStatusTip('Clear (dipole) data file(s)') - clearDataFileAct.triggered.connect(self.clearDataFile) - - runSimAct = QAction('Run simulation', self) - runSimAct.setShortcut('Ctrl+S') - runSimAct.setStatusTip('Run simulation') - runSimAct.triggered.connect(self.controlsim) - - runSimNSGAct = QAction('Run simulation on NSG', self) - runSimNSGAct.setShortcut('Ctrl+N') - runSimNSGAct.setStatusTip('Run simulation on Neuroscience Gateway Portal (requires NSG account and internet connection).') - runSimNSGAct.triggered.connect(self.controlNSGsim) - - self.menubar = self.menuBar() - fileMenu = self.menubar.addMenu('&File') - self.menubar.setNativeMenuBar(False) - fileMenu.addAction(selParamFile) - fileMenu.addSeparator() - fileMenu.addAction(loadDataFile) - fileMenu.addAction(clearDataFileAct) - fileMenu.addSeparator() - fileMenu.addAction(exitAction) - - # part of edit menu for changing drawing properties (line thickness, font size, toggle avg dipole drawing) - editMenu = self.menubar.addMenu('&Edit') - viewAvgDplAction = QAction('Toggle Average Dipole Drawing',self) - viewAvgDplAction.setStatusTip('Toggle Average Dipole Drawing') - viewAvgDplAction.triggered.connect(self.togAvgDpl) - editMenu.addAction(viewAvgDplAction) - changeFontSizeAction = QAction('Change Font Size',self) - changeFontSizeAction.setStatusTip('Change Font Size.') - changeFontSizeAction.triggered.connect(self.changeFontSize) - editMenu.addAction(changeFontSizeAction) - changeLineWidthAction = QAction('Change Line Width',self) - changeLineWidthAction.setStatusTip('Change Line Width.') - changeLineWidthAction.triggered.connect(self.changeLineWidth) - editMenu.addAction(changeLineWidthAction) - changeMarkerSizeAction = QAction('Change Marker Size',self) - changeMarkerSizeAction.setStatusTip('Change Marker Size.') - changeMarkerSizeAction.triggered.connect(self.changeMarkerSize) - editMenu.addAction(changeMarkerSizeAction) - editMenu.addSeparator() - editMenu.addAction(clearSims) - clearDataFileAct2 = QAction(QIcon.fromTheme('close'), 'Clear data file(s)', self) # need new act to avoid DBus warning - clearDataFileAct2.setStatusTip('Clear (dipole) data file(s)') - clearDataFileAct2.triggered.connect(self.clearDataFile) - editMenu.addAction(clearDataFileAct2) - editMenu.addAction(clearCanv) - - # view menu - to view drawing/visualizations - viewMenu = self.menubar.addMenu('&View') - viewDipoleAction = QAction('View Simulation Dipoles',self) - viewDipoleAction.setStatusTip('View Simulation Dipoles') - viewDipoleAction.triggered.connect(self.showDipolePlot) - viewMenu.addAction(viewDipoleAction) - viewRasterAction = QAction('View Simulation Spiking Activity',self) - viewRasterAction.setStatusTip('View Simulation Raster Plot') - viewRasterAction.triggered.connect(self.showRasterPlot) - viewMenu.addAction(viewRasterAction) - viewPSDAction = QAction('View PSD',self) - viewPSDAction.setStatusTip('View PSD') - viewPSDAction.triggered.connect(self.showPSDPlot) - viewMenu.addAction(viewPSDAction) - - viewSomaVAction = QAction('View Somatic Voltage',self) - viewSomaVAction.setStatusTip('View Somatic Voltage') - viewSomaVAction.triggered.connect(self.showSomaVPlot) - viewMenu.addAction(viewSomaVAction) - - if testLFP: - viewLFPAction = QAction('View Simulation LFPs',self) - viewLFPAction.setStatusTip('View LFP') - viewLFPAction.triggered.connect(self.showLFPPlot) - viewMenu.addAction(viewLFPAction) - - viewSpecAction = QAction('View Spectrograms',self) - viewSpecAction.setStatusTip('View Spectrograms/Dipoles from Experimental Data') - viewSpecAction.triggered.connect(self.showSpecPlot) - viewMenu.addAction(viewSpecAction) - - viewMenu.addSeparator() - viewSchemAction = QAction('View Model Schematics',self) - viewSchemAction.setStatusTip('View Model Schematics') - viewSchemAction.triggered.connect(self.showschematics) - viewMenu.addAction(viewSchemAction) - viewNetAction = QAction('View Local Network (3D)',self) - viewNetAction.setStatusTip('View Local Network Model (3D)') - viewNetAction.triggered.connect(self.showvisnet) - viewMenu.addAction(viewNetAction) - viewSimLogAction = QAction('View Simulation Log',self) - viewSimLogAction.setStatusTip('View Detailed Simulation Log') - viewSimLogAction.triggered.connect(self.showwaitsimwin) - viewMenu.addAction(viewSimLogAction) - viewMenu.addSeparator() - distributeWindowsAction = QAction('Distribute Windows',self) - distributeWindowsAction.setStatusTip('Distribute Parameter Windows Across Screen.') - distributeWindowsAction.triggered.connect(self.distribsubwin) - viewMenu.addAction(distributeWindowsAction) - hideWindowsAction = QAction('Hide Windows',self) - hideWindowsAction.setStatusTip('Hide Parameter Windows.') - hideWindowsAction.triggered.connect(self.hidesubwin) - hideWindowsAction.setShortcut('Ctrl+H') - viewMenu.addAction(hideWindowsAction) - - simMenu = self.menubar.addMenu('&Simulation') - setParmAct = QAction('Set Parameters',self) - setParmAct.setStatusTip('Set Simulation Parameters') - setParmAct.triggered.connect(self.setparams) - simMenu.addAction(setParmAct) - simMenu.addAction(runSimAct) - if dconf['nsgrun']: simMenu.addAction(runSimNSGAct) - setOptParamAct = QAction('Configure Optimization', self) - setOptParamAct.setShortcut('Ctrl+O') - setOptParamAct.setStatusTip('Set parameters for evoked input optimization') - setOptParamAct.triggered.connect(self.showoptparamwin) - simMenu.addAction(setOptParamAct) - self.toggleEnableOptimization(False) - prevSimAct = QAction('Go to Previous Simulation',self) - prevSimAct.setShortcut('Ctrl+Z') - prevSimAct.setStatusTip('Go Back to Previous Simulation') - prevSimAct.triggered.connect(self.prevSim) - simMenu.addAction(prevSimAct) - nextSimAct = QAction('Go to Next Simulation',self) - nextSimAct.setShortcut('Ctrl+Y') - nextSimAct.setStatusTip('Go Forward to Next Simulation') - nextSimAct.triggered.connect(self.nextSim) - simMenu.addAction(nextSimAct) - clearSims2 = QAction('Clear simulation(s)', self) # need another QAction to avoid DBus warning - clearSims2.setStatusTip('Clear simulation(s)') - clearSims2.triggered.connect(self.clearSimulations) - simMenu.addAction(clearSims2) - - aboutMenu = self.menubar.addMenu('&About') - aboutAction = QAction('About HNN',self) - aboutAction.setStatusTip('About HNN') - aboutAction.triggered.connect(self.showAboutDialog) - aboutMenu.addAction(aboutAction) - helpAction = QAction('Help',self) - helpAction.setStatusTip('Help on how to use HNN (parameters).') - helpAction.triggered.connect(self.showHelpDialog) - #aboutMenu.addAction(helpAction) - - def toggleEnableOptimization (self, toEnable): - for menu in self.menubar.findChildren(QMenu): - if menu.title() == '&Simulation': - for item in menu.actions(): - if item.text() == 'Configure Optimization': - item.setEnabled(toEnable) - break - break - - def addButtons (self, gRow): - self.pbtn = pbtn = QPushButton('Set Parameters', self) - pbtn.setToolTip('Set Parameters') - pbtn.resize(pbtn.sizeHint()) - pbtn.clicked.connect(self.setparams) - self.grid.addWidget(self.pbtn, gRow, 0, 1, 1) - - self.pfbtn = pfbtn = QPushButton('Set Parameters From File', self) - pfbtn.setToolTip('Set Parameters From File') - pfbtn.resize(pfbtn.sizeHint()) - pfbtn.clicked.connect(self.selParamFileDialog) - self.grid.addWidget(self.pfbtn, gRow, 1, 1, 1) - - self.btnsim = btn = QPushButton('Run Simulation', self) - btn.setToolTip('Run Simulation') - btn.resize(btn.sizeHint()) - btn.clicked.connect(self.controlsim) - self.grid.addWidget(self.btnsim, gRow, 2, 1, 1) - - self.qbtn = qbtn = QPushButton('Quit', self) - qbtn.clicked.connect(QCoreApplication.instance().quit) - qbtn.resize(qbtn.sizeHint()) - self.grid.addWidget(self.qbtn, gRow, 3, 1, 1) - - def shownetparamwin (self): bringwintotop(self.baseparamwin.netparamwin) - def showoptparamwin (self): bringwintotop(self.baseparamwin.optparamwin) - def showdistparamwin (self): bringwintotop(self.erselectdistal) - def showproxparamwin (self): bringwintotop(self.erselectprox) - def showvisnet (self): Popen([getPyComm(), 'visnet.py', 'cells', paramf]) # nonblocking - def showschematics (self): bringwintotop(self.schemwin) - - def addParamImageButtons (self,gRow): - # add parameter image buttons to the GUI - - self.locbtn = QPushButton('Local Network'+os.linesep+'Connections',self) - self.locbtn.setIcon(QIcon(lookupresource('connfig'))) - self.locbtn.clicked.connect(self.shownetparamwin) - self.grid.addWidget(self.locbtn,gRow,0,1,1) - - self.proxbtn = QPushButton('Proximal Drive'+os.linesep+'Thalamus',self) - self.proxbtn.setIcon(QIcon(lookupresource('proxfig'))) - self.proxbtn.clicked.connect(self.showproxparamwin) - self.grid.addWidget(self.proxbtn,gRow,1,1,1) - - self.distbtn = QPushButton('Distal Drive NonLemniscal'+os.linesep+'Thal./Cortical Feedback',self) - self.distbtn.setIcon(QIcon(lookupresource('distfig'))) - self.distbtn.clicked.connect(self.showdistparamwin) - self.grid.addWidget(self.distbtn,gRow,2,1,1) - - self.netbtn = QPushButton('Model'+os.linesep+'Visualization',self) - self.netbtn.setIcon(QIcon(lookupresource('netfig'))) - self.netbtn.clicked.connect(self.showvisnet) - self.grid.addWidget(self.netbtn,gRow,3,1,1) - - gRow += 1 - - return - - # for schematic dialog box - self.pixConn = QPixmap(lookupresource('connfig')) - self.pixConnlbl = ClickLabel(self) - self.pixConnlbl.setScaledContents(True) - #self.pixConnlbl.resize(self.pixConnlbl.size()) - self.pixConnlbl.setPixmap(self.pixConn) - # self.pixConnlbl.clicked.connect(self.shownetparamwin) - self.grid.addWidget(self.pixConnlbl,gRow,0,1,1) - - self.pixProx = QPixmap(lookupresource('proxfig')) - self.pixProxlbl = ClickLabel(self) - self.pixProxlbl.setScaledContents(True) - self.pixProxlbl.setPixmap(self.pixProx) - # self.pixProxlbl.clicked.connect(self.showproxparamwin) - self.grid.addWidget(self.pixProxlbl,gRow,1,1,1) - - self.pixDist = QPixmap(lookupresource('distfig')) - self.pixDistlbl = ClickLabel(self) - self.pixDistlbl.setScaledContents(True) - self.pixDistlbl.setPixmap(self.pixDist) - # self.pixDistlbl.clicked.connect(self.showdistparamwin) - self.grid.addWidget(self.pixDistlbl,gRow,2,1,1) - - self.pixNet = QPixmap(lookupresource('netfig')) - self.pixNetlbl = ClickLabel(self) - self.pixNetlbl.setScaledContents(True) - self.pixNetlbl.setPixmap(self.pixNet) - # self.pixNetlbl.clicked.connect(self.showvisnet) - self.grid.addWidget(self.pixNetlbl,gRow,3,1,1) - - - def initUI (self): - # initialize the user interface (UI) - - self.initMenu() - self.statusBar() - - setscalegeomcenter(self, 1500, 1300) # start GUI in center of screenm, scale based on screen w x h - - # move param windows to be offset from main GUI - new_x = max(0, self.x() - 300) - new_y = max(0, self.y() + 100) - self.baseparamwin.move(new_x, new_y) - self.baseparamwin.evparamwin.move(new_x+50, new_y+50) - self.baseparamwin.optparamwin.move(new_x+100, new_y+100) - self.setWindowTitle(paramf) - QToolTip.setFont(QFont('SansSerif', 10)) - - self.grid = grid = QGridLayout() - #grid.setSpacing(10) - - gRow = 0 - - self.addButtons(gRow) - - gRow += 1 - - self.initSimCanvas(gRow=gRow, reInit=False) - gRow += 2 - - # store any sim just loaded in simdat's list - is this the desired behavior? or should we start empty? - import simdat - if 'dpl' in simdat.ddat: - simdat.updatelsimdat(paramf,simdat.ddat['dpl']) # update lsimdat and its current sim index - - self.cbsim = QComboBox(self) - self.populateSimCB() # populate the combobox - self.cbsim.activated[str].connect(self.onActivateSimCB) - self.grid.addWidget(self.cbsim, gRow, 0, 1, 3)#, 1, 3) - self.btnrmsim = QPushButton('Remove Simulation',self) - self.btnrmsim.resize(self.btnrmsim.sizeHint()) - self.btnrmsim.clicked.connect(self.removeSim) - self.btnrmsim.setToolTip('Remove Currently Selected Simulation') - self.grid.addWidget(self.btnrmsim, gRow, 3)#, 4, 1) - - gRow += 1 - self.addParamImageButtons(gRow) - - # need a separate widget to put grid on - widget = QWidget(self) - widget.setLayout(grid) - self.setCentralWidget(widget) - - self.c = Communicate() - self.c.commsig.connect(self.baseparamwin.updateDispParam) - - self.d = DoneSignal() - self.d.finishSim.connect(self.done) - - try: self.setWindowIcon(QIcon(os.path.join('res','icon.png'))) - except: pass - - self.schemwin.show() # so it's underneath main window - - if 'dataf' in dconf: - if os.path.isfile(dconf['dataf']): - self.loadDataFile(dconf['dataf']) - - self.show() - - def onActivateSimCB (self, s): - # load simulation when activating simulation combobox - global paramf,basedir - import simdat - if debug: print('onActivateSimCB',s,paramf,self.cbsim.currentIndex(),simdat.lsimidx) - if self.cbsim.currentIndex() != simdat.lsimidx: - if debug: print('Loading',s) - paramf = s - param_fname = os.path.splitext(os.path.basename(paramf)) - basedir = os.path.join(dconf['datdir'], param_fname[0]) - simdat.lsimidx = self.cbsim.currentIndex() - self.updateDatCanv(paramf) - - def populateSimCB (self): - # populate the simulation combobox - if debug: print('populateSimCB') - global paramf - self.cbsim.clear() - import simdat - for l in simdat.lsimdat: - self.cbsim.addItem(l[0]) - self.cbsim.setCurrentIndex(simdat.lsimidx) - - def initSimCanvas (self,recalcErr=True,optMode=False,gRow=1,reInit=True): - # initialize the simulation canvas, loading any required data - - gCol = 0 - - if reInit == True: - self.grid.itemAtPosition(gRow, gCol).widget().deleteLater() - self.grid.itemAtPosition(gRow + 1, gCol).widget().deleteLater() - - if debug: print('paramf in initSimCanvas:',paramf) - self.m = SIMCanvas(paramf, parent = self, width=10, height=1, dpi=getmplDPI(), optMode=optMode) # also loads data - # this is the Navigation widget - # it takes the Canvas widget and a parent - self.toolbar = NavigationToolbar(self.m, self) - gWidth = 4 - self.grid.addWidget(self.toolbar, gRow, gCol, 1, gWidth) - self.grid.addWidget(self.m, gRow + 1, gCol, 1, gWidth) - if len(self.dextdata.keys()) > 0: - import simdat - simdat.ddat['dextdata'] = self.dextdata - self.m.plot(recalcErr) - self.m.draw() - - def setcursors (self,cursor): - # set cursors of self and children - self.setCursor(cursor) - self.update() - kids = self.children() - kids.append(self.m) # matplotlib simcanvas - for k in kids: - try: - k.setCursor(cursor) - k.update() - except: - pass - - def startoptmodel (self): - # start model optimization - if self.runningsim: - self.stopsim() # stop sim works but leaves subproc as zombie until this main GUI thread exits - else: - self.optMode = True - try: - self.optmodel(self.baseparamwin.runparamwin.getntrial(),self.baseparamwin.runparamwin.getncore()) - except RuntimeError: - print("ERR: Optimization aborted") - - def controlsim (self): - # control the simulation - if self.runningsim: - self.stopsim() # stop sim works but leaves subproc as zombie until this main GUI thread exits - else: - self.optMode = False - self.startsim(self.baseparamwin.runparamwin.getntrial(),self.baseparamwin.runparamwin.getncore()) - - def controlNSGsim (self): - # control simulation on NSG - if self.runningsim: - self.stopsim() # stop sim works but leaves subproc as zombie until this main GUI thread exits - else: - self.startsim(self.baseparamwin.runparamwin.getntrial(),self.baseparamwin.runparamwin.getncore(),True) - - def stopsim (self): - # stop the simulation - if self.runningsim: - self.waitsimwin.hide() - print('Terminating simulation. . .') - self.statusBar().showMessage('Terminating sim. . .') - self.runningsim = False - self.runthread.stop() # killed = True # terminate() - self.btnsim.setText("Run Simulation") - self.qbtn.setEnabled(True) - self.statusBar().showMessage('') - self.setcursors(Qt.ArrowCursor) - - def optmodel (self, ntrial, ncore): - # make sure params saved and ok to run - if not self.baseparamwin.saveparams(): - return - - self.baseparamwin.optparamwin.btnreset.setEnabled(False) - self.baseparamwin.optparamwin.btnrunop.setText('Stop Optimization') - self.baseparamwin.optparamwin.btnrunop.clicked.disconnect() - self.baseparamwin.optparamwin.btnrunop.clicked.connect(self.stopsim) - - # optimize the model - self.setcursors(Qt.WaitCursor) - print('Starting model optimization. . .') - - if debug: print('in optmodel') - self.runningsim = True - - self.statusBar().showMessage("Optimizing model. . .") - - self.runthread = RunSimThread(self.c, self.d, ntrial, ncore, self.waitsimwin, opt=True, baseparamwin=self.baseparamwin, mainwin=self, onNSG=False) - - # We have all the events we need connected we can start the thread - self.runthread.start() - # At this point we want to allow user to stop/terminate the thread - # so we enable that button - self.btnsim.setText("Stop Optimization") - self.qbtn.setEnabled(False) - bringwintotop(self.waitsimwin) - - def startsim (self, ntrial, ncore, onNSG=False): - # start the simulation - if not self.baseparamwin.saveparams(): return # make sure params saved and ok to run - - self.setcursors(Qt.WaitCursor) - - print('Starting simulation (%d cores). . .'%ncore) - self.runningsim = True - - if onNSG: - self.statusBar().showMessage("Running simulation on Neuroscience Gateway Portal. . .") - else: - self.statusBar().showMessage("Running simulation. . .") - - self.runthread=RunSimThread(self.c,self.d,ntrial,ncore,self.waitsimwin,opt=False,baseparamwin=None,mainwin=None,onNSG=onNSG) - - # We have all the events we need connected we can start the thread - self.runthread.start() - # At this point we want to allow user to stop/terminate the thread - # so we enable that button - self.btnsim.setText("Stop Simulation") # setEnabled(False) - # We don't want to enable user to start another thread while this one is - # running so we disable the start button. - # self.btn_start.setEnabled(False) - self.qbtn.setEnabled(False) - - bringwintotop(self.waitsimwin) - - def done (self, optMode, failed): - # called when the simulation completes running - if debug: print('done') - self.runningsim = False - self.waitsimwin.hide() - self.statusBar().showMessage("") - self.btnsim.setText("Run Simulation") - self.qbtn.setEnabled(True) - self.initSimCanvas(optMode=optMode) # recreate canvas (plots too) to avoid incorrect axes - # self.m.plot() - global basedir - self.setcursors(Qt.ArrowCursor) - if failed: - msg = "Failed " - else: - msg = "Finished " - - if optMode: - msg += "running optimization " - self.baseparamwin.optparamwin.btnrunop.setText('Prepare for Another Optimization') - self.baseparamwin.optparamwin.btnrunop.clicked.disconnect() - self.baseparamwin.optparamwin.btnrunop.clicked.connect(self.baseparamwin.optparamwin.prepareOptimization) - else: - msg += "running sim " - - QMessageBox.information(self, "Done!", msg + "using " + paramf + '. Saved data/figures in: ' + basedir) - self.setWindowTitle(paramf) - self.populateSimCB() # populate the combobox - -if __name__ == '__main__': - app = QApplication(sys.argv) - ex = HNNGUI() - sys.exit(app.exec_()) diff --git a/init.py b/init.py deleted file mode 100644 index 6e995eb75..000000000 --- a/init.py +++ /dev/null @@ -1,8 +0,0 @@ -# init.py - Starting script to run NetPyNE-based model. -# Usage: python init.py # Run simulation, optionally plot a raster -# MPI usage: mpiexec -n 4 nrniv -python -mpi init.py - -from netpyne import sim -cfg, netParams = sim.readCmdLineArgs('cfg.py','netParams.py') # read cfg and netParams from command line arguments -sim.createSimulateAnalyze(simConfig = cfg, netParams = netParams) - diff --git a/installer/README.md b/installer/README.md index 00d931778..cf99bb984 100644 --- a/installer/README.md +++ b/installer/README.md @@ -14,4 +14,4 @@ HNN also works on cloud and HPC environments: If you are running into problems with the instructions given for your machine, we recommend using the VirtualBox VM with HNN pre-installed: -* [HNN VirtualBox install instructions](virtualbox) \ No newline at end of file +* [HNN VirtualBox install instructions](virtualbox) diff --git a/installer/aws/aws-build.sh b/installer/aws/aws-build.sh index 1b3801740..24291af5d 100644 --- a/installer/aws/aws-build.sh +++ b/installer/aws/aws-build.sh @@ -15,99 +15,16 @@ sudo apt-get install -y git python3-dev python3-pip python3-psutil \ git vim iputils-ping net-tools iproute2 nano sudo \ telnet language-pack-en-base sudo pip3 install pip --upgrade -sudo pip install PyOpenGL matplotlib pyqt5 pyqtgraph scipy numpy nlopt - -# build MESA from source (for software 3D rendering) -# this part is optional if the 'Model Visualization' feature is not needed -sudo apt-get update && \ - sudo apt-get upgrade -y && \ - sudo apt-get install --no-install-recommends -y \ - wget \ - bzip2 \ - curl \ - python \ - libllvm6.0 \ - llvm-6.0-dev \ - zlib1g-dev \ - xserver-xorg-dev \ - build-essential \ - libxcb-dri2-0-dev \ - libxcb-xfixes0-dev \ - libxext-dev \ - libx11-xcb-dev \ - pkg-config && \ - update-alternatives --install \ - /usr/bin/llvm-config llvm-config /usr/bin/llvm-config-6.0 200 \ ---slave /usr/bin/llvm-ar llvm-ar /usr/bin/llvm-ar-6.0 \ ---slave /usr/bin/llvm-as llvm-as /usr/bin/llvm-as-6.0 \ ---slave /usr/bin/llvm-bcanalyzer llvm-bcanalyzer /usr/bin/llvm-bcanalyzer-6.0 \ ---slave /usr/bin/llvm-cov llvm-cov /usr/bin/llvm-cov-6.0 \ ---slave /usr/bin/llvm-diff llvm-diff /usr/bin/llvm-diff-6.0 \ ---slave /usr/bin/llvm-dis llvm-dis /usr/bin/llvm-dis-6.0 \ ---slave /usr/bin/llvm-dwarfdump llvm-dwarfdump /usr/bin/llvm-dwarfdump-6.0 \ ---slave /usr/bin/llvm-extract llvm-extract /usr/bin/llvm-extract-6.0 \ ---slave /usr/bin/llvm-link llvm-link /usr/bin/llvm-link-6.0 \ ---slave /usr/bin/llvm-mc llvm-mc /usr/bin/llvm-mc-6.0 \ ---slave /usr/bin/llvm-mcmarkup llvm-mcmarkup /usr/bin/llvm-mcmarkup-6.0 \ ---slave /usr/bin/llvm-nm llvm-nm /usr/bin/llvm-nm-6.0 \ ---slave /usr/bin/llvm-objdump llvm-objdump /usr/bin/llvm-objdump-6.0 \ ---slave /usr/bin/llvm-ranlib llvm-ranlib /usr/bin/llvm-ranlib-6.0 \ ---slave /usr/bin/llvm-readobj llvm-readobj /usr/bin/llvm-readobj-6.0 \ ---slave /usr/bin/llvm-rtdyld llvm-rtdyld /usr/bin/llvm-rtdyld-6.0 \ ---slave /usr/bin/llvm-size llvm-size /usr/bin/llvm-size-6.0 \ ---slave /usr/bin/llvm-stress llvm-stress /usr/bin/llvm-stress-6.0 \ ---slave /usr/bin/llvm-symbolizer llvm-symbolizer /usr/bin/llvm-symbolizer-6.0 \ ---slave /usr/bin/llvm-tblgen llvm-tblgen /usr/bin/llvm-tblgen-6.0 && \ - set -xe; \ - mkdir -p /var/tmp/build; \ - cd /var/tmp/build; \ - wget -q --no-check-certificate "https://mesa.freedesktop.org/archive/mesa-18.0.1.tar.gz"; \ - tar xf mesa-18.0.1.tar.gz; \ - rm mesa-18.0.1.tar.gz; \ - cd mesa-18.0.1; \ - ./configure --enable-glx=gallium-xlib --with-gallium-drivers=swrast,swr --disable-dri --disable-gbm --disable-egl --enable-gallium-osmesa --enable-llvm --prefix=/usr/local; \ - make; \ - sudo make install; \ - cd .. ; \ - rm -rf mesa-18.0.1; \ - sudo apt-get -y remove --purge llvm-6.0-dev \ - zlib1g-dev \ - xserver-xorg-dev \ - python3-dev \ - python \ - pkg-config \ - libxext-dev \ - libx11-xcb-dev \ - libxcb-xfixes0-dev \ - libxcb-dri2-0-dev && \ - sudo apt autoremove -y --purge && \ - sudo apt clean - -cd $HOME && \ - git clone https://github.com/neuronsimulator/nrn.git && \ - cd nrn && \ - git checkout 7.7 && \ - ./build.sh && \ - ./configure --with-nrnpython=python3 \ - --with-paranrn --without-iv --disable-rx3d && \ - make && \ - sudo make install +sudo pip install matplotlib pyqt5 nlopt hnn-core echo '# these lines define global session variables for HNN' >> ~/.bashrc -echo 'export CPU=$(uname -m)' >> ~/.bashrc -echo 'export PATH=$PATH:/usr/local/nrn/$CPU/bin' >> ~/.bashrc echo 'export OMPI_MCA_btl_base_warn_component_unused=0' >> ~/.bashrc -echo 'export PYTHONPATH=/usr/local/nrn/lib/python:$PYTHONPATH' >> ~/.bashrc -export CPU=$(uname -m) -export PATH=$PATH:/usr/local/nrn/$CPU/bin export OMPI_MCA_btl_base_warn_component_unused=0 -export PYTHONPATH=/usr/local/nrn/lib/python:$PYTHONPATH cd $HOME && \ git clone https://github.com/jonescompneurolab/hnn && \ - cd hnn && \ - make + cd hnn echo '#!/bin/bash' | sudo tee /usr/local/bin/hnn echo 'cd $HOME/hnn' | sudo tee -a /usr/local/bin/hnn diff --git a/installer/brown_ccv/oscar_setup.sh b/installer/brown_ccv/oscar_setup.sh index 05fe8a6dd..1464563b7 100644 --- a/installer/brown_ccv/oscar_setup.sh +++ b/installer/brown_ccv/oscar_setup.sh @@ -6,30 +6,10 @@ mkdir -p $HOME/HNN # Clone the source code for HNN and prerequisites cd $HOME/HNN -git clone https://github.com/neuronsimulator/nrn -git clone https://github.com/neuronsimulator/iv git clone https://github.com/jonescompneurolab/hnn -# Build HNN prerequisites - -# Build NEURON - -cd $HOME/HNN/nrn && \ - ./build.sh && \ - ./configure --with-nrnpython=python3 --with-paranrn --disable-rx3d \ - --without-iv --with-mpi --prefix=$(pwd)/build && \ - make -j2 && \ - make install -j2 && \ - cd src/nrnpython && \ - python3 setup.py install --home=$HOME/HNN/nrn/build/x86_64/python - -# Cleanup compiled prerequisites - -cd $HOME/HNN/nrn && \ - make clean - # Install python modules. Ignore the errors -pip3 install --user PyOpenGL pyqtgraph psutil nlopt >/dev/null 2>&1 +pip3 install --user psutil nlopt hnn-core >/dev/null 2>&1 # Build HNN cd $HOME/HNN/hnn && \ @@ -38,8 +18,7 @@ cd $HOME/HNN/hnn && \ # Set commands to run at login for future logins cat < /dev/null -export PATH="\$PATH:\$HOME/HNN/nrn/build/x86_64/bin" -export PYTHONPATH="/gpfs/runtime/opt/hnn/1.0/pyqt:\$HOME/HNN/nrn/build/x86_64/python/lib/python" +export PYTHONPATH="/gpfs/runtime/opt/hnn/1.0/pyqt:" export OMPI_MCA_btl_openib_allow_ib=1 # HNN settings if [[ ! "\$(ulimit -l)" =~ "unlimited" ]]; then diff --git a/installer/centos/README.md b/installer/centos/README.md index ce1dbf468..7f452d5d0 100644 --- a/installer/centos/README.md +++ b/installer/centos/README.md @@ -1,6 +1,6 @@ # HNN "Python" install (CentOS) -The script below assumes that it can update OS packages for python and prerequisites for NEURON and HNN. +The script below assumes that it can update OS packages for python and prerequisites for HNN. * CentOS 7: [centos7-installer.sh](centos7-installer.sh) diff --git a/installer/centos/hnn-centos6.sh b/installer/centos/hnn-centos6.sh index 86b335ca7..16ad5867d 100644 --- a/installer/centos/hnn-centos6.sh +++ b/installer/centos/hnn-centos6.sh @@ -16,6 +16,7 @@ sudo yum -y install python34-setuptools sudo easy_install-3.4 pip pip3 install --upgrade matplotlib --user pip3 install --upgrade nlopt scipy --user +pip3 install hnn-core sudo yum -y install ncurses-devel sudo yum -y install openmpi openmpi-devel sudo yum -y install libXext libXext-devel @@ -26,8 +27,6 @@ sudo PATH=$PATH:/usr/lib64/openmpi/bin pip3 install mpi4py startdir=$(pwd) echo $startdir -pip install NEURON - # move outside of nrn directories cd $startdir @@ -40,7 +39,7 @@ cd .. # create the global session variables, make available for all users echo '# these lines define global session variables for HNN' | sudo tee -a /etc/profile.d/hnn.sh echo 'export CPU=$(uname -m)' | sudo tee -a /etc/profile.d/hnn.sh -echo "export PATH=\$PATH::/usr/lib64/openmpi/bin:$startdir/nrn/build/\$CPU/bin" | sudo tee -a /etc/profile.d/hnn.sh +echo "export PATH=\$PATH::/usr/lib64/openmpi/bin" | sudo tee -a /etc/profile.d/hnn.sh # qt, pyqt, and supporting packages - needed for GUI # SIP unforutnately not available as a wheel for Python 3.4, so have to compile @@ -71,4 +70,4 @@ rm -f PyQt5_gpl-5.8.2.tar.gz # needed for matplotlib sudo yum -y install python34-tkinter -pip3 install psutil --user \ No newline at end of file +pip3 install psutil --user diff --git a/installer/centos/hnn-centos7.sh b/installer/centos/hnn-centos7.sh index 2880d5afe..0cbae2bdd 100755 --- a/installer/centos/hnn-centos7.sh +++ b/installer/centos/hnn-centos7.sh @@ -9,7 +9,7 @@ sudo yum -y install automake gcc gcc-c++ flex bison libtool git \ # the system version of pip installs nlopt in the wrong directory sudo pip3 install --upgrade pip -sudo /usr/local/bin/pip3 install NEURON PyOpenGL matplotlib pyqt5 pyqtgraph scipy numpy nlopt +sudo /usr/local/bin/pip3 install hnn-core matplotlib pyqt5 nlopt export PATH=$PATH:/usr/lib64/openmpi/bin diff --git a/installer/centos/uninstall.sh b/installer/centos/uninstall.sh deleted file mode 100644 index c1636639e..000000000 --- a/installer/centos/uninstall.sh +++ /dev/null @@ -1 +0,0 @@ -sudo rm -f /etc/profile.d/hnn.sh diff --git a/installer/docker/Dockerfile b/installer/docker/Dockerfile index f9b98ea8d..f417f381e 100644 --- a/installer/docker/Dockerfile +++ b/installer/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM caldweba/opengl-docker +FROM ubuntu:18.04 # avoid questions from debconf ENV DEBIAN_FRONTEND noninteractive @@ -32,8 +32,8 @@ RUN sudo pip3 install --no-cache-dir --upgrade pip && \ sudo apt-get update && \ sudo apt-get install --no-install-recommends -y \ gcc python3-dev && \ - sudo pip install --no-cache-dir matplotlib PyOpenGL \ - pyqt5 pyqtgraph scipy numpy nlopt psutil && \ + sudo pip install --no-cache-dir matplotlib \ + pyqt5 scipy numpy nlopt psutil && \ sudo apt-get -y remove --purge \ gcc python3-dev && \ sudo apt-get autoremove -y --purge && \ @@ -83,15 +83,16 @@ LABEL org.label-schema.build-date=$BUILD_DATE \ org.label-schema.vcs-ref=$VCS_REF \ org.label-schema.schema-version=$VCS_TAG -# install NEURON -RUN sudo pip install NEURON +# install hnn-core +RUN sudo pip install hnn-core # install HNN RUN sudo apt-get update && \ sudo apt-get install --no-install-recommends -y \ make gcc libc6-dev libtinfo-dev libncurses-dev \ libx11-dev libreadline-dev g++ && \ - git clone ${SOURCE_REPO} \ + git clone --single-branch --branch maint/pre-hnn-core \ + ${SOURCE_REPO} \ --depth 1 --single-branch --branch $SOURCE_BRANCH \ $HOME/hnn_source_code && \ cd $HOME/hnn_source_code && \ diff --git a/installer/mac/README.md b/installer/mac/README.md index 70977bb2c..49f56553f 100644 --- a/installer/mac/README.md +++ b/installer/mac/README.md @@ -45,7 +45,7 @@ The Xcode Command Line Tools package includes utilities for compiling code from 1. Create a conda environment with the Python prerequisites for HNN. ```bash - conda create -y -n hnn python=3.7 openmpi pyqtgraph pyopengl matplotlib scipy psutil + conda env create -f environment.yml ``` 2. Activate the HNN conda environment and install nlopt and NEURON diff --git a/installer/mac/check-post.sh b/installer/mac/check-post.sh index 6b2f02693..df71702d3 100755 --- a/installer/mac/check-post.sh +++ b/installer/mac/check-post.sh @@ -26,15 +26,6 @@ check_python_version () { echo "Performing post-install checks for HNN" echo "--------------------------------------" - -echo -n "Checking for XQuartz..." -XQUARTZ_VERSION=$(mdls -name kMDItemVersion /Applications/Utilities/XQuartz.app) -if [[ "$?" -eq "0" ]]; then - echo "ok" -else - echo "failed" -fi - CUR_DIR=$(pwd) echo -n "Checking if HNN is compiled..." @@ -211,7 +202,7 @@ else fi PREREQS_INSTALLED=1 -for prereq in "pyqtgraph" "matplotlib" "scipy" "psutil" "numpy" "nlopt" "neuron"; do +for prereq in "matplotlib" "scipy" "psutil" "numpy" "nlopt" "neuron"; do echo -n "Checking Python can import $prereq module..." $PYTHON -c "import $prereq" > /dev/null 2>&1 if [[ "$?" -eq "0" ]]; then diff --git a/installer/mac/check-pre.sh b/installer/mac/check-pre.sh index 327879b9e..1658122a3 100755 --- a/installer/mac/check-pre.sh +++ b/installer/mac/check-pre.sh @@ -30,18 +30,6 @@ check_python_version () { echo "Performing pre-install checks for HNN" echo "--------------------------------------" -echo -n "Checking if XQuartz is installed..." -XQUARTZ_VERSION=$(mdls -name kMDItemVersion /Applications/Utilities/XQuartz.app) -if [[ "$?" -eq "0" ]]; then - echo "ok" - echo "Xquartz version $(echo ${XQUARTZ_VERSION}|cut -d ' ' -f 3) is already installed" - echo "You can skip the XQuartz installation step" - echo -else - echo "failed" - return=2 -fi - echo -n "Checking for existing python..." PYTHON_VERSION= FOUND= diff --git a/installer/nsg/README.md b/installer/nsg/README.md deleted file mode 100644 index e90052e4a..000000000 --- a/installer/nsg/README.md +++ /dev/null @@ -1,184 +0,0 @@ -# Running HNN on Neuroscience Gateway Portal (NSG) - -## HNN GUI is temporality not available on NSG, please email us at hnneurosolver@gmail.com if you have questions about running HNN on NSG - ---- - -The [Neuroscience Gateway Portal (NSG)](http://www.nsgportal.org) is high-performance compute environment sponsored by NSF/NIH and the BBSRC (UK). Availability of HNN on NSG enables users to run simulations on one of many cluster compute nodes, which reduces simulation runtime. In addition, users will not have to install HNN on their local workstations, allowing them to get started running simulations more quickly. Below are instructions on how to run HNN on NSG. - -## Getting an account on NSG - -1. Create your NSG account [here](https://www.nsgportal.org/gest/reg.php ) if you don't already have one - -2. Log in to NSG through a web browser [here](https://nsgdev.sdsc.edu:8443/portal2/login!input.action) - -## Preliminary software (VNC, ssh) - -You should make sure you have a VNC viewer (e.g. Tightvnc) and `ssh` available from a terminal window. These tools are included by default with Mac/Linux but you may need to install additional software for Windows. - -[RealVNC](https://www.realvnc.com/en/connect/download/viewer/) is available on many platforms, but is not open source. - -For Ubuntu Linux you can install a vnc viewer using the following command from the terminal: - -```bash -sudo apt install xvnc4viewer -``` - -## Running HNN on NSG - -1. You will need to upload some data since NSG jobs require input data. The data can be a zip file with arbitrary content. To do so, click on **New Folder** on the front page, and enter a **Label**. Next, click on **Data**, underneath the new folder on the left hand side, then **Upload/Enter Data** to upload an arbitrary file with an arbitrary **Label**. - -2. Click on Toolkit on the top menu: - - - - This will take you to a list of computational neuroscience tools available on NSG: - - - -3. Click on **VNC on Jetstream** at the bottom of the list: - - - - This will create a new "task" that can allow you to access a desktop on the Jetstream supercomputer through VNC: - - - -4. If you want to customize how much time you have to run HNN once your VNC job is created, you can click on **Set Parameters** and then enter the **Maximum Hours to Run** (0.5 is the default). - -5. Enter a **Description** for your task in the text field next to **Description**. For example, you could write "HNN simulation". - -6. Click on **Select Data** then pick the zip file you uploaded in step 3. - -7. Click **Save Task**. This will take you to a list of available tasks to run: - - - -8. Click **Run Task** for the task you created to start the job. Click **OK** on the resulting dialogue box. - -9. Click on **View Status** to see how the job is progressing and to get the information needed to login to NSG using VNC. - - - -10. After a few minutes (or longer, depending on others' use of NSG) you will receive an email from [](nsguser@jetstream-cloud.org) stating that your VNC session is ready, and including instructions for how to login to your VNC session, which are as follows, but are explained in more detail below: - - 1. On your local machine, start an ssh tunnel from your local port 5902 to the vnc server node port 5901. Use the password in **Intermediate Results**: `ssh.password` - - ```bash - ssh -l nsguser -L 5902:localhost:5901 149.165.157.55 - ``` - - 2. Use a VNC client from [above](./README.md#preliminary-software-vnc-ssh) - - 3. Then start then vnc viewer to `localhost:2` - Use the password in **Intermediate Results**: `vnc.password` - - To examine or retrieve files on the VNC node, ssh or scp may be used: - - ```bash - ssh -l nsguser 149.165.157.55 - ``` - - For the hnn data directory: - - ```bash - scp -pr 'nsguser@149.165.157.55:~/hnn/data/*' mydownloaddir - ``` - - or for the entire job directory: - - ```bash - scp -pr 'nsguser@149.165.157.55:~' mydownloaddir - ``` - -11. Click on the **List** link next to **Intermediate Results**. This will bring up a window listing files after a few minutes that contains a link to `vnc.password` and `ssh.password`: - - - - You should download both of these files to your computer. - -12. Start a terminal. - -13. Enter the following command in the terminal: - - ```bash - ssh -l nsguser -L 5902:localhost:5901 149.165.157.55 - ``` - - When prompted for a password, use the text inside of the `ssh.password` file you downloaded in step 11. In order to see the text inside of the `ssh.password` file, enter the following command from the directory where you have stored the file (in a separate terminal): - - ```bash - more ssh.password - ``` - -14. Enter the following command in a separate terminal: - - ```bash - vncviewer localhost:2 - ``` - - If this does not work, you may have to switch to the VNC Viewer folder before running the vncviewer command. For example, on the Windows operating system, you would first type: - - ``` - cd C:\Program Files\RealVNC\VNC Viewer - ``` - - When prompted for a password use the text inside the vnc.password file you downloaded in step 11. (Note that the vncviewer command depends on which VNC software you have installed on your computer). This will open a VNC viewer display of your NSG desktop: - - - -15. In the VNC viewer window Click on **Applications** -> **System tools** -> **Terminal** - - This will start a command line terminal as shown here: - - - -16. In the terminal, enter the following command: - - ```bash - hnn - ``` - - This will start the HNN graphical user interface as shown here: - - - -17. Use HNN to Run simulations by clicking on **Run Simulation**. If asked whether to overwrite a file, click **OK**. After that, HNN will run its default ERP simulation, producing the output shown here: - - - - You can now continue to run other simulations, as desired. For more information on how to use HNN to run simulations and compare model to experimental data, see our [Tutorials](https://hnn.brown.edu/index.php/tutorials/) - -18. When you are done running simulations, log out of the VNC viewer by clicking first on the **nsguser** menu in the top right corner, and then on **Quit**, as shown here: - - - - When prompted click on **Log out**. NSG will then create an output file that you can download. - -19. You can download the data using either `scp` (19.1) or NSG's web browser interface (19.2). - - 1. To use `scp` for the hnn data directory enter the following in the terminal: - - ```bash - scp -pr 'nsguser@149.165.157.55:~/hnn/data/*' mydownloaddir - ``` - - or for the entire job directory: - - ```bash - scp -pr 'nsguser@149.165.157.55:~' mydownloaddir - ``` - - When prompted for a password, use the text in the `ssh.password` you downloaded in step 11. Note that `mydownloaddir` is a local directory on your computer that you want to download the data to. - - 2. Refresh the task list in the web browser. Then click on output. After a few minutes you will see a file `output.tar.gz` in the list, which contains the output data from the simulations you ran. After downloading it, you can extract its contents using the following command from the terminal: - - ```bash - tar -xvf output.tar.gz - ``` - - This will extract the data into a new directory. You can browse the output in the data subdirectory, and it will contain one subdirectory for each simulation run. See the HNN website for more information on the content of these files. - -## Troubleshooting - -For HNN software issues, please visit the [HNN bulletin board](https://www.neuron.yale.edu/phpBB/viewforum.php?f=46) diff --git a/installer/nsg/install_pngs/001.png b/installer/nsg/install_pngs/001.png deleted file mode 100644 index c7678d378..000000000 Binary files a/installer/nsg/install_pngs/001.png and /dev/null differ diff --git a/installer/nsg/install_pngs/002.png b/installer/nsg/install_pngs/002.png deleted file mode 100644 index 7e1c32fef..000000000 Binary files a/installer/nsg/install_pngs/002.png and /dev/null differ diff --git a/installer/nsg/install_pngs/003.png b/installer/nsg/install_pngs/003.png deleted file mode 100644 index bde19fbb2..000000000 Binary files a/installer/nsg/install_pngs/003.png and /dev/null differ diff --git a/installer/nsg/install_pngs/004.png b/installer/nsg/install_pngs/004.png deleted file mode 100644 index 99e8c5a8a..000000000 Binary files a/installer/nsg/install_pngs/004.png and /dev/null differ diff --git a/installer/nsg/install_pngs/005.png b/installer/nsg/install_pngs/005.png deleted file mode 100644 index b02108b3c..000000000 Binary files a/installer/nsg/install_pngs/005.png and /dev/null differ diff --git a/installer/nsg/install_pngs/006.png b/installer/nsg/install_pngs/006.png deleted file mode 100644 index 2ea807940..000000000 Binary files a/installer/nsg/install_pngs/006.png and /dev/null differ diff --git a/installer/nsg/install_pngs/007.png b/installer/nsg/install_pngs/007.png deleted file mode 100644 index b2e3c735e..000000000 Binary files a/installer/nsg/install_pngs/007.png and /dev/null differ diff --git a/installer/nsg/install_pngs/008.png b/installer/nsg/install_pngs/008.png deleted file mode 100644 index 1f9c1910e..000000000 Binary files a/installer/nsg/install_pngs/008.png and /dev/null differ diff --git a/installer/nsg/install_pngs/009.png b/installer/nsg/install_pngs/009.png deleted file mode 100644 index fb87b1813..000000000 Binary files a/installer/nsg/install_pngs/009.png and /dev/null differ diff --git a/installer/nsg/install_pngs/010.png b/installer/nsg/install_pngs/010.png deleted file mode 100644 index b47567226..000000000 Binary files a/installer/nsg/install_pngs/010.png and /dev/null differ diff --git a/installer/nsg/install_pngs/011.png b/installer/nsg/install_pngs/011.png deleted file mode 100644 index 8e5912dee..000000000 Binary files a/installer/nsg/install_pngs/011.png and /dev/null differ diff --git a/installer/nsg/install_pngs/012.png b/installer/nsg/install_pngs/012.png deleted file mode 100644 index c6ab792c3..000000000 Binary files a/installer/nsg/install_pngs/012.png and /dev/null differ diff --git a/installer/ubuntu/uninstaller.sh b/installer/ubuntu/uninstaller.sh deleted file mode 100755 index 8d9d3a812..000000000 --- a/installer/ubuntu/uninstaller.sh +++ /dev/null @@ -1,5 +0,0 @@ -# clean up the bashrc (sed looks like it uses regex, but I don't -# think that it does. You do have to escape the '/' character, though) -sed -i '/# these lines define global session variables for HNN/d' ~/.bashrc -sed -i '/export CPU=$(uname -m)/d' ~/.bashrc -sed -i "/export PATH=\$PATH:$startdir\/nrn\/build\/\$CPU\/bin/d" ~/.bashrc diff --git a/lfp.py b/lfp.py deleted file mode 100644 index c68d31c32..000000000 --- a/lfp.py +++ /dev/null @@ -1,274 +0,0 @@ -""" -LFPsim - Simulation scripts to compute Local Field Potentials (LFP) from cable compartmental -models of neurons and networks implemented in NEURON simulation environment. - -LFPsim works reliably on biophysically detailed multi-compartmental neurons with ion channels in -some or all compartments. - -Last updated 12-March-2016 -Developed by : Harilal Parasuram & Shyam Diwakar -Computational Neuroscience & Neurophysiology Lab, School of Biotechnology, Amrita University, India. -Email: harilalp@am.amrita.edu; shyam@amrita.edu -www.amrita.edu/compneuro - -translated to Python and modified to use use_fast_imem by Sam Neymotin -based on mhines code -""" - -from neuron import h -from math import sqrt, log, pi, exp -from seg3d import * -from pylab import * - -# get all Sections -def getallSections (ty='Pyr'): - ls = h.allsec() - ls = [s for s in ls if s.name().count(ty)>0 or len(ty)==0] - return ls - -def getcoordinf (s): - lcoord = []; ldist = []; lend = []; lsegloc = [] - if s.nseg == 1: - i = 1 - x0, y0, z0 = s.x3d(i-1,sec=s), s.y3d(i-1, sec=s), s.z3d(i-1, sec=s) - x1, y1, z1 = s.x3d(i,sec=s), s.y3d(i, sec=s), s.z3d(i, sec=s) - lcoord.append([(x0+x1)/2.0,(y0+y1)/2.0,(z0+z1)/2.0]) - dist = sqrt((x1-x0)**2 + (y1-y0)**2 + (z1-z0)**2) - ldist.append( dist ) - lend.append([x1, y1, z1]) - lsegloc.append(0.5) - else: - for i in range(1,s.n3d(),1): - x0, y0, z0 = s.x3d(i-1,sec=s), s.y3d(i-1, sec=s), s.z3d(i-1, sec=s) - x1, y1, z1 = s.x3d(i,sec=s), s.y3d(i, sec=s), s.z3d(i, sec=s) - lcoord.append( [(x0+x1)/2.,(y0+y1)/2.(z0+z1)/2.] ) - dist = sqrt((x1-x0)**2 + (y1-y0)**2 + (z1-z0)**2) - ldist.append( dist ) - lend.append([x1, y1, z1]) - lsegloc.append() - return lcoord, ldist, lend, lsegloc - -# this function not used ... yet -def transfer_resistance2 (exyz): - vres = h.Vector() - lsec = getallSections() - sigma = 3.0 # extracellular conductivity in mS/cm - # see http://jn.physiology.org/content/104/6/3388.long shows table of values with conductivity - for s in lsec: - lcoord, ldist, lend = getcoordinf(s) - for i in range(len(lcoord)): - x,y,z = lcoord[i] - dis = sqrt((exyz[0] - x)**2 + (exyz[1] - y)**2 + (exyz[2] - z)**2 ) - # setting radius limit - if(dis<(s.diam/2.0)): dis = (s.diam/2.0) + 0.1 - - dist_comp = ldist[i] # length of the compartment - sum_dist_comp = sqrt(dist_comp[0]**2 + dist_comp[0]**2 + dist_comp[0]**2) - - # print "sum_dist_comp=",sum_dist_comp, secname() - - # setting radius limit - if sum_dist_comp < s.diam/2.0: sum_dist_comp = s.diam/2.0 + 0.1 - - long_dist_x = exyz[0] - lend[i][0] - long_dist_y = exyz[1] - lend[i][1] - long_dist_z = exyz[2] - lend[i][2] - - sum_HH = long_dist_x*dist_comp_x + long_dist_y*dist_comp_y + long_dist_z*dist_comp_z - - final_sum_HH = sum_HH / sum_dist_comp - - sum_temp1 = long_dist_x**2 + long_dist_y**2 + long_dist_z**2 - r_sq = sum_temp1 - (final_sum_HH * final_sum_HH) - - Length_vector = final_sum_HH + sum_dist_comp - - if final_sum_HH < 0 and Length_vector <= 0: - phi=log((sqrt(final_sum_HH**2 + r_sq) - final_sum_HH)/(sqrt(Length_vector**2+r_sq)-Length_vector)) - elif final_sum_HH > 0 and Length_vector > 0: - phi=log((sqrt(Length_vector**2+r_sq) + Length_vector)/(sqrt(final_sum_HH**2+r_sq) + final_sum_HH)) - else: - phi=log(((sqrt(Length_vector**2+r_sq)+Length_vector) * (sqrt(final_sum_HH**2+r_sq)-final_sum_HH))/r_sq) - - line_part1 = 1.0 / (4.0*pi*sum_dist_comp*sigma) * phi - vres.append(line_part1) - - return vres - -# represents a simple LFP electrode -class LFPElectrode (): - - def __init__ (self, coord, sigma = 3.0, pc = None, usePoint = True): - - self.sigma = sigma # extracellular conductivity in mS/cm (uniform for simplicity) - # see http://jn.physiology.org/content/104/6/3388.long shows table of values with conductivity - self.coord = coord - self.vres = None - self.vx = None - - self.imem_ptrvec = self.imem_vec = self.rx = self.vx = self.vres = None - self.bscallback = self.fih = None - - if pc is None: self.pc = h.ParallelContext() - else: self.pc = pc - - def setup (self): - h.cvode.use_fast_imem(1) # enables fast calculation of transmembrane current (nA) at each segment - self.bscallback = h.beforestep_callback(h.cas()(.5)) - self.bscallback.set_callback(self.callback) - fih = h.FInitializeHandler(1, self.LFPinit) - - def transfer_resistance (self, exyz,usePoint=True): - vres = h.Vector() - lsec = getallSections() - for s in lsec: - - x = (h.x3d(0,sec=s) + h.x3d(1,sec=s)) / 2.0 - y = (h.y3d(0,sec=s) + h.y3d(1,sec=s)) / 2.0 - z = (h.z3d(0,sec=s) + h.z3d(1,sec=s)) / 2.0 - - sigma = self.sigma - - dis = sqrt((exyz[0] - x)**2 + (exyz[1] - y)**2 + (exyz[2] - z)**2 ) - - # setting radius limit - if(dis<(s.diam/2.0)): dis = (s.diam/2.0) + 0.1 - - if usePoint: - point_part1 = 10000.0 * (1.0 / (4.0 * pi * dis * sigma)) # x10000 for units of microV : nA/(microm*(mS/cm)) -> microV - vres.append(point_part1) - else: - # calculate length of the compartment - dist_comp = sqrt((h.x3d(1,sec=s) - h.x3d(0,sec=s))**2 + (h.y3d(1,sec=s) - h.y3d(0,sec=s))**2 + (h.z3d(1,sec=s) - h.z3d(0,sec=s))**2) - - dist_comp_x = (h.x3d(1,sec=s) - h.x3d(0,sec=s)) - dist_comp_y = (h.y3d(1,sec=s) - h.y3d(0,sec=s)) - dist_comp_z = (h.z3d(1,sec=s) - h.z3d(0,sec=s)) - - sum_dist_comp = sqrt(dist_comp_x**2 + dist_comp_y**2 + dist_comp_z**2) - - # print "sum_dist_comp=",sum_dist_comp, secname() - - # setting radius limit - if sum_dist_comp < s.diam/2.0: sum_dist_comp = s.diam/2.0 + 0.1 - - long_dist_x = exyz[0] - h.x3d(1,sec=s) - long_dist_y = exyz[1] - h.y3d(1,sec=s) - long_dist_z = exyz[2] - h.z3d(1,sec=s) - - sum_HH = long_dist_x*dist_comp_x + long_dist_y*dist_comp_y + long_dist_z*dist_comp_z - - final_sum_HH = sum_HH / sum_dist_comp - - sum_temp1 = long_dist_x**2 + long_dist_y**2 + long_dist_z**2 - r_sq = sum_temp1 -(final_sum_HH * final_sum_HH) - - Length_vector = final_sum_HH + sum_dist_comp - - if final_sum_HH < 0 and Length_vector <= 0: - phi=log((sqrt(final_sum_HH**2 + r_sq) - final_sum_HH)/(sqrt(Length_vector**2+r_sq)-Length_vector)) - elif final_sum_HH > 0 and Length_vector > 0: - phi=log((sqrt(Length_vector**2+r_sq) + Length_vector)/(sqrt(final_sum_HH**2+r_sq) + final_sum_HH)) - else: - phi=log(((sqrt(Length_vector**2+r_sq)+Length_vector) * (sqrt(final_sum_HH**2+r_sq)-final_sum_HH))/r_sq) - - line_part1 = 10000.0 * (1.0 / (4.0*pi*sum_dist_comp*sigma) * phi) # x10000 for units of microV - vres.append(line_part1) - - return vres - - def LFPinit (self): - lsec = getallSections() - n = len(lsec) - # print('In LFPinit - pc.id = ',self.pc.id(),'len(lsec)=',n) - self.imem_ptrvec = h.PtrVector(n) # - self.imem_vec = h.Vector(n) - for i,s in enumerate(lsec): - seg = s(0.5) - #for seg in s # so do not need to use segments...? more accurate to use segments and their neighbors - self.imem_ptrvec.pset(i, seg._ref_i_membrane_) - - self.vres = self.transfer_resistance(self.coord) - self.lfp_t = h.Vector() - self.lfp_v = h.Vector() - - #for i, cellinfo in enumerate(gidinfo.values()): - # seg = cellinfo.cell.soma(0.5) - # imem_ptrvec.pset(i, seg._ref_i_membrane_) - #rx = h.Matrix(nelectrode, n) - #vx = h.Vector(nelectrode) - #for i in range(nelectrode): - # for j, cellinfo in enumerate(gidinfo.values()): - # rx.setval(i, j, transfer_resistance(cellinfo.cell, e_coord[i])) - # #rx.setval(i,1,1.0) - - def callback (self): - # print('In lfp callback - pc.id = ',self.pc.id(),' t=',self.pc.t(0)) - self.imem_ptrvec.gather(self.imem_vec) - #s = pc.allreduce(imem_vec.sum(), 1) #verify sum i_membrane_ == stimulus - #if rank == 0: print pc.t(0), s - - #sum up the weighted i_membrane_. Result in vx - # rx.mulv(imem_vec, vx) - - val = 0.0 - for j in range(len(self.vres)): val += self.imem_vec.x[j] * self.vres.x[j] - - # append to Vector - self.lfp_t.append(self.pc.t(0)) - self.lfp_v.append(val) - - def lfp_final (self): - self.pc.allreduce(self.lfp_v, 1) - - def lfpout (self,fn = 'LFP.txt', append=False, tvec = None): - fmode = 'w' - if append: fmode = 'a' - if int(self.pc.id()) == 0: - print('len(lfp_t) is %d' % len(self.lfp_t)) - f = open(fn, fmode) - if tvec is None: - for i in range(1, len(self.lfp_t), 1): - line = '%g' % self.lfp_v.x[i] - f.write(line + '\n') - else: - for i in range(1, len(self.lfp_t), 1): - line = '%g'%self.lfp_t.x[i] - line += ' %g' % self.lfp_v.x[i] - f.write(line + '\n') - f.close() - -def test (): - from L5_pyramidal import L5Pyr - cell = L5Pyr() - - h.load_file("stdgui.hoc") - h.cvode_active(1) - - ns = h.NetStim() - ns.number = 10 - ns.start = 100 - ns.interval=50.0 - - nc = h.NetCon(ns,cell.apicaltuft_ampa) - nc.weight[0] = 0.001 - - h.tstop=2000.0 - - elec = LFPElectrode([0, 100.0, 100.0], pc = h.ParallelContext()) - elec.setup() - elec.LFPinit() - h.run() - elec.lfp_final() - ion() - plot(elec.lfp_t, elec.lfp_v) - -if __name__ == '__main__': - test() - """ - for i in range(len(lfp_t)): - print(lfp_t.x[i],) - for j in range(nelectrode): - print(lfp_v[j].x[i],) - print("") - """ diff --git a/loadmodel_nrnui.py b/loadmodel_nrnui.py deleted file mode 100644 index 3dded0af6..000000000 --- a/loadmodel_nrnui.py +++ /dev/null @@ -1,6 +0,0 @@ -# Used by Jupyter/NEURON-UI to import model -import os -os.chdir('NEURON-UI/neuron_ui/models/hnn') -import hnn_nrnui -net=hnn_nrnui.HNN() - diff --git a/mod/ar.mod b/mod/ar.mod deleted file mode 100644 index 0c300dbc2..000000000 --- a/mod/ar.mod +++ /dev/null @@ -1,58 +0,0 @@ -TITLE Anomalous rectifier current for RD Traub, J Neurophysiol 89:909-921, 2003 - -COMMENT - Implemented by Maciej Lazarewicz 2003 (mlazarew@seas.upenn.edu) -ENDCOMMENT - -INDEPENDENT { t FROM 0 TO 1 WITH 1 (ms) } - -UNITS { - (mV) = (millivolt) - (mA) = (milliamp) -} - -NEURON { - SUFFIX ar - NONSPECIFIC_CURRENT i - RANGE gbar, i -} - -PARAMETER { - gbar = 0.0 (mho/cm2) - v (mV) - erev = -35 (mV) -} - -ASSIGNED { - i (mA/cm2) - minf (1) - mtau (ms) -} - -STATE { - m -} - -BREAKPOINT { - SOLVE states METHOD cnexp - i = gbar * m * ( v - erev ) -} - -INITIAL { - settables(v) - m = minf - m = 0.25 -} - -DERIVATIVE states { - settables(v) - m' = ( minf - m ) / mtau -} - -UNITSOFF -PROCEDURE settables(v) { - TABLE minf, mtau FROM -120 TO 40 WITH 641 - minf = 1 / ( 1 + exp( ( v + 75 ) / 5.5 ) ) - mtau = 1 / ( exp( -14.6 - 0.086 * v ) + exp( -1.87 + 0.07 * v ) ) -} -UNITSON diff --git a/mod/beforestep_py.mod b/mod/beforestep_py.mod deleted file mode 100644 index 3e2552a7f..000000000 --- a/mod/beforestep_py.mod +++ /dev/null @@ -1,48 +0,0 @@ -: Python callback from BEFORE STEP - -NEURON { - POINT_PROCESS beforestep_callback - POINTER ptr -} - -ASSIGNED { - ptr -} - -INITIAL { -} - -VERBATIM -extern int (*nrnpy_hoccommand_exec)(Object*); -extern Object** hoc_objgetarg(int); -extern int ifarg(int); -extern void hoc_obj_ref(Object*); -extern void hoc_obj_unref(Object*); -ENDVERBATIM - -BEFORE STEP { - :printf("beforestep_callback t=%g\n", t) -VERBATIM -{ - Object* cb = (Object*)(_p_ptr); - if (cb) { - (*nrnpy_hoccommand_exec)(cb); - } -} -ENDVERBATIM -} - -PROCEDURE set_callback() { -VERBATIM - Object** pcb = (Object**)(&(_p_ptr)); - if (*pcb) { - hoc_obj_unref(*pcb); - *pcb = (Object*)0; - } - if (ifarg(1)) { - *pcb = *(hoc_objgetarg(1)); - hoc_obj_ref(*pcb); - } -ENDVERBATIM -} - diff --git a/mod/ca.mod b/mod/ca.mod deleted file mode 100644 index e6571ba47..000000000 --- a/mod/ca.mod +++ /dev/null @@ -1,135 +0,0 @@ -COMMENT - 26 Ago 2002 Modification of original channel to allow variable time step - and to correct an initialization error. - - Done by Michael Hines(michael.hines@yale.edu) and Ruggero Scorcioni (rscorcio@gmu.edu) - at EU Advance Course in Computational Neuroscience. Obidos, Portugal - - ca.mod - Uses fixed eca instead of GHK eqn - - HVA Ca current - Based on Reuveni, Friedman, Amitai and Gutnick (1993) J. Neurosci. 13: 4609-4621. - - Author: Zach Mainen, Salk Institute, 1994, zach@salk.edu -ENDCOMMENT - -INDEPENDENT {t FROM 0 TO 1 WITH 1 (ms)} - -NEURON { - SUFFIX ca - USEION ca READ eca WRITE ica - RANGE m, h, gca, gbar - RANGE minf, hinf, mtau, htau - GLOBAL q10, temp, tadj, vmin, vmax, vshift, tshift -} - -PARAMETER { - gbar = 0.1 (pS/um2) : 0.12 mho/cm2 - vshift = 0 (mV) : voltage shift (affects all) - - cao = 2.5 (mM) : external ca concentration - cai (mM) - - temp = 23 (degC) : original temp - q10 = 2.3 : temperature sensitivity - tshift = 30.7 - - v (mV) - dt (ms) - celsius (degC) - vmin = -120 (mV) - vmax = 100 (mV) -} - -UNITS { - (mA) = (milliamp) - (mV) = (millivolt) - (pS) = (picosiemens) - (um) = (micron) - FARADAY = (faraday) (coulomb) - R = (k-mole) (joule/degC) - PI = (pi) (1) -} - -ASSIGNED { - ica (mA/cm2) - gca (pS/um2) - eca (mV) - minf hinf - mtau (ms) htau (ms) - tadj -} - -STATE { m h } - -INITIAL { - trates(v+vshift) - m = minf - h = hinf -} - -BREAKPOINT { - SOLVE states METHOD cnexp - gca = tadj * gbar * m * m * h - ica = (1e-4) * gca * (v - eca) -} - -LOCAL mexp, hexp - -: PROCEDURE states() { -: trates(v+vshift) -: m = m + mexp*(minf-m) -: h = h + hexp*(hinf-h) -: VERBATIM -: return 0; -: ENDVERBATIM -: } - -DERIVATIVE states { - trates(v + vshift) - m' = (minf - m) / mtau - h' = (hinf - h) / htau -} - -PROCEDURE trates(v) { - TABLE minf, hinf, mtau, htau - DEPEND celsius, temp - - FROM vmin TO vmax WITH 199 - - : not consistently executed from here if usetable == 1 - rates(v) - - : tinc = -dt * tadj - - : mexp = 1 - exp(tinc/mtau) - : hexp = 1 - exp(tinc/htau) -} - -PROCEDURE rates(vm) { - LOCAL a, b - - tadj = q10^((celsius - temp - tshift)/10) - - a = 0.055 * (-27 - vm) / (exp((-27 - vm) / 3.8) - 1) - b = 0.94 * exp((-75 - vm) / 17) - - mtau = 1 / tadj / (a+b) - minf = a / (a + b) - - : "h" inactivation - a = 0.000457 * exp((-13 - vm) / 50) - b = 0.0065 / (exp((-vm - 15) / 28) + 1) - - htau = 1 / tadj / (a + b) - hinf = a / (a + b) -} - -FUNCTION efun(z) { - if (fabs(z) < 1e-4) { - efun = 1 - z/2 - } else { - efun = z / (exp(z) - 1) - } -} diff --git a/mod/cad.mod b/mod/cad.mod deleted file mode 100644 index 4ec879cba..000000000 --- a/mod/cad.mod +++ /dev/null @@ -1,97 +0,0 @@ -COMMENT - 26 Ago 2002 Modification of original channel to allow variable time step - and to correct an initialization error. - - Done by Michael Hines (michael.hines@yale.edu) and Ruggero Scorcioni (rscorcio@gmu.edu) - at EU Advance Course in Computational Neuroscience. Obidos, Portugal - - Internal calcium concentration due to calcium currents and pump. - Differential equations. - - Simple model of ATPase pump with 3 kinetic constants (Destexhe 92) - Cai + P <-> CaP -> Cao + P (k1,k2,k3) - - A Michaelis-Menten approximation is assumed, which reduces the complexity - of the system to 2 parameters: - - kt = * k3 -> TIME CONSTANT OF THE PUMP - kd = k2/k1 (dissociation constant) -> EQUILIBRIUM CALCIUM VALUE - - The values of these parameters are chosen assuming a high affinity of - the pump to calcium and a low transport capacity (cfr. Blaustein, - TINS, 11: 438, 1988, and references therein). - - Units checked using "modlunit" -> factor 10000 needed in ca entry - - VERSION OF PUMP + DECAY (decay can be viewed as simplified buffering) - - All variables are range variables - - This mechanism was published in: Destexhe, A. Babloyantz, A. and - Sejnowski, TJ. Ionic mechanisms for intrinsic slow oscillations in - thalamic relay neurons. Biophys. J. 65: 1538-1552, 1993) - - Written by Alain Destexhe, Salk Institute, Nov 12, 1992 - -ENDCOMMENT - -TITLE Decay of internal calcium concentration - -INDEPENDENT {t FROM 0 TO 1 WITH 1 (ms)} - -NEURON { - SUFFIX cad - USEION ca - READ ica, cai - WRITE cai - - : SRJones put taur up here - RANGE ca, taur - GLOBAL depth, cainf - : GLOBAL depth, cainf, taur -} - -UNITS { - : moles do not appear in units - (molar) = (1/liter) - (mM) = (millimolar) - (um) = (micron) - (mA) = (milliamp) - (msM) = (ms mM) - FARADAY = (faraday) (coulomb) -} - -PARAMETER { - depth = .1 (um) : depth of shell - taur = 200 (ms) : rate of calcium removal 200 default - cainf = 100e-6 (mM) - cai (mM) -} - -STATE { - ca (mM) <1e-5> -} - -INITIAL { - ca = cainf - cai = ca -} - -ASSIGNED { - ica (mA/cm2) - drive_channel (mM/ms) -} - -BREAKPOINT { - SOLVE state METHOD cnexp -} - -DERIVATIVE state { - drive_channel = -(10000) * ica / (2 * FARADAY * depth) - - : cannot pump inward - if (drive_channel <= 0.) { drive_channel = 0. } - - ca' = drive_channel + (cainf-ca) / taur - cai = ca -} diff --git a/mod/cat.mod b/mod/cat.mod deleted file mode 100644 index 51ccc94e9..000000000 --- a/mod/cat.mod +++ /dev/null @@ -1,67 +0,0 @@ -TITLE Calcium low threshold T type current for RD Traub, J Neurophysiol 89:909-921, 2003 - -COMMENT - Implemented by Maciej Lazarewicz 2003 (mlazarew@seas.upenn.edu) -ENDCOMMENT - -INDEPENDENT { t FROM 0 TO 1 WITH 1 (ms) } - -UNITS { - (mV) = (millivolt) - (mA) = (milliamp) -} - -NEURON { - SUFFIX cat - NONSPECIFIC_CURRENT i : not causing [Ca2+] influx - RANGE gbar, i -} - -PARAMETER { - gbar = 0.0 (mho/cm2) - v eca (mV) -} - -ASSIGNED { - i (mA/cm2) - minf hinf (1) - mtau htau (ms) -} - -STATE { - m h -} - -BREAKPOINT { - SOLVE states METHOD cnexp - i = gbar * m * m * h * ( v - 125 ) -} - -INITIAL { - settables(v) - m = minf - h = hinf - m = 0 -} - -DERIVATIVE states { - settables(v) - m' = ( minf - m ) / mtau - h' = ( hinf - h ) / htau -} - -UNITSOFF -PROCEDURE settables(v) { - TABLE minf, mtau, hinf, htau FROM -120 TO 40 WITH 641 - - minf = 1 / (1 + exp(( -v - 56 ) / 6.2)) - mtau = 0.204 + 0.333 / (exp(( v + 15.8) / 18.2) + exp((-v - 131) / 16.7)) - hinf = 1 / (1 + exp((v + 80) / 4)) - - if (v < -81) { - htau = 0.333 * exp((v + 466 ) / 66.6) - } else { - htau = 9.32 + 0.333 * exp((-v - 21) / 10.5) - } -} -UNITSON diff --git a/mod/dipole.mod b/mod/dipole.mod deleted file mode 100644 index 06c0c5e4a..000000000 --- a/mod/dipole.mod +++ /dev/null @@ -1,52 +0,0 @@ -: dipole.mod - mod file for range variable dipole -: -: v 1.9.1m0 -: rev 2015-12-15 (SL: minor) -: last rev: (SL: Added back Qtotal, which WAS used in par version) - -NEURON { - SUFFIX dipole - RANGE ri, ia, Q, ztan - POINTER pv - - : for density. sums into Dipole at section position 1 - POINTER Qsum - POINTER Qtotal -} - -UNITS { - (nA) = (nanoamp) - (mV) = (millivolt) - (Mohm) = (megaohm) - (um) = (micrometer) - (Am) = (amp meter) - (fAm) = (femto amp meter) -} - -ASSIGNED { - ia (nA) - ri (Mohm) - pv (mV) - v (mV) - ztan (um) - Q (fAm) - - : human dipole order of 10 nAm - Qsum (fAm) - Qtotal (fAm) -} - -: solve for v's first then use them -AFTER SOLVE { - ia = (pv - v) / ri - Q = ia * ztan - Qsum = Qsum + Q - Qtotal = Qtotal + Q -} - -AFTER INITIAL { - ia = (pv - v) / ri - Q = ia * ztan - Qsum = Qsum + Q - Qtotal = Qtotal + Q -} diff --git a/mod/dipole_pp.mod b/mod/dipole_pp.mod deleted file mode 100644 index 5a9179e18..000000000 --- a/mod/dipole_pp.mod +++ /dev/null @@ -1,61 +0,0 @@ -: dipole_pp.mod - creates point process mechanism Dipole -: -: v 1.9.1m0 -: rev 2015-12-15 (SL: minor) -: last rev: (SL: added Qtotal back, used for par calc) - -NEURON { - POINT_PROCESS Dipole - RANGE ri, ia, Q, ztan - POINTER pv - - : for POINT_PROCESS. Gets additions from dipole - RANGE Qsum - POINTER Qtotal -} - -UNITS { - (nA) = (nanoamp) - (mV) = (millivolt) - (Mohm) = (megaohm) - (um) = (micrometer) - (Am) = (amp meter) - (fAm) = (femto amp meter) -} - -ASSIGNED { - ia (nA) - ri (Mohm) - pv (mV) - v (mV) - ztan (um) - Q (fAm) - Qsum (fAm) - Qtotal (fAm) -} - -: solve for v's first then use them -AFTER SOLVE { - ia = (pv - v) / ri - Q = ia * ztan - Qsum = Qsum + Q - Qtotal = Qtotal + Q -} - -AFTER INITIAL { - ia = (pv - v) / ri - Q = ia * ztan - Qsum = Qsum + Q - Qtotal = Qtotal + Q -} - -: following needed for POINT_PROCESS only but will work if also in SUFFIX -BEFORE INITIAL { - Qsum = 0 - Qtotal = 0 -} - -BEFORE BREAKPOINT { - Qsum = 0 - Qtotal = 0 -} diff --git a/mod/hh2.mod b/mod/hh2.mod deleted file mode 100644 index 7d0884c3f..000000000 --- a/mod/hh2.mod +++ /dev/null @@ -1,122 +0,0 @@ -TITLE hh2.mod sodium, potassium, and leak channels - -COMMENT - This is an adjusted Hodgkin-Huxley treatment for sodium, - potassium, and leakage channels. - Membrane voltage is in absolute mV and has been reversed in polarity - from the original HH convention and shifted to reflect a resting potential - of -65 mV. - Remember to set celsius in your HOC file. -ENDCOMMENT - -UNITS { - (mA) = (milliamp) - (mV) = (millivolt) - (S) = (siemens) -} - -? interface -NEURON { - SUFFIX hh2 - USEION na READ ena WRITE ina - USEION k READ ek WRITE ik - NONSPECIFIC_CURRENT il - RANGE gnabar, gkbar, gl, el, gna, gk - GLOBAL minf, hinf, ninf, mtau, htau, ntau, tshift, temp - THREADSAFE : assigned GLOBALs will be per thread -} - -PARAMETER { - gnabar = .12 (S/cm2) <0,1e9> - gkbar = .036 (S/cm2) <0,1e9> - gl = .0003 (S/cm2) <0,1e9> - el = -54.3 (mV) - temp = 6.3 - tshift = 30.7 -} - -STATE { - m h n -} - -ASSIGNED { - v (mV) - celsius (degC) - ena (mV) - ek (mV) - - gna (S/cm2) - gk (S/cm2) - ina (mA/cm2) - ik (mA/cm2) - il (mA/cm2) - minf hinf ninf - mtau (ms) htau (ms) ntau (ms) -} - -? currents -BREAKPOINT { - SOLVE states METHOD cnexp - gna = gnabar*m*m*m*h - ina = gna*(v - ena) - gk = gkbar*n*n*n*n - ik = gk*(v - ek) - il = gl*(v - el) -} - - -INITIAL { - rates(v) - m = minf - h = hinf - n = ninf -} - -? states -DERIVATIVE states { - rates(v) - m' = (minf-m)/mtau - h' = (hinf-h)/htau - n' = (ninf-n)/ntau -} - -:LOCAL q10 - - -? rates -PROCEDURE rates(v(mV)) { :Computes rate and other constants at current v. - :Call once from HOC to initialize inf at resting v. - LOCAL alpha, beta, sum, q10 - TABLE minf, mtau, hinf, htau, ninf, ntau DEPEND celsius FROM -100 TO 100 WITH 200 - -UNITSOFF - q10 = 3^((celsius - temp - tshift)/10) - :"m" sodium activation system - alpha = .1 * vtrap(-(v+40),10) - beta = 4 * exp(-(v+65)/18) - sum = alpha + beta - mtau = 1/(q10*sum) - minf = alpha/sum - :"h" sodium inactivation system - alpha = .07 * exp(-(v+65)/20) - beta = 1 / (exp(-(v+35)/10) + 1) - sum = alpha + beta - htau = 1/(q10*sum) - hinf = alpha/sum - :"n" potassium activation system - alpha = .01*vtrap(-(v+55),10) - beta = .125*exp(-(v+65)/80) - sum = alpha + beta - ntau = 1/(q10*sum) - ninf = alpha/sum -} - -FUNCTION vtrap(x,y) { :Traps for 0 in denominator of rate eqns. - if (fabs(x/y) < 1e-6) { - vtrap = y*(1 - x/y/2) - }else{ - vtrap = x/(exp(x/y) - 1) - } -} - -UNITSON diff --git a/mod/kca.mod b/mod/kca.mod deleted file mode 100644 index 951e42f04..000000000 --- a/mod/kca.mod +++ /dev/null @@ -1,108 +0,0 @@ -COMMENT - 26 Ago 2002 Modification of original channel to allow variable time step - and to correct an initialization error. - - Done by Michael Hines (michael.hines@yale.edu) and Ruggero Scorcioni (rscorcio@gmu.edu) - at EU Advance Course in Computational Neuroscience. Obidos, Portugal - - kca.mod - - Calcium-dependent potassium channel - Based on Pennefather (1990) -- sympathetic ganglion cells - taken from Reuveni et al (1993) -- neocortical cells - - Author: Zach Mainen, Salk Institute, 1995, zach@salk.edu - -ENDCOMMENT - -INDEPENDENT {t FROM 0 TO 1 WITH 1 (ms)} - -NEURON { - SUFFIX kca - USEION k READ ek WRITE ik - USEION ca READ cai - RANGE n, gk, gbar - RANGE ninf, ntau - GLOBAL Ra, Rb, caix - GLOBAL q10, temp, tadj, vmin, vmax, tshift -} - -UNITS { - (mA) = (milliamp) - (mV) = (millivolt) - (pS) = (picosiemens) - (um) = (micron) -} - -PARAMETER { - : 0.03 mho/cm2 - gbar = 10 (pS/um2) - v (mV) - cai (mM) - caix = 1 - - : max act rate - Ra = 0.01 (/ms) - - : max deact rate - Rb = 0.02 (/ms) - - dt (ms) - celsius (degC) - - : original temp - temp = 23 (degC) - q10 = 2.3 - tshift = 30.7 - - vmin = -120 (mV) - vmax = 100 (mV) -} - -ASSIGNED { - a (/ms) - b (/ms) - ik (mA/cm2) - gk (pS/um2) - ek (mV) - ninf - ntau (ms) - tadj -} - - -STATE { - n -} - -INITIAL { - rates(cai) - n = ninf -} - -BREAKPOINT { - SOLVE states METHOD cnexp - gk = tadj * gbar * n - ik = (1e-4) * gk * (v - ek) -} - -LOCAL nexp - -: Computes state variable n at the current v and dt. -DERIVATIVE states { - rates(cai) - n' = (ninf - n) / ntau -} - -PROCEDURE rates(cai(mM)) { - a = Ra * cai^caix - b = Rb - - tadj = q10^((celsius - temp - tshift) / 10) - - ntau = 1 / tadj / (a + b) - ninf = a / (a + b) - - : tinc = -dt * tadj - : nexp = 1 - exp(tinc/ntau) -} diff --git a/mod/km.mod b/mod/km.mod deleted file mode 100644 index df88cfa60..000000000 --- a/mod/km.mod +++ /dev/null @@ -1,126 +0,0 @@ -COMMENT - 26 Ago 2002 Modification of original channel to allow variable time step - and to correct an initialization error. - - Done by Michael Hines (michael.hines@yale.edu) and Ruggero Scorcioni (rscorcio@gmu.edu) - at EU Advance Course in Computational Neuroscience. Obidos, Portugal - - km.mod - - Potassium channel, Hodgkin-Huxley style kinetics - Based on I-M (muscarinic K channel) - Slow, noninactivating - - Original Author: Zach Mainen, Salk Institute, 1995, zach@salk.edu -ENDCOMMENT - -INDEPENDENT {t FROM 0 TO 1 WITH 1 (ms)} - -NEURON { - SUFFIX km - USEION k READ ek WRITE ik - RANGE n, gk, gbar - RANGE ninf, ntau - GLOBAL Ra, Rb - GLOBAL q10, temp, tadj, vmin, vmax, tshift -} - -UNITS { - (mA) = (milliamp) - (mV) = (millivolt) - (pS) = (picosiemens) - (um) = (micron) -} - -PARAMETER { - : 0.03 mho/cm2 - gbar = 10 (pS/um2) - v (mV) - - : v 1/2 for inf - tha = -30 (mV) - - : inf slope - qa = 9 (mV) - - : max act rate (slow) - Ra = 0.001 (/ms) - - : max deact rate (slow) - Rb = 0.001 (/ms) - - dt (ms) - celsius (degC) - - : original temp - temp = 23 (degC) - - : temp sensitivity - q10 = 2.3 - - tshift = 30.7 - - vmin = -120 (mV) - vmax = 100 (mV) -} - - -ASSIGNED { - a (/ms) - b (/ms) - ik (mA/cm2) - gk (pS/um2) - ek (mV) - ninf - ntau (ms) - tadj -} - - -STATE { - n -} - -INITIAL { - trates(v) - n = ninf -} - -BREAKPOINT { - SOLVE states METHOD cnexp - gk = tadj * gbar * n - ik = (1e-4) * gk * (v - ek) -} - -LOCAL nexp - -: Computes state variable n at the current v and dt. -DERIVATIVE states { - trates(v) - n' = (ninf - n) / ntau -} - -: Computes rate and other constants at current v. -: Call once from HOC to initialize inf at resting v. -PROCEDURE trates(v) { - TABLE ninf, ntau - DEPEND celsius, temp, Ra, Rb, tha, qa - - FROM vmin TO vmax WITH 199 - - : not consistently executed from here if usetable_hh == 1 - rates(v) - : tinc = -dt * tadj - : nexp = 1 - exp(tinc/ntau) -} - -: Computes rate and other constants at current v. -: Call once from HOC to initialize inf at resting v. -PROCEDURE rates(v) { - a = Ra * (v - tha) / (1 - exp(-(v - tha) / qa)) - b = -Rb * (v - tha) / (1 - exp((v - tha) / qa)) - - tadj = q10^((celsius - temp - tshift) / 10) - ntau = 1/tadj/(a+b) - ninf = a/(a+b) -} diff --git a/mod/lfp.mod b/mod/lfp.mod deleted file mode 100644 index 975a4dad1..000000000 --- a/mod/lfp.mod +++ /dev/null @@ -1,49 +0,0 @@ -: lfp.mod - -COMMENT -LFPsim - Simulation scripts to compute Local Field Potentials (LFP) from cable compartmental models of neurons and networks implemented in NEURON simulation environment. - -LFPsim works reliably on biophysically detailed multi-compartmental neurons with ion channels in some or all compartments. - -Last updated 12-March-2016 -Developed by : Harilal Parasuram & Shyam Diwakar -Computational Neuroscience & Neurophysiology Lab, School of Biotechnology, Amrita University, India. -Email: harilalp@am.amrita.edu; shyam@amrita.edu -www.amrita.edu/compneuro -ENDCOMMENT - -NEURON { - SUFFIX lfp - POINTER transmembrane_current - RANGE lfp_line,lfp_point,lfp_rc,initial_part_point, initial_part_line, initial_part_rc - -} - - -ASSIGNED { - - initial_part_line - initial_part_rc - transmembrane_current - lfp_line - lfp_point - lfp_rc - initial_part_point - - -} - -BREAKPOINT { - - :Point Source Approximation - lfp_point = transmembrane_current * initial_part_point * 1e-1 : So the calculated signal will be in nV - - :Line Source Approximation - lfp_line = transmembrane_current * initial_part_line * 1e-1 : So the calculated signal will be in nV - - :RC - lfp_rc = transmembrane_current * initial_part_rc * 1e-3 : So the calculated signal will be in nV - -} - - diff --git a/mod/mea.mod b/mod/mea.mod deleted file mode 100644 index c237b4b1f..000000000 --- a/mod/mea.mod +++ /dev/null @@ -1,89 +0,0 @@ -: mea.mod - -COMMENT -LFPsim - Simulation scripts to compute Local Field Potentials (LFP) from cable compartmental models of neurons and networks implemented in NEURON simulation environment. - -LFPsim works reliably on biophysically detailed multi-compartmental neurons with ion channels in some or all compartments. - -Last updated 12-March-2016 -Developed by : Harilal Parasuram & Shyam Diwakar -Computational Neuroscience & Neurophysiology Lab, School of Biotechnology, Amrita University, India. -Email: harilalp@am.amrita.edu; shyam@amrita.edu -www.amrita.edu/compneuro -ENDCOMMENT - -NEURON { - SUFFIX mea - POINTER transmembrane_current_m - RANGE mea_line0,mea_line1,mea_line2,mea_line3,mea_line4,mea_line5,mea_line6,mea_line7,mea_line8,mea_line9,mea_line10,mea_line11,mea_line12,mea_line13,mea_line14,mea_line15 - RANGE initial_part_line0,initial_part_line1,initial_part_line2,initial_part_line3,initial_part_line4,initial_part_line5,initial_part_line6,initial_part_line7,initial_part_line8,initial_part_line9,initial_part_line10,initial_part_line11,initial_part_line12,initial_part_line13,initial_part_line14,initial_part_line15 - -} - -PARAMETER { - : default values put here - - } - -ASSIGNED { - - transmembrane_current_m - initial_part_line0 - initial_part_line1 - initial_part_line2 - initial_part_line3 - initial_part_line4 - initial_part_line5 - initial_part_line6 - initial_part_line7 - initial_part_line8 - initial_part_line9 - initial_part_line10 - initial_part_line11 - initial_part_line12 - initial_part_line13 - initial_part_line14 - initial_part_line15 - - mea_line0 - mea_line1 - mea_line2 - mea_line3 - mea_line4 - mea_line5 - mea_line6 - mea_line7 - mea_line8 - mea_line9 - mea_line10 - mea_line11 - mea_line12 - mea_line13 - mea_line14 - mea_line15 - - -} - -BREAKPOINT { - - :Line Source Approximation - mea_line0 = transmembrane_current_m * initial_part_line0 * 1e-1 : 1e-1 (mA to uA) : calculated potential will be in uV - mea_line1 = transmembrane_current_m * initial_part_line1 * 1e-1 - mea_line2 = transmembrane_current_m * initial_part_line2 * 1e-1 - mea_line3 = transmembrane_current_m * initial_part_line3 * 1e-1 - mea_line4 = transmembrane_current_m * initial_part_line4 * 1e-1 - mea_line5 = transmembrane_current_m * initial_part_line5 * 1e-1 - mea_line6 = transmembrane_current_m * initial_part_line6 * 1e-1 - mea_line7 = transmembrane_current_m * initial_part_line7 * 1e-1 - mea_line8 = transmembrane_current_m * initial_part_line8 * 1e-1 - mea_line9 = transmembrane_current_m * initial_part_line9 * 1e-1 - mea_line10 = transmembrane_current_m * initial_part_line10 * 1e-1 - mea_line11 = transmembrane_current_m * initial_part_line11 * 1e-1 - mea_line12 = transmembrane_current_m * initial_part_line12 * 1e-1 - mea_line13 = transmembrane_current_m * initial_part_line13 * 1e-1 - mea_line14 = transmembrane_current_m * initial_part_line14 * 1e-1 - mea_line15 = transmembrane_current_m * initial_part_line15 * 1e-1 - -} - diff --git a/mod/vecevent.mod b/mod/vecevent.mod deleted file mode 100644 index c5f624d9c..000000000 --- a/mod/vecevent.mod +++ /dev/null @@ -1,69 +0,0 @@ -: Vector stream of events - -NEURON { - ARTIFICIAL_CELL VecStim -} - -ASSIGNED { - index - etime (ms) - space -} - -INITIAL { - index = 0 - element() - if (index > 0) { - net_send(etime - t, 1) - } -} - -NET_RECEIVE (w) { - if (flag == 1) { - net_event(t) - element() - if (index > 0) { - net_send(etime - t, 1) - } - } -} - -VERBATIM -extern double* vector_vec(); -extern int vector_capacity(); -extern void* vector_arg(); -ENDVERBATIM - -PROCEDURE element() { -VERBATIM - { void* vv; int i, size; double* px; - i = (int)index; - if (i >= 0) { - vv = *((void**)(&space)); - if (vv) { - size = vector_capacity(vv); - px = vector_vec(vv); - if (i < size) { - etime = px[i]; - index += 1.; - } else { - index = -1.; - } - } else { - index = -1.; - } - } - } -ENDVERBATIM -} - -PROCEDURE play() { -VERBATIM - void** vv; - vv = (void**)(&space); - *vv = (void*)0; - if (ifarg(1)) { - *vv = vector_arg(1); - } -ENDVERBATIM -} diff --git a/morphology.py b/morphology.py deleted file mode 100644 index afe9ddb51..000000000 --- a/morphology.py +++ /dev/null @@ -1,624 +0,0 @@ -# from https://github.com/ahwillia/PyNeuron-Toolbox 4/2017 -# adjusted for python3 -from __future__ import division -import numpy as np -import pylab as plt -from matplotlib.pyplot import cm -import string -from neuron import h -import numbers - -# a helper library, included with NEURON -h.load_file('stdlib.hoc') -h.load_file('import3d.hoc') - -class Cell: - def __init__(self,name='neuron',soma=None,apic=None,dend=None,axon=None): - self.soma = soma if soma is not None else [] - self.apic = apic if apic is not None else [] - self.dend = dend if dend is not None else [] - self.axon = axon if axon is not None else [] - self.all = self.soma + self.apic + self.dend + self.axon - - def delete(self): - self.soma = None - self.apic = None - self.dend = None - self.axon = None - self.all = None - - def __str__(self): - return self.name - -def load(filename, fileformat=None, cell=None, use_axon=True, xshift=0, yshift=0, zshift=0): - """ - Load an SWC from filename and instantiate inside cell. Code kindly provided - by @ramcdougal. - - Args: - filename = .swc file containing morphology - cell = Cell() object. (Default: None, creates new object) - filename = the filename of the SWC file - use_axon = include the axon? Default: True (yes) - xshift, yshift, zshift = use to position the cell - - Returns: - Cell() object with populated soma, axon, dend, & apic fields - - Minimal example: - # pull the morphology for the demo from NeuroMorpho.Org - from PyNeuronToolbox import neuromorphoorg - with open('c91662.swc', 'w') as f: - f.write(neuromorphoorg.morphology('c91662')) - cell = load_swc(filename) - - """ - - if cell is None: - cell = Cell(name=string.join(filename.split('.')[:-1])) - - if fileformat is None: - fileformat = filename.split('.')[-1] - - name_form = {1: 'soma[%d]', 2: 'axon[%d]', 3: 'dend[%d]', 4: 'apic[%d]'} - - # load the data. Use Import3d_SWC_read for swc, Import3d_Neurolucida3 for - # Neurolucida V3, Import3d_MorphML for MorphML (level 1 of NeuroML), or - # Import3d_Eutectic_read for Eutectic. - if fileformat == 'swc': - morph = h.Import3d_SWC_read() - elif fileformat == 'asc': - morph = h.Import3d_Neurolucida3() - else: - raise Exception('file format `%s` not recognized'%(fileformat)) - morph.input(filename) - - # easiest to instantiate by passing the loaded morphology to the Import3d_GUI - # tool; with a second argument of 0, it won't display the GUI, but it will allow - # use of the GUI's features - i3d = h.Import3d_GUI(morph, 0) - - # get a list of the swc section objects - swc_secs = i3d.swc.sections - swc_secs = [swc_secs.object(i) for i in range(int(swc_secs.count()))] - - # initialize the lists of sections - sec_list = {1: cell.soma, 2: cell.axon, 3: cell.dend, 4: cell.apic} - - # name and create the sections - real_secs = {} - for swc_sec in swc_secs: - cell_part = int(swc_sec.type) - - # skip everything else if it's an axon and we're not supposed to - # use it... or if is_subsidiary - if (not(use_axon) and cell_part == 2) or swc_sec.is_subsidiary: - continue - - # figure out the name of the new section - if cell_part not in name_form: - raise Exception('unsupported point type') - name = name_form[cell_part] % len(sec_list[cell_part]) - - # create the section - sec = h.Section(name=name) - - # connect to parent, if any - if swc_sec.parentsec is not None: - sec.connect(real_secs[swc_sec.parentsec.hname()](swc_sec.parentx)) - - # define shape - if swc_sec.first == 1: - h.pt3dstyle(1, swc_sec.raw.getval(0, 0), swc_sec.raw.getval(1, 0), - swc_sec.raw.getval(2, 0), sec=sec) - - j = swc_sec.first - xx, yy, zz = [swc_sec.raw.getrow(i).c(j) for i in range(3)] - dd = swc_sec.d.c(j) - if swc_sec.iscontour_: - # never happens in SWC files, but can happen in other formats supported - # by NEURON's Import3D GUI - raise Exception('Unsupported section style: contour') - - if dd.size() == 1: - # single point soma; treat as sphere - x, y, z, d = [dim.x[0] for dim in [xx, yy, zz, dd]] - for xprime in [x - d / 2., x, x + d / 2.]: - h.pt3dadd(xprime + xshift, y + yshift, z + zshift, d, sec=sec) - else: - for x, y, z, d in zip(xx, yy, zz, dd): - h.pt3dadd(x + xshift, y + yshift, z + zshift, d, sec=sec) - - # store the section in the appropriate list in the cell and lookup table - sec_list[cell_part].append(sec) - real_secs[swc_sec.hname()] = sec - - cell.all = cell.soma + cell.apic + cell.dend + cell.axon - return cell - -def sequential_spherical(xyz): - """ - Converts sequence of cartesian coordinates into a sequence of - line segments defined by spherical coordinates. - - Args: - xyz = 2d numpy array, each row specifies a point in - cartesian coordinates (x,y,z) tracing out a - path in 3D space. - - Returns: - r = lengths of each line segment (1D array) - theta = angles of line segments in XY plane (1D array) - phi = angles of line segments down from Z axis (1D array) - """ - d_xyz = np.diff(xyz,axis=0) - - r = np.linalg.norm(d_xyz,axis=1) - theta = np.arctan2(d_xyz[:,1], d_xyz[:,0]) - hyp = d_xyz[:,0]**2 + d_xyz[:,1]**2 - phi = np.arctan2(np.sqrt(hyp), d_xyz[:,2]) - - return (r,theta,phi) - -def spherical_to_cartesian(r,theta,phi): - """ - Simple conversion of spherical to cartesian coordinates - - Args: - r,theta,phi = scalar spherical coordinates - - Returns: - x,y,z = scalar cartesian coordinates - """ - x = r * np.sin(phi) * np.cos(theta) - y = r * np.sin(phi) * np.sin(theta) - z = r * np.cos(phi) - return (x,y,z) - -def find_coord(targ_length,xyz,rcum,theta,phi): - """ - Find (x,y,z) ending coordinate of segment path along section - path. - - Args: - targ_length = scalar specifying length of segment path, starting - from the begining of the section path - xyz = coordinates specifying the section path - rcum = cumulative sum of section path length at each node in xyz - theta, phi = angles between each coordinate in xyz - """ - # [1] Find spherical coordinates for the line segment containing - # the endpoint. - # [2] Find endpoint in spherical coords and convert to cartesian - i = np.nonzero(rcum <= targ_length)[0][-1] - if i == len(theta): - return xyz[-1,:] - else: - r_lcl = targ_length-rcum[i] # remaining length along line segment - (dx,dy,dz) = spherical_to_cartesian(r_lcl,theta[i],phi[i]) - return xyz[i,:] + [dx,dy,dz] - -def interpolate_jagged(xyz,nseg): - """ - Interpolates along a jagged path in 3D - - Args: - xyz = section path specified in cartesian coordinates - nseg = number of segment paths in section path - - Returns: - interp_xyz = interpolated path - """ - - # Spherical coordinates specifying the angles of all line - # segments that make up the section path - (r,theta,phi) = sequential_spherical(xyz) - - # cumulative length of section path at each coordinate - rcum = np.append(0,np.cumsum(r)) - - # breakpoints for segment paths along section path - breakpoints = np.linspace(0,rcum[-1],nseg+1) - np.delete(breakpoints,0) - - # Find segment paths - seg_paths = [] - for a in range(nseg): - path = [] - - # find (x,y,z) starting coordinate of path - if a == 0: - start_coord = xyz[0,:] - else: - start_coord = end_coord # start at end of last path - path.append(start_coord) - - # find all coordinates between the start and end points - start_length = breakpoints[a] - end_length = breakpoints[a+1] - mid_boolean = (rcum > start_length) & (rcum < end_length) - mid_indices = np.nonzero(mid_boolean)[0] - for mi in mid_indices: - path.append(xyz[mi,:]) - - # find (x,y,z) ending coordinate of path - end_coord = find_coord(end_length,xyz,rcum,theta,phi) - path.append(end_coord) - - # Append path to list of segment paths - seg_paths.append(np.array(path)) - - # Return all segment paths - return seg_paths - -def get_section_path(h,sec): - n3d = int(h.n3d(sec=sec)) - xyz = [] - for i in range(0,n3d): - xyz.append([h.x3d(i,sec=sec),h.y3d(i,sec=sec),h.z3d(i,sec=sec)]) - xyz = np.array(xyz) - return xyz - -def shapeplot(h,ax,sections=None,order='pre',cvals=None,\ - clim=None,cmap=cm.YlOrBr_r,**kwargs): - """ - Plots a 3D shapeplot - - Args: - h = hocObject to interface with neuron - ax = matplotlib axis for plotting - sections = list of h.Section() objects to be plotted - order = { None= use h.allsec() to get sections - 'pre'= pre-order traversal of morphology } - cvals = list/array with values mapped to color by cmap; useful - for displaying voltage, calcium or some other state - variable across the shapeplot. - **kwargs passes on to matplotlib (e.g. color='r' for red lines) - - Returns: - lines = list of line objects making up shapeplot - """ - - # Default is to plot all sections. - if sections is None: - if order == 'pre': - sections = allsec_preorder(h) # Get sections in "pre-order" - else: - sections = list(h.allsec()) - - # Determine color limits - if cvals is not None and clim is None: - cn = [ isinstance(cv, numbers.Number) for cv in cvals ] - if any(cn): - clim = [np.min(cvals[cn]), np.max(cvals[cn])] - - # Plot each segement as a line - lines = [] - i = 0 - for sec in sections: - xyz = get_section_path(h,sec) - seg_paths = interpolate_jagged(xyz,sec.nseg) - - for (j,path) in enumerate(seg_paths): - line, = plt.plot(path[:,0], path[:,1], path[:,2], '-k',**kwargs) - if cvals is not None: - if isinstance(cvals[i], numbers.Number): - # map number to colormap - col = cmap(int((cvals[i]-clim[0])*255/(clim[1]-clim[0]))) - else: - # use input directly. E.g. if user specified color with a string. - col = cvals[i] - line.set_color(col) - lines.append(line) - i += 1 - - return lines - -def getshapecoords (h,sections=None,order='pre',**kwargs): - if sections is None: - if order == 'pre': - sections = allsec_preorder(h) # Get sections in "pre-order" - else: - sections = list(h.allsec()) - i = 0 - lx,ly,lz=[],[],[] - for sec in sections: - xyz = get_section_path(h,sec) - seg_paths = interpolate_jagged(xyz,sec.nseg) - for path in seg_paths: - for i in [0,1]: - lx.append(path[i][0]) - ly.append(path[i][1]) - lz.append(path[i][2]) - return lx,ly,lz - - -def shapeplot_animate(v,lines,nframes=None,tscale='linear',\ - clim=[-80,50],cmap=cm.YlOrBr_r): - """ Returns animate function which updates color of shapeplot """ - if nframes is None: - nframes = v.shape[0] - if tscale == 'linear': - def animate(i): - i_t = int((i/nframes)*v.shape[0]) - for i_seg in range(v.shape[1]): - lines[i_seg].set_color(cmap(int((v[i_t,i_seg]-clim[0])*255/(clim[1]-clim[0])))) - return [] - elif tscale == 'log': - def animate(i): - i_t = int(np.round((v.shape[0] ** (1.0/(nframes-1))) ** i - 1)) - for i_seg in range(v.shape[1]): - lines[i_seg].set_color(cmap(int((v[i_t,i_seg]-clim[0])*255/(clim[1]-clim[0])))) - return [] - else: - raise ValueError("Unrecognized option '%s' for tscale" % tscale) - - return animate - -def mark_locations(h,section,locs,markspec='or',**kwargs): - """ - Marks one or more locations on along a section. Could be used to - mark the location of a recording or electrical stimulation. - - Args: - h = hocObject to interface with neuron - section = reference to section - locs = float between 0 and 1, or array of floats - optional arguments specify details of marker - - Returns: - line = reference to plotted markers - """ - - # get list of cartesian coordinates specifying section path - xyz = get_section_path(h,section) - (r,theta,phi) = sequential_spherical(xyz) - rcum = np.append(0,np.cumsum(r)) - - # convert locs into lengths from the beginning of the path - if type(locs) is float or type(locs) is np.float64: - locs = np.array([locs]) - if type(locs) is list: - locs = np.array(locs) - lengths = locs*rcum[-1] - - # find cartesian coordinates for markers - xyz_marks = [] - for targ_length in lengths: - xyz_marks.append(find_coord(targ_length,xyz,rcum,theta,phi)) - xyz_marks = np.array(xyz_marks) - - # plot markers - line, = plt.plot(xyz_marks[:,0], xyz_marks[:,1], \ - xyz_marks[:,2], markspec, **kwargs) - return line - -def root_sections(h): - """ - Returns a list of all sections that have no parent. - """ - roots = [] - for section in h.allsec(): - sref = h.SectionRef(sec=section) - # has_parent returns a float... cast to bool - if sref.has_parent() < 0.9: - roots.append(section) - return roots - -def leaf_sections(h): - """ - Returns a list of all sections that have no children. - """ - leaves = [] - for section in h.allsec(): - sref = h.SectionRef(sec=section) - # nchild returns a float... cast to bool - if sref.nchild() < 0.9: - leaves.append(section) - return leaves - -def root_indices(sec_list): - """ - Returns the index of all sections without a parent. - """ - roots = [] - for i,section in enumerate(sec_list): - sref = h.SectionRef(sec=section) - # has_parent returns a float... cast to bool - if sref.has_parent() < 0.9: - roots.append(i) - return roots - -def allsec_preorder(h): - """ - Alternative to using h.allsec(). This returns all sections in order from - the root. Traverses the topology each neuron in "pre-order" - """ - #Iterate over all sections, find roots - roots = root_sections(h) - - # Build list of all sections - sec_list = [] - for r in roots: - add_pre(h,sec_list,r) - return sec_list - -def add_pre(h,sec_list,section,order_list=None,branch_order=None): - """ - A helper function that traverses a neuron's morphology (or a sub-tree) - of the morphology in pre-order. This is usually not necessary for the - user to import. - """ - - sec_list.append(section) - sref = h.SectionRef(sec=section) - - if branch_order is not None: - order_list.append(branch_order) - if len(sref.child) > 1: - branch_order += 1 - - for next_node in sref.child: - add_pre(h,sec_list,next_node,order_list,branch_order) - -def dist_between(h,seg1,seg2): - """ - Calculates the distance between two segments. I stole this function from - a post by Michael Hines on the NEURON forum - (www.neuron.yale.edu/phpbb/viewtopic.php?f=2&t=2114) - """ - h.distance(0, seg1.x, sec=seg1.sec) - return h.distance(seg2.x, sec=seg2.sec) - -def all_branch_orders(h): - """ - Produces a list branch orders for each section (following pre-order tree - traversal) - """ - #Iterate over all sections, find roots - roots = [] - for section in h.allsec(): - sref = h.SectionRef(sec=section) - # has_parent returns a float... cast to bool - if sref.has_parent() < 0.9: - roots.append(section) - - # Build list of all sections - order_list = [] - for r in roots: - add_pre(h,[],r,order_list,0) - return order_list - -def branch_order(h,section, path=[]): - """ - Returns the branch order of a section - """ - path.append(section) - sref = h.SectionRef(sec=section) - # has_parent returns a float... cast to bool - if sref.has_parent() < 0.9: - return 0 # section is a root - else: - nchild = len(list(h.SectionRef(sec=sref.parent).child)) - if nchild <= 1.1: - return branch_order(h,sref.parent,path) - else: - return 1+branch_order(h,sref.parent,path) - -def dist_to_mark(h, section, secdict, path=[]): - path.append(section) - sref = h.SectionRef(sec=section) - # print 'current : '+str(section) - # print 'parent : '+str(sref.parent) - if secdict[sref.parent] is None: - # print '-> go to parent' - s = section.L + dist_to_mark(h, sref.parent, secdict, path) - # print 'summing, '+str(s) - return s - else: - # print 'end <- start summing: '+str(section.L) - return section.L # parent is marked - -def branch_precedence(h): - roots = root_sections(h) - leaves = leaf_sections(h) - seclist = allsec_preorder(h) - secdict = { sec:None for sec in seclist } - - for r in roots: - secdict[r] = 0 - - precedence = 1 - while len(leaves)>0: - # build list of distances of all paths to remaining leaves - d = [] - for leaf in leaves: - p = [] - dist = dist_to_mark(h, leaf, secdict, path=p) - d.append((dist,[pp for pp in p])) - - # longest path index - i = np.argmax([ dd[0] for dd in d ]) - leaves.pop(i) # this leaf will be marked - - # mark all sections in longest path - for sec in d[i][1]: - if secdict[sec] is None: - secdict[sec] = precedence - - # increment precedence across iterations - precedence += 1 - - #prec = secdict.values() - #return [0 if p is None else 1 for p in prec], d[i][1] - return [ secdict[sec] for sec in seclist ] - - -from neuron import h -import json - -def parent(sec): - seg = sec.trueparentseg() - if seg is None: - return None - else: - return seg.sec - -def parent_loc(sec, trueparent): - seg = sec.trueparentseg() - if seg is None: - return None - else: - return seg.x - -def morphology_to_dict(sections, outfile=None): - section_map = {sec: i for i, sec in enumerate(sections)} - result = [] - h.define_shape() - - for sec in sections: - my_parent = parent(sec) - my_parent_loc = -1 if my_parent is None else parent_loc(sec, my_parent) - my_parent = -1 if my_parent is None else section_map[my_parent] - n3d = int(h.n3d(sec=sec)) - result.append({ - 'section_orientation': h.section_orientation(sec=sec), - 'parent': my_parent, - 'parent_loc': my_parent_loc, - 'x': [h.x3d(i, sec=sec) for i in range(n3d)], - 'y': [h.y3d(i, sec=sec) for i in range(n3d)], - 'z': [h.z3d(i, sec=sec) for i in range(n3d)], - 'diam': [h.diam3d(i, sec=sec) for i in range(n3d)], - 'name': sec.hname() - }) - - if outfile is not None: - with open(outfile, 'w') as f: - json.dump(result, f) - - return result - - -def load_json(morphfile): - - with open(morphfile, 'r') as f: - secdata = json.load(morphfile) - - seclist = [] - for sd in secdata: - # make section - sec = h.Section(name=sd['name']) - seclist.append(sec) - - - # make 3d morphology - for x,y,z,d in zip(sd['x'], sd['y'], sd['z'], sd('diam')): - h.pt3dadd(x, y, z, d, sec=sec) - - # connect children to parent compartments - for sec,sd in zip(seclist,secdata): - if sd['parent_loc'] >= 0: - parent_sec = sec_list[sd['parent']] - sec.connect(parent_sec(sd['parent_loc']), sd['section_orientation']) - - return seclist diff --git a/netParams.py b/netParams.py deleted file mode 100644 index 54271e666..000000000 --- a/netParams.py +++ /dev/null @@ -1,127 +0,0 @@ -# netParams.py - High-level specifications for network model using NetPyNE -from netpyne import specs - -try: - from __main__ import cfg # import SimConfig object with params from parent module -except: - from cfg import cfg # if no simConfig in parent module, import directly from cfg module - -############################################################################### -# -# NETWORK PARAMETERS -# -############################################################################### - -netParams = specs.NetParams() # object of class NetParams to store the network parameters - -############################################################################### -# Cell parameters -############################################################################### - -# L2Pyr params -cellRule = netParams.importCellParams(label='L2Pyr',conds={'cellType':'L2Pyr','cellModel':'HH_reduced'}, - fileName='L2_pyramidal.py',cellName='L2Pyr') - -cellRule['secLists']['alldend'] = [] -cellRule['secLists']['apicdend'] = [] -cellRule['secLists']['basaldend'] = [] - - -# L2Bas params -cellRule = netParams.importCellParams(label='L2Bas',conds={'cellType':'L2Bas','cellModel':'HH_simple'}, - fileName='L2_basket.py',cellName='L2Basket') - - - -# L5Pyr params -cellRule = netParams.importCellParams(label='L5Pyr',conds={'cellType':'L5Pyr','cellModel':'HH_reduced'}, - fileName='L5_pyramidal.py',cellName='L5Pyr') - - -# L5Bas params -cellRule = netParams.importCellParams(label='L5Bas',conds={'cellType':'L5Bas','cellModel':'HH_simple'}, - fileName='L5_basket.py',cellName='L5Basket') - - -""" -# PT cell params (6-comp) -cellRule = netParams.importCellParams(label='PT_6comp', conds={'cellType': 'PT', 'cellModel': 'HH_reduced'}, - fileName='cells/SPI6.py', cellName='SPI6') - -cellRule['secLists']['alldend'] = ['Bdend', 'Adend1', 'Adend2', 'Adend3'] # define section lists -cellRule['secLists']['apicdend'] = ['Adend1', 'Adend2', 'Adend3'] - -for secName,sec in cellRule['secs'].iteritems(): - sec['vinit'] = -75.0413649414 # set vinit for all secs - if secName in cellRule['secLists']['alldend']: - sec['mechs']['nax']['gbar'] = cfg.dendNa # set dend Na gmax for all dends -""" - -############################################################################### -# Population parameters -############################################################################### -#netParams.popParams['PT5B'] = {'cellModel': 'HH_reduced', 'cellType': 'PT', 'numCells': 1} - -num = { - 'E': 100, - 'I': 35 -} - -p = 1.0 - -netParams.popParams['L2Bas'] = {'cellModel': 'HH_simple', 'cellType': 'L2Bas', 'numCells': int(p*num['E'])} -netParams.popParams['L2Pyr'] = {'cellModel': 'HH_reduced', 'cellType': 'L2Pyr', 'numCells': int(p*num['I'])} -netParams.popParams['L5Bas'] = {'cellModel': 'HH_simple', 'cellType': 'L5Bas', 'numCells': int(p*num['E'])} -netParams.popParams['L5Pyr'] = {'cellModel': 'HH_reduced', 'cellType': 'L5Pyr', 'numCells': int(p*num['I'])} - - - -############################################################################### -# Synaptic mechanism parameters -############################################################################### -# netParams.synMechParams['NMDA'] = {'mod': 'MyExp2SynNMDABB', 'tau1NMDA': cfg.tau1NMDA, 'tau2NMDA': 150, 'e': 0} - -#------------------------------------------------------------------------------ -# Synaptic mechanism parameters -#------------------------------------------------------------------------------ -netParams.synMechParams['NMDA'] = {'mod': 'NMDA'} #, 'tau1NMDA': 15, 'tau2NMDA': 150, 'e': 0} -netParams.synMechParams['AMPA'] = {'mod':'AMPA'}#, 'tau1': 0.05, 'tau2': 5.3, 'e': 0} -netParams.synMechParams['GABAA'] = {'mod':'GABAA'}#, 'tau1': 0.07, 'tau2': 18.2, 'e': -80} - -ESynMech = ['AMPA','NMDA'] - - - -""" -############################################################################### -# Current inputs (IClamp) -############################################################################### -if cfg.addIClamp: - for iclabel in [k for k in dir(cfg) if k.startswith('IClamp')]: - ic = getattr(cfg, iclabel, None) # get dict with params - - # add stim source - netParams.stimSourceParams[iclabel] = {'type': 'IClamp', 'delay': ic['start'], 'dur': ic['dur'], 'amp': ic['amp']} - - # connect stim source to target - netParams.stimTargetParams[iclabel+'_'+ic['pop']] = \ - {'source': iclabel, 'conds': {'pop': ic['pop']}, 'sec': ic['sec'], 'loc': ic['loc']} - - -############################################################################### -# NetStim inputs -############################################################################### -if cfg.addNetStim: - for nslabel in [k for k in dir(cfg) if k.startswith('NetStim')]: - ns = getattr(cfg, nslabel, None) - - # add stim source - netParams.stimSourceParams[nslabel] = {'type': 'NetStim', 'start': ns['start'], 'interval': ns['interval'], - 'noise': ns['noise'], 'number': ns['number']} - - # connect stim source to target - netParams.stimTargetParams[nslabel+'_'+ns['pop']] = \ - {'source': nslabel, 'conds': {'pop': ns['pop']}, 'sec': ns['sec'], 'loc': ns['loc'], - 'synMech': ns['synMech'], 'weight': ns['weight'], 'delay': ns['delay']} -""" - diff --git a/network.py b/network.py deleted file mode 100644 index 0993073dc..000000000 --- a/network.py +++ /dev/null @@ -1,426 +0,0 @@ -# class_net.py - establishes the Network class and related methods -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed izip) -# last major: (SL: toward python3) - -import itertools as it -import numpy as np -import sys - -from neuron import h -from feed import ParFeedAll -from L2_pyramidal import L2Pyr -from L5_pyramidal import L5Pyr -from L2_basket import L2Basket -from L5_basket import L5Basket -import paramrw as paramrw - -# create Network class -class NetworkOnNode (): - - def __init__ (self, p): - # set the params internally for this net - # better than passing it around like ... - self.p = p - # Number of time points - # Originally used to create the empty vec for synaptic currents, - # ensuring that they exist on this node irrespective of whether - # or not cells of relevant type actually do - self.N_t = np.arange(0., h.tstop, self.p['dt']).size + 1 - # Create a h.Vector() with size 1xself.N_t, zero'd - self.current = { - 'L5Pyr_soma': h.Vector(self.N_t, 0), - 'L2Pyr_soma': h.Vector(self.N_t, 0), - } - # int variables for grid of pyramidal cells (for now in both L2 and L5) - self.gridpyr = { - 'x': self.p['N_pyr_x'], - 'y': self.p['N_pyr_y'], - } - # Parallel stuff - self.pc = h.ParallelContext() - self.n_hosts = int(self.pc.nhost()) - self.rank = int(self.pc.id()) - self.N_src = 0 - # seed debugging - # for key, val in self.p.items(): - # if key.startswith('prng_seedcore_'): - # print("in net: %i, %s, %i" % (self.rank, key, val)) - self.N = {} # numbers of sources - self.N_cells = 0 # init self.N_cells - # zdiff is expressed as a positive DEPTH of L5 relative to L2 - # this is a deviation from the original, where L5 was defined at 0 - # this should not change interlaminar weight/delay calculations - self.zdiff = 1307.4 - # params of external inputs in p_ext - # Global number of external inputs ... automatic counting makes more sense - # p_unique represent ext inputs that are going to go to each cell - self.p_ext, self.p_unique = paramrw.create_pext(self.p, h.tstop) - self.N_extinput = len(self.p_ext) - # Source list of names - # in particular order (cells, extinput, alpha names of unique inputs) - self.src_list_new = self.__create_src_list() - # cell position lists, also will give counts: must be known by ALL nodes - # extinput positions are all located at origin. - # sort of a hack bc of redundancy - self.pos_dict = dict.fromkeys(self.src_list_new) - # create coords in pos_dict for all cells first - self.__create_coords_pyr() - self.__create_coords_basket() - self.__count_cells() - # create coords for all other sources - self.__create_coords_extinput() - # count external sources - self.__count_extsrcs() - # create dictionary of GIDs according to cell type - # global dictionary of gid and cell type - self.gid_dict = {} - self.__create_gid_dict() - # assign gid to hosts, creates list of gids for this node in __gid_list - # __gid_list length is number of cells assigned to this id() - self.__gid_list = [] - self.__gid_assign() - # create cells (and create self.origin in create_cells_pyr()) - self.cells = [] - self.extinput_list = [] - # external unique input list dictionary - self.ext_list = dict.fromkeys(self.p_unique) - # initialize the lists in the dict - for key in self.ext_list.keys(): self.ext_list[key] = [] - # create sources and init - self.__create_all_src() - self.state_init() - # parallel network connector - self.__parnet_connect() - # set to record spikes - self.spiketimes = h.Vector() - self.spikegids = h.Vector() - self.__record_spikes() - - # creates the immutable source list along with corresponding numbers of cells - def __create_src_list (self): - # base source list of tuples, name and number, in this order - self.cellname_list = [ - 'L2_basket', - 'L2_pyramidal', - 'L5_basket', - 'L5_pyramidal', - ] - # add the legacy extinput here - self.extname_list = [] - self.extname_list.append('extinput') - # grab the keys for the unique set of inputs and sort the names - # append them to the src list along with the number of cells - unique_keys = sorted(self.p_unique.keys()) - self.extname_list += unique_keys - # return one final source list - src_list = self.cellname_list + self.extname_list - return src_list - - # Creates cells and grid - def __create_coords_pyr (self): - """ pyr grid is the immutable grid, origin now calculated in relation to feed - """ - xrange = np.arange(self.gridpyr['x']) - yrange = np.arange(self.gridpyr['y']) - # create list of tuples/coords, (x, y, z) - self.pos_dict['L2_pyramidal'] = [pos for pos in it.product(xrange, yrange, [0])] - self.pos_dict['L5_pyramidal'] = [pos for pos in it.product(xrange, yrange, [self.zdiff])] - - # create basket cell coords based on pyr grid - def __create_coords_basket (self): - # define relevant x spacings for basket cells - xzero = np.arange(0, self.gridpyr['x'], 3) - xone = np.arange(1, self.gridpyr['x'], 3) - # split even and odd y vals - yeven = np.arange(0, self.gridpyr['y'], 2) - yodd = np.arange(1, self.gridpyr['y'], 2) - # create general list of x,y coords and sort it - coords = [pos for pos in it.product(xzero, yeven)] + [pos for pos in it.product(xone, yodd)] - coords_sorted = sorted(coords, key=lambda pos: pos[1]) - # append the z value for position for L2 and L5 - # print(len(coords_sorted)) - self.pos_dict['L2_basket'] = [pos_xy + (0,) for pos_xy in coords_sorted] - self.pos_dict['L5_basket'] = [pos_xy + (self.zdiff,) for pos_xy in coords_sorted] - - # creates origin AND creates external input coords - def __create_coords_extinput (self): - """ (same thing for now but won't fix because could change) - """ - xrange = np.arange(self.gridpyr['x']) - yrange = np.arange(self.gridpyr['y']) - # origin's z component isn't really used in calculating distance functions from origin - # these will be forced as ints! - origin_x = xrange[int((len(xrange)-1)//2)] - origin_y = yrange[int((len(yrange)-1)//2)] - origin_z = np.floor(self.zdiff/2) - self.origin = (origin_x, origin_y, origin_z) - self.pos_dict['extinput'] = [self.origin for i in range(self.N_extinput)] - # at this time, each of the unique inputs is per cell - for key in self.p_unique.keys(): - # create the pos_dict for all the sources - self.pos_dict[key] = [self.origin for i in range(self.N_cells)] - - # cell counting routine - def __count_cells (self): - # cellname list is used *only* for this purpose for now - for src in self.cellname_list: - # if it's a cell, then add the number to total number of cells - self.N[src] = len(self.pos_dict[src]) - self.N_cells += self.N[src] - - # general counting method requires pos_dict is correct for each source - # and that all sources are represented - def __count_extsrcs (self): - # all src numbers are based off of length of pos_dict entry - # generally done here in lieu of upstream changes - for src in self.extname_list: - self.N[src] = len(self.pos_dict[src]) - - # creates gid dicts and pos_lists - def __create_gid_dict (self): - # initialize gid index gid_ind to start at 0 - gid_ind = [0] - # append a new gid_ind based on previous and next cell count - # order is guaranteed by self.src_list_new - for i in range(len(self.src_list_new)): - # N = self.src_list_new[i][1] - # grab the src name in ordered list src_list_new - src = self.src_list_new[i] - # query the N dict for that number and append here to gid_ind, based on previous entry - gid_ind.append(gid_ind[i]+self.N[src]) - # accumulate total source count - self.N_src += self.N[src] - # now actually assign the ranges - for i in range(len(self.src_list_new)): - src = self.src_list_new[i] - self.gid_dict[src] = range(gid_ind[i], gid_ind[i+1]) - - # this happens on EACH node - # creates self.__gid_list for THIS node - def __gid_assign (self): - # round robin assignment of gids - for gid in range(self.rank, self.N_cells, self.n_hosts): - # set the cell gid - self.pc.set_gid2node(gid, self.rank) - self.__gid_list.append(gid) - # now to do the cell-specific external input gids on the same proc - # these are guaranteed to exist because all of these inputs were created - # for each cell - for key in self.p_unique.keys(): - gid_input = gid + self.gid_dict[key][0] - self.pc.set_gid2node(gid_input, self.rank) - self.__gid_list.append(gid_input) - # legacy handling of the external inputs - # NOT perfectly balanced for now - for gid_base in range(self.rank, self.N_extinput, self.n_hosts): - # shift the gid_base to the extinput gid - gid = gid_base + self.gid_dict['extinput'][0] - # set as usual - self.pc.set_gid2node(gid, self.rank) - self.__gid_list.append(gid) - # extremely important to get the gids in the right order - self.__gid_list.sort() - - # reverse lookup of gid to type - def gid_to_type (self, gid): - for gidtype, gids in self.gid_dict.items(): - if gid in gids: - return gidtype - - """ - def checkInputOn (self, type): - if type.startswith('ev'): - if self.p['useEvokedInputs']: - return True - else: - return False - return True - """ - - # reset src (source/external) event times - # evinputinc is an offset for evoked inputs (added to mean start time - e.g. per trial increment) - def reset_src_event_times (self, seed=None,debug=False, inc_evinput = 0.0): - if debug: - print('in reset_src_input_times') - print('self.extinput_list:',self.extinput_list) - print('self.ext_list:',type(self.ext_list),self.ext_list) - - for feed in self.extinput_list: - if seed is None: - feed.inc_prng(1000) - else: - feed.set_prng(seed) - feed.set_event_times(inc_evinput) # uses feed.seed - - for k,lfeed in self.ext_list.items(): # dictionary of lists... - for feed in lfeed: # of feeds - if seed is None: - feed.inc_prng(1) - else: - feed.set_prng(seed) - feed.set_event_times(inc_evinput) # uses feed.seed - - # parallel create cells AND external inputs (feeds) - # these are spike SOURCES but cells are also targets - # external inputs are not targets - def __create_all_src (self): - #print('in __create_all_src') - # loop through gids on this node - for gid in self.__gid_list: - # check existence of gid with Neuron - if self.pc.gid_exists(gid): - # get type of cell and pos via gid - # now should be valid for ext inputs - type = self.gid_to_type(gid) - type_pos_ind = gid - self.gid_dict[type][0] - pos = self.pos_dict[type][type_pos_ind] - # figure out which cell type is assoc with the gid - # create cells based on loc property - # creates a NetCon object internally to Neuron - if type == 'L2_pyramidal': - self.cells.append(L2Pyr(gid, pos, self.p)) - self.pc.cell(gid, self.cells[-1].connect_to_target(None, self.p['threshold'])) - # run the IClamp function here - # create_all_IClamp() is defined in L2Pyr (etc) - self.cells[-1].create_all_IClamp(self.p) - if self.p['save_vsoma']: self.cells[-1].record_volt_soma() - elif type == 'L5_pyramidal': - self.cells.append(L5Pyr(gid, pos, self.p)) - self.pc.cell(gid, self.cells[-1].connect_to_target(None,self.p['threshold'])) - # run the IClamp function here - self.cells[-1].create_all_IClamp(self.p) - if self.p['save_vsoma']: self.cells[-1].record_volt_soma() - elif type == 'L2_basket': - self.cells.append(L2Basket(gid, pos)) - self.pc.cell(gid, self.cells[-1].connect_to_target(None,self.p['threshold'])) - # also run the IClamp for L2_basket - self.cells[-1].create_all_IClamp(self.p) - if self.p['save_vsoma']: self.cells[-1].record_volt_soma() - elif type == 'L5_basket': - self.cells.append(L5Basket(gid, pos)) - self.pc.cell(gid, self.cells[-1].connect_to_target(None,self.p['threshold'])) - # run the IClamp function here - self.cells[-1].create_all_IClamp(self.p) - if self.p['save_vsoma']: self.cells[-1].record_volt_soma() - elif type == 'extinput': - #print('type',type) - # to find param index, take difference between REAL gid - # here and gid start point of the items - p_ind = gid - self.gid_dict['extinput'][0] - # now use the param index in the params and create - # the cell and artificial NetCon - self.extinput_list.append(ParFeedAll(type, None, self.p_ext[p_ind], gid)) - self.pc.cell(gid, self.extinput_list[-1].connect_to_target(self.p['threshold'])) - elif type in self.p_unique.keys(): - gid_post = gid - self.gid_dict[type][0] - cell_type = self.gid_to_type(gid_post) - # create dictionary entry, append to list - self.ext_list[type].append(ParFeedAll(type, cell_type, self.p_unique[type], gid)) - self.pc.cell(gid, self.ext_list[type][-1].connect_to_target(self.p['threshold'])) - else: - print("None of these types in Net()") - exit() - else: - print("GID does not exist. See Cell()") - exit() - - # connections: - # this NODE is aware of its cells as targets - # for each syn, return list of source GIDs. - # for each item in the list, do a: - # nc = pc.gid_connect(source_gid, target_syn), weight,delay - # Both for synapses AND for external inputs - def __parnet_connect (self): - # loop over target zipped gids and cells - # cells has NO extinputs anyway. also no extgausses - for gid, cell in zip(self.__gid_list, self.cells): - # ignore iteration over inputs, since they are NOT targets - if self.pc.gid_exists(gid) and self.gid_to_type(gid) is not 'extinput': - # print("rank:", self.rank, "gid:", gid, cell, self.gid_to_type(gid)) - # for each gid, find all the other cells connected to it, based on gid - # this MUST be defined in EACH class of cell in self.cells - # parconnect receives connections from other cells - # parreceive receives connections from external inputs - cell.parconnect(gid, self.gid_dict, self.pos_dict, self.p) - cell.parreceive(gid, self.gid_dict, self.pos_dict, self.p_ext) - # now do the unique inputs specific to these cells - # parreceive_ext receives connections from UNIQUE external inputs - for type in self.p_unique.keys(): - p_type = self.p_unique[type] - # print('parnet_connect p_type:',p_type) - cell.parreceive_ext(type, gid, self.gid_dict, self.pos_dict, p_type) - - # setup spike recording for this node - def __record_spikes (self): - # iterate through gids on this node and - # set to record spikes in spike time vec and id vec - # agnostic to type of source, will sort that out later - for gid in self.__gid_list: - if self.pc.gid_exists(gid): - self.pc.spike_record(gid, self.spiketimes, self.spikegids) - - def get_vsoma (self): - dsoma = {} - for cell in self.cells: dsoma[cell.gid] = (cell.celltype, np.array(cell.vsoma.to_python())) - return dsoma - - # aggregate recording all the somatic voltages for pyr - def aggregate_currents (self): - """ this method must be run post-integration - """ - # this is quite ugly - for cell in self.cells: - # check for celltype - if cell.celltype == 'L5_pyramidal': - # iterate over somatic currents, assumes this list exists - # is guaranteed in L5Pyr() - for key, I_soma in cell.dict_currents.items(): - # self.current_L5Pyr_soma was created upon - # in parallel, each node has its own Net() - self.current['L5Pyr_soma'].add(I_soma) - elif cell.celltype == 'L2_pyramidal': - for key, I_soma in cell.dict_currents.items(): - # self.current_L5Pyr_soma was created upon - # in parallel, each node has its own Net() - self.current['L2Pyr_soma'].add(I_soma) - - # recording debug function - def rec_debug (self, rank_exec, gid): - # only execute on this rank, make sure called properly - if rank_exec == self.rank: - # only if the gid exists here - # this will break if non-cell source is attempted - if gid in self.__gid_list: - n = self.__gid_list.index(gid) - v = h.Vector() - v.record(self.cells[n].soma(0.5)._ref_v) - return v - - # initializes the state closer to baseline - def state_init (self): - for cell in self.cells: - seclist = h.SectionList() - seclist.wholetree(sec=cell.soma) - for sect in seclist: - for seg in sect: - if cell.celltype == 'L2_pyramidal': - seg.v = -71.46 - elif cell.celltype == 'L5_pyramidal': - if sect.name() == 'L5Pyr_apical_1': - seg.v = -71.32 - elif sect.name() == 'L5Pyr_apical_2': - seg.v = -69.08 - elif sect.name() == 'L5Pyr_apical_tuft': - seg.v = -67.30 - else: - seg.v = -72. - elif cell.celltype == 'L2_basket': - seg.v = -64.9737 - elif cell.celltype == 'L5_basket': - seg.v = -64.9737 - - # move cells 3d positions to positions used for wiring - def movecellstopos (self): - for cell in self.cells: cell.movetopos() diff --git a/nsgr.py b/nsgr.py deleted file mode 100644 index 4d70de488..000000000 --- a/nsgr.py +++ /dev/null @@ -1,270 +0,0 @@ -# based on https://github.com/kenneth59715/nsg-rest-client/blob/master/nsg.nopassword.ipynb -# This works with python 3, with requests module installed -# use port 8443 for production, 8444 for test -# register at https://www.nsgportal.org/reg/reg.php for username and password - -import os -import requests -import xml.etree.ElementTree -import time -import sys -import re -import zipfile -import tarfile -import glob -from conf import dconf - -debug = dconf['debug'] - -def getuserpass (): - f = open('nsgr.txt') - l = f.readlines() - CRA_USER = l[0].strip() - PASSWORD = l[1].strip() # #'changeme' - f.close() - return CRA_USER,PASSWORD - -CRA_USER,PASSWORD = getuserpass() # this will be collected from the HNN GUI later on - -# for production version: -# log in at https://nsgr.sdsc.edu:8443/restusers/login.action -# Tool names can be found at Developer->Documentation (Tools: How to Configure Specific Tools) -# create a new application at Developer->Application Management (Create New Application) -# save the Application Key for use in REST requests - -KEY = 'HNN-418776D750A84FC28A19D5EF1C7B4933' -TOOL = 'SINGULARITY_HNN_TG' -URL = 'https://nsgr.sdsc.edu:8443/cipresrest/v1' # for production version - -def createpayload (paramf, ntrial, tstop): - # returns dictionary of parameters for the NSG job - payload = {'metadata.statusEmail' : 'true'} - payload['vparam.runtime_'] = 0.1 # 0.5 - payload['vparam.filename_'] = 'run.py' - if ntrial == 0: ntrial = 1 - payload['vparam.cmdlineopts_'] = '-nohomeout -paramf ' + os.path.join('param',paramf) + ' ' + str(ntrial) - payload['vparam.number_nodes_'] = 1 - payload['tool'] = TOOL - return payload - -# -def prepinputzip (fout='inputfile.zip'): - """ prepares input zip file for NSGR; file contains all py,mod,param,cfg - files needed to run the simulation """ - try: - if debug: print('Preparing NSGR input zip file...',fout) - fp = zipfile.ZipFile(fout, "w") - lglob = ['*.py','mod/*.mod','*.cfg','param/*.param','res/*.png','Makefile'] - for glb in lglob: - for name in glob.glob(glb): - #if debug: print('adding:',os.path.realpath(name)) - if name.endswith('.mod'): - fp.write(name, 'hnn/mod/'+os.path.basename(name), zipfile.ZIP_DEFLATED) - elif name.endswith('.param'): - fp.write(name,'hnn/param/'+os.path.basename(name),zipfile.ZIP_DEFLATED) - elif name.endswith('.png'): - fp.write(name,'hnn/res/'+os.path.basename(name),zipfile.ZIP_DEFLATED) - else: - fp.write(name, 'hnn/'+os.path.basename(name), zipfile.ZIP_DEFLATED) - fp.close() - return True - except: - print('prepinputzip ERR: could not prepare input zip file',fout,'for NSGR.') - return False - -# -def untar (fname): - # extract contents of tar gz file to current directory - tar = tarfile.open(fname) - tar.extractall() - tar.close() - print("Extracted",fname," in Current Directory.") - -# -def procoutputtar (fname='output.tar.gz'):#,simstr='default'): - """ process HNN NSGR output tar file, saving simulation data - and param file to appropriate directories """ - try: - tar = tarfile.open(fname) - for member in tar.getmembers(): - if member.isreg(): # skip if not a file (e.g. directory) - f = member.name - if f.count('data')>0: - lp = f.split(os.path.sep) - member.name = os.path.basename(member.name) # remove the path by resetting it - tar.extract(member,os.path.join('data',lp[-2])) # extract to data subdir - #tar.extract(member,os.path.join('data',simstr)) # extract to data subdir - if f.endswith('.param'): - tar.extract(member,'param') # extract to param subdir - tar.close() - if debug: print("Extracted",fname) - return True - except: - print('procoutputtar ERR: Could not extract contents of ',fname) - return False - -# -def cleanup (zippath='inputfile.zip'): - # cleanup the temporary NSGR files - try: - l = ['output.tar.gz', zippath, 'STDERR', 'STDOUT', 'scheduler_stdout.txt', 'scheduler_stderr.txt'] - for f in l: os.unlink(f) - except: - print('Could not cleanup temp files.') - -def runjobNSGR (paramf='default.param', ntrial=1, tstop=710.0): - """ run a simulation job on NSG using Restful interface; first prepares input zip - file, then submits job and waits for it to finish, finally downloads simulation output - data and extracts it to appropriate location """ - - try: - - payload = createpayload(paramf,ntrial,tstop) - print('payload:',payload) - headers = {'cipres-appkey' : KEY} # application KEY - zippath = os.path.realpath('inputfile.zip') - - if not prepinputzip(zippath): - print('runjobNSGR ERR: could not prepare NSGR input zip file',zippath) - return False - - files = {'input.infile_' : open(zippath,'rb')} # input zip file with code to run - - r = requests.post('{}/job/{}'.format(URL, CRA_USER), auth=(CRA_USER, PASSWORD), data=payload, headers=headers, files=files) - #if debug: print(r.text) - root = xml.etree.ElementTree.fromstring(r.text) - - #if debug: print(r.text) - print(r.url) - - for child in root: - if child.tag == 'resultsUri': - for urlchild in child: - if urlchild.tag == 'url': - outputuri = urlchild.text - if child.tag == 'selfUri': - for urlchild in child: - if urlchild.tag == 'url': - selfuri = urlchild.text - - if debug: print(outputuri) - if debug: print(selfuri) - - print('Waiting for NSG job to complete. . .') - jobdone = False - while not jobdone: - r = requests.get(selfuri, auth=(CRA_USER, PASSWORD), headers=headers) - #if debug: print(r.text) - root = xml.etree.ElementTree.fromstring(r.text) - for child in root: - if child.tag == 'terminalStage': - jobstatus = child.text - if jobstatus == 'false': - time.sleep(5) - else: - jobdone = True - break - - print('Job completion detected, getting download URIs...') - - r = requests.get(outputuri, - headers= headers, auth=(CRA_USER, PASSWORD)) - #if debug: print(r.text) - globaldownloadurilist = [] - root = xml.etree.ElementTree.fromstring(r.text) - for child in root: - if child.tag == 'jobfiles': - for jobchild in child: - if jobchild.tag == 'jobfile': - for downloadchild in jobchild: - if downloadchild.tag == 'downloadUri': - for attchild in downloadchild: - if attchild.tag == 'url': - print(attchild.text) - globaldownloadurilist.append(attchild.text) - - print('NSG download complete.') - - globaloutputdict = {} - for downloaduri in globaldownloadurilist: - r = requests.get(downloaduri, auth=(CRA_USER, PASSWORD), headers=headers) - #if debug: print(r.text) - globaloutputdict[downloaduri] = r.text - - #http://stackoverflow.com/questions/31804799/how-to-get-pdf-filename-with-python-requests - for downloaduri in globaldownloadurilist: - r = requests.get(downloaduri, auth=(CRA_USER, PASSWORD), headers=headers) - print(r.headers) - d = r.headers['content-disposition'] - fname_list = re.findall("filename=(.+)", d) - for fname in fname_list: - if debug: print(fname) - - # download all output files - for downloaduri in globaldownloadurilist: - r = requests.get(downloaduri, auth=(CRA_USER, PASSWORD), headers=headers) - #if debug: print(r.json()) - #r.content - d = r.headers['content-disposition'] - filename_list = re.findall('filename=(.+)', d) - for filename in filename_list: - #http://docs.python-requests.org/en/master/user/quickstart/#raw-response-content - with open(filename, 'wb') as fd: - for chunk in r.iter_content(): - fd.write(chunk) - - # get a list of jobs for user and app key, and terminalStage status - r = requests.get("%s/job/%s" % (URL,CRA_USER), auth=(CRA_USER, PASSWORD), headers=headers) - #print(r.text) - - ldeluri = [] # list of jobs to delete - root = xml.etree.ElementTree.fromstring(r.text) - for child in root: - if child.tag == 'jobs': - for jobchild in child: - if jobchild.tag == 'jobstatus': - for statuschild in jobchild: - if statuschild.tag == 'selfUri': - for selfchild in statuschild: - if selfchild.tag == 'url': - #print(child) - joburi = selfchild.text - jobr = requests.get(joburi, auth=(CRA_USER, PASSWORD), headers=headers) - jobroot = xml.etree.ElementTree.fromstring(jobr.text) - for jobrchild in jobroot: - if jobrchild.tag == 'terminalStage': - jobstatus = jobrchild.text - if debug: print('job url:',joburi,' status terminalStage:',jobstatus) - ldeluri.append(joburi) - - # get information for a single job, print out raw XML, need to set joburi according to above list - # delete an old job, need to set joburi - for joburi in ldeluri: - if debug: print('deleting old job with joburi = ',joburi) - r = requests.get(joburi, headers= headers, auth=(CRA_USER, PASSWORD)) - #print(r.text) - r = requests.delete(joburi, auth=(CRA_USER, PASSWORD), headers=headers) - if debug: print(r.text) - - #if not procoutputtar('output.tar.gz',paramf.split('.param')[0]): - if not procoutputtar('output.tar.gz'): - print('runjobNSGR ERR: could not extract simulation output data.') - return False - - if not debug: cleanup() - - return True - - except: - print('runjobNSGR: unhandled exception!') - return False - -if __name__ == '__main__': - if debug: print(sys.argv) - if len(sys.argv) < 4: - print('usage: python nsgr.py paramf ntrial tstop') - else: - print(sys.argv) - paramf = sys.argv[1].split(os.path.sep)[-1] - print('paramf:',paramf) - runjobNSGR(paramf=paramf, ntrial=int(sys.argv[2]), tstop=float(sys.argv[3])) diff --git a/param/Alpha.param b/param/Alpha.param index 75078df30..4224a3dc5 100644 --- a/param/Alpha.param +++ b/param/Alpha.param @@ -3,14 +3,14 @@ expmt_groups: {Alpha} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 300000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_input_prox: 13 prng_seedcore_input_dist: 14 prng_seedcore_extpois: 4 @@ -210,7 +210,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/AlphaAndBeta.param b/param/AlphaAndBeta.param index ac04295fa..7f41ab9f8 100644 --- a/param/AlphaAndBeta.param +++ b/param/AlphaAndBeta.param @@ -3,14 +3,14 @@ expmt_groups: {AlphaAndBeta} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 300000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 13 prng_seedcore_input_dist: 14 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/AlphaAndBeta2.param b/param/AlphaAndBeta2.param index c4a3e6fc0..b134dda21 100644 --- a/param/AlphaAndBeta2.param +++ b/param/AlphaAndBeta2.param @@ -3,14 +3,14 @@ expmt_groups: {AlphaAndBeta2} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 30000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/AlphaAndBetaJitter0.param b/param/AlphaAndBetaJitter0.param index b24775913..98bba1003 100644 --- a/param/AlphaAndBetaJitter0.param +++ b/param/AlphaAndBetaJitter0.param @@ -10,7 +10,7 @@ save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 150000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/AlphaAndBetaSpike.param b/param/AlphaAndBetaSpike.param index 17a02926a..5a51862c0 100644 --- a/param/AlphaAndBetaSpike.param +++ b/param/AlphaAndBetaSpike.param @@ -3,14 +3,14 @@ expmt_groups: {AlphaAndBetaSpike} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 120 dipole_scalefctr: 150000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/AlphaAndMoreBeta.param b/param/AlphaAndMoreBeta.param index 483bf52c0..cd4c30410 100644 --- a/param/AlphaAndMoreBeta.param +++ b/param/AlphaAndMoreBeta.param @@ -3,14 +3,14 @@ expmt_groups: {AlphaAndMoreBeta} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 150000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/ERPNo100Trials.param b/param/ERPNo100Trials.param index c9bb9a691..22da2767f 100644 --- a/param/ERPNo100Trials.param +++ b/param/ERPNo100Trials.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 30 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/ERPYes100Trials.param b/param/ERPYes100Trials.param index dbdd0a27d..3c33db67a 100644 --- a/param/ERPYes100Trials.param +++ b/param/ERPYes100Trials.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 30 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/ERPYesSupraT.param b/param/ERPYesSupraT.param index 4bc809cff..147e41363 100644 --- a/param/ERPYesSupraT.param +++ b/param/ERPYesSupraT.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 20 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/OnlyRhythmicDist.param b/param/OnlyRhythmicDist.param index 5bf917798..394aaa6af 100644 --- a/param/OnlyRhythmicDist.param +++ b/param/OnlyRhythmicDist.param @@ -3,14 +3,14 @@ expmt_groups: {OnlyRhythmicDist} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 150000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/OnlyRhythmicProx.param b/param/OnlyRhythmicProx.param index ef75a7ea2..7d1e2d41b 100644 --- a/param/OnlyRhythmicProx.param +++ b/param/OnlyRhythmicProx.param @@ -3,14 +3,14 @@ expmt_groups: {OnlyRhythmicProx} tstop: 710. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 40. dipole_scalefctr: 150000.0 dipole_smooth_win: 0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 710.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/SRJ_2007_Fig5A_Super.param b/param/SRJ_2007_Fig5A_Super.param index a18da02be..2dfe20579 100644 --- a/param/SRJ_2007_Fig5A_Super.param +++ b/param/SRJ_2007_Fig5A_Super.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 12.5 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/SRJ_2007_Fig5A_ThreshNonP.param b/param/SRJ_2007_Fig5A_ThreshNonP.param index 75f47ad49..f14a5293c 100644 --- a/param/SRJ_2007_Fig5A_ThreshNonP.param +++ b/param/SRJ_2007_Fig5A_ThreshNonP.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 1.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/SRJ_2007_Fig6_ThreshP.param b/param/SRJ_2007_Fig6_ThreshP.param index ef5b15e01..451e6f270 100644 --- a/param/SRJ_2007_Fig6_ThreshP.param +++ b/param/SRJ_2007_Fig6_ThreshP.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 1.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/default.param b/param/default.param index d492c869f..89ca3e752 100644 --- a/param/default.param +++ b/param/default.param @@ -10,7 +10,7 @@ save_spec_data: 0 f_max_spec: 100 dipole_scalefctr: 3000 dipole_smooth_win: 30 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 4 prng_seedcore_input_dist: 4 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 170.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/gamma_L5ping_L2ping.param b/param/gamma_L5ping_L2ping.param index 1241e861f..e0ff34c48 100644 --- a/param/gamma_L5ping_L2ping.param +++ b/param/gamma_L5ping_L2ping.param @@ -3,14 +3,14 @@ expmt_groups: {gamma_L5ping_L2ping} tstop: 550. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 80. dipole_scalefctr: 30000.0 dipole_smooth_win: 5.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 0 prng_seedcore_input_dist: 0 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 550.0 Itonic_A_L2Pyr_soma: 4.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/gamma_L5weak_L2weak.param b/param/gamma_L5weak_L2weak.param index d4d93cf3e..5532bf69b 100644 --- a/param/gamma_L5weak_L2weak.param +++ b/param/gamma_L5weak_L2weak.param @@ -3,14 +3,14 @@ expmt_groups: {gamma_L5weak_L2weak} tstop: 550. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 100 dipole_scalefctr: 30000.0 dipole_smooth_win: 0.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 0 prng_seedcore_input_dist: 0 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 550.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/gamma_L5weak_L2weak_bursty.param b/param/gamma_L5weak_L2weak_bursty.param index c92af3721..5409adae0 100644 --- a/param/gamma_L5weak_L2weak_bursty.param +++ b/param/gamma_L5weak_L2weak_bursty.param @@ -10,7 +10,7 @@ save_spec_data: 1 f_max_spec: 100 dipole_scalefctr: 5.0 dipole_smooth_win: 0.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 0 prng_seedcore_input_dist: 0 @@ -212,7 +212,7 @@ L5Basket_Pois_A_weight_ampa: 0 L5Basket_Pois_A_weight_nmda: 0 L5Basket_Pois_lamtha: 0 t0_pois: 0.0 -T_pois: -1 +T_pois: 700.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/gamma_rhythmic_drive.param b/param/gamma_rhythmic_drive.param index 4070ebf8c..43b4970e7 100644 --- a/param/gamma_rhythmic_drive.param +++ b/param/gamma_rhythmic_drive.param @@ -3,14 +3,14 @@ expmt_groups: {gamma_rhythmic_drive} tstop: 550. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 140.0 dipole_scalefctr: 30000.0 dipole_smooth_win: 0.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 0 prng_seedcore_input_dist: 0 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 550.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/param/gamma_rhythmic_drive_more_noise.param b/param/gamma_rhythmic_drive_more_noise.param index 495717abc..1bcf33a08 100644 --- a/param/gamma_rhythmic_drive_more_noise.param +++ b/param/gamma_rhythmic_drive_more_noise.param @@ -3,14 +3,14 @@ expmt_groups: {gamma_rhythmic_drive_more_noise} tstop: 550. dt: 0.025 celsius: 37.0 -N_trials: 0 +N_trials: 1 threshold: 0.0 save_figs: 0 save_spec_data: 1 f_max_spec: 140.0 dipole_scalefctr: 30000.0 dipole_smooth_win: 0.0 -save_vsoma: 0 +record_vsoma: 0 prng_seedcore_opt: 0 prng_seedcore_input_prox: 0 prng_seedcore_input_dist: 0 @@ -211,7 +211,7 @@ L5Basket_Pois_A_weight_ampa: 0.0 L5Basket_Pois_A_weight_nmda: 0.0 L5Basket_Pois_lamtha: 0.0 t0_pois: 0.0 -T_pois: -1 +T_pois: 550.0 Itonic_A_L2Pyr_soma: 0.0 Itonic_t0_L2Pyr_soma: 0.0 Itonic_T_L2Pyr_soma: -1.0 diff --git a/paramrw.py b/paramrw.py deleted file mode 100644 index 3d1556705..000000000 --- a/paramrw.py +++ /dev/null @@ -1,980 +0,0 @@ -# paramrw.py - routines for reading the param files -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed dependence on cartesian, updated for python3) -# last major: (SL: cleanup of self.p_all) - -import re -import fileio as fio -import numpy as np -import itertools as it -# from cartesian import cartesian -from params_default import get_params_default - -# get dict of ':' separated params from fn; ignore lines starting with # -def quickreadprm (fn): - d = {} - with open(fn,'r') as fp: - ln = fp.readlines() - for l in ln: - s = l.strip() - if s.startswith('#'): continue - sp = s.split(':') - if len(sp) > 1: - d[sp[0].strip()]=str(sp[1]).strip() - return d - -def validate_param_file (fn): - try: - fp = open(fn, 'r') - except OSError: - print("ERROR: could not open/read file") - raise ValueError - - d = {} - with fp: - try: - ln = fp.readlines() - except UnicodeDecodeError: - print("ERROR: bad file format") - raise ValueError - for l in ln: - s = l.strip() - if s.startswith('#'): continue - sp = s.split(':') - if len(sp) > 1: - d[sp[0].strip()]=str(sp[1]).strip() - if not 'tstop' in d: - print("ERROR: parameter file not valid. Could not find 'tstop'") - raise ValueError - -# get dict of ':' separated params from fn; ignore lines starting with # -def quickgetprm (fn,k,ty): - d = quickreadprm(fn) - return ty(d[k]) - -# check if using ongoing inputs -def usingOngoingInputs (d, lty = ['_prox', '_dist']): - if type(d)==str: d = quickreadprm(d) - tstop = float(d['tstop']) - dpref = {'_prox':'input_prox_A_','_dist':'input_dist_A_'} - try: - for postfix in lty: - if float(d['t0_input'+postfix])<= tstop and \ - float(d['tstop_input'+postfix])>=float(d['t0_input'+postfix]) and \ - float(d['f_input'+postfix])>0.: - for k in ['weight_L2Pyr_ampa','weight_L2Pyr_nmda',\ - 'weight_L5Pyr_ampa','weight_L5Pyr_nmda',\ - 'weight_inh_ampa','weight_inh_nmda']: - if float(d[dpref[postfix]+k])>0.: - #print('usingOngoingInputs:',d[dpref[postfix]+k]) - return True - except: - return False - return False - -# return number of evoked inputs (proximal, distal) -# using dictionary d (or if d is a string, first load the dictionary from filename d) -def countEvokedInputs (d): - if type(d) == str: d = quickreadprm(d) - nprox = ndist = 0 - for k,v in d.items(): - if k.startswith('t_'): - if k.count('evprox') > 0: - nprox += 1 - elif k.count('evdist') > 0: - ndist += 1 - return nprox, ndist - -# check if using any evoked inputs -def usingEvokedInputs (d, lsuffty = ['_evprox_', '_evdist_']): - if type(d) == str: d = quickreadprm(d) - nprox,ndist = countEvokedInputs(d) - tstop = float(d['tstop']) - lsuff = [] - if '_evprox_' in lsuffty: - for i in range(1,nprox+1,1): lsuff.append('_evprox_'+str(i)) - if '_evdist_' in lsuffty: - for i in range(1,ndist+1,1): lsuff.append('_evdist_'+str(i)) - for suff in lsuff: - k = 't' + suff - if k not in d: continue - if float(d[k]) > tstop: continue - k = 'gbar' + suff - for k1 in d.keys(): - if k1.startswith(k): - if float(d[k1]) > 0.0: return True - return False - -# check if using any poisson inputs -def usingPoissonInputs (d): - if type(d)==str: d = quickreadprm(d) - tstop = float(d['tstop']) - if 't0_pois' in d and 'T_pois' in d: - t0_pois = float(d['t0_pois']) - if t0_pois > tstop: return False - T_pois = float(d['T_pois']) - if t0_pois > T_pois and T_pois != -1.0: - return False - for cty in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: - for sy in ['ampa','nmda']: - k = cty+'_Pois_A_weight_'+sy - if k in d: - if float(d[k]) != 0.0: return True - return False - -# check if using any tonic (IClamp) inputs -def usingTonicInputs (d): - if type(d)==str: d = quickreadprm(d) - tstop = float(d['tstop']) - for cty in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket']: - k = 'Itonic_A_' + cty + '_soma' - if k in d: - amp = float(d[k]) - if amp != 0.0: - print(k,'amp != 0.0',amp) - k = 'Itonic_t0_' + cty - t0,t1 = 0.0,-1.0 - if k in d: t0 = float(d[k]) - k = 'Itonic_T_' + cty - if k in d: t1 = float(d[k]) - if t0 > tstop: continue - #print('t0:',t0,'t1:',t1) - if t0 < t1 or t1 == -1.0: return True - return False - -# class controlling multiple simulation files (.param) -class ExpParams(): - def __init__ (self, f_psim, debug=False): - - self.debug = debug - - self.expmt_group_params = [] - - # self.prng_seedcore = {} - # this list is simply to access these easily - self.prng_seed_list = [] - - # read in params from a file - p_all_input = self.__read_sim(f_psim) - self.p_template = dict.fromkeys(self.expmt_group_params) - - # create non-exp params dict from default dict - self.p_all = self.__create_dict_from_default(p_all_input) - - # pop off fixed known vals and create experimental prefix templates - self.__pop_known_values() - - # make dict of coupled params - self.coupled_params = self.__find_coupled_params() - - # create the list of iterated params - self.list_params = self.__create_paramlist() - self.N_sims = len(self.list_params[0][1]) - - # return pdict based on that one value, PLUS append the p_ext here ... yes, hack-y - def return_pdict(self, expmt_group, i): - # p_template was always updated to include the ones from exp and others - p_sim = dict.fromkeys(self.p_template) - - # go through params in list_params - for param, val_list in self.list_params: - if param.startswith('prng_seedcore_'): - p_sim[param] = int(val_list[i]) - else: - p_sim[param] = val_list[i] - - # go through the expmt group-based params - for param, val in self.p_group[expmt_group].items(): - p_sim[param] = val - - # add alpha distributions. A bit hack-y - for param, val in self.alpha_distributions.items(): - p_sim[param] = val - - # Add coupled params - for coupled_param, val_param in self.coupled_params.items(): - p_sim[coupled_param] = p_sim[val_param] - - # Add spec_cmap - p_sim['spec_cmap'] = self.spec_cmap - - return p_sim - - # reads .param file and returns p_all_input dict - def __read_sim(self, f_psim): - lines = fio.clean_lines(f_psim) - - # ignore comments - lines = [line for line in lines if line[0] is not '#'] - p = {} - - for line in lines: - # splits line by ':' - param, val = line.split(": ") - - # sim_prefix is not a rotated variable - # not sure why `if param is 'sim_prefix':` does not work here - if param == 'sim_prefix': - p[param] = str(val) - - # expmt_groups must be listed before other vals - elif param == 'expmt_groups': - # this list will be the preservation of the original order - self.expmt_groups = [expmt_group for expmt_group in val[1:-1].split(', ')] - - # this dict here for easy access - # p_group saves each of the changed params per group - self.p_group = dict.fromkeys(self.expmt_groups) - - # create empty dicts in each - for group in self.p_group: - self.p_group[group] = {} - - elif param.startswith('prng_seedcore_'): - p[param] = int(val) - # key = param.split('prng_seedcore_')[-1] - # self.prng_seedcore[key] = val - - # only add values that will change - if p[param] == -1: - self.prng_seed_list.append(param) - - elif param.startswith('distribution_'): - p[param] = str(val) - - elif param == 'Run_Date': - pass - - else: - # assign group params first - if val[0] is '{': - # check for a linspace as a param! - if val[1] is 'L': - # in this case, val_range must be as long as the correct expmt_group length - # everything beyond that will be truncated by the zip operation below - # param passed will strip away the curly braces and just pass the linspace - val_range = self.__expand_linspace(val[1:-1]) - else: - val_range = self.__expand_array(val) - - # add the expmt_group param to the list if it's not already present - if param not in self.expmt_group_params: - self.expmt_group_params.append(param) - - # parcel out vals to exp groups with assigned param names - for expmt_group, val in zip(self.expmt_groups, val_range): - self.p_group[expmt_group][param] = val - - # interpret this as a list of vals - # type floats to a np array - elif val[0] is '[': - p[param] = self.__expand_array(val) - - # interpret as a linspace - elif val[0] is 'L': - p[param] = self.__expand_linspace(val) - - elif val[0] is 'A': - p[param] = self.__expand_arange(val) - - else: - try: - p[param] = float(val) - except ValueError: - p[param] = str(val) - - # hack-y. sorry, future - # tstop_* = 0 is valid now, resets to the actual tstop - # with the added bonus of saving this time to the indiv params - for param, val in p.items(): - if param.startswith('tstop_'): - if isinstance(val, float): - if val == 0: - p[param] = p['tstop'] - elif isinstance(val, np.ndarray): - p[param][p[param] == 0] = p['tstop'] - - return p - - # general function to expand a list of values - def __expand_array(self, str_val): - val_list = str_val[1:-1].split(', ') - val_range = np.array([float(item) for item in val_list]) - - return val_range - - # general function to expand the arange - def __expand_arange(self, str_val): - # strip away the leading character along with the brackets and split the csv values - val_list = str_val[2:-1].split(', ') - - # use the values in val_list as params for np.linspace - val_range = np.arange(float(val_list[0]), float(val_list[1]), float(val_list[2])) - - # return the final linspace expanded - return val_range - - # general function to expand the linspace - def __expand_linspace(self, str_val): - # strip away the leading character along with the brackets and split the csv values - val_list = str_val[2:-1].split(', ') - - # use the values in val_list as params for np.linspace - val_range = np.linspace(float(val_list[0]), float(val_list[1]), int(val_list[2])) - - # return the final linspace expanded - return val_range - - # creates dict of params whose values are to be coupled - def __find_coupled_params(self): - coupled_params = {} - # iterates over all key/value pairs to find vals that are strings - for key, val in self.p_all.items(): - if isinstance(val, str): - # check that string is another param in p_all - if val in self.p_all.keys(): - coupled_params[key] = val - else: - print("Unknown key: %s. Probably going to error." % (val)) - - # Pop coupled params - for key in coupled_params: - self.p_all.pop(key) - - return coupled_params - - # pop known values & strings off of the params list - def __pop_known_values(self): - self.sim_prefix = self.p_all.pop('sim_prefix') - self.spec_cmap = self.p_all.pop('spec_cmap') - - # create an experimental string prefix template - self.exp_prefix_str = self.sim_prefix+"-%03d" - self.trial_prefix_str = self.exp_prefix_str+"-T%02d" - - # self.N_trials = int(self.p_all.pop('N_trials')) - # self.prng_state = self.p_all.pop('prng_state')[1:-1] - - # Save alpha distribution types in dict for later use - self.alpha_distributions = { - 'distribution_prox': self.p_all.pop('distribution_prox'), - 'distribution_dist': self.p_all.pop('distribution_dist'), - } - - # create the dict based on the default param dict - def __create_dict_from_default (self, p_all_input): - nprox, ndist = countEvokedInputs(p_all_input) - # print('found nprox,ndist ev inputs:', nprox, ndist) - - # create a copy of params_default through which to iterate - p_all = get_params_default(nprox, ndist) - - # now find ONLY the values that are present in the supplied p_all_input - # based on the default dict - for key in p_all.keys(): - # automatically expects that keys are either in p_all_input OR will resort - # to default value - if key in p_all_input: - # pop val off so the remaining items in p_all_input are extraneous - p_all[key] = p_all_input.pop(key) - - # now display extraneous keys, if there were any - if len(p_all_input): - if self.debug: print("Invalid keys from param file not found in default params: %s" % str(p_all_input.keys())) - - return p_all - - # creates all combination of non-exp params - def __create_paramlist (self): - # p_all is the dict specifying all of the changing params - plist = [] - - # get all key/val pairs from the all dict - list_sorted = [item for item in self.p_all.items()] - - # sort the list by the key (alpha) - list_sorted.sort(key=lambda x: x[0]) - - # grab just the keys (but now in order) - self.keys_sorted = [item[0] for item in list_sorted] - self.p_template.update(dict.fromkeys(self.keys_sorted)) - - # grab just the values (but now in order) - # plist = [item[1] for item in list_sorted] - for item in list_sorted: - if isinstance(item[1], np.ndarray): - plist.append(item[1]) - else: - plist.append(np.array([item[1]])) - - # print(plist) - # vals_all = cartesian(plist) - vals_new = np.array([np.array(val) for val in it.product(*plist)]) - vals_new = vals_new.transpose() - - return [item for item in zip(self.keys_sorted, vals_new)] - - # Find keys that change anytime during simulation - # (i.e. have more than one associated value) - def get_key_types(self): - key_dict = { - 'expmt_keys': [], - 'dynamic_keys': [], - 'static_keys': [], - } - - # Save exmpt keys - key_dict['expmt_keys'] = self.expmt_group_params - - # Save expmt keys as dynamic keys - key_dict['dynamic_keys'] = self.expmt_group_params - - # Find keys that change run to run within experiments - for key in self.p_all.keys(): - # if key has length associated with it, must change run to run - try: - len(self.p_all[key]) - - # Before storing key, check to make sure it has not already been stored - if key not in key_dict['dynamic_keys']: - key_dict['dynamic_keys'].append(key) - - except TypeError: - key_dict['static_keys'].append(key) - - # Check if coupled params are dynamic - for dep_param, ind_param in self.coupled_params.items(): - if ind_param in key_dict['dynamic_keys']: - key_dict['dynamic_keys'].append(dep_param) - else: - key_dict['static_keys'].append(dep_param) - - return key_dict - -# reads params from a generated txt file and returns gid dict and p dict -def read (fparam): - lines = fio.clean_lines(fparam) - p = {} - gid_dict = {} - for line in lines: - if line.startswith('#'): continue - keystring, val = line.split(": ") - key = keystring.strip() - if val[0] is '[': - val_range = val[1:-1].split(', ') - if len(val_range) is 2: - ind_start = int(val_range[0]) - ind_end = int(val_range[1]) + 1 - gid_dict[key] = np.arange(ind_start, ind_end) - else: - gid_dict[key] = np.array([]) - else: - try: - p[key] = float(val) - except ValueError: - p[key] = str(val) - return gid_dict, p - -# write the params to a filename -def write(fparam, p, gid_list): - """ now sorting - """ - # sort the items in the dict by key - # p_sorted = [item for item in p.items()] - p_keys = [key for key, val in p.items()] - p_sorted = [(key, p[key]) for key in p_keys] - # for some reason this is now crashing in python/mpi - # specifically, lambda sorting in place? - # p_sorted = [item for item in p.items()] - # p_sorted.sort(key=lambda x: x[0]) - # open the file for writing - with open(fparam, 'w') as f: - pstring = '%26s: ' - # write the gid info first - for key in gid_list.keys(): - f.write(pstring % key) - if len(gid_list[key]): - f.write('[%4i, %4i] ' % (gid_list[key][0], gid_list[key][-1])) - else: - f.write('[]') - f.write('\n') - # do the params in p_sorted - for param in p_sorted: - key, val = param - f.write(pstring % key) - if key.startswith('N_'): - f.write('%i\n' % val) - else: - f.write(str(val)+'\n') - -# Searches f_param for any match of p -def find_param(fparam, param_key): - _, p = read(fparam) - - try: - return p[param_key] - - except KeyError: - return "There is no key by the name %s" % param_key - -# reads the simgroup name from fparam -def read_sim_prefix(fparam): - lines = fio.clean_lines(fparam) - param_list = [line for line in lines if line.split(': ')[0].startswith('sim_prefix')] - - # Assume we found something ... - if param_list: - return param_list[0].split(" ")[1] - else: - print("No sim_prefix found") - return 0 - -# Finds the experiments list from the simulation param file (.param) -def read_expmt_groups(fparam): - lines = fio.clean_lines(fparam) - lines = [line for line in lines if line.split(': ')[0] == 'expmt_groups'] - - try: - return lines[0].split(': ')[1][1:-1].split(', ') - except: - print("Couldn't get a handle on expmts") - return 0 - -# qnd function to add feeds if they are sensible -def feed_validate(p_ext, d, tstop): - """ whips into shape ones that are not - could be properly made into a meaningful class. - """ - # only append if t0 is less than simulation tstop - if tstop > d['t0']: - # # reset tstop if the specified tstop exceeds the - # # simulation runtime - # if d['tstop'] == 0: - # d['tstop'] = tstop - - if d['tstop'] > tstop: - d['tstop'] = tstop - - # if stdev is zero, increase synaptic weights 5 fold to make - # single input equivalent to 5 simultaneous input to prevent spiking <<---- SN: WHAT IS THIS RULE!?!?!? - if not d['stdev'] and d['distribution'] != 'uniform': - for key in d.keys(): - if key.endswith('Pyr'): - d[key] = (d[key][0] * 5., d[key][1]) - elif key.endswith('Basket'): - d[key] = (d[key][0] * 5., d[key][1]) - - # if L5 delay is -1, use same delays as L2 unless L2 delay is 0.1 in which case use 1. <<---- SN: WHAT IS THIS RULE!?!?!? - if d['L5Pyr_ampa'][1] == -1: - for key in d.keys(): - if key.startswith('L5'): - if d['L2Pyr'][1] != 0.1: - d[key] = (d[key][0], d['L2Pyr'][1]) - else: - d[key] = (d[key][0], 1.) - - p_ext.append(d) - - return p_ext - -# -def checkevokedsynkeys (p, nprox, ndist): - # make sure ampa,nmda gbar values are in the param dict for evoked inputs(for backwards compatibility) - lctprox = ['L2Pyr','L5Pyr','L2Basket','L5Basket'] # evoked distal target cell types - lctdist = ['L2Pyr','L5Pyr','L2Basket'] # evoked proximal target cell types - lsy = ['ampa','nmda'] # synapse types used in evoked inputs - for nev,pref,lct in zip([nprox,ndist],['evprox_','evdist_'],[lctprox,lctdist]): - for i in range(nev): - skey = pref + str(i+1) - for sy in lsy: - for ct in lct: - k = 'gbar_'+skey+'_'+ct+'_'+sy - # if the synapse-specific gbar not present, use the existing weight for both ampa,nmda - if k not in p: - p[k] = p['gbar_'+skey+'_'+ct] - -# -def checkpoissynkeys (p): - # make sure ampa,nmda gbar values are in the param dict for Poisson inputs (for backwards compatibility) - lct = ['L2Pyr','L5Pyr','L2Basket','L5Basket'] # target cell types - lsy = ['ampa','nmda'] # synapse types used in Poisson inputs - for ct in lct: - for sy in lsy: - k = ct + '_Pois_A_weight_' + sy - # if the synapse-specific weight not present, set it to 0 in p - if k not in p: - p[k] = 0.0 - -# creates the external feed params based on individual simulation params p -def create_pext (p, tstop): - # indexable py list of param dicts for parallel - # turn off individual feeds by commenting out relevant line here. - # always valid, no matter the length - p_ext = [] - - # p_unique is a dict of input param types that end up going to each cell uniquely - p_unique = {} - - # default params for proximal rhythmic inputs - feed_prox = { - 'f_input': p['f_input_prox'], - 't0': p['t0_input_prox'], - 'tstop': p['tstop_input_prox'], - 'stdev': p['f_stdev_prox'], - 'L2Pyr_ampa': (p['input_prox_A_weight_L2Pyr_ampa'], p['input_prox_A_delay_L2']), - 'L2Pyr_nmda': (p['input_prox_A_weight_L2Pyr_nmda'], p['input_prox_A_delay_L2']), - 'L5Pyr_ampa': (p['input_prox_A_weight_L5Pyr_ampa'], p['input_prox_A_delay_L5']), - 'L5Pyr_nmda': (p['input_prox_A_weight_L5Pyr_nmda'], p['input_prox_A_delay_L5']), - 'L2Basket_ampa': (p['input_prox_A_weight_L2Basket_ampa'], p['input_prox_A_delay_L2']), - 'L2Basket_nmda': (p['input_prox_A_weight_L2Basket_nmda'], p['input_prox_A_delay_L2']), - 'L5Basket_ampa': (p['input_prox_A_weight_L5Basket_ampa'], p['input_prox_A_delay_L5']), - 'L5Basket_nmda': (p['input_prox_A_weight_L5Basket_nmda'], p['input_prox_A_delay_L5']), - 'events_per_cycle': p['events_per_cycle_prox'], - 'prng_seedcore': int(p['prng_seedcore_input_prox']), - 'distribution': p['distribution_prox'], - 'lamtha': 100., - 'loc': 'proximal', - 'repeats': p['repeats_prox'], - 't0_stdev': p['t0_input_stdev_prox'], - 'threshold': p['threshold'] - } - - # ensures time interval makes sense - p_ext = feed_validate(p_ext, feed_prox, tstop) - - # default params for distal rhythmic inputs - feed_dist = { - 'f_input': p['f_input_dist'], - 't0': p['t0_input_dist'], - 'tstop': p['tstop_input_dist'], - 'stdev': p['f_stdev_dist'], - 'L2Pyr_ampa': (p['input_dist_A_weight_L2Pyr_ampa'], p['input_dist_A_delay_L2']), - 'L2Pyr_nmda': (p['input_dist_A_weight_L2Pyr_nmda'], p['input_dist_A_delay_L2']), - 'L5Pyr_ampa': (p['input_dist_A_weight_L5Pyr_ampa'], p['input_dist_A_delay_L5']), - 'L5Pyr_nmda': (p['input_dist_A_weight_L5Pyr_nmda'], p['input_dist_A_delay_L5']), - 'L2Basket_ampa': (p['input_dist_A_weight_L2Basket_ampa'], p['input_dist_A_delay_L2']), - 'L2Basket_nmda': (p['input_dist_A_weight_L2Basket_nmda'], p['input_dist_A_delay_L2']), - 'events_per_cycle': p['events_per_cycle_dist'], - 'prng_seedcore': int(p['prng_seedcore_input_dist']), - 'distribution': p['distribution_dist'], - 'lamtha': 100., - 'loc': 'distal', - 'repeats': p['repeats_dist'], - 't0_stdev': p['t0_input_stdev_dist'], - 'threshold': p['threshold'] - } - - p_ext = feed_validate(p_ext, feed_dist, tstop) - - nprox, ndist = countEvokedInputs(p) - # print('nprox,ndist evoked inputs:', nprox, ndist) - - # NEW: make sure all evoked synaptic weights present (for backwards compatibility) - # could cause differences between output of param files since some nmda weights should - # be 0 while others > 0 - checkevokedsynkeys(p,nprox,ndist) - - # Create proximal evoked response parameters - # f_input needs to be defined as 0 - for i in range(nprox): - skey = 'evprox_' + str(i+1) - p_unique['evprox' + str(i+1)] = { - 't0': p['t_' + skey], - 'L2_pyramidal':(p['gbar_'+skey+'_L2Pyr_ampa'],p['gbar_'+skey+'_L2Pyr_nmda'],0.1,p['sigma_t_'+skey]), - 'L2_basket':(p['gbar_'+skey+'_L2Basket_ampa'],p['gbar_'+skey+'_L2Basket_nmda'],0.1,p['sigma_t_'+skey]), - 'L5_pyramidal':(p['gbar_'+skey+'_L5Pyr_ampa'],p['gbar_'+skey+'_L5Pyr_nmda'],1.,p['sigma_t_'+skey]), - 'L5_basket':(p['gbar_'+skey+'_L5Basket_ampa'],p['gbar_'+skey+'_L5Basket_nmda'],1.,p['sigma_t_'+skey]), - 'prng_seedcore': int(p['prng_seedcore_' + skey]), - 'lamtha_space': 3., - 'loc': 'proximal', - 'sync_evinput': p['sync_evinput'], - 'threshold': p['threshold'], - 'numspikes': p['numspikes_' + skey] - } - - # Create distal evoked response parameters - # f_input needs to be defined as 0 - for i in range(ndist): - skey = 'evdist_' + str(i+1) - p_unique['evdist' + str(i+1)] = { - 't0': p['t_' + skey], - 'L2_pyramidal':(p['gbar_'+skey+'_L2Pyr_ampa'],p['gbar_'+skey+'_L2Pyr_nmda'],0.1,p['sigma_t_'+skey]), - 'L5_pyramidal':(p['gbar_'+skey+'_L5Pyr_ampa'],p['gbar_'+skey+'_L5Pyr_nmda'],0.1,p['sigma_t_'+skey]), - 'L2_basket':(p['gbar_'+skey+'_L2Basket_ampa'],p['gbar_'+skey+'_L2Basket_nmda'],0.1,p['sigma_t_' + skey]), - 'prng_seedcore': int(p['prng_seedcore_' + skey]), - 'lamtha_space': 3., - 'loc': 'distal', - 'sync_evinput': p['sync_evinput'], - 'threshold': p['threshold'], - 'numspikes': p['numspikes_' + skey] - } - - # this needs to create many feeds - # (amplitude, delay, mu, sigma). ordered this way to preserve compatibility - p_unique['extgauss'] = { # NEW: note double weight specification since only use ampa for gauss inputs - 'stim': 'gaussian', - 'L2_basket':(p['L2Basket_Gauss_A_weight'],p['L2Basket_Gauss_A_weight'],1.,p['L2Basket_Gauss_mu'],p['L2Basket_Gauss_sigma']), - 'L2_pyramidal':(p['L2Pyr_Gauss_A_weight'],p['L2Pyr_Gauss_A_weight'],0.1,p['L2Pyr_Gauss_mu'],p['L2Pyr_Gauss_sigma']), - 'L5_basket':(p['L5Basket_Gauss_A_weight'],p['L5Basket_Gauss_A_weight'],1.,p['L5Basket_Gauss_mu'],p['L5Basket_Gauss_sigma']), - 'L5_pyramidal':(p['L5Pyr_Gauss_A_weight'],p['L5Pyr_Gauss_A_weight'],1.,p['L5Pyr_Gauss_mu'],p['L5Pyr_Gauss_sigma']), - 'lamtha': 100., - 'prng_seedcore': int(p['prng_seedcore_extgauss']), - 'loc': 'proximal', - 'threshold': p['threshold'] - } - - checkpoissynkeys(p) - - # define T_pois as 0 or -1 to reset automatically to tstop - if p['T_pois'] in (0, -1): p['T_pois'] = tstop - - # Poisson distributed inputs to proximal - p_unique['extpois'] = {# NEW: setting up AMPA and NMDA for Poisson inputs; why delays differ? - 'stim': 'poisson', - 'L2_basket': (p['L2Basket_Pois_A_weight_ampa'],p['L2Basket_Pois_A_weight_nmda'],1.,p['L2Basket_Pois_lamtha']), - 'L2_pyramidal': (p['L2Pyr_Pois_A_weight_ampa'],p['L2Pyr_Pois_A_weight_nmda'], 0.1,p['L2Pyr_Pois_lamtha']), - 'L5_basket': (p['L5Basket_Pois_A_weight_ampa'],p['L5Basket_Pois_A_weight_nmda'],1.,p['L5Basket_Pois_lamtha']), - 'L5_pyramidal': (p['L5Pyr_Pois_A_weight_ampa'],p['L5Pyr_Pois_A_weight_nmda'],1.,p['L5Pyr_Pois_lamtha']), - 'lamtha_space': 100., - 'prng_seedcore': int(p['prng_seedcore_extpois']), - 't_interval': (p['t0_pois'], p['T_pois']), - 'loc': 'proximal', - 'threshold': p['threshold'] - } - - return p_ext, p_unique - -# Finds the changed variables -# sort of inefficient, probably should be part of something else -# not worried about all that right now, as it appears to work -# brittle in that the match string needs to be correct to find all the changed params -# is redundant with(?) get_key_types() dynamic keys information -def changed_vars(fparam): - # Strip empty lines and comments - lines = fio.clean_lines(fparam) - lines = [line for line in lines if line[0] != '#'] - - # grab the keys and vals in a list of lists - # each item of keyvals is a pair [key, val] - keyvals = [line.split(": ") for line in lines] - - # match the list for changed items starting with "AKL[(" on the 1st char of the val - var_list = [line for line in keyvals if re.match('[AKL[\(]', line[1][0])] - - # additional default info to add always - list_meta = [ - 'N_trials', - 'N_sims', - 'Run_Date' - ] - - # list concatenate these lists - var_list += [line for line in keyvals if line[0] in list_meta] - - # return the list of "changed" or "default" vars - return var_list - -# Takes two dictionaries (d1 and d2) and compares the keys in d1 to those in d2 -# if any match, updates the (key, value) pair of d1 to match that of d2 -# not real happy with variable names, but will have to do for now -def compare_dictionaries(d1, d2): - # iterate over intersection of key sets (i.e. any common keys) - for key in d1.keys() and d2.keys(): - # update d1 to have same (key, value) pair as d2 - d1[key] = d2[key] - - return d1 - -# get diff on 2 dictionaries -def diffdict (d1, d2, verbose=True): - print('d1,d2 num keys - ', len(d1.keys()), len(d2.keys())) - for k in d1.keys(): - if not k in d2: - if verbose: print(k, ' in d1, not in d2') - for k in d2.keys(): - if not k in d1: - if verbose: print(k, ' in d2, not in d1') - for k in d1.keys(): - if k in d2: - if d1[k] != d2[k]: - print('d1[',k,']=',d1[k],' d2[',k,']=',d2[k]) - -def consolidate_chunks(input_dict): - # get a list of sorted chunks - sorted_inputs = sorted(input_dict.items(), key=lambda x: x[1]['user_start']) - - consolidated_chunks = [] - for one_input in sorted_inputs: - if not 'opt_start' in one_input[1]: - continue - - # extract info from sorted list - input_dict = {'inputs': [one_input[0]], - 'chunk_start': one_input[1]['user_start'], - 'chunk_end': one_input[1]['user_end'], - 'opt_start': one_input[1]['opt_start'], - 'opt_end': one_input[1]['opt_end'], - 'weights': one_input[1]['weights'], - } - - if (len(consolidated_chunks) > 0) and \ - (input_dict['chunk_start'] <= consolidated_chunks[-1]['chunk_end']): - # update previous chunk - consolidated_chunks[-1]['inputs'].extend(input_dict['inputs']) - consolidated_chunks[-1]['chunk_end'] = input_dict['chunk_end'] - consolidated_chunks[-1]['opt_end'] = max(consolidated_chunks[-1]['opt_end'], input_dict['opt_end']) - # average the weights - consolidated_chunks[-1]['weights'] = (consolidated_chunks[-1]['weights'] + one_input[1]['weights'])/2 - else: - # new chunk - consolidated_chunks.append(input_dict) - - return consolidated_chunks - -def combine_chunks(input_chunks): - # Used for creating the opt params of the last step with all inputs - - final_chunk = {'inputs': [], - 'opt_start': 0.0, - 'opt_end': 0.0, - 'chunk_start': 0.0, - 'chunk_end': 0.0} - - for evinput in input_chunks: - final_chunk['inputs'].extend(evinput['inputs']) - if evinput['opt_end'] > final_chunk['opt_end']: - final_chunk['opt_end'] = evinput['opt_end'] - if evinput['chunk_end'] > final_chunk['chunk_end']: - final_chunk['chunk_end'] = evinput['chunk_end'] - - # wRMSE with weights of 1's is the same as regular RMSE. - final_chunk['weights'] = np.ones(len(input_chunks[-1]['weights'])) - return final_chunk - -def chunk_evinputs(opt_params, sim_tstop, sim_dt): - """ - Take dictionary (opt_params) sorted by input and - return a sorted list of dictionaries describing - chunks with inputs consolidated as determined the - range between 'user_start' and 'user_end'. - - The keys of the chunks in chunk_list dictionary - returned are: - 'weights' - 'chunk_start' - 'chunk_end' - 'opt_start' - 'opt_end' - """ - - import re - import scipy.stats as stats - from math import ceil, floor - - num_step = ceil(sim_tstop / sim_dt) + 1 - times = np.linspace(0, sim_tstop, num_step) - - # input_dict will be passed to consolidate_chunks, so it has - # keys 'user_start' and 'user_end' instead of chunk_start and - # 'chunk_start' that will be returned in the dicts returned - # in chunk_list - input_dict = {} - cdfs = {} - - - for input_name in opt_params.keys(): - if opt_params[input_name]['user_start'] > sim_tstop or \ - opt_params[input_name]['user_end'] < 0: - # can't optimize over this input - continue - - # calculate cdf using start time (minival of optimization range) - cdf = stats.norm.cdf(times, opt_params[input_name]['user_start'], - opt_params[input_name]['sigma']) - cdfs[input_name] = cdf.copy() - - for input_name in opt_params.keys(): - if opt_params[input_name]['user_start'] > sim_tstop or \ - opt_params[input_name]['user_end'] < 0: - # can't optimize over this input - continue - input_dict[input_name] = {'weights': cdfs[input_name].copy(), - 'user_start': opt_params[input_name]['user_start'], - 'user_end': opt_params[input_name]['user_end']} - - for other_input in opt_params: - if opt_params[other_input]['user_start'] > sim_tstop or \ - opt_params[other_input]['user_end'] < 0: - # not optimizing over that input - continue - if input_name == other_input: - # don't subtract our own cdf(s) - continue - if opt_params[other_input]['mean'] < \ - opt_params[input_name]['mean']: - # check ordering to only use inputs after us - continue - else: - decay_factor = opt_params[input_name]['decay_multiplier']*(opt_params[other_input]['mean'] - \ - opt_params[input_name]['mean']) / \ - sim_tstop - input_dict[input_name]['weights'] -= cdfs[other_input] * decay_factor - - # weights should not drop below 0 - input_dict[input_name]['weights'] = np.clip(input_dict[input_name]['weights'], a_min=0, a_max=None) - - # start and stop optimization where the weights are insignificant - good_indices = np.where( input_dict[input_name]['weights'] > 0.01) - if len(good_indices[0]) > 0: - input_dict[input_name]['opt_start'] = min(opt_params[input_name]['user_start'], times[good_indices][0]) - input_dict[input_name]['opt_end'] = max(opt_params[input_name]['user_end'], times[good_indices][-1]) - else: - input_dict[input_name]['opt_start'] = opt_params[other_input]['user_start'] - input_dict[input_name]['opt_end'] = opt_params[other_input]['user_end'] - - # convert to multiples of dt - input_dict[input_name]['opt_start'] = floor(input_dict[input_name]['opt_start']/sim_dt)*sim_dt - input_dict[input_name]['opt_end'] = ceil(input_dict[input_name]['opt_end']/sim_dt)*sim_dt - - # combined chunks that have overlapping ranges - # opt_params is a dict, turn into a list - chunk_list = consolidate_chunks(input_dict) - - # add one last chunk to the end - if len(chunk_list) > 1: - chunk_list.append(combine_chunks(chunk_list)) - - return chunk_list - -def get_inputs (params): - import re - input_list = [] - - # first pass through all params to get mu and sigma for each - for k in params.keys(): - input_mu = re.match('^t_ev(prox|dist)_([0-9]+)', k) - if input_mu: - id_str = 'ev' + input_mu.group(1) + '_' + input_mu.group(2) - input_list.append(id_str) - - return input_list - -def trans_input (input_var): - import re - - input_str = input_var - input_match = re.match('^ev(prox|dist)_([0-9]+)', input_var) - if input_match: - if input_match.group(1) == "prox": - input_str = 'Proximal ' + input_match.group(2) - if input_match.group(1) == "dist": - input_str = 'Distal ' + input_match.group(2) - - return input_str -# debug test function -if __name__ == '__main__': - fparam = 'param/debug.param' - p = ExpParams(fparam,debug=True) - # print(find_param(fparam, 'WhoDat')) # ? - diff --git a/params_default.py b/params_default.py deleted file mode 100644 index b573f7399..000000000 --- a/params_default.py +++ /dev/null @@ -1,365 +0,0 @@ -from collections import OrderedDict - -# params_default.py - master list of changeable params. most set to default val of inactive -# -# v 1.9.01 -# rev 2015-12-08 (RL: added t0_pois) -# last major: (SL: Added default params for L2Basket and L5Basket cells) - -# returns default params - see note -def get_params_default (nprox = 2, ndist = 1): - """ Note that nearly all start times are set BEYOND tstop for this file - Most values here are set to whatever default value inactivates them, such as 0 for conductance - prng seed values are also set to 0 (non-random) - flat file of default values - will most often be overwritten - """ - # set default params - p = { - 'sim_prefix': 'default', - - # simulation end time (ms) - 'tstop': 250., - - # numbers of cells making up the pyramidal grids - 'N_pyr_x': 1, - 'N_pyr_y': 1, - - # amplitudes of individual Gaussian random inputs to L2Pyr and L5Pyr - # L2 Basket params - 'L2Basket_Gauss_A_weight': 0., - 'L2Basket_Gauss_mu': 2000., - 'L2Basket_Gauss_sigma': 3.6, - 'L2Basket_Pois_A_weight_ampa': 0., - 'L2Basket_Pois_A_weight_nmda': 0., - 'L2Basket_Pois_lamtha': 0., - - # L2 Pyr params - 'L2Pyr_Gauss_A_weight': 0., - 'L2Pyr_Gauss_mu': 2000., - 'L2Pyr_Gauss_sigma': 3.6, - 'L2Pyr_Pois_A_weight_ampa': 0., - 'L2Pyr_Pois_A_weight_nmda': 0., - 'L2Pyr_Pois_lamtha': 0., - - # L5 Pyr params - 'L5Pyr_Gauss_A_weight': 0., - 'L5Pyr_Gauss_mu': 2000., - 'L5Pyr_Gauss_sigma': 4.8, - 'L5Pyr_Pois_A_weight_ampa': 0., - 'L5Pyr_Pois_A_weight_nmda': 0., - 'L5Pyr_Pois_lamtha': 0., - - # L5 Basket params - 'L5Basket_Gauss_A_weight': 0., - 'L5Basket_Gauss_mu': 2000., - 'L5Basket_Gauss_sigma': 2., - 'L5Basket_Pois_A_weight_ampa': 0., - 'L5Basket_Pois_A_weight_nmda': 0., - 'L5Basket_Pois_lamtha': 0., - - # maximal conductances for all synapses - # max conductances TO L2Pyrs - 'gbar_L2Pyr_L2Pyr_ampa': 0., - 'gbar_L2Pyr_L2Pyr_nmda': 0., - 'gbar_L2Basket_L2Pyr_gabaa': 0., - 'gbar_L2Basket_L2Pyr_gabab': 0., - - # max conductances TO L2Baskets - 'gbar_L2Pyr_L2Basket': 0., - 'gbar_L2Basket_L2Basket': 0., - - # max conductances TO L5Pyr - 'gbar_L5Pyr_L5Pyr_ampa': 0., - 'gbar_L5Pyr_L5Pyr_nmda': 0., - 'gbar_L2Pyr_L5Pyr': 0., - 'gbar_L2Basket_L5Pyr': 0., - 'gbar_L5Basket_L5Pyr_gabaa': 0., - 'gbar_L5Basket_L5Pyr_gabab': 0., - - # max conductances TO L5Baskets - 'gbar_L5Basket_L5Basket': 0., - 'gbar_L5Pyr_L5Basket': 0., - 'gbar_L2Pyr_L5Basket': 0., - - # Ongoing proximal alpha rhythm - 'distribution_prox': 'normal', - 't0_input_prox': 1000., - 'tstop_input_prox': 250., - 'f_input_prox': 10., - 'f_stdev_prox': 20., - 'events_per_cycle_prox': 2, - 'repeats_prox': 10, - 't0_input_stdev_prox': 0.0, - - # Ongoing distal alpha rhythm - 'distribution_dist': 'normal', - 't0_input_dist': 1000., - 'tstop_input_dist': 250., - 'f_input_dist': 10., - 'f_stdev_dist': 20., - 'events_per_cycle_dist': 2, - 'repeats_dist': 10, - 't0_input_stdev_dist': 0.0, - - # thalamic input amplitudes and delays - 'input_prox_A_weight_L2Pyr_ampa': 0., - 'input_prox_A_weight_L2Pyr_nmda': 0., - 'input_prox_A_weight_L5Pyr_ampa': 0., - 'input_prox_A_weight_L5Pyr_nmda': 0., - 'input_prox_A_weight_L2Basket_ampa': 0., - 'input_prox_A_weight_L2Basket_nmda': 0., - 'input_prox_A_weight_L5Basket_ampa': 0., - 'input_prox_A_weight_L5Basket_nmda': 0., - 'input_prox_A_delay_L2': 0.1, - 'input_prox_A_delay_L5': 1.0, - - # current values, not sure where these distal values come from, need to check - 'input_dist_A_weight_L2Pyr_ampa': 0., - 'input_dist_A_weight_L2Pyr_nmda': 0., - 'input_dist_A_weight_L5Pyr_ampa': 0., - 'input_dist_A_weight_L5Pyr_nmda': 0., - 'input_dist_A_weight_L2Basket_ampa': 0., - 'input_dist_A_weight_L2Basket_nmda': 0., - 'input_dist_A_delay_L2': 5., - 'input_dist_A_delay_L5': 5., - - # times and stdevs for evoked responses - 'dt_evprox0_evdist': -1, # not used in GUI - 'dt_evprox0_evprox1': -1, # not used in GUI - 'sync_evinput': 1, # whether evoked inputs arrive at same time to all cells - 'inc_evinput': 0.0, # increment (ms) for avg evoked input start (for trial n, avg start time is n * evinputinc - - # analysis - 'save_spec_data': 0, - 'f_max_spec': 40., - 'spec_cmap': 'jet', # default colormap for consistency with previous versions - 'dipole_scalefctr': 30e3, # scale factor for dipole - default at 30e3 - #based on scaling needed to match model ongoing rhythms from jones 2009 - for ERPs can use 300 - # for ongoing rhythms + ERPs ... use ... ? - 'dipole_smooth_win': 15.0, # window for smoothing (box filter) - 15 ms from jones 2009; shorten - # in case want to look at higher frequency activity - 'save_figs': 0, - 'save_vsoma': 0, # whether to record/save somatic voltage - - # IClamp params for L2Pyr - 'Itonic_A_L2Pyr_soma': 0., - 'Itonic_t0_L2Pyr_soma': 0., - 'Itonic_T_L2Pyr_soma': -1., - - # IClamp param for L2Basket - 'Itonic_A_L2Basket': 0., - 'Itonic_t0_L2Basket': 0., - 'Itonic_T_L2Basket': -1., - - # IClamp params for L5Pyr - 'Itonic_A_L5Pyr_soma': 0., - 'Itonic_t0_L5Pyr_soma': 0., - 'Itonic_T_L5Pyr_soma': -1., - - # IClamp param for L5Basket - 'Itonic_A_L5Basket': 0., - 'Itonic_t0_L5Basket': 0., - 'Itonic_T_L5Basket': -1., - - # numerics - # N_trials of 1 means that seed is set by rank - 'N_trials': 1, - - # prng_state is a string for a filename containing the random state one wants to use - # prng seed cores are the base integer seed for the specific - # prng object for a specific random number stream - # 'prng_state': None, - 'prng_seedcore_opt': 1, - 'prng_seedcore_input_prox': 0, - 'prng_seedcore_input_dist': 0, - 'prng_seedcore_extpois': 0, - 'prng_seedcore_extgauss': 0, - - # default end time for pois inputs - 't0_pois': 0., - 'T_pois': -1, - 'dt': 0.025, - 'celsius': 37.0, - 'threshold': 0.0 # firing threshold - } - - # grab cell-specific params and update p accordingly - p_L2Pyr = get_L2Pyr_params_default() - p_L5Pyr = get_L5Pyr_params_default() - p.update(p_L2Pyr) - p.update(p_L5Pyr) - - # get evoked params and update p accordingly - p_ev_prox = get_ev_params_default(nprox,True) - p_ev_dist = get_ev_params_default(ndist,False) - p.update(p_ev_prox) - p.update(p_ev_dist) - - return p - -# return dict with default params (empty) for evoked inputs; n is number of evoked inputs -# isprox == True iff proximal (otherwise distal) -def get_ev_params_default (n,isprox): - dout = {}#OrderedDict() - if isprox: pref = 'evprox' - else: pref = 'evdist' - # print('isprox:',isprox,'n:',n) - lty = ['L2Pyr', 'L5Pyr', 'L2Basket'] - if isprox: lty.append('L5Basket') - lsy = ['ampa', 'nmda'] # allow changing both ampa and nmda weights - for i in range(n): - tystr = pref + '_' + str(i+1) # this string includes input number - for ty in lty: - for sy in lsy: - dout['gbar_' + tystr + '_' + ty + '_' + sy] = 0. # feed strength - dout['t_' + tystr] = 0. # times and stdevs for evoked responses - dout['sigma_t_' + tystr] = 0. - dout['prng_seedcore_' + tystr] = 0 # random number generator seed for this input - dout['numspikes_' + tystr] = 1 # number of presynaptic spikes (postsynaptic inputs) - return dout - -# returns default params for L2 pyramidal cell -def get_L2Pyr_params_default(): - return { - # Soma - 'L2Pyr_soma_L': 22.1, - 'L2Pyr_soma_diam': 23.4, - 'L2Pyr_soma_cm': 0.6195, - 'L2Pyr_soma_Ra': 200., - - # Dendrites - 'L2Pyr_dend_cm': 0.6195, - 'L2Pyr_dend_Ra': 200., - - 'L2Pyr_apicaltrunk_L': 59.5, - 'L2Pyr_apicaltrunk_diam': 4.25, - - 'L2Pyr_apical1_L': 306., - 'L2Pyr_apical1_diam': 4.08, - - 'L2Pyr_apicaltuft_L': 238., - 'L2Pyr_apicaltuft_diam': 3.4, - - 'L2Pyr_apicaloblique_L': 340., - 'L2Pyr_apicaloblique_diam': 3.91, - - 'L2Pyr_basal1_L': 85., - 'L2Pyr_basal1_diam': 4.25, - - 'L2Pyr_basal2_L': 255., - 'L2Pyr_basal2_diam': 2.72, - - 'L2Pyr_basal3_L': 255., - 'L2Pyr_basal3_diam': 2.72, - - # Synapses - 'L2Pyr_ampa_e': 0., - 'L2Pyr_ampa_tau1': 0.5, - 'L2Pyr_ampa_tau2': 5., - - 'L2Pyr_nmda_e': 0., - 'L2Pyr_nmda_tau1': 1., - 'L2Pyr_nmda_tau2': 20., - - 'L2Pyr_gabaa_e': -80., - 'L2Pyr_gabaa_tau1': 0.5, - 'L2Pyr_gabaa_tau2': 5., - - 'L2Pyr_gabab_e': -80., - 'L2Pyr_gabab_tau1': 1., - 'L2Pyr_gabab_tau2': 20., - - # Biophysics soma - 'L2Pyr_soma_gkbar_hh2': 0.01, - 'L2Pyr_soma_gnabar_hh2': 0.18, - 'L2Pyr_soma_el_hh2': -65., - 'L2Pyr_soma_gl_hh2': 4.26e-5, - 'L2Pyr_soma_gbar_km': 250., - - # Biophysics dends - 'L2Pyr_dend_gkbar_hh2': 0.01, - 'L2Pyr_dend_gnabar_hh2': 0.15, - 'L2Pyr_dend_el_hh2': -65., - 'L2Pyr_dend_gl_hh2': 4.26e-5, - 'L2Pyr_dend_gbar_km': 250., - } - -# returns default params for L5 pyramidal cell -def get_L5Pyr_params_default(): - return { - # Soma - 'L5Pyr_soma_L': 39., - 'L5Pyr_soma_diam': 28.9, - 'L5Pyr_soma_cm': 0.85, - 'L5Pyr_soma_Ra': 200., - - # Dendrites - 'L5Pyr_dend_cm': 0.85, - 'L5Pyr_dend_Ra': 200., - - 'L5Pyr_apicaltrunk_L': 102., - 'L5Pyr_apicaltrunk_diam': 10.2, - - 'L5Pyr_apical1_L': 680., - 'L5Pyr_apical1_diam': 7.48, - - 'L5Pyr_apical2_L': 680., - 'L5Pyr_apical2_diam': 4.93, - - 'L5Pyr_apicaltuft_L': 425., - 'L5Pyr_apicaltuft_diam': 3.4, - - 'L5Pyr_apicaloblique_L': 255., - 'L5Pyr_apicaloblique_diam': 5.1, - - 'L5Pyr_basal1_L': 85., - 'L5Pyr_basal1_diam': 6.8, - - 'L5Pyr_basal2_L': 255., - 'L5Pyr_basal2_diam': 8.5, - - 'L5Pyr_basal3_L': 255., - 'L5Pyr_basal3_diam': 8.5, - - # Synapses - 'L5Pyr_ampa_e': 0., - 'L5Pyr_ampa_tau1': 0.5, - 'L5Pyr_ampa_tau2': 5., - - 'L5Pyr_nmda_e': 0., - 'L5Pyr_nmda_tau1': 1., - 'L5Pyr_nmda_tau2': 20., - - 'L5Pyr_gabaa_e': -80., - 'L5Pyr_gabaa_tau1': 0.5, - 'L5Pyr_gabaa_tau2': 5., - - 'L5Pyr_gabab_e': -80., - 'L5Pyr_gabab_tau1': 1., - 'L5Pyr_gabab_tau2': 20., - - # Biophysics soma - 'L5Pyr_soma_gkbar_hh2': 0.01, - 'L5Pyr_soma_gnabar_hh2': 0.16, - 'L5Pyr_soma_el_hh2': -65., - 'L5Pyr_soma_gl_hh2': 4.26e-5, - 'L5Pyr_soma_gbar_ca': 60., - 'L5Pyr_soma_taur_cad': 20., - 'L5Pyr_soma_gbar_kca': 2e-4, - 'L5Pyr_soma_gbar_km': 200., - 'L5Pyr_soma_gbar_cat': 2e-4, - 'L5Pyr_soma_gbar_ar': 1e-6, - - # Biophysics dends - 'L5Pyr_dend_gkbar_hh2': 0.01, - 'L5Pyr_dend_gnabar_hh2': 0.14, - 'L5Pyr_dend_el_hh2': -71., - 'L5Pyr_dend_gl_hh2': 4.26e-5, - 'L5Pyr_dend_gbar_ca': 60., - 'L5Pyr_dend_taur_cad': 20., - 'L5Pyr_dend_gbar_kca': 2e-4, - 'L5Pyr_dend_gbar_km': 200., - 'L5Pyr_dend_gbar_cat': 2e-4, - 'L5Pyr_dend_gbar_ar': 1e-6, - } diff --git a/plotfn.py b/plotfn.py deleted file mode 100644 index 406490057..000000000 --- a/plotfn.py +++ /dev/null @@ -1,196 +0,0 @@ -# plotfn.py - pall and possibly other plot routines -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed it.izip() dependence) -# last major: (SL: toward python3) - -from praster import praster -import axes_create as ac -import dipolefn -import paramrw -import pspec -import specfn -import os -import fileio as fio -from multiprocessing import Pool - -# terrible handling of variables -def pkernel(dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim=None, ylim=None): - gid_dict, p_dict = paramrw.read(f_param) - tstop = p_dict['tstop'] - # fig dirs - dfig_dpl = dfig['figdpl'] - dfig_spec = dfig['figspec'] - dfig_spk = dfig['figspk'] - pdipole_dict = { - 'xlim': xlim, - 'ylim': ylim, - # 'xmin': xlim[0], - # 'xmax': xlim[1], - # 'ymin': None, - # 'ymax': None, - } - # plot kernels - praster(f_param, tstop, f_spk, dfig_spk) - dipolefn.pdipole(f_dpl, dfig_dpl, pdipole_dict, f_param, key_types) - # dipolefn.pdipole(f_dpl, f_param, dfig_dpl, key_types, pdipole_dict) - # usage of xlim to pspec is temporarily disabled. pspec_dpl() will use internal states for plotting - pspec.pspec_dpl(f_spec, f_dpl, dfig_spec, p_dict, key_types, xlim, ylim, f_param) - # pspec.pspec_dpl(f_spec, f_dpl, dfig_spec, p_dict, key_types) - # pspec.pspec_dpl(data_spec, f_dpl, dfig_spec, p_dict, key_types, xlim) - return 0 - -# Kernel for plotting dipole and spec with alpha feed histograms -def pkernel_with_hist(datdir, dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim=None, ylim=None): - # gid_dict, p_dict = paramrw.read(f_param) - # tstop = p_dict['tstop'] - # fig dirs - dfig_dpl = datdir - dfig_spec = datdir - dfig_spk = datdir - pdipole_dict = { - 'xmin': None, - 'xmax': None, - 'ymin': None, - 'ymax': None, - } - # plot kernels - dipolefn.pdipole_with_hist(f_dpl, f_spk, dfig_dpl, f_param, key_types, pdipole_dict) - pspec.pspec_with_hist(f_spec, f_dpl, f_spk, dfig_spec, f_param, key_types, xlim, ylim) - return 0 - -# r is the value returned by pkernel -# this is sort of a dummy function -def cb(r): pass - -# plot function - this is sort of a stop-gap and shouldn't live here, really -# reads all data except spec and gid_dict from files -def pallsimp (datdir, p_exp, doutf, xlim=None, ylim=None): - key_types = p_exp.get_key_types() - param_list = [doutf['file_param']] - dpl_list = [doutf['file_dpl']] - spec_list = [doutf['file_spec']] - spk_list = [doutf['file_spikes']] - dfig_list = [{'figavgdpl': None, 'avgspec': None, 'param': None, 'normdpl': None, 'rawspk': None, 'rawspec': None, 'figavgspec': None, 'rawdpl': None, 'figdpl': None, 'rawcurrent': None, 'avgdpl': None, 'figspk': None, 'rawspeccurrent': None, 'figspec': None}] - # print('dfig_list:',dfig_list) - for dfig, f_param, f_spk, f_dpl, f_spec in zip(dfig_list, param_list, spk_list, dpl_list, spec_list): - pkernel_with_hist(datdir, dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim, ylim) - -# plot function - this is sort of a stop-gap and shouldn't live here, really -# reads all data except spec and gid_dict from files -def pall(datdir, ddir, p_exp, xlim=None, ylim=None): - # def pall(ddir, p_exp, spec_results, xlim=[0., 'tstop']): - # runtype allows easy (hard coded switching between two modes) - # either 'parallel' or 'debug' - # runtype = 'parallel' - runtype = 'debug' - dsim = ddir.dsim - key_types = p_exp.get_key_types() - # preallocate lists for use below - param_list = [] - dpl_list = [] - spec_list = [] - spk_list = [] - dfig_list = [] - # aggregate all file types from individual expmts into lists - # NB The only reason this works is because the analysis results are returned - # IDENTICALLY! - for expmt_group in ddir.expmt_groups: - # these should be equivalent lengths - param_list.extend(ddir.file_match(expmt_group, 'param')) - dpl_list.extend(ddir.file_match(expmt_group, 'rawdpl')) - spec_list.extend(ddir.file_match(expmt_group, 'rawspec')) - spk_list.extend(ddir.file_match(expmt_group, 'rawspk')) - # append as many copies of expmt dfig dict as there were runs in expmt - # this must be done because we're iterating over ALL expmts at the same time - for i in range(len(ddir.file_match(expmt_group, 'param'))): - dfig_list.append(ddir.dfig[expmt_group]) - # create giant list of appropriate files and run them all at the same time - if runtype is 'parallel': - # apply async to compiled lists - pl = Pool() - for dfig, f_param, f_spk, f_dpl, f_spec in zip(dfig_list, param_list, spk_list, dpl_list, spec_list): - pl.apply_async(pkernel, (dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim, ylim), callback=cb) - pl.close() - pl.join() - elif runtype is 'debug': - # run serially - for dfig, f_param, f_spk, f_dpl, f_spec in zip(dfig_list, param_list, spk_list, dpl_list, spec_list): - pkernel_with_hist(dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim, ylim) - # pkernel(dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim, ylim) - -# Plots dipole and spec with alpha feed histograms -def pdpl_pspec_with_hist(ddir, p_exp, xlim=None, ylim=None): - # def pdpl_pspec_with_hist(ddir, p_exp, spec_results, xlim=[0., 'tstop']): - # runtype = 'debug' - runtype = 'parallel' - # preallocate lists for use below - param_list = [] - dpl_list = [] - spec_list = [] - spk_list = [] - dfig_list = [] - # Grab all necessary data in aggregated lists - for expmt_group in ddir.expmt_groups: - # these should be equivalent lengths - param_list.extend(ddir.file_match(expmt_group, 'param')) - dpl_list.extend(ddir.file_match(expmt_group, 'rawdpl')) - spec_list.extend(ddir.file_match(expmt_group, 'rawspec')) - spk_list.extend(ddir.file_match(expmt_group, 'rawspk')) - # append as many copies of expmt dfig dict as there were runs in expmt - for i in range(len(ddir.file_match(expmt_group, 'param'))): - dfig_list.append(ddir.dfig[expmt_group]) - # grab the key types - key_types = p_exp.get_key_types() - print(spec_list) - if runtype is 'parallel': - # apply async to compiled lists - pl = Pool() - for dfig, f_param, f_spk, f_dpl, f_spec in zip(dfig_list, param_list, spk_list, dpl_list, spec_list): - pl.apply_async(pkernel_with_hist, (dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim, ylim), callback=cb) - pl.close() - pl.join() - elif runtype is 'debug': - for dfig, f_param, f_spk, f_dpl, f_spec in zip(dfig_list, param_list, spk_list, dpl_list, spec_list): - pkernel_with_hist(dfig, f_param, f_spk, f_dpl, f_spec, key_types, xlim, ylim) - -def aggregate_spec_with_hist(ddir, p_exp, labels): - untype = 'debug' - # preallocate lists for use below - param_list = [] - dpl_list = [] - spec_list = [] - spk_list = [] - dfig_list = [] - spec_list = [] - # Get dimensions for aggregate fig - N_rows = len(ddir.expmt_groups) - N_cols = len(ddir.file_match(ddir.expmt_groups[0], 'param')) - # Create figure - f = ac.FigAggregateSpecWithHist(N_rows, N_cols) - # Grab all necessary data in aggregated lists - for expmt_group in ddir.expmt_groups: - # these should be equivalent lengths - param_list.extend(ddir.file_match(expmt_group, 'param')) - dpl_list.extend(ddir.file_match(expmt_group, 'rawdpl')) - spec_list.extend(ddir.file_match(expmt_group, 'rawspec')) - spk_list.extend(ddir.file_match(expmt_group, 'rawspk')) - # apply async to compiled lists - if runtype is 'parallel': - pl = Pool() - for f_param, f_spk, f_dpl, fspec, ax in zip(param_list, spk_list, dpl_list, spec_list, f.ax_list): - _, p_dict = paramrw.read(f_param) - pl.apply_async(specfn.aggregate_with_hist, (f, ax, fspec, f_dpl, f_spk, fparam, p_dict)) - pl.close() - pl.join() - elif runtype is 'debug': - for f_param, f_spk, f_dpl, fspec, ax in zip(param_list, spk_list, dpl_list, spec_list, f.ax_list): - # _, p_dict = paramrw.read(f_param) - pspec.aggregate_with_hist(f, ax, fspec, f_dpl, f_spk, f_param) - # add row labels - f.add_row_labels(param_list, labels[0]) - # add column labels - f.add_column_labels(param_list, labels[1]) - fig_name = os.path.join(ddir.dsim, 'aggregate_hist.png') - f.save(fig_name) - f.close() diff --git a/pmanu_gamma.py b/pmanu_gamma.py deleted file mode 100644 index 9f6aacceb..000000000 --- a/pmanu_gamma.py +++ /dev/null @@ -1,1713 +0,0 @@ -# pmanu_gamma.py - plot functions for gamma manuscript -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: return_data_dir() and it.izip) -# last major: (SL: plot updates) - -import numpy as np -import os -import fileio as fio -import currentfn -import dipolefn -import specfn -import spikefn -import paramrw -import ac_manu_gamma as acg -import axes_create as ac - -def spec_fig(): - f = acg.FigSimpleSpec() - dproj = fio.return_data_dir() - d = os.path.join(dproj, '2013-12-04/ftremor-003') - - ddata = fio.SimulationPaths() - ddata.read_sim(dproj, d) - - expmt_group = ddata.expmt_groups[0] - f_dpl = ddata.file_match(expmt_group, 'rawdpl')[0] - fspec = ddata.file_match(expmt_group, 'rawspec')[0] - fparam = ddata.file_match(expmt_group, 'param')[0] - - dpl = dipolefn.Dipole(f_dpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - xlim = (200., 700.) - - dpl.plot(f.ax['dipole'], xlim, layer='L5') - pc = { - 'spec': specfn.pspec_ax(f.ax['spec'], fspec, (50, 1050), layer='L5'), - } - cb = f.f.colorbar(pc['spec'], ax=f.ax['spec'], format='%.1e') - - f.ax['spec'].set_xlim(xlim) - f.ax['spec'].set_ylabel('Frequency (Hz)') - f.ax['spec'].set_xlabel('Time (ms)') - - f.set_fontsize(14) - - # fname = os.path.join(d, 'testing.eps') - - f.saveeps(d, 'testing') - f.close() - -# all the data comes from one sim -def hf_epochs(ddata): - # runtype = 'debug' - runtype = 'pub' - - # create the figure from the ac template - f = acg.FigHFEpochs(runtype) - - # hard coded for now - n_sim = 0 - n_trial = 0 - - # and assume just the first expmt for now - expmt = ddata.expmt_groups[0] - - # hard code the 50 ms epochs that will be used here. - # centers for the data in tuples (t, f) - tf_specmax = [ - (79.525, 115.), - (136.925, 114.), - (324.350, 109.), - (418.400, 106.), - ] - - # these are approximate centers on which to draw sines - tf_centers = [ - [83.], - [137.], - [330.], - [411., 429.], - ] - - # these all come from one filename - f_spec = ddata.return_specific_filename(expmt, 'rawspec', n_sim, n_trial) - f_dpl = ddata.return_specific_filename(expmt, 'rawdpl', n_sim, n_trial) - f_spk = ddata.return_specific_filename(expmt, 'rawspk', n_sim, n_trial) - f_param = ddata.return_specific_filename(expmt, 'param', n_sim, n_trial) - f_current = ddata.return_specific_filename(expmt, 'rawcurrent', n_sim, n_trial) - - # p_dict is needed for the spike thing. - _, p_dict = paramrw.read(f_param) - - # figure out the tstop and xlim - tstop = paramrw.find_param(f_param, 'tstop') - dt = paramrw.find_param(f_param, 'dt') - xlim = (50., tstop) - - # grab the dipole data - dpl = dipolefn.Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - - # current data - I_soma = currentfn.SynapticCurrent(f_current) - I_soma.convert_nA_to_uA() - - # grab the spike data for histogram - s = spikefn.spikes_from_file(f_param, f_spk) - n_bins = 500 - s_list = np.concatenate(s['L5_pyramidal'].spike_list) - - # spikes - spikes = { - 'L5': spikefn.filter_spike_dict(s, 'L5_'), - } - - # xrange is just the length of the window, dx is the distance from the center - xrange = 50. - dx = xrange / 2. - - # pc will be a list of the colorbar props for each spec (length len(f.gspec)) - pc = [] - - # plot the aggregate data too - specfn.pspec_ax(f.ax['L_spec'], f_spec, xlim, layer='L5') - dpl.plot(f.ax['L_dpl'], xlim, layer='L5') - spikefn.spike_png(f.ax['L_spk'], spikes['L5']) - - # change the color - color_dpl_L5 = '#1e90ff' - f.set_linecolor('L_dpl', color_dpl_L5) - - # grab a list of the ax handles for the leftmost plot - ax_L_keys = [ax for ax in f.ax.keys() if ax.startswith('L_')] - - # set all these xlim correctly - for ax_h in ax_L_keys: - f.ax[ax_h].set_xlim(xlim) - - list_ylim_dpl = [] - - # now plot the individual epochs - for i in range(len(f.gspec_ex)): - # for each (t, f) pair, find the xlim_window - t_center = tf_specmax[i][0] - f_center = tf_specmax[i][1] - xlim_window = (t_center - dx, t_center + dx) - print xlim_window - - # crude setting of vertical lines to denote roi - for ax_h in ax_L_keys: - f.ax[ax_h].axvline(x=xlim_window[0], color='b') - f.ax[ax_h].axvline(x=xlim_window[1], color='k') - - # this is the highlight portion, one fixed - dx_hl = 0.5 * (1000. / f_center) - dx_hl_fixed = 10. - - xlim_hl = (t_center - dx_hl, t_center + dx_hl) - xlim_fixed = (t_center - dx_hl_fixed, t_center + dx_hl_fixed) - - # fix xlim_window in case - if xlim_window[0] < xlim[0]: - xlim_window[0] = xlim[0] - - if xlim_window[1] == -1: - xlim_window[1] = tstop - - I_soma.plot_to_axis(f.ax_twinx['dpl'][i], 'L5') - - # truncate and then plot the dpl - # dpl_short must be a dict of all the different dipoles - # so only need here the L5 key - t_short, dpl_short = dpl.truncate_ext(xlim_hl[0], xlim_hl[1]) - t_fixed, dpl_fixed = dpl.truncate_ext(xlim_fixed[0], xlim_fixed[1]) - - # create a sine waveform for this interval - f_max = tf_specmax[i][1] - t_half = 0.5 * (1000. / f_max) - - # plot the dipole for the xlim window - # plot the dipole and the current, either over the appropriate range or the whole window for now - dpl.plot(f.ax['dpl'][i], xlim_window, layer='L5') - f.ax['dpl'][i].hold(True) - - for j in range(len(tf_centers[i])): - t0 = tf_centers[i][j] - t_half - T = tf_centers[i][j] + t_half - - ylim_dpl = dpl.lim('L5', (t0, T)) - - props_dict = { - 't': (t0, T), - 'dt': dt, - 'f': f_max, - 'A': ylim_dpl[1], - } - - f.add_sine(f.ax['dpl'][i], props_dict) - - f.ax['dpl'][i].set_ylim((-0.025, 0.025)) - - # f.ax['dpl'][i].plot(t_fixed, dpl_fixed['L5'], 'g') - # f.ax['dpl'][i].plot(t_short, dpl_short['L5'], 'r') - - # spec - i think must be plotted as xlim first and then truncated? - pc.append(specfn.pspec_ax(f.ax['spec'][i], f_spec, xlim, layer='L5')) - - # plot the data - spikefn.pinput_hist_onesided(f.ax['hist'][i], s_list, n_bins) - spikefn.spike_png(f.ax['spk'][i], spikes['L5']) - - # set xlim_windows accordingly - f.ax['hist'][i].set_xlim(xlim_window) - f.ax['spk'][i].set_xlim(xlim_window) - f.ax_twinx['dpl'][i].set_xlim(xlim_window) - f.ax['spec'][i].set_xlim(xlim_window) - - # hist - f.ax['hist'][i].yaxis.set_ticks(np.arange(0, 9, 4)) - - # set the color of the I_soma line - f.ax['dpl'][i].lines[0].set_color(color_dpl_L5) - f.ax_twinx['dpl'][i].lines[0].set_color('k') - f.ysymmetry(f.ax_twinx['dpl'][i]) - - # no need for outputs - if runtype == 'debug': - f.f.colorbar(pc[i], ax=f.ax['spec'][i], format='%.3e') - - if runtype == 'pub': - p_ticks = np.arange(0, 9e-5, 2e-5) - pctest = f.f.colorbar(pc[-1], ax=f.ax['spec'][-1], format='%.2e', ticks=p_ticks) - pctest.ax.set_yticklabels(p_ticks) - # pctest.ax.set_ytick(np.arange(0, 8e-5, 2e-5)) - # pctest.ax.locator_params(axis='y', nbins=5) - - # some fig naming stuff - dfig = os.path.join(ddata.dsim, expmt) - trial_prefix = ddata.trial_prefix_str % (n_sim, n_trial) - fprefix_short = trial_prefix + '-hf_epochs' - - # use methods to save figs - f.savepng_new(dfig, fprefix_short) - f.saveeps(dfig, fprefix_short) - f.close() - -def hf(ddata, xlim_window, n_sim, n_trial): - # data directories (made up for now) - # the resultant figure is saved in d0 - # d = os.path.join(dproj, 'pub', '2013-06-28_gamma_weak_L5-000') - - # for now grab the first experiment - # ddata = fio.SimulationPaths() - # ddata.read_sim(dproj, d) - expmt = ddata.expmt_groups[0] - - runtype = 'debug' - # runtype = 'pub' - - # prints the fig in ddata0 - f = acg.FigHF(runtype) - - # grab the relevant files - f_spec = ddata.return_specific_filename(expmt, 'rawspec', n_sim, n_trial) - f_dpl = ddata.return_specific_filename(expmt, 'rawdpl', n_sim, n_trial) - f_spk = ddata.return_specific_filename(expmt, 'rawspk', n_sim, n_trial) - f_param = ddata.return_specific_filename(expmt, 'param', n_sim, n_trial) - f_current = ddata.return_specific_filename(expmt, 'rawcurrent', n_sim, n_trial) - - # p_dict is needed for the spike thing. - _, p_dict = paramrw.read(f_param) - - # figure out the tstop and xlim - tstop = paramrw.find_param(f_param, 'tstop') - dt = paramrw.find_param(f_param, 'dt') - xlim = (50., tstop) - - # fix xlim_window - if xlim_window[0] < xlim[0]: - xlim_window[0] = xlim[0] - - if xlim_window[1] == -1: - xlim_window[1] = tstop - - # grab the dipole data - dpl = dipolefn.Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - dpl.plot(f.ax['dpl_L'], xlim, layer='agg') - - # plot currents - I_soma = currentfn.SynapticCurrent(f_current) - I_soma.plot_to_axis(f.ax_twinx['dpl_L'], 'L5') - - # spec - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], f_spec, xlim, layer='L5'), - } - - # no need for outputs - if runtype == 'debug': - f.f.colorbar(pc['L'], ax=f.ax['spec_L'], format='%.3e') - - # grab the spike data for histogram - s = spikefn.spikes_from_file(f_param, f_spk) - n_bins = 500 - s_list = np.concatenate(s['L5_pyramidal'].spike_list) - spikefn.pinput_hist_onesided(f.ax['hist_L'], s_list, n_bins) - - # spikes - spikes = { - 'L5': spikefn.filter_spike_dict(s, 'L5_'), - } - - # plot the data - spikefn.spike_png(f.ax['spk'], spikes['L5']) - - # xlim_window - # xlim_window = (400., 450.) - f.ax['hist_L'].set_xlim(xlim_window) - f.ax['dpl_L'].set_xlim(xlim_window) - f.ax['spec_L'].set_xlim(xlim_window) - f.ax_twinx['dpl_L'].set_xlim(xlim_window) - - f.ax_twinx['dpl_L'].lines[0].set_color('k') - f.ysymmetry(f.ax_twinx['dpl_L']) - f.ax['spk'].set_xlim(xlim_window) - - # # save the fig in ddata0 (arbitrary) - trial_prefix = ddata.trial_prefix_str % (n_sim, n_trial) - f_prefix = '%s_hf' % trial_prefix - dfig = os.path.join(ddata.dsim, expmt) - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -def laminar(ddata): - # for now grab the first experiment - expmt = ddata.expmt_groups[0] - - # runtype = 'debug' - runtype = 'pub' - - # for now hard code the simulation run - n_run = 0 - - # prints the fig in ddata0 - f = acg.FigLaminarComparison(runtype) - - # grab the relevant files - f_spec = ddata.file_match(expmt, 'rawspec')[n_run] - f_dpl = ddata.file_match(expmt, 'rawdpl')[n_run] - f_spk = ddata.file_match(expmt, 'rawspk')[n_run] - f_param = ddata.file_match(expmt, 'param')[n_run] - f_current = ddata.file_match(expmt, 'rawcurrent')[n_run] - - # figure out the tstop and xlim - tstop = paramrw.find_param(f_param, 'tstop') - dt = paramrw.find_param(f_param, 'dt') - xlim = (50., tstop) - - # grab the dipole data - dpl = dipolefn.Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - - # calculate the Welch periodogram - pgram = { - 'agg': specfn.Welch(dpl.t, dpl.dpl['agg'], dt), - 'L2': specfn.Welch(dpl.t, dpl.dpl['L2'], dt), - 'L5': specfn.Welch(dpl.t, dpl.dpl['L5'], dt), - } - - # plot periodograms - pgram['agg'].plot_to_ax(f.ax['pgram_L']) - pgram['L2'].plot_to_ax(f.ax['pgram_M']) - pgram['L5'].plot_to_ax(f.ax['pgram_R']) - - # plot currents - I_soma = currentfn.SynapticCurrent(f_current) - I_soma.convert_nA_to_uA() - I_soma.plot_to_axis(f.ax['current_M'], 'L2') - I_soma.plot_to_axis(f.ax['current_R'], 'L5') - f.set_linecolor('current_M', 'k') - f.set_linecolor('current_R', 'k') - # f.set_axes_pingping() - - # cols have same suffix - list_cols = ['L', 'M', 'R'] - - # create handles list - list_h_pgram = ['pgram_'+col for col in list_cols] - list_h_dpl = ['dpl_'+col for col in list_cols] - - # spec - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], f_spec, xlim, layer='agg'), - 'R': specfn.pspec_ax(f.ax['spec_R'], f_spec, xlim, layer='L5'), - } - - pc2 = { - 'M': specfn.pspec_ax(f.ax['spec_M'], f_spec, xlim, layer='L2'), - } - - # create a list of spec color handles - # list_h_spec_cb = ['pc_'+col for col in list_cols] - - # get the vmin, vmax and add them to the master list - # f.equalize_speclim(pc) - # list_lim_spec = [] - - # no need for outputs - if runtype == 'debug': - f.f.colorbar(pc['L'], ax=f.ax['spec_L'], format='%.1e') - f.f.colorbar(pc2['M'], ax=f.ax['spec_M'], format='%.1e') - f.f.colorbar(pc['R'], ax=f.ax['spec_R'], format='%.1e') - # list_spec_handles = [ax for ax in f.ax.keys() if ax.startswith('spec')] - list_spec_handles = ['spec_M', 'spec_R'] - f.remove_tick_labels(list_spec_handles, ax_xy='y') - - elif runtype == 'pub': - f.f.colorbar(pc['L'], ax=f.ax['spec_L'], format='%.1e') - f.f.colorbar(pc2['M'], ax=f.ax['spec_M'], format='%.1e') - f.f.colorbar(pc['R'], ax=f.ax['spec_R'], format='%.1e') - list_spec_handles = ['spec_L', 'spec_R'] - f.remove_tick_labels(list_spec_handles, ax_xy='y') - - # grab the spike data - s = spikefn.spikes_from_file(f_param, f_spk) - - # dipoles - dpl.plot(f.ax['dpl_L'], xlim, layer='agg') - # f.set_linecolor('dpl_L', 'k') - f.ax['dpl_L'].hold(True) - dpl.plot(f.ax['dpl_L'], xlim, layer='L5') - dpl.plot(f.ax['dpl_L'], xlim, layer='L2') - - color_dpl_L5 = '#1e90ff' - - # these colors mirror below, should be vars - f.ax['dpl_L'].lines[0].set_color('k') - f.ax['dpl_L'].lines[1].set_color(color_dpl_L5) - f.ax['dpl_L'].lines[2].set_color('#b22222') - - # plot and color - dpl.plot(f.ax['dpl_M'], xlim, layer='L2') - f.set_linecolor('dpl_M', '#b22222') - - # plot and color - dpl.plot(f.ax['dpl_R'], xlim, layer='L5') - f.set_linecolor('dpl_R', color_dpl_L5) - - # equalize the ylim - # f.equalize_ylim(list_h_pgram) - f.equalize_speclim(pc) - f.equalize_ylim(['dpl_L', 'dpl_R']) - ylim_dpl_M = dpl.lim('L2', xlim) - # f.ysymmetry(f.ax['dpl_M']) - # f.ax['dpl_M'].set_ylim(ylim_dpl_M) - f.ax['dpl_M'].set_ylim((-0.01, 0.01)) - for ax in f.ax.keys(): - if ax.startswith('dpl'): - f.ax[ax].locator_params(axis='y', nbins=7) - - # spikes - spikes = { - 'L2': spikefn.filter_spike_dict(s, 'L2_'), - 'L5': spikefn.filter_spike_dict(s, 'L5_'), - } - - # plot the data - spikefn.spike_png(f.ax['spk_M'], spikes['L2']) - spikefn.spike_png(f.ax['spk_R'], spikes['L5']) - f.ax['spk_M'].set_xlim(xlim) - f.ax['spk_R'].set_xlim(xlim) - - # thin the yaxis - # function defined in FigBase() - # f.thin_yaxis(f.ax['current_M'], 5) - f.ax['current_M'].locator_params(axis='y', nbins=5) - f.ax['current_M'].set_ylim((-0.20, 0)) - f.ax['current_R'].locator_params(axis='y', nbins=5) - f.ax['current_R'].set_ylim((-0.8, 0)) - - # Welch number of labels - f.ax['pgram_M'].locator_params(axis='y', nbins=5) - f.ax['pgram_R'].locator_params(axis='y', nbins=5) - f.ax['pgram_L'].locator_params(axis='y', nbins=5) - - # set the colors - f.set_linecolor('pgram_L', 'k') - f.set_linecolor('pgram_R', 'k') - f.set_linecolor('pgram_M', 'k') - - # save the fig in ddata0 (arbitrary) - f_prefix = '%s_laminar' % ddata.sim_prefix - dfig = os.path.join(ddata.dsim, expmt) - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -# compares PING regimes for two different trial runs -def compare_ping(): - dproj = fio.return_data_dir() - runtype = 'pub2' - # runtype = 'debug' - - # data directories (made up for now) - # the resultant figure is saved in d0 - d0 = os.path.join(dproj, 'pub', '2013-06-28_gamma_ping_L5-000') - d1 = os.path.join(dproj, 'pub', '2013-07-31_gamma_weak_L5-000') - # d1 = os.path.join(dproj, 'pub', '2013-06-28_gamma_weak_L5-000') - - # hard code the data for now - ddata0 = fio.SimulationPaths() - ddata1 = fio.SimulationPaths() - - # use read_sim() to read the simulations - ddata0.read_sim(dproj, d0) - ddata1.read_sim(dproj, d1) - - # for now grab the first experiment in each - expmt0 = ddata0.expmt_groups[0] - expmt1 = ddata1.expmt_groups[0] - - # for now hard code the simulation run - run0 = 0 - run1 = 0 - - # prints the fig in ddata0 - f = acg.FigL5PingExample(runtype) - - # first panel data - f_spec0 = ddata0.file_match(expmt0, 'rawspec')[run0] - f_dpl0 = ddata0.file_match(expmt0, 'rawdpl')[run0] - f_spk0 = ddata0.file_match(expmt0, 'rawspk')[run0] - f_param0 = ddata0.file_match(expmt0, 'param')[run0] - f_current0 = ddata0.file_match(expmt0, 'rawcurrent')[run0] - - # figure out the tstop and xlim - tstop0 = paramrw.find_param(f_param0, 'tstop') - dt = paramrw.find_param(f_param0, 'dt') - xlim0 = (50., tstop0) - - # grab the dipole data - dpl0 = dipolefn.Dipole(f_dpl0) - dpl0.baseline_renormalize(f_param0) - dpl0.convert_fAm_to_nAm() - - # calculate the Welch periodogram - f_max = 150. - pgram0 = specfn.Welch(dpl0.t, dpl0.dpl['L5'], dt) - pgram0.plot_to_ax(f.ax['pgram_L'], f_max) - - # grab the spike data - s0 = spikefn.spikes_from_file(f_param0, f_spk0) - s0_L5 = spikefn.filter_spike_dict(s0, 'L5_') - - # plot the spike histogram data - icell0_spikes = np.concatenate(s0_L5['L5_basket'].spike_list) - ecell0_spikes = np.concatenate(s0_L5['L5_pyramidal'].spike_list) - - # 1 ms bins - n_bins = int(tstop0) - - f.ax['hist_L'].hist(icell0_spikes, n_bins, facecolor='r', histtype='stepfilled', alpha=0.75, edgecolor='none') - f.ax_twinx['hist_L'].hist(ecell0_spikes, n_bins, facecolor='k') - - # based on number of cells - f.ax['hist_L'].set_ylim((0, 20)) - f.ax_twinx['hist_L'].set_ylim((0, 100)) - - f.ax_twinx['hist_L'].set_xlim(xlim0) - f.ax['hist_L'].set_xlim(xlim0) - - # hack - labels = f.ax['hist_L'].yaxis.get_ticklocs() - labels_text = [str(label) for label in labels[:-1]] - for i in range(len(labels_text)): - labels_text[i] = '' - - labels_text.append('20') - f.ax['hist_L'].set_yticklabels(labels_text) - - labels_twinx = f.ax_twinx['hist_L'].yaxis.get_ticklocs() - labels_text = [str(label) for label in labels_twinx[:-1]] - for i in range(len(labels_text)): - labels_text[i] = '' - - labels_text.append('100') - f.ax_twinx['hist_L'].set_yticklabels(labels_text) - - # grab the current data - I_soma0 = currentfn.SynapticCurrent(f_current0) - I_soma0.convert_nA_to_uA() - - # plot the data - dpl0.plot(f.ax['dpl_L'], xlim0, layer='L5') - spikefn.spike_png(f.ax['raster_L'], s0_L5) - f.ax['raster_L'].set_xlim(xlim0) - - # second panel data - f_spec1 = ddata1.file_match(expmt1, 'rawspec')[run1] - f_dpl1 = ddata1.file_match(expmt1, 'rawdpl')[run1] - f_spk1 = ddata1.file_match(expmt1, 'rawspk')[run1] - f_param1 = ddata1.file_match(expmt1, 'param')[run1] - f_current1 = ddata1.file_match(expmt1, 'rawcurrent')[run1] - - # figure out the tstop and xlim - tstop1 = paramrw.find_param(f_param1, 'tstop') - xlim1 = (50., tstop1) - - # grab the dipole data - dpl1 = dipolefn.Dipole(f_dpl1) - dpl1.baseline_renormalize(f_param1) - dpl1.convert_fAm_to_nAm() - - # calculate the Welch periodogram - pgram1 = specfn.Welch(dpl1.t, dpl1.dpl['L5'], dt) - pgram1.plot_to_ax(f.ax['pgram_R'], f_max) - - # grab the spike data - s1 = spikefn.spikes_from_file(f_param1, f_spk1) - s1_L5 = spikefn.filter_spike_dict(s1, 'L5_') - # s1_L2 = spikefn.filter_spike_dict(s1, 'L2_') - - # plot the spike histogram data - icell1_spikes = np.concatenate(s1_L5['L5_basket'].spike_list) - ecell1_spikes = np.concatenate(s1_L5['L5_pyramidal'].spike_list) - - # 1 ms bins - n_bins = int(tstop1) - - f.ax['hist_R'].hist(icell1_spikes, n_bins, facecolor='r', histtype='stepfilled', alpha=0.75, edgecolor='none') - f.ax_twinx['hist_R'].hist(ecell1_spikes, n_bins, facecolor='k') - - # based on number of cells - f.ax['hist_R'].set_ylim((0, 12)) - f.ax_twinx['hist_R'].set_ylim((0, 12)) - - f.ax_twinx['hist_R'].set_xlim(xlim0) - f.ax['hist_R'].set_xlim(xlim0) - - # hack - labels = f.ax['hist_R'].yaxis.get_ticklocs() - labels_text = [str(label) for label in labels[:-1]] - for i in range(len(labels_text)): - labels_text[i] = '' - - labels_text.append('12') - f.ax['hist_R'].set_yticklabels(labels_text) - - labels_twinx = f.ax_twinx['hist_R'].yaxis.get_ticklocs() - labels_text = [str(label) for label in labels_twinx[:-1]] - for i in range(len(labels_text)): - labels_text[i] = '' - - labels_text.append('12') - f.ax_twinx['hist_R'].set_yticklabels(labels_text) - - # grab the current data - I_soma1 = currentfn.SynapticCurrent(f_current1) - I_soma1.convert_nA_to_uA() - - # plot the data - dpl1.plot(f.ax['dpl_R'], xlim1, layer='L5') - f.ysymmetry(f.ax['dpl_R']) - spikefn.spike_png(f.ax['raster_R'], s1_L5) - f.ax['raster_R'].set_xlim(xlim1) - - # plot the spec data - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], f_spec0, xlim0, layer='L5'), - 'R': specfn.pspec_ax(f.ax['spec_R'], f_spec1, xlim1, layer='L5'), - } - - # f.equalize_speclim(pc) - - # grab the dipole figure handles - list_h_dpl = [h for h in f.ax.keys() if h.startswith('dpl')] - for ax_h in list_h_dpl: - f.ax[ax_h].locator_params(axis='y', nbins=5) - # f.equalize_ylim(list_h_dpl) - - # and the pgrams - # list_h_pgram = [h for h in f.ax.keys() if h.startswith('pgram')] - # test = f.equalize_ylim(list_h_pgram) - - # plot current and do lims - I_soma0.plot_to_axis(f.ax['current_L'], 'L5') - I_soma1.plot_to_axis(f.ax['current_R'], 'L5') - list_h_current = [ax_h for ax_h in f.ax.keys() if ax_h.startswith('current')] - f.equalize_ylim(list_h_current) - - # this is a hack - # now in uA instead of nA - for ax_handle in f.ax.keys(): - if ax_handle.startswith('current_'): - f.ax[ax_handle].set_ylim((-2, 0.)) - - # testing something - # f.ax['pgram_L'].set_yscale('log') - # f.ax['pgram_R'].set_yscale('log') - # f.ax['pgram_L'].set_ylim((1e-12, 1e-3)) - # f.ax['pgram_R'].set_ylim((1e-12, 1e-3)) - - # save the fig in ddata0 (arbitrary) - f_prefix = 'gamma_L5ping_L5weak' - dfig = os.path.join(ddata0.dsim, expmt0) - - # create the colorbars - cb = dict.fromkeys(pc) - - if runtype in ('debug', 'pub2'): - for key in pc.keys(): - key_ax = 'spec_' + key - cb[key] = f.f.colorbar(pc[key], ax=f.ax[key_ax], format='%.1e') - - elif runtype == 'pub': - cb['R'] = f.f.colorbar(pc['R'], ax=f.ax['spec_R'], format='%.1e') - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -def sub_dist_examples(): - dproj = fio.return_data_dir() - # runtype = 'pub2' - runtype = 'debug' - - # data directories (made up for now) - # the resultant figure is saved in d0 - d0 = os.path.join(dproj, 'pub', '2013-07-01_gamma_sub_50Hz-002') - d1 = os.path.join(dproj, 'pub', '2013-07-18_gamma_sub_100Hz-000') - - # hard code the data for now - ddata0 = fio.SimulationPaths() - ddata1 = fio.SimulationPaths() - - # use read_sim() to read the simulations - ddata0.read_sim(dproj, d0) - ddata1.read_sim(dproj, d1) - - # for now grab the first experiment in each - expmt0 = ddata0.expmt_groups[0] - expmt1 = ddata1.expmt_groups[0] - - # for now hard code the simulation run - run0 = 0 - run1 = 0 - - # number of bins for the spike histograms - n_bins = 500 - - # prints the fig in ddata0 - f = acg.FigSubDistExample(runtype) - - # first panel data - f_spec0 = ddata0.file_match(expmt0, 'rawspec')[run0] - f_dpl0 = ddata0.file_match(expmt0, 'rawdpl')[run0] - f_spk0 = ddata0.file_match(expmt0, 'rawspk')[run0] - f_param0 = ddata0.file_match(expmt0, 'param')[run0] - # f_current0 = ddata0.file_match(expmt0, 'rawcurrent')[run0] - - # figure out the tstop and xlim - tstop0 = paramrw.find_param(f_param0, 'tstop') - dt = paramrw.find_param(f_param0, 'dt') - xlim0 = (50., tstop0) - - # grab the dipole data - dpl0 = dipolefn.Dipole(f_dpl0) - dpl0.baseline_renormalize(f_param0) - dpl0.convert_fAm_to_nAm() - - # grab the current data - # I_soma0 = currentfn.SynapticCurrent(f_current0) - - # grab the spike data - _, p_dict0 = paramrw.read(f_param0) - s0 = spikefn.spikes_from_file(f_param0, f_spk0) - s0 = spikefn.alpha_feed_verify(s0, p_dict0) - sp_list = s0['alpha_feed_prox'].spike_list[0] - sd_list = s0['alpha_feed_dist'].spike_list[0] - spikefn.pinput_hist(f.ax['hist_L'], f.ax_twinx['hist_L'], sp_list, sd_list, n_bins, xlim0) - - # plot the data - dpl0.plot(f.ax['dpl_L'], xlim0, layer='L5') - - # second panel data - f_spec1 = ddata1.file_match(expmt1, 'rawspec')[run1] - f_dpl1 = ddata1.file_match(expmt1, 'rawdpl')[run1] - f_spk1 = ddata1.file_match(expmt1, 'rawspk')[run1] - f_param1 = ddata1.file_match(expmt1, 'param')[run1] - # f_current1 = ddata1.file_match(expmt1, 'rawcurrent')[run1] - - # figure out the tstop and xlim - tstop1 = paramrw.find_param(f_param1, 'tstop') - xlim1 = (50., tstop1) - - # grab the dipole data - dpl1 = dipolefn.Dipole(f_dpl1) - dpl1.baseline_renormalize(f_param1) - dpl1.convert_fAm_to_nAm() - - # # calculate the Welch periodogram - # pgram1 = specfn.Welch(dpl1.t, dpl1.dpl['L5'], dt) - # pgram1.plot_to_ax(f.ax['pgram_R'], f_max) - - # grab the spike data - _, p_dict1 = paramrw.read(f_param1) - s1 = spikefn.spikes_from_file(f_param1, f_spk1) - s1 = spikefn.alpha_feed_verify(s1, p_dict1) - sp_list = s1['alpha_feed_prox'].spike_list[0] - sd_list = s1['alpha_feed_dist'].spike_list[0] - spikefn.pinput_hist(f.ax['hist_R'], f.ax_twinx['hist_R'], sp_list, sd_list, n_bins, xlim1) - - # grab the current data - # I_soma1 = currentfn.SynapticCurrent(f_current1) - - # plot the data - dpl1.plot(f.ax['dpl_R'], xlim1, layer='L5') - # spikefn.spike_png(f.ax['raster_R'], s1_L5) - # f.ax['raster_R'].set_xlim(xlim1) - - # plot the spec data - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], f_spec0, xlim0, layer='L5'), - 'R': specfn.pspec_ax(f.ax['spec_R'], f_spec1, xlim1, layer='L5'), - } - - # f.equalize_speclim(pc) - - # # grab the dipole figure handles - # # list_h_dpl = [h for h in f.ax.keys() if h.startswith('dpl')] - # # f.equalize_ylim(list_h_dpl) - - # # and the pgrams - # # list_h_pgram = [h for h in f.ax.keys() if h.startswith('pgram')] - # # test = f.equalize_ylim(list_h_pgram) - - # # plot current and do lims - # I_soma0.plot_to_axis(f.ax['current_L'], 'L5') - # I_soma1.plot_to_axis(f.ax['current_R'], 'L5') - # for ax_handle in f.ax.keys(): - # if ax_handle.startswith('current_'): - # f.ax[ax_handle].set_ylim((-2000, 0.)) - - # # testing something - # # f.ax['pgram_L'].set_yscale('log') - # # f.ax['pgram_R'].set_yscale('log') - # # f.ax['pgram_L'].set_ylim((1e-12, 1e-3)) - # # f.ax['pgram_R'].set_ylim((1e-12, 1e-3)) - - # save the fig in ddata0 (arbitrary) - f_prefix = 'gamma_sub_examples' - dfig = os.path.join(ddata0.dsim, expmt0) - - # create the colorbars - cb = dict.fromkeys(pc) - - if runtype == 'debug': - for key in pc.keys(): - key_ax = 'spec_' + key - cb[key] = f.f.colorbar(pc[key], ax=f.ax[key_ax]) - - elif runtype == 'pub': - cb['R'] = f.f.colorbar(pc['R'], ax=f.ax['spec_R']) - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -def sub_dist_example2(): - dproj = fio.return_data_dir() - runtype = 'pub2' - # runtype = 'debug' - - # data directories (made up for now) - # the resultant figure is saved in d0 - d0 = os.path.join(dproj, 'pub', '2013-08-07_gamma_sub_50Hz-000') - - # hard code the data for now - ddata0 = fio.SimulationPaths() - - # use read_sim() to read the simulations - ddata0.read_sim(dproj, d0) - - # for now grab the first experiment in each - expmt0 = ddata0.expmt_groups[0] - - # for now hard code the simulation run - run0 = 0 - run1 = 1 - - # number of bins for the spike histograms - n_bins = 500 - - # prints the fig in ddata0 - f = acg.FigSubDistExample(runtype) - - # first panel data - f_spec0 = ddata0.file_match(expmt0, 'rawspec')[run0] - f_dpl0 = ddata0.file_match(expmt0, 'rawdpl')[run0] - f_spk0 = ddata0.file_match(expmt0, 'rawspk')[run0] - f_param0 = ddata0.file_match(expmt0, 'param')[run0] - # f_current0 = ddata0.file_match(expmt0, 'rawcurrent')[run0] - - # figure out the tstop and xlim - tstop0 = paramrw.find_param(f_param0, 'tstop') - dt = paramrw.find_param(f_param0, 'dt') - xlim0 = (50., tstop0) - - # grab the dipole data - dpl0 = dipolefn.Dipole(f_dpl0) - dpl0.baseline_renormalize(f_param0) - dpl0.convert_fAm_to_nAm() - - # grab the current data - # I_soma0 = currentfn.SynapticCurrent(f_current0) - - # grab the spike data - _, p_dict0 = paramrw.read(f_param0) - s0 = spikefn.spikes_from_file(f_param0, f_spk0) - s0 = spikefn.alpha_feed_verify(s0, p_dict0) - sp_list = s0['alpha_feed_prox'].spike_list[0] - sd_list = s0['alpha_feed_dist'].spike_list[0] - spikefn.pinput_hist(f.ax['hist_L'], f.ax_twinx['hist_L'], sp_list, sd_list, n_bins, xlim0) - - # plot the data - dpl0.plot(f.ax['dpl_L'], xlim0, layer='L5') - - # second panel data - f_spec1 = ddata0.file_match(expmt0, 'rawspec')[run1] - f_dpl1 = ddata0.file_match(expmt0, 'rawdpl')[run1] - f_spk1 = ddata0.file_match(expmt0, 'rawspk')[run1] - f_param1 = ddata0.file_match(expmt0, 'param')[run1] - # f_current1 = ddata1.file_match(expmt1, 'rawcurrent')[run1] - - # figure out the tstop and xlim - tstop1 = paramrw.find_param(f_param1, 'tstop') - xlim1 = (50., tstop1) - - # grab the dipole data - dpl1 = dipolefn.Dipole(f_dpl1) - dpl1.baseline_renormalize(f_param1) - dpl1.convert_fAm_to_nAm() - - # # calculate the Welch periodogram - # pgram1 = specfn.Welch(dpl1.t, dpl1.dpl['L5'], dt) - # pgram1.plot_to_ax(f.ax['pgram_R'], f_max) - - # grab the spike data - _, p_dict1 = paramrw.read(f_param1) - s1 = spikefn.spikes_from_file(f_param1, f_spk1) - s1 = spikefn.alpha_feed_verify(s1, p_dict1) - sp_list = s1['alpha_feed_prox'].spike_list[0] - sd_list = s1['alpha_feed_dist'].spike_list[0] - spikefn.pinput_hist(f.ax['hist_R'], f.ax_twinx['hist_R'], sp_list, sd_list, n_bins, xlim1) - f.ax['hist_R'].set_ylim((0., 20.)) - f.ax_twinx['hist_R'].set_ylim((20., 0.)) - - # grab the current data - # I_soma1 = currentfn.SynapticCurrent(f_current1) - - # plot the data - dpl1.plot(f.ax['dpl_R'], xlim1, layer='L5') - # spikefn.spike_png(f.ax['raster_R'], s1_L5) - # f.ax['raster_R'].set_xlim(xlim1) - - # plot the spec data - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], f_spec0, xlim0, layer='L5'), - 'R': specfn.pspec_ax(f.ax['spec_R'], f_spec1, xlim1, layer='L5'), - } - - # change the xlim format - f.set_notation_scientific([ax for ax in f.ax.keys() if ax.startswith('dpl')], n=2) - - # f.equalize_speclim(pc) - - # # grab the dipole figure handles - list_h_dpl = [h for h in f.ax.keys() if h.startswith('dpl')] - f.equalize_ylim(list_h_dpl) - - # hack. - f.ax['dpl_R'].set_yticklabels('') - - # and the pgrams - # list_h_pgram = [h for h in f.ax.keys() if h.startswith('pgram')] - # test = f.equalize_ylim(list_h_pgram) - - # # plot current and do lims - # I_soma0.plot_to_axis(f.ax['current_L'], 'L5') - # I_soma1.plot_to_axis(f.ax['current_R'], 'L5') - # for ax_handle in f.ax.keys(): - # if ax_handle.startswith('current_'): - # f.ax[ax_handle].set_ylim((-2000, 0.)) - - # # testing something - # # f.ax['pgram_L'].set_yscale('log') - # # f.ax['pgram_R'].set_yscale('log') - # # f.ax['pgram_L'].set_ylim((1e-12, 1e-3)) - # # f.ax['pgram_R'].set_ylim((1e-12, 1e-3)) - - # save the fig in ddata0 (arbitrary) - f_prefix = 'gamma_sub_examples' - dfig = os.path.join(ddata0.dsim, expmt0) - - # create the colorbars - cb = dict.fromkeys(pc) - - if runtype in ('debug', 'pub2'): - for key in pc.keys(): - key_ax = 'spec_' + key - cb[key] = f.f.colorbar(pc[key], ax=f.ax[key_ax]) - - f.remove_twinx_labels() - - elif runtype == 'pub': - cb['R'] = f.f.colorbar(pc['R'], ax=f.ax['spec_R']) - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -# plots a histogram of e cell spikes relative to I cell spikes -def spikephase(): - dproj = fio.return_data_dir() - # runtype = 'pub2' - runtype = 'debug' - - # data directories (made up for now) - # the resultant figure is saved in d0 - d0 = os.path.join(dproj, 'pub', '2013-06-28_gamma_weak_L5-000') - - # hard code the data for now - ddata0 = fio.SimulationPaths() - - # use read_sim() to read the simulations - ddata0.read_sim(dproj, d0) - - # for now grab the first experiment in each - expmt0 = ddata0.expmt_groups[0] - - # for now hard code the simulation run - run0 = 0 - - # prints the fig in ddata0 - f = ac.FigStd() - - # create a twin axis - f.create_axis_twinx('ax0') - - # first panel data - f_spec0 = ddata0.file_match(expmt0, 'rawspec')[run0] - f_dpl0 = ddata0.file_match(expmt0, 'rawdpl')[run0] - f_spk0 = ddata0.file_match(expmt0, 'rawspk')[run0] - f_param0 = ddata0.file_match(expmt0, 'param')[run0] - # f_current0 = ddata0.file_match(expmt0, 'rawcurrent')[run0] - - # figure out the tstop and xlim - tstop0 = paramrw.find_param(f_param0, 'tstop') - dt = paramrw.find_param(f_param0, 'dt') - xlim0 = (0., tstop0) - - # grab the spike data - _, p_dict0 = paramrw.read(f_param0) - s0 = spikefn.spikes_from_file(f_param0, f_spk0) - icell_spikes = s0['L5_basket'].spike_list - ecell_spikes = s0['L5_pyramidal'].spike_list - - ispike_counts = [len(slist) for slist in icell_spikes] - espike_counts = [len(slist) for slist in ecell_spikes] - - # let's try a sort ... - icell_spikes_agg = np.concatenate(icell_spikes) - ecell_spikes_agg = np.concatenate(ecell_spikes) - - # lop off the first 50 ms - icell_spikes_agg = icell_spikes_agg[icell_spikes_agg >= 50] - icell_spikes_agg_sorted = np.sort(icell_spikes_agg) - - n_bins = int(tstop0 - 50) - - f.ax['ax0'].hist(icell_spikes_agg, n_bins, facecolor='r', histtype='stepfilled', alpha=0.75, edgecolor='none') - f.ax_twinx['ax0'].hist(ecell_spikes_agg, n_bins, facecolor='k') - # f.ax_twinx['ax0'].hist(ecell_spikes_agg, n_bins, facecolor='k', alpha=0.75) - - # sets these lims to the MAX number of possible events per bin (n_celltype limited) - f.ax['ax0'].set_ylim((0, 35)) - f.ax_twinx['ax0'].set_ylim((0, 100)) - - f.ax_twinx['ax0'].set_xlim((50, tstop0)) - f.ax['ax0'].set_xlim((50, tstop0)) - - f.savepng_new(d0, 'testing') - f.close() - - # save the fig in ddata0 (arbitrary) - f_prefix = 'gamma_spikephase' - dfig = os.path.join(ddata0.dsim, expmt0) - -def peaks(): - dproj = fio.return_data_dir() - # runtype = 'pub2' - runtype = 'debug' - - # data directories (made up for now) - # the resultant figure is saved in d0 - d0 = os.path.join(dproj, 'pub', '2013-07-01_gamma_sub_50Hz-002') - # d1 = os.path.join(dproj, 'pub', '2013-07-18_gamma_sub_100Hz-000') - - # hard code the data for now - ddata0 = fio.SimulationPaths() - - # use read_sim() to read the simulations - ddata0.read_sim(dproj, d0) - - # for now grab the first experiment in each - expmt0 = ddata0.expmt_groups[0] - - # for now hard code the simulation run - run0 = 0 - - # prints the fig in ddata0 - f = acg.FigPeaks(runtype) - - # first panel data - f_spec0 = ddata0.file_match(expmt0, 'rawspec')[run0] - f_dpl0 = ddata0.file_match(expmt0, 'rawdpl')[run0] - f_spk0 = ddata0.file_match(expmt0, 'rawspk')[run0] - f_param0 = ddata0.file_match(expmt0, 'param')[run0] - # f_current0 = ddata0.file_match(expmt0, 'rawcurrent')[run0] - - # figure out the tstop and xlim - tstop0 = paramrw.find_param(f_param0, 'tstop') - dt = paramrw.find_param(f_param0, 'dt') - xlim0 = (0., tstop0) - - # grab the dipole data - dpl0 = dipolefn.Dipole(f_dpl0) - dpl0.baseline_renormalize(f_param0) - dpl0.convert_fAm_to_nAm() - - # grab the current data - # I_soma0 = currentfn.SynapticCurrent(f_current0) - - # grab the spike data - _, p_dict0 = paramrw.read(f_param0) - s0 = spikefn.spikes_from_file(f_param0, f_spk0) - s0 = spikefn.alpha_feed_verify(s0, p_dict0) - sp_list = s0['alpha_feed_prox'].spike_list[0] - sd_list = s0['alpha_feed_dist'].spike_list[0] - # spikefn.pinput_hist(f.ax['hist_L'], f.ax_twinx['hist_L'], sp_list, sd_list, n_bins, xlim0) - - # plot the data - dpl0.plot(f.ax['dpl_L'], xlim0, layer='L5') - - # plot the spec data - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], f_spec0, xlim0, layer='L5'), - } - - # save the fig in ddata0 (arbitrary) - f_prefix = 'gamma_peaks' - dfig = os.path.join(ddata0.dsim, expmt0) - - # create the colorbars - cb = dict.fromkeys(pc) - - if runtype == 'debug': - for key in pc.keys(): - key_ax = 'spec_' + key - cb[key] = f.f.colorbar(pc[key], ax=f.ax[key_ax]) - - elif runtype == 'pub': - cb['R'] = f.f.colorbar(pc['L'], ax=f.ax['spec_L']) - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -# needs spec for multiple experiments, will plot 2 examples and aggregate -def pgamma_distal_phase(ddata, data_L=0, data_M=1, data_R=2): - layer_specific = 'agg' - - for expmt in ddata.expmt_groups: - f = acg.FigDistalPhase() - - # grab file lists - list_spec = ddata.file_match(expmt, 'rawspec') - list_dpl = ddata.file_match(expmt, 'rawdpl') - list_spk = ddata.file_match(expmt, 'rawspk') - list_param = ddata.file_match(expmt, 'param') - - # grab the tstop and make an xlim - T = paramrw.find_param(list_param[0], 'tstop') - xlim = (50., T) - - # grab the input frequency, try prox before dist - f_max = paramrw.find_param(list_param[0], 'f_input_prox') - - # only try dist if prox is 0, otherwise, use prox - if not f_max: - f_max = paramrw.find_param(list_param[0], 'f_input_dist') - - # dealing with the left panel - dpl_L = dipolefn.Dipole(list_dpl[data_L]) - dpl_L.baseline_renormalize(list_param[data_L]) - dpl_L.convert_fAm_to_nAm() - dpl_L.plot(f.ax['dpl_L'], xlim, layer='agg') - - # middle data panel - dpl_M = dipolefn.Dipole(list_dpl[data_M]) - dpl_M.baseline_renormalize(list_param[data_M]) - dpl_M.convert_fAm_to_nAm() - dpl_M.plot(f.ax['dpl_M'], xlim, layer='agg') - - # dealing with right panel - dpl_R = dipolefn.Dipole(list_dpl[data_R]) - dpl_R.baseline_renormalize(list_param[data_R]) - dpl_R.convert_fAm_to_nAm() - dpl_R.plot(f.ax['dpl_R'], xlim, layer='agg') - - # get the vmin, vmax and add them to the master list - pc = { - 'L': specfn.pspec_ax(f.ax['spec_L'], list_spec[data_L], xlim, layer=layer_specific), - 'M': specfn.pspec_ax(f.ax['spec_M'], list_spec[data_M], xlim, layer=layer_specific), - 'R': specfn.pspec_ax(f.ax['spec_R'], list_spec[data_R], xlim, layer=layer_specific), - } - - # use the equalize function - f.equalize_speclim(pc) - - # create colorbars - f.f.colorbar(pc['L'], ax=f.ax['spec_L']) - f.f.colorbar(pc['M'], ax=f.ax['spec_M']) - f.f.colorbar(pc['R'], ax=f.ax['spec_R']) - - # hist data - xlim_hist = (50., 100.) - - # get the data for the left panel - _, p_dict = paramrw.read(list_param[data_L]) - s_L = spikefn.spikes_from_file(list_param[data_L], list_spk[data_L]) - s_L = spikefn.alpha_feed_verify(s_L, p_dict) - # n_bins = spikefn.hist_bin_opt(s_L['alpha_feed_prox'].spike_list, 10) - n_bins = 500 - - # prox and dist spike lists - sp_list = spike_list_truncate(s_L['alpha_feed_prox'].spike_list[0]) - sd_list = spike_list_truncate(s_L['alpha_feed_dist'].spike_list[0]) - spikefn.pinput_hist(f.ax['hist_L'], f.ax_twinx['hist_L'], sp_list, sd_list, n_bins, xlim_hist) - - # same motif as previous lines, I'm tired. - _, p_dict = paramrw.read(list_param[data_M]) - s_M = spikefn.spikes_from_file(list_param[data_M], list_spk[data_M]) - s_M = spikefn.alpha_feed_verify(s_M, p_dict) - sp_list = spike_list_truncate(s_M['alpha_feed_prox'].spike_list[0]) - sd_list = spike_list_truncate(s_M['alpha_feed_dist'].spike_list[0]) - spikefn.pinput_hist(f.ax['hist_M'], f.ax_twinx['hist_M'], sp_list, sd_list, n_bins, xlim_hist) - - # same motif as previous lines, I'm tired. - _, p_dict = paramrw.read(list_param[data_R]) - s_R = spikefn.spikes_from_file(list_param[data_R], list_spk[data_R]) - s_R = spikefn.alpha_feed_verify(s_R, p_dict) - sp_list = spike_list_truncate(s_R['alpha_feed_prox'].spike_list[0]) - sd_list = spike_list_truncate(s_R['alpha_feed_dist'].spike_list[0]) - spikefn.pinput_hist(f.ax['hist_R'], f.ax_twinx['hist_R'], sp_list, sd_list, n_bins, xlim_hist) - - # now do the aggregate data - # theta is the normalized phase - list_spec_max = np.zeros(len(list_spec)) - list_theta = np.zeros(len(list_spec)) - list_delay = np.zeros(len(list_spec)) - - i = 0 - for fspec, fparam in zip(list_spec, list_param): - # f_max comes from the input f - # f_max = 50. - t_pd = 1000. / f_max - - # read the data - data_spec = specfn.read(fspec) - - # use specpwr_stationary() to get an aggregate measure of power over the entire time - p_stat = specfn.specpwr_stationary(data_spec['time'], data_spec['freq'], data_spec['TFR']) - - # this is ONLY for aggregate and NOT for individual layers right now - # here, f_max is the hard coded one and NOT the calculated one from specpwr_stationary() - list_spec_max[i] = p_stat['p'][p_stat['f']==f_max] - - # get the relevant param's value - t0_prox = paramrw.find_param(fparam, 't0_input_prox') - t0_dist = paramrw.find_param(fparam, 't0_input_dist') - - # calculating these two together BUT don't need to. Cleanness beats efficiency here - list_delay[i] = t0_dist - t0_prox - list_theta[i] = list_delay[i] / t_pd - - i += 1 - - f.ax['aggregate'].plot(list_delay, list_spec_max, marker='o') - - # deal with names - f_prefix = 'gamma_%s_distal_phase' % expmt - dfig = os.path.join(ddata.dsim, expmt) - - f.savepng_new(dfig, f_prefix) - f.saveeps(dfig, f_prefix) - f.close() - -def spike_list_truncate(s_list): - return s_list[(s_list > 55.) & (s_list < 100.)] - -# needs spec for 3 experiments -# really a generic comparison of the top 3 sims in a given exp -# the list is naturally truncated by the length of ax_suffices -def pgamma_stdev(ddata): - for expmt in ddata.expmt_groups: - # runtype = 'debug' - # runtype = 'pub2' - runtype = 'pub' - - f = acg.Fig3PanelPlusAgg(runtype) - - # data types - list_spec = ddata.file_match(expmt, 'rawspec') - list_dpl = ddata.file_match(expmt, 'rawdpl') - list_param = ddata.file_match(expmt, 'param') - list_spk = ddata.file_match(expmt, 'rawspk') - - # time info - T = paramrw.find_param(list_param[0], 'tstop') - xlim = (50., T) - - # assume only the first 3 files are the ones we care about - ax_suffices = [ - '_L', - '_M', - '_R', - '_FR', - ] - - # dpl handles list - list_handles_dpl = [] - - # spec handles - pc = {} - - # lists in zip are naturally truncated by the shortest list - for ax_end, fdpl, fspec, fparam, fspk in zip(ax_suffices, list_dpl, list_spec, list_param, list_spk): - # create axis handle names - ax_dpl = 'dpl%s' % ax_end - ax_spec = 'spec%s' % ax_end - ax_hist = 'hist%s' % ax_end - - # add to my terrible list - list_handles_dpl.append(ax_dpl) - - # grab the dipole and convert - dpl = dipolefn.Dipole(fdpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - # plot relevant data - dpl.plot(f.ax[ax_dpl], xlim, layer='L5') - pc[ax_spec] = specfn.pspec_ax(f.ax[ax_spec], fspec, xlim, layer='L5') - - # only set the colorbar for all axes in debug mode - # otherwise set only for the rightmost spec axis - if runtype in ('debug', 'pub2'): - f.f.colorbar(pc[ax_spec], ax=f.ax[ax_spec]) - - elif runtype == 'pub': - if ax_end == '_FR': - f.f.colorbar(pc[ax_spec], ax=f.ax[ax_spec]) - - # histogram stuff - _, p_dict = paramrw.read(fparam) - s = spikefn.spikes_from_file(fparam, fspk) - s = spikefn.alpha_feed_verify(s, p_dict) - - # result of the optimization function, for the right 2 panels. 100 - # was the value returned for the L panel for f plot - # result for stdev plot was 290, 80, 110 - # n_bins = spikefn.hist_bin_opt(s['alpha_feed_prox'][0].spike_list, 10) - # print n_bins - n_bins = 110 - - # plot the hist - spikefn.pinput_hist_onesided(f.ax[ax_hist], s['alpha_feed_prox'][0].spike_list, n_bins) - f.ax[ax_hist].set_xlim(xlim) - - # equalize ylim on hists - list_ax_hist = [ax for ax in f.ax.keys() if ax.startswith('hist')] - f.equalize_ylim(list_ax_hist) - - # normalize the spec - f.equalize_speclim(pc) - f.remove_twinx_labels() - - # normalize the dpl with that hack - # centers c and lim l - # c = [1e-3, 1.2e-3, 1.8e-3] - # l = 2e-3 - # for h in list_handles_dpl: - # f.ax[h].set_ylim((-3e-3, 3e-3)) - f.ysymmetry(f.ax['dpl_L']) - f.set_notation_scientific(['dpl_L']) - list_ax_dpl = [ax for ax in f.ax.keys() if ax.startswith('dpl')] - f.equalize_ylim(list_ax_dpl) - - # some fig naming stuff - fprefix_short = 'gamma_%s_compare3' % expmt - dfig = os.path.join(ddata.dsim, expmt) - - # use methods to save figs - f.savepng_new(dfig, fprefix_short) - f.saveeps(dfig, fprefix_short) - f.close() - -def pgamma_stdev_new(ddata, p): - for expmt in ddata.expmt_groups: - # runtype = 'debug' - runtype = 'pub2' - # runtype = 'pub' - - f = acg.Fig3PanelPlusAgg(runtype) - - # data types - list_spec = ddata.file_match(expmt, 'rawspec') - list_dpl = ddata.file_match(expmt, 'rawdpl') - list_param = ddata.file_match(expmt, 'param') - list_spk = ddata.file_match(expmt, 'rawspk') - - # time info - T = paramrw.find_param(list_param[0], 'tstop') - xlim = (50., T) - - # assume only the first 3 files are the ones we care about - ax_suffices = [ - '_L', - '_M', - '_R', - '_FR', - ] - - # dpl handles list - list_handles_dpl = [] - - # spec handles - pc = {} - - # pgram_list - list_pgram = [] - - # lists in zip are naturally truncated by the shortest list - for ax_end, fdpl, fspec, fparam, fspk in zip(ax_suffices, list_dpl, list_spec, list_param, list_spk): - # create axis handle names - ax_dpl = 'dpl%s' % ax_end - ax_spec = 'spec%s' % ax_end - ax_hist = 'hist%s' % ax_end - - # add to my terrible list - list_handles_dpl.append(ax_dpl) - - # grab the dipole and convert - dpl = dipolefn.Dipole(fdpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - # find the dpl lim - ylim_dpl = dpl.lim('L5', xlim) - - # plot relevant data - dpl.plot(f.ax[ax_dpl], xlim, layer='L5') - f.ax[ax_dpl].set_ylim(ylim_dpl) - pc[ax_spec] = specfn.pspec_ax(f.ax[ax_spec], fspec, xlim, layer='L5') - - # only set the colorbar for all axes in debug mode - # otherwise set only for the rightmost spec axis - if runtype in ('debug', 'pub2'): - f.f.colorbar(pc[ax_spec], ax=f.ax[ax_spec]) - - elif runtype == 'pub': - if ax_end == '_FR': - f.f.colorbar(pc[ax_spec], ax=f.ax[ax_spec]) - - # histogram stuff - _, p_dict = paramrw.read(fparam) - s = spikefn.spikes_from_file(fparam, fspk) - s = spikefn.alpha_feed_verify(s, p_dict) - - # result of the optimization function, for the right 2 panels. 100 - # was the value returned for the L panel for f plot - # result for stdev plot was 290, 80, 110 - # n_bins = spikefn.hist_bin_opt(s['alpha_feed_prox'][0].spike_list, 10) - # print n_bins - n_bins = 110 - - # plot the hist - spikefn.pinput_hist_onesided(f.ax[ax_hist], s['alpha_feed_prox'][0].spike_list, n_bins) - # print s['alpha_feed_prox'] - f.ax[ax_hist].set_xlim(xlim) - - # run the Welch and plot it - # get the dt, for the welch - dt = paramrw.find_param(fparam, 'dt') - list_pgram.append(specfn.Welch(dpl.t, dpl.dpl['L5'], dt)) - list_pgram[-1].scale(1e7) - # f.ax['pgram'].plot(list_pgram[-1].f, list_pgram[-1].P) - list_pgram[-1].plot_to_ax(f.ax['pgram'], p['f_max_welch']) - print list_pgram[-1].units - - # equalize ylim on hists - list_ax_hist = [ax for ax in f.ax.keys() if ax.startswith('hist')] - f.equalize_ylim(list_ax_hist) - for ax_h in list_ax_hist: - f.ax[ax_h].set_ylim((0, 10)) - f.ax[ax_h].locator_params(axis='y', nbins=3) - - # f.ax['pgram'].yaxis.tick_right() - - # normalize the spec - # f.equalize_speclim(pc) - f.remove_twinx_labels() - - # normalize the dpl with that hack - # centers c and lim l - # c = [1e-3, 1.2e-3, 1.8e-3] - # l = 2e-3 - # for h in list_handles_dpl: - # f.ax[h].set_ylim((-3e-3, 3e-3)) - # f.ysymmetry(f.ax['dpl_L']) - # f.set_notation_scientific(['dpl_L']) - list_ax_dpl = [ax for ax in f.ax.keys() if ax.startswith('dpl')] - f.equalize_ylim(list_ax_dpl) - for ax_h in list_ax_dpl: - f.ax[ax_h].locator_params(axis='y', nbins=5) - - # some fig naming stuff - fprefix_short = 'gamma_%s_compare3' % expmt - dfig = os.path.join(ddata.dsim, expmt) - - # use methods to save figs - f.savepng_new(dfig, fprefix_short) - f.saveeps(dfig, fprefix_short) - f.close() - -def prox_dist_new(ddata, p): - for expmt in ddata.expmt_groups: - # runtype = 'debug' - runtype = 'pub2' - # runtype = 'pub' - - f = acg.Fig3PanelPlusAgg(runtype) - - # data types - list_spec = ddata.file_match(expmt, 'rawspec') - list_dpl = ddata.file_match(expmt, 'rawdpl') - list_param = ddata.file_match(expmt, 'param') - list_spk = ddata.file_match(expmt, 'rawspk') - - # time info - T = paramrw.find_param(list_param[0], 'tstop') - xlim = (50., T) - - # assume only the first 3 files are the ones we care about - ax_suffices = [ - '_L', - '_M', - '_R', - '_FR', - ] - - # dpl handles list - list_handles_dpl = [] - - # spec handles - pc = {} - - # pgram_list - list_pgram = [] - - # lists in zip are naturally truncated by the shortest list - for ax_end, fdpl, fspec, fparam, fspk in zip(ax_suffices, list_dpl, list_spec, list_param, list_spk): - # create axis handle names - ax_dpl = 'dpl%s' % ax_end - ax_spec = 'spec%s' % ax_end - ax_hist = 'hist%s' % ax_end - - # add to my terrible list - list_handles_dpl.append(ax_dpl) - - # grab the dipole and convert - dpl = dipolefn.Dipole(fdpl) - dpl.baseline_renormalize(fparam) - dpl.convert_fAm_to_nAm() - - # find the dpl lim - ylim_dpl = dpl.lim('L5', xlim) - - # plot relevant data - dpl.plot(f.ax[ax_dpl], xlim, layer='L5') - f.ax[ax_dpl].set_ylim(ylim_dpl) - pc[ax_spec] = specfn.pspec_ax(f.ax[ax_spec], fspec, xlim, layer='L5') - - # only set the colorbar for all axes in debug mode - # otherwise set only for the rightmost spec axis - if runtype in ('debug', 'pub2'): - f.f.colorbar(pc[ax_spec], ax=f.ax[ax_spec]) - - elif runtype == 'pub': - if ax_end == '_FR': - f.f.colorbar(pc[ax_spec], ax=f.ax[ax_spec]) - - # histogram stuff - n_bins = 110 - _, p_dict = paramrw.read(fparam) - s = spikefn.spikes_from_file(fparam, fspk) - s = spikefn.alpha_feed_verify(s, p_dict) - sp_list = s['alpha_feed_prox'].spike_list[0] - sd_list = s['alpha_feed_dist'].spike_list[0] - spikefn.pinput_hist(f.ax[ax_hist], f.ax_twinx[ax_hist], sp_list, sd_list, n_bins, xlim) - - # result of the optimization function, for the right 2 panels. 100 - # was the value returned for the L panel for f plot - # result for stdev plot was 290, 80, 110 - # n_bins = spikefn.hist_bin_opt(s['alpha_feed_prox'][0].spike_list, 10) - # print n_bins - - # plot the hist - # spikefn.pinput_hist_onesided(f.ax[ax_hist], s['alpha_feed_prox'][0].spike_list, n_bins) - # print s['alpha_feed_prox'] - f.ax[ax_hist].set_xlim(xlim) - - # run the Welch and plot it - # get the dt, for the welch - dt = paramrw.find_param(fparam, 'dt') - list_pgram.append(specfn.Welch(dpl.t, dpl.dpl['L5'], dt)) - list_pgram[-1].scale(1e7) - # f.ax['pgram'].plot(list_pgram[-1].f, list_pgram[-1].P) - list_pgram[-1].plot_to_ax(f.ax['pgram'], p['f_max_welch']) - print list_pgram[-1].units - - # equalize ylim on hists - list_ax_hist = [ax for ax in f.ax.keys() if ax.startswith('hist')] - f.equalize_ylim(list_ax_hist) - - for ax_h in list_ax_hist: - f.ax[ax_h].set_ylim((0, 20)) - f.ax_twinx[ax_h].set_ylim((20, 0)) - - f.ax[ax_h].locator_params(axis='y', nbins=3) - f.ax_twinx[ax_h].locator_params(axis='y', nbins=3) - - f.ax_twinx['hist_M'].set_yticklabels('') - f.ax_twinx['hist_R'].set_yticklabels('') - # f.ax['pgram'].yaxis.tick_right() - - # normalize the spec - # f.equalize_speclim(pc) - # f.remove_twinx_labels() - - # normalize the dpl with that hack - # centers c and lim l - # c = [1e-3, 1.2e-3, 1.8e-3] - # l = 2e-3 - # for h in list_handles_dpl: - # f.ax[h].set_ylim((-3e-3, 3e-3)) - # f.ysymmetry(f.ax['dpl_L']) - # f.set_notation_scientific(['dpl_L']) - list_ax_dpl = [ax for ax in f.ax.keys() if ax.startswith('dpl')] - f.equalize_ylim(list_ax_dpl) - for ax_h in list_ax_dpl: - f.ax[ax_h].locator_params(axis='y', nbins=5) - - # some fig naming stuff - fprefix_short = 'gamma_%s_compare3' % expmt - dfig = os.path.join(ddata.dsim, expmt) - - # use methods to save figs - f.savepng_new(dfig, fprefix_short) - f.saveeps(dfig, fprefix_short) - f.close() - -# manual setting of ylims -def ylim_hack(f, list_handles, ylim_centers, ylim_limit): - # ylim_centers = [1.5e-5, 2e-5, 2.5e-5] - # ylim_limit = 1.5e-5 - - # gross - for h, c in zip(list_handles, ylim_centers): - f.ax[h].grid(True, which='minor') - ylim = (c - ylim_limit, c + ylim_limit) - f.ax[h].set_ylim(ylim) - f.f.canvas.draw() - # labels = [tick.get_text() for tick in f.ax[list_handles[1]].get_yticklabels()] - labels = f.ax[h].yaxis.get_ticklocs() - labels_text = [str(label) for label in labels[:-1]] - labels_text[0] = '' - f.ax[h].set_yticklabels(labels_text) - # print labels_text - -if __name__ == '__main__': - hf_epochs() diff --git a/ppsth.py b/ppsth.py deleted file mode 100644 index 735309f7b..000000000 --- a/ppsth.py +++ /dev/null @@ -1,230 +0,0 @@ -# ppsth.py - Plots aggregate psth of all trials in an "experiment" -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed it.izip()) -# last rev: (SL: changed class names from axes_create.py) - -import numpy as np -import itertools as it -import matplotlib.pyplot as plt -import os -import paramrw, spikefn -import fileio as fio -import axes_create as ac -# from axes_create import ac.FigPSTH, ac.FigPSTHGrid - -def ppsth_grid(simpaths): - # get filename lists in dictionaries of experiments - dict_exp_param = simpaths.exp_files_of_type('param') - dict_exp_spk = simpaths.exp_files_of_type('rawspk') - - # recreate the ExpParams object used in the simulation - p_exp = paramrw.ExpParams(simpaths.fparam[0]) - - # need number of lambda vals (cols) and number of sigma vals (rows) - try: - N_rows = len(p_exp.p_all['L2Pyr_Gauss_A_weight']) - except TypeError: - N_rows = 1 - - try: - N_cols = len(p_exp.p_all['L2Basket_Pois_lamtha']) - except TypeError: - N_cols = 1 - - tstop = p_exp.p_all['tstop'] - - print N_rows, N_cols, tstop - - # ugly but slightly less ugly than the index arithmetic i had planned. muahaha - f = ac.FigGrid(N_rows, N_cols, tstop) - - # create coordinates for axes - # this is backward-looking for a reason! - axes_coords = [(j, i) for i, j in it.product(np.arange(N_cols), np.arange(N_rows))] - - if len(simpaths.expnames) != len(axes_coords): - print "um ... see ppsth.py" - - # assumes a match between expnames and the keys of the previous dicts - for expname, axis_coord in zip(simpaths.expnames, axes_coords): - # get the tstop - exp_param_list = dict_exp_param[expname] - exp_spk_list = dict_exp_spk[expname] - gid_dict, p = paramrw.read(exp_param_list[0]) - tstop = p['tstop'] - lamtha = p['L2Basket_Pois_lamtha'] - sigma = p['L2Pyr_Gauss_A_weight'] - - # these are total spike dicts for the experiments - s_L2Pyr_list = [] - # s_L5Pyr_list = [] - - # iterate through params and spikes for a given experiment - for fparam, fspk in zip(dict_exp_param[expname], dict_exp_spk[expname]): - # get gid dict - gid_dict, p = paramrw.read(fparam) - - # get spike dict - s_dict = spikefn.spikes_from_file(gid_dict, fspk) - - # add a new entry to list for each different file assoc with an experiment - s_L2Pyr_list.append(np.array(list(it.chain.from_iterable(s_dict['L2_pyramidal'].spike_list)))) - # s_L5Pyr_list.append(np.array(list(it.chain.from_iterable(s_dict['L5_pyramidal'].spike_list)))) - - # now aggregate over all spikes - s_L2Pyr = np.array(list(it.chain.from_iterable(s_L2Pyr_list))) - # s_L5Pyr = np.array(list(it.chain.from_iterable(s_L5Pyr_list))) - - # optimize bins, currently unused for comparison reasons! - N_trials = len(fparam) - bin_L2 = 250 - # bin_L5 = 120 - # bin_L2 = spikefn.hist_bin_opt(s_L2Pyr, N_trials) - # bin_L5 = spikefn.hist_bin_opt(s_L5Pyr, N_trials) - - r = axis_coord[0] - c = axis_coord[1] - # create standard fig and axes - f.ax[r][c].hist(s_L2Pyr, bin_L2, facecolor='g', alpha=0.75) - - if r == 0: - f.ax[r][c].set_title(r'$\lambda_i$ = %d' % lamtha) - - if c == 0: - f.ax[r][c].set_ylabel(r'$A_{gauss}$ = %.3e' % sigma) - # f.ax[r][c].set_ylabel(r'$\sigma_{gauss}$ = %d' % sigma) - - # normalize these axes - y_L2 = f.ax[r][c].get_ylim() - # y_L2 = f.ax['L2_psth'].get_ylim() - - print expname, lamtha, sigma, r, c, y_L2[1] - - f.ax[r][c].set_ylim((0, 250.)) - # f.ax['L2_psth'].set_ylim((0, 450.)) - # f.ax['L5_psth'].set_ylim((0, 450.)) - - # spikefn.spike_png(f.ax['L2'], s_dict_L2) - # spikefn.spike_png(f.ax['L5'], s_dict_L5) - # spikefn.spike_png(f.ax['L2_extpois'], s_dict_L2_extpois) - # spikefn.spike_png(f.ax['L2_extgauss'], s_dict_L2_extgauss) - # spikefn.spike_png(f.ax['L5_extpois'], s_dict_L5_extpois) - # spikefn.spike_png(f.ax['L5_extgauss'], s_dict_L5_extgauss) - - # testfig.ax0.plot(t_vec, dp_total) - fig_name = os.path.join(simpaths.dsim, 'aggregate.eps') - - plt.savefig(fig_name) - f.close() - - # run the compression - fio.epscompress(simpaths.dsim, '.eps', 1) - -# will take a directory, find the files bin all the psth's, plot a representative spike raster -def ppsth(simpaths): - # get filename lists in dictionaries of experiments - dict_exp_param = simpaths.exp_files_of_type('param') - dict_exp_spk = simpaths.exp_files_of_type('rawspk') - - # assumes a match between expnames and the keys of the previous dicts - for expname in simpaths.expnames: - # get the tstop - exp_param_list = dict_exp_param[expname] - exp_spk_list = dict_exp_spk[expname] - gid_dict, p = paramrw.read(exp_param_list[0]) - # gid_dict, p = paramrw.read(dict_exp_param[expname][0]) - tstop = p['tstop'] - - # get representative spikes - s_dict = spikefn.spikes_from_file(gid_dict, exp_spk_list[0]) - - s_dict_L2 = {} - s_dict_L5 = {} - s_dict_L2_extgauss = {} - s_dict_L2_extpois = {} - s_dict_L5_extgauss = {} - s_dict_L5_extpois = {} - - # clean out s_dict destructively - # borrowed from praster - for key in s_dict.keys(): - # do this first to remove all extgauss feeds - if 'extgauss' in key: - if 'L2_' in key: - s_dict_L2_extgauss[key] = s_dict.pop(key) - - elif 'L5_' in key: - s_dict_L5_extgauss[key] = s_dict.pop(key) - - elif 'extpois' in key: - # s_dict_extpois[key] = s_dict.pop(key) - if 'L2_' in key: - s_dict_L2_extpois[key] = s_dict.pop(key) - - elif 'L5_' in key: - s_dict_L5_extpois[key] = s_dict.pop(key) - - # L2 next - elif 'L2_' in key: - s_dict_L2[key] = s_dict.pop(key) - - elif 'L5_' in key: - s_dict_L5[key] = s_dict.pop(key) - - # these are total spike dicts for the experiments - s_L2Pyr_list = [] - s_L5Pyr_list = [] - - # iterate through params and spikes for a given experiment - for fparam, fspk in zip(dict_exp_param[expname], dict_exp_spk[expname]): - # get gid dict - gid_dict, p = paramrw.read(fparam) - - # get spike dict - s_dict = spikefn.spikes_from_file(gid_dict, fspk) - - # add a new entry to list for each different file assoc with an experiment - s_L2Pyr_list.append(np.array(list(it.chain.from_iterable(s_dict['L2_pyramidal'].spike_list)))) - s_L5Pyr_list.append(np.array(list(it.chain.from_iterable(s_dict['L5_pyramidal'].spike_list)))) - - # now aggregate over all spikes - s_L2Pyr = np.array(list(it.chain.from_iterable(s_L2Pyr_list))) - s_L5Pyr = np.array(list(it.chain.from_iterable(s_L5Pyr_list))) - - # optimize bins, currently unused for comparison reasons! - N_trials = len(fparam) - # bin_L2 = 120 - # bin_L5 = 120 - bin_L2 = spikefn.hist_bin_opt(s_L2Pyr, N_trials) - bin_L5 = spikefn.hist_bin_opt(s_L5Pyr, N_trials) - - # create standard fig and axes - f = ac.FigPSTH(400.) - f.ax['L2_psth'].hist(s_L2Pyr, bin_L2, facecolor='g', alpha=0.75) - f.ax['L5_psth'].hist(s_L5Pyr, bin_L5, facecolor='g', alpha=0.75) - - # normalize these axes - y_L2 = f.ax['L2_psth'].get_ylim() - y_L5 = f.ax['L5_psth'].get_ylim() - - print y_L2, y_L5 - - # f.ax['L2_psth'].set_ylim((0, 450.)) - # f.ax['L5_psth'].set_ylim((0, 450.)) - - spikefn.spike_png(f.ax['L2'], s_dict_L2) - spikefn.spike_png(f.ax['L5'], s_dict_L5) - spikefn.spike_png(f.ax['L2_extpois'], s_dict_L2_extpois) - spikefn.spike_png(f.ax['L2_extgauss'], s_dict_L2_extgauss) - spikefn.spike_png(f.ax['L5_extpois'], s_dict_L5_extpois) - spikefn.spike_png(f.ax['L5_extgauss'], s_dict_L5_extgauss) - - # # testfig.ax0.plot(t_vec, dp_total) - fig_name = os.path.join(simpaths.dsim, expname+'.eps') - - plt.savefig(fig_name) - f.close() - - # run the compression - fio.epscompress(simpaths.dsim, '.eps', 1) diff --git a/praster.py b/praster.py deleted file mode 100644 index 0ad9c021b..000000000 --- a/praster.py +++ /dev/null @@ -1,67 +0,0 @@ -# praster.py - plot dipole function -# -# v 1.9.2a -# rev 2013-04-08 (SL: changed spikes_from_file) -# last major: (SL: minor changes to FigRaster) - -import os -import numpy as np -import matplotlib.pyplot as plt -from neuron import h as nrn -from axes_create import FigRaster -import spikefn as spikefn - -# file_info is (rootdir, subdir, -def praster(f_param, tstop, file_spk, dfig): - # ddipole is dipole data - s_dict = spikefn.spikes_from_file(f_param, file_spk) - - s_dict_L2 = {} - s_dict_L5 = {} - s_dict_L2_extgauss = {} - s_dict_L2_extpois = {} - s_dict_L5_extgauss = {} - s_dict_L5_extpois = {} - - # clean out s_dict destructively - for key in s_dict.keys(): - # do this first to remove all extgauss feeds - if 'extgauss' in key: - if 'L2_' in key: - s_dict_L2_extgauss[key] = s_dict.pop(key) - - elif 'L5_' in key: - s_dict_L5_extgauss[key] = s_dict.pop(key) - - elif 'extpois' in key: - # s_dict_extpois[key] = s_dict.pop(key) - if 'L2_' in key: - s_dict_L2_extpois[key] = s_dict.pop(key) - - elif 'L5_' in key: - s_dict_L5_extpois[key] = s_dict.pop(key) - - # L2 next - elif 'L2_' in key: - s_dict_L2[key] = s_dict.pop(key) - - elif 'L5_' in key: - s_dict_L5[key] = s_dict.pop(key) - - # split to find file prefix - file_prefix = file_spk.split('/')[-1].split('.')[0] - - # create standard fig and axes - f = FigRaster(tstop) - spikefn.spike_png(f.ax['L2'], s_dict_L2) - spikefn.spike_png(f.ax['L5'], s_dict_L5) - spikefn.spike_png(f.ax['L2_extpois'], s_dict_L2_extpois) - spikefn.spike_png(f.ax['L2_extgauss'], s_dict_L2_extgauss) - spikefn.spike_png(f.ax['L5_extpois'], s_dict_L5_extpois) - spikefn.spike_png(f.ax['L5_extgauss'], s_dict_L5_extgauss) - - # testfig.ax0.plot(t_vec, dp_total) - fig_name = os.path.join(dfig, file_prefix+'.png') - - plt.savefig(fig_name, dpi=300) - f.close() diff --git a/praw.py b/praw.py deleted file mode 100644 index 4be678e5b..000000000 --- a/praw.py +++ /dev/null @@ -1,154 +0,0 @@ -# praw.py - all of the raw data types on one fig -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: updated for it.izip() and return_data_dir()) -# last major: (SL: minor) - -import fileio as fio -import numpy as np -import multiprocessing as mp -import ast -import os -import paramrw -import dipolefn -import spikefn -import specfn -import currentfn -import matplotlib.pyplot as plt -from neuron import h as nrn -import axes_create as ac - -def pkernel(dfig_dpl, f_dpl, f_spk, f_spec, f_current, f_spec_current, f_param, ax_handles, spec_cmap='jet'): - T = paramrw.find_param(f_param, 'tstop') - xlim = (50., T) - - # into the pdipole directory, this will plot dipole, spec, and spikes - # create the axis handle - f = ac.FigDipoleExp(ax_handles) - - # create the figure name - fprefix = fio.strip_extprefix(f_dpl) + '-dpl' - fname = os.path.join(dfig_dpl, fprefix + '.png') - - # grab the dipole - dpl = dipolefn.Dipole(f_dpl) - dpl.convert_fAm_to_nAm() - - # plot the dipole to the agg axes - dpl.plot(f.ax['dpl_agg'], xlim) - dpl.plot(f.ax['dpl_agg_L5'], xlim) - # f.ax['dpl_agg_L5'].hold(True) - # dpl.plot(f.ax['dpl_agg_L5'], xlim, 'L5') - - # plot individual dipoles - dpl.plot(f.ax['dpl'], xlim, 'L2') - dpl.plot(f.ax['dpl_L5'], xlim, 'L5') - - # f.ysymmetry(f.ax['dpl']) - # print dpl.max('L5', (0., -1)), dpl.max('L5', (50., -1)) - # print f.ax['dpl_L5'].get_ylim() - # f.ax['dpl_L5'].set_ylim((-1e5, 1e5)) - # f.ysymmetry(f.ax['dpl_L5']) - - # plot the current - I_soma = currentfn.SynapticCurrent(f_current) - I_soma.plot_to_axis(f.ax['I_soma'], 'L2') - I_soma.plot_to_axis(f.ax['I_soma_L5'], 'L5') - - # plot the dipole-based spec data - pc = specfn.pspec_ax(f.ax['spec_dpl'], f_spec, xlim, 'L2') - f.f.colorbar(pc, ax=f.ax['spec_dpl']) - - pc = specfn.pspec_ax(f.ax['spec_dpl_L5'], f_spec, xlim, 'L5') - f.f.colorbar(pc, ax=f.ax['spec_dpl_L5']) - - # grab the current spec and plot them - spec_L2, spec_L5 = data_spec_current = specfn.read(f_spec_current, type='current') - pc_L2 = f.ax['spec_I'].imshow(spec_L2['TFR'], aspect='auto', origin='upper',cmap=plt.get_cmap(spec_cmap)) - pc_L5 = f.ax['spec_I_L5'].imshow(spec_L5['TFR'], aspect='auto', origin='upper',cmap=plt.get_cmap(spec_cmap)) - - # plot the current-based spec data - # pci = specfn.pspec_ax(f.ax['spec_I'], f_spec_current, type='current') - f.f.colorbar(pc_L2, ax=f.ax['spec_I']) - f.f.colorbar(pc_L5, ax=f.ax['spec_I_L5']) - - # get all spikes - s = spikefn.spikes_from_file(f_param, f_spk) - - # these work primarily because of how the keys are done - # in the spike dict s (consequence of spikefn.spikes_from_file()) - s_L2 = spikefn.filter_spike_dict(s, 'L2_') - s_L5 = spikefn.filter_spike_dict(s, 'L5_') - - # resize xlim based on our 50 ms cutoff thingy - xlim = (50., xlim[1]) - - # plot the spikes - spikefn.spike_png(f.ax['spk'], s_L2) - spikefn.spike_png(f.ax['spk_L5'], s_L5) - - f.ax['dpl'].set_xlim(xlim) - # f.ax['dpl_L5'].set_xlim(xlim) - # f.ax['spec_dpl'].set_xlim(xlim) - f.ax['spk'].set_xlim(xlim) - f.ax['spk_L5'].set_xlim(xlim) - - f.savepng(fname) - f.close() - - return 0 - -# dummy function for callback -def cb(r): - pass - -# For a given ddata (SimulationPaths object), find the mean dipole -# over ALL trials in ALL conditions in EACH experiment -def praw(ddata): - # grab the original dipole from a specific dir - dproj = fio.return_data_dir() - - runtype = 'parallel' - # runtype = 'debug' - - # check on spec data - # generates both spec because both are needed here - specfn.generate_missing_spec(ddata) - - # test experiment - # expmt_group = ddata.expmt_groups[0] - - ax_handles = [ - 'dpl_agg', - 'dpl', - 'spec_dpl', - 'spk', - 'I_soma', - 'spec_I', - ] - - # iterate over exmpt groups - for expmt_group in ddata.expmt_groups: - dfig_dpl = ddata.dfig[expmt_group]['figdpl'] - - # grab lists of files (l_) - l_dpl = ddata.file_match(expmt_group, 'rawdpl') - l_spk = ddata.file_match(expmt_group, 'rawspk') - l_param = ddata.file_match(expmt_group, 'param') - l_spec = ddata.file_match(expmt_group, 'rawspec') - l_current = ddata.file_match(expmt_group, 'rawcurrent') - l_spec_current = ddata.file_match(expmt_group, 'rawspeccurrent') - - if runtype == 'parallel': - pl = mp.Pool() - - for f_dpl, f_spk, f_spec, f_current, f_spec_current, f_param \ - in zip(l_dpl, l_spk, l_spec, l_current, l_spec_current, l_param): - pl.apply_async(pkernel, (dfig_dpl, f_dpl, f_spk, f_spec, f_current, f_spec_current, f_param, ax_handles), callback=cb) - pl.close() - pl.join() - - elif runtype == 'debug': - for f_dpl, f_spk, f_spec, f_current, f_spec_current, f_param \ - in zip(l_dpl, l_spk, l_spec, l_current, l_spec_current, l_param): - pkernel(dfig_dpl, f_dpl, f_spk, f_spec, f_current, f_spec_current, f_param, ax_handles) diff --git a/pspec.py b/pspec.py deleted file mode 100644 index c88ab0bf6..000000000 --- a/pspec.py +++ /dev/null @@ -1,287 +0,0 @@ -# pspec.py - Very long plotting methods having to do with spec. -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed it.izip()) -# last major: (MS: Updated calls to extinput.plot_hist()) - -import os -import sys -import numpy as np -import scipy.signal as sps -import matplotlib.pyplot as plt -import paramrw -import fileio as fio -import multiprocessing as mp -from neuron import h as nrn -from math import ceil - -import fileio as fio -import currentfn -import dipolefn -import specfn -import spikefn -import axes_create as ac - -# this is actually a plot kernel for one sim that does dipole, etc. -# needs f_param not p_dict -def pspec_dpl(f_spec, f_dpl, dfig, p_dict, key_types, xlim=None, ylim=None, f_param=None): - # Generate file prefix - fprefix = f_spec.split('/')[-1].split('.')[0] - - # using png for now - # fig_name = os.path.join(dfig, fprefix+'.eps') - fig_name = os.path.join(dfig, fprefix+'.png') - - # f.f is the figure handle! - f = ac.FigSpec() - - # load spec data - spec = specfn.Spec(f_spec) - - # Plot TFR data and add colorbar - pc = spec.plot_TFR(f.ax['spec'], 'agg', xlim, ylim) - f.f.colorbar(pc, ax=f.ax['spec']) - - # grab the dipole data - # data_dipole = np.loadtxt(open(f_dpl, 'r')) - dpl = dipolefn.Dipole(f_dpl) - - # If f_param supplied, renormalize dipole data - if f_param: - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - - # plot routine - dpl.plot(f.ax['dipole'], xlim, 'agg') - - # Plot Welch data - # Use try/except for backwards compatibility - try: - spec.plot_pgram(f.ax['pgram']) - - except KeyError: - pgram = specfn.Welch(dpl.t, dpl.dpl['agg'], p_dict['dt']) - pgram.plot_to_ax(f.ax['pgram'], spec.spec['agg']['f'][-1]) - - # plot and create an xlim - xlim_new = f.ax['spec'].get_xlim() - xticks = f.ax['spec'].get_xticks() - xticks[0] = xlim_new[0] - - # for now, set the xlim for the other one, force it! - f.ax['dipole'].set_xlim(xlim_new) - f.ax['dipole'].set_xticks(xticks) - f.ax['spec'].set_xticks(xticks) - - # axis labels - f.ax['spec'].set_xlabel('Time (ms)') - f.ax['spec'].set_ylabel('Frequency (Hz)') - - # create title - title_str = ac.create_title(p_dict, key_types) - f.f.suptitle(title_str) - - # use our fig classes to save and close - f.savepng(fig_name) - # f.saveeps(dfig, fprefix) - f.close() - -# Spectral plotting kernel with alpha feed histogram for ONE simulation run -def pspec_with_hist(f_spec, f_dpl, f_spk, dfig, f_param, key_types, xlim=None, ylim=None): - # Generate file prefix - # print('f_spec:',f_spec) - fprefix = f_spec.split('/')[-1].split('.')[0] - # Create the fig name - fig_name = os.path.join(dfig, fprefix+'.png') - # print('fig_name:',fig_name) - # load param dict - _, p_dict = paramrw.read(f_param) - f = ac.FigSpecWithHist() - # load spec data - spec = specfn.Spec(f_spec) - # Plot TFR data and add colorbar - pc = spec.plot_TFR(f.ax['spec'], 'agg', xlim, ylim) - f.f.colorbar(pc, ax=f.ax['spec']) - # set xlim based on TFR plot - xlim_new = f.ax['spec'].get_xlim() - # grab the dipole data - dpl = dipolefn.Dipole(f_dpl) - dpl.baseline_renormalize(f_param) - dpl.convert_fAm_to_nAm() - # plot routine - dpl.plot(f.ax['dipole'], xlim_new, 'agg') - # data_dipole = np.loadtxt(open(f_dpl, 'r')) - # t_dpl = data_dipole[xmin_ind:xmax_ind+1, 0] - # dp_total = data_dipole[xmin_ind:xmax_ind+1, 1] - # f.ax['dipole'].plot(t_dpl, dp_total) - # x = (xmin, xmax) - # # grab alpha feed data. spikes_from_file() from spikefn.py - # s_dict = spikefn.spikes_from_file(f_param, f_spk) - # # check for existance of alpha feed keys in s_dict. - # s_dict = spikefn.alpha_feed_verify(s_dict, p_dict) - # # Account for possible delays - # s_dict = spikefn.add_delay_times(s_dict, p_dict) - # Get extinput data and account for delays - try: - extinputs = spikefn.ExtInputs(f_spk, f_param) - except ValueError: - print("Error: could not load spike timings from %s" % f_spk) - f.close() - return - - extinputs.add_delay_times() - extinputs.get_envelope(dpl.t, feed='dist') - # set number of bins (150 bins per 1000ms) - bins = ceil(150. * (xlim_new[1] - xlim_new[0]) / 1000.) # bins should be int - # plot histograms - hist = {} - hist['feed_prox'] = extinputs.plot_hist(f.ax['feed_prox'], 'prox', dpl.t, bins=bins, xlim=xlim_new, color='red') - hist['feed_dist'] = extinputs.plot_hist(f.ax['feed_dist'], 'dist', dpl.t, bins=bins, xlim=xlim_new, color='green') - f.ax['feed_dist'].invert_yaxis() - # for now, set the xlim for the other one, force it! - f.ax['dipole'].set_xlim(xlim_new) - f.ax['spec'].set_xlim(xlim_new) - f.ax['feed_prox'].set_xlim(xlim_new) - f.ax['feed_dist'].set_xlim(xlim_new) - # set hist axis props - f.set_hist_props(hist) - # axis labels - f.ax['spec'].set_xlabel('Time (ms)') - f.ax['spec'].set_ylabel('Frequency (Hz)') - # Add legend to histogram - for key in f.ax.keys(): - if 'feed' in key: - f.ax[key].legend() - # create title - title_str = ac.create_title(p_dict, key_types) - f.f.suptitle(title_str) - f.savepng(fig_name) - f.close() - -def pspecpwr(file_name, results_list, fparam_list, key_types, error_vec=[]): - # instantiate fig - f = ac.FigStd() - f.set_fontsize(18) - - # pspecpwr_ax is a plot kernel for specpwr plotting - legend_list = pspecpwr_ax(f.ax0, results_list, fparam_list, key_types) - - # Add error bars if necessary - if len(error_vec): - # errors are only used with avg'ed data. There will be only one entry in results_list - pyerrorbars_ax(f.ax0, results_list[0]['freq'], results_list[0]['p_avg'], error_vec) - - # insert legend - f.ax0.legend(legend_list, loc='upper right', prop={'size': 8}) - - # axes labels - f.ax0.set_xlabel('Freq (Hz)') - f.ax0.set_ylabel('Avgerage Power (nAm^2)') - - # add title - # f.set_title(fparam_list[0], key_types) - - f.savepng(file_name) - # f.save(file_name) - # f.saveeps(file_name) - f.close() - -# frequency-power analysis plotting kernel -def pspecpwr_ax(ax_specpwr, specpwr_list, fparam_list, key_types): - ax_specpwr.hold(True) - - # Preallocate legend list - legend_list = [] - - # iterate over freqpwr results and param list to plot and construct legend - for result, fparam in zip(specpwr_list, fparam_list): - # Plot to axis - ax_specpwr.plot(result['freq'], result['p_avg']) - - # Build legend - p = paramrw.read(fparam)[1] - lgd_temp = [key + ': %2.1f' %p[key] for key in key_types['dynamic_keys']] - legend_list.append(reduce(lambda x, y: x+', '+y, lgd_temp[:])) - - # Do not need to return axis, apparently - return legend_list - -# Plot vertical error bars -def pyerrorbars_ax(ax, x, y, yerr_vec): - ax.errorbar(x, y, xerr=None, yerr=yerr_vec, fmt=None, ecolor='blue') - -def aggregate_with_hist(f, ax, f_spec, f_dpl, f_spk, f_param, spec_cmap='jet'): - # load param dict - _, p_dict = paramrw.read(f_param) - - # load spec data from file - spec = specfn.Spec(f_spec) - # data_spec = np.load(f_spec) - - # timevec = data_spec['time'] - # freqvec = data_spec['freq'] - # TFR = data_spec['TFR'] - - xmin = timevec[0] - xmax = p_dict['tstop'] - x = (xmin, xmax) - - pc = spec.plot_TFR(ax['spec'], layer='agg', xlim=x) - # pc = ax['spec'].imshow(TFR, extent=[timevec[0], timevec[-1], freqvec[-1], freqvec[0]], aspect='auto', origin='upper') - f.f.colorbar(pc, ax=ax['spec'],cmap=plt.get_cmap(spec_cmap)) - - # grab the dipole data - dpl = dipolefn.Dipole(f_dpl) - dpl.plot(ax['dipole'], x, layer='agg') - # data_dipole = np.loadtxt(open(f_dpl, 'r')) - - # t_dpl = data_dipole[xmin/p_dict['dt']:, 0] - # dp_total = data_dipole[xmin/p_dict['dt']:, 1] - - # ax['dipole'].plot(t_dpl, dp_total) - - # grab alpha feed data. spikes_from_file() from spikefn.py - s_dict = spikefn.spikes_from_file(f_param, f_spk) - - # check for existance of alpha feed keys in s_dict. - s_dict = spikefn.alpha_feed_verify(s_dict, p_dict) - - # Account for possible delays - s_dict = spikefn.add_delay_times(s_dict, p_dict) - - # set number of bins (150 bins/1000ms) - bins = 150. * (xmax - xmin) / 1000. - - hist = {} - - # Proximal feed - hist['feed_prox'] = ax['feed_prox'].hist(s_dict['alpha_feed_prox'].spike_list, bins, range=[xmin, xmax], color='red', label='Proximal feed') - - # Distal feed - hist['feed_dist'] = ax['feed_dist'].hist(s_dict['alpha_feed_dist'].spike_list, bins, range=[xmin, xmax], color='green', label='Distal feed') - - # for now, set the xlim for the other one, force it! - ax['dipole'].set_xlim(x) - ax['spec'].set_xlim(x) - ax['feed_prox'].set_xlim(x) - ax['feed_dist'].set_xlim(x) - - # set hist axis props - f.set_hist_props(ax, hist) - - # axis labels - ax['spec'].set_xlabel('Time (ms)') - ax['spec'].set_ylabel('Frequency (Hz)') - - # Add legend to histogram - for key in ax.keys(): - if 'feed' in key: - ax[key].legend() - - # create title - # title_str = ac.create_title(p_dict, key_types) - # f.f.suptitle(title_str) - # title_str = [key + ': %2.1f' % p_dict[key] for key in key_types['dynamic_keys']] - - # plt.savefig(fig_name) - # f.close() diff --git a/ptest.py b/ptest.py deleted file mode 100644 index e0e600e6f..000000000 --- a/ptest.py +++ /dev/null @@ -1,22 +0,0 @@ -# ptest.py - plot test function -# -# v 1.7.11a -# rev 2012-08-23 (SL: created) -# last major: - -import matplotlib as mpl -mpl.use("Agg") - -import matplotlib.pyplot as plt -from neuron import h as nrn -from axes_create import FigStd - -def ptest(t_vec, v_e, v_i): - testfig = FigStd() - testfig.ax0.hold(True) - - testfig.ax0.plot(t_vec, v_e) - testfig.ax0.plot(t_vec, v_i) - - plt.savefig('outputspikes.png') - testfig.close() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0d9ff76cc..000000000 --- a/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -MPI -Matplotlib -MNE-Python -NEURON compiled with MPI, Python support - PYTHONPATH must point to Python 3 -Numpy -PyOpenGL -Python 3 -PyQt5 -pyqtgraph -Scipy - diff --git a/run.py b/run.py deleted file mode 100755 index 6e390839e..000000000 --- a/run.py +++ /dev/null @@ -1,471 +0,0 @@ -#!/usr/bin/env python -# run.py - primary run function for s1 project -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: removed izip, fixed an nhost bug) -# last major: (SL: toward python3) -# other branch for hnn - -import os -import sys -import time -import shutil -import numpy as np -from neuron import h -h.load_file("stdrun.hoc") -# Cells are defined in other files -import network -import fileio as fio -import paramrw as paramrw -from paramrw import usingOngoingInputs -import plotfn as plotfn -import specfn as specfn -import pickle -from dipolefn import Dipole -from conf import readconf -from L5_pyramidal import L5Pyr -from L2_pyramidal import L2Pyr -from L2_basket import L2Basket -from L5_basket import L5Basket -from lfp import LFPElectrode -from morphology import shapeplot, getshapecoords -import traceback - -dconf = readconf() - -# data directory - ./data -dproj = dconf['datdir'] # fio.return_data_dir(dconf['datdir']) -debug = dconf['debug'] -pc = h.ParallelContext() -pcID = int(pc.id()) -f_psim = '' -ntrial = 1 -simlength = 0.0 -testLFP = dconf['testlfp']; -testlaminarLFP = dconf['testlaminarlfp'] -lelec = [] # list of LFP electrodes - -# reads the specified param file -foundprm = False -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.param'): - f_psim = sys.argv[i] - foundprm = True - if pcID==0 and debug: print('using ',f_psim,' param file.') - elif sys.argv[i] == 'ntrial' and i+1 0: - expmt_group = p_exp.expmt_groups[0] -else: - expmt_group = None -simparams = p = p_exp.return_pdict(expmt_group, 0) # return the param dict for this simulation - -pc.barrier() # get all nodes to this place before continuing -pc.gid_clear() - -# global variables, should be node-independent -h("dp_total_L2 = 0."); h("dp_total_L5 = 0.") - -# Set tstop before instantiating any classes -if simlength > 0.0: - h.tstop = simlength -else: - h.tstop = p['tstop'] # simulation duration - -h.dt = p['dt'] # simulation time-step -h.celsius = p['celsius'] # 37.0 # p['celsius'] # set temperature -# spike file needs to be known by all nodes -file_spikes_tmp = fio.file_spike_tmp(dproj) -net = network.NetworkOnNode(p) # create node-specific network - -t_vec = h.Vector(); t_vec.record(h._ref_t) # time recording -dp_rec_L2 = h.Vector(); dp_rec_L2.record(h._ref_dp_total_L2) # L2 dipole recording -dp_rec_L5 = h.Vector(); dp_rec_L5.record(h._ref_dp_total_L5) # L5 dipole recording - -net.movecellstopos() # position cells in 2D grid - -def expandbbox (boxA, boxB): - return [(min(boxA[i][0],boxB[i][0]),max(boxA[i][1],boxB[i][1])) for i in range(3)] - -def arrangelayers (): - # offsets for L2, L5 cells so that L5 below L2 in display - dyoff = {L2Pyr: 1000, 'L2_pyramidal' : 1000, - L5Pyr: -1000-149.39990234375, 'L5_pyramidal' : -1000-149.39990234375, - L2Basket: 1000, 'L2_basket' : 1000, - L5Basket: -1000-149.39990234375, 'L5_basket' : -1000-149.39990234375} - for cell in net.cells: cell.translate3d(0,dyoff[cell.celltype],0) - dcheck = {x:False for x in dyoff.keys()} - dbbox = {x:[[1e9,-1e9],[1e9,-1e9],[1e9,-1e9]] for x in dyoff.keys()} - for cell in net.cells: - - dbbox[cell.celltype] = expandbbox(dbbox[cell.celltype], cell.getbbox()) - #if dcheck[cell.celltype]: continue - """ - bbox = cell.getbbox() - lx,ly,lz = getshapecoords(h,cell.get_sections()) - if cell.celltype == L2Pyr or cell.celltype == 'L2_pyramidal': - print('L2Pyr bbox:',bbox)#,lx,ly,lz) - elif cell.celltype == L5Pyr or cell.celltype == 'L5_pyramidal': - print('L5Pyr bbox:',bbox)#,lx,ly,lz) - elif cell.celltype == L2Basket or cell.celltype == 'L2_basket': - print('L2Basket bbox:',bbox)#,lx,ly,lz) - elif cell.celltype == L5Basket or cell.celltype == 'L5_basket': - print('L5Basket bbox:',bbox)#,lx,ly,lz) - dcheck[cell.celltype]=True - """ - # for ty in ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']: print(ty, dbbox[ty]) - -arrangelayers() # arrange cells in layers - for visualization purposes - -pc.barrier() - -# save spikes from the individual trials in a single file -def catspks (): - lf = [os.path.join(datdir,'spk_'+str(i)+'.txt') for i in range(ntrial)] - if debug: print('catspk lf:',lf) - lspk = [[],[]] - for f in lf: - xarr = np.loadtxt(f) - for i in range(2): - lspk[i].extend(xarr[:,i]) - if debug: print('xarr.shape:',xarr.shape) - lspk = np.array(lspk).T - # lspk.sort(axis=1) # not multidim sort - can fix if want spikes across trials in temporal order - fout = os.path.join(datdir,'spk.txt') - with open(fout, 'w') as fspkout: - for i in range(lspk.shape[0]): - fspkout.write('%3.2f\t%d\n' % (lspk[i,0], lspk[i,1])) - if debug: print('lspk.shape:',lspk.shape) - return lspk - -# save average dipole from individual trials in a single file -def catdpl (): - ldpl = [] - for pre in ['dpl','rawdpl']: - lf = [os.path.join(datdir,pre+'_'+str(i)+'.txt') for i in range(ntrial)] - dpl_dat = np.array([np.loadtxt(f) for f in lf]) - try: - dpl = np.mean(dpl_dat,axis=0) - except ValueError: - print("ERROR: could not caluclate mean. Inconsistent trial lengths?") - with open(os.path.join(datdir,pre+'.txt'), 'w') as fp: - for i in range(dpl.shape[0]): - fp.write("%03.3f\t" % dpl[i,0]) - fp.write("%9.8f\t" % dpl[i,1]) - fp.write("%9.8f\t" % dpl[i,2]) - fp.write("%9.8f\n" % dpl[i,3]) - ldpl.append(dpl) - return ldpl - -# save average spectrogram from individual trials in a single file -def catspec (): - lf = [os.path.join(datdir,'rawspec_'+str(i)+'.npz') for i in range(ntrial)] - dspecin = {} - dout = {} - try: - for f in lf: dspecin[f] = np.load(f) - except: - return None - for k in ['t_L5', 'f_L5', 't_L2', 'f_L2', 'time', 'freq']: dout[k] = dspecin[lf[0]][k] - for k in ['TFR', 'TFR_L5', 'TFR_L2']: dout[k] = np.mean(np.array([dspecin[f][k] for f in lf]),axis=0) - with open(os.path.join(datdir,'rawspec.npz'), 'wb') as fdpl: - np.savez_compressed(fdpl,t_L5=dout['t_L5'],f_L5=dout['f_L5'],t_L2=dout['t_L2'],f_L2=dout['f_L2'],time=dout['time'],freq=dout['freq'],TFR=dout['TFR'],TFR_L5=dout['TFR_L5'],TFR_L2=dout['TFR_L2']) - return dout - -# gather trial outputs via either raw concatenation or averaging -def cattrialoutput (): - global doutf - lspk = catspks() # concatenate spikes from different trials to a single file - ldpl = catdpl() - dspec = catspec() - del lspk,ldpl,dspec # do not need these variables; returned for testing - -# run individual trials via runsim, then calc/save average dipole/specgram -# evinputinc is an increment (in milliseconds) that gets added to the evoked inputs on each -# successive trial. the default value is 0.0. -def runtrials (ntrial, inc_evinput=0.0): - global doutf - if pcID==0: print('Running', ntrial, 'trials.') - for i in range(ntrial): - if pcID==0: print(os.linesep+'Running trial',i+1,'...') - doutf = setoutfiles(ddir,i,ntrial) - # initrands(ntrial+(i+1)**ntrial) # reinit for each trial - net.state_init() # initialize voltages - runsim() # run the simulation - net.reset_src_event_times(inc_evinput = inc_evinput * (i + 1)) # adjusts the rng seeds and then the feed/event input times - doutf = setoutfiles(ddir,0,1) # reset output files based on sim name - if pcID==0: cattrialoutput() # get/save the averages - -def initrands (s=0): # fix to use s - # if there are N_trials, then randomize the seed - # establishes random seed for the seed seeder (yeah.) - # this creates a prng_tmp on each, but only the value from 0 will be used - prng_tmp = np.random.RandomState() - if pcID == 0: - r = h.Vector(1, s) # initialize vector to 1 element, with a 0 - if ntrial == 1: - prng_base = np.random.RandomState(pcID + s) - else: - # Create a random seed value - r.x[0] = prng_tmp.randint(1e9) - else: r = h.Vector(1, s) # create the vector 'r' but don't change its init value - pc.broadcast(r, 0) # broadcast random seed value in r to everyone - # set object prngbase to random state for the seed value - # other random seeds here will then be based on the gid - prng_base = np.random.RandomState(int(r.x[0])) - # seed list is now a list of seeds to be changed on each run - # otherwise, its originally set value will remain - # give a random int seed from [0, 1e9] - for param in p_exp.prng_seed_list: # this list empty for single experiment/trial - p[param] = prng_base.randint(1e9) - # print('simparams[prng_seedcore]:',simparams['prng_seedcore']) - - -initrands(0) # init once - -def setupLFPelectrodes (): - lelec = [] - if testlaminarLFP: - for y in np.linspace(1466.0,-72.0,16): lelec.append(LFPElectrode([370.0, y, 450.0], pc = pc)) - elif testLFP: - lelec.append(LFPElectrode([370.0, 1050.0, 450.0], pc = pc)) - lelec.append(LFPElectrode([370.0, 208.0, 450.0], pc = pc)) - return lelec - -lelec = setupLFPelectrodes() - -# All units for time: ms -def runsim (): - t0 = time.time() # clock start time - - pc.set_maxstep(10) # sets the default max solver step in ms (purposefully large) - - for elec in lelec: - elec.setup() - elec.LFPinit() - - h.finitialize() # initialize cells to -65 mV, after all the NetCon delays have been specified - if pcID == 0: - for tt in range(0,int(h.tstop),printdt): h.cvode.event(tt, prsimtime) # print time callbacks - - h.fcurrent() - h.frecord_init() # set state variables if they have been changed since h.finitialize - pc.psolve(h.tstop) # actual simulation - run the solver - pc.barrier() - - # these calls aggregate data across procs/nodes - pc.allreduce(dp_rec_L2, 1); - pc.allreduce(dp_rec_L5, 1) # combine dp_rec on every node, 1=add contributions together - for elec in lelec: elec.lfp_final() - net.aggregate_currents() # aggregate the currents independently on each proc - # combine net.current{} variables on each proc - pc.allreduce(net.current['L5Pyr_soma'], 1); pc.allreduce(net.current['L2Pyr_soma'], 1) - - pc.barrier() - - # write time and calculated dipole to data file only if on the first proc - # only execute this statement on one proc - savedat(p, pcID, t_vec, dp_rec_L2, dp_rec_L5, net) - - for elec in lelec: print('end; t_vec.size()',t_vec.size(),'elec.lfp_t.size()',elec.lfp_t.size()) - - if pcID == 0: - if debug: print("Simulation run time: %4.4f s" % (time.time()-t0)) - if debug: print("Simulation directory is: %s" % ddir.dsim) - if paramrw.find_param(doutf['file_param'],'save_spec_data') or usingOngoingInputs(doutf['file_param']): - runanalysis(p, doutf['file_param'], doutf['file_dpl_norm'], doutf['file_spec']) # run spectral analysis - if paramrw.find_param(doutf['file_param'],'save_figs'): - savefigs(ddir,p,p_exp) # save output figures - - pc.barrier() # make sure all done in case multiple trials - -def excepthook(exc_type, exc_value, exc_tb): - traceback.print_exception(exc_type, exc_value, exc_tb, file=sys.stdout, chain=False) - traceback.print_exception(exc_type, exc_value, exc_tb, file=sys.stderr, chain=False) - pc.runworker() - pc.done() - exit(-1) - -if __name__ == "__main__": - sys.excepthook = excepthook - if dconf['dorun']: - if ntrial > 1: runtrials(ntrial,p['inc_evinput']) - else: runsim() - pc.runworker() - pc.done() - if dconf['doquit']: h.quit() diff --git a/scripts/run-pytest.sh b/scripts/run-pytest.sh new file mode 100755 index 000000000..bc49d6fea --- /dev/null +++ b/scripts/run-pytest.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +if [[ "${WSL_INSTALL}" -eq 1 ]]; then + export PATH="$PATH:$HOME/.local/bin" + # unable to get vcxsrv to work (DLL loading problems) + unset DISPLAY +fi + +# first check code style with flake8 +echo "Checking code style compliance with flake8..." +flake8 --count --exclude __init__.py,qt_evoked.py +echo "Running unit tests with pytest..." +py.test . --cov=hnn hnn/tests/ # --cov-report=xml \ No newline at end of file diff --git a/scripts/run-travis-wsl.sh b/scripts/run-travis-wsl.sh deleted file mode 100755 index 9fdca1133..000000000 --- a/scripts/run-travis-wsl.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -export TRAVIS_TESTING=1 -export PATH="$PATH:$HOME/.local/bin" -export OMPI_MCA_btl_vader_single_copy_mechanism=none - -echo "Testing GUI on WSL..." -cd $DIR/../ - -export DISPLAY=:0 -python3 hnn.py - -echo "Testing MPI in WSL..." -mpiexec -np 2 nrniv -mpi -python run.py \ No newline at end of file diff --git a/scripts/setup-travis-linux.sh b/scripts/setup-travis-linux.sh new file mode 100755 index 000000000..222c7bacc --- /dev/null +++ b/scripts/setup-travis-linux.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +export PATH=/usr/bin:/usr/local/bin:$PATH + +echo "Starting fake Xserver" +Xvfb $DISPLAY -listen tcp -screen 0 1024x768x24 > /dev/null & + +echo "Starting Ubuntu install script" +installer/ubuntu/hnn-ubuntu.sh + +# test X server +xset -display $DISPLAY -q > /dev/null; diff --git a/scripts/setup-travis-mac.sh b/scripts/setup-travis-mac.sh deleted file mode 100755 index 226b12471..000000000 --- a/scripts/setup-travis-mac.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash -set -e - -export TRAVIS_TESTING=1 - -source scripts/utils.sh - -URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh" -FILENAME="$HOME/miniconda.sh" -start_download "$FILENAME" "$URL" - -echo "Installing miniconda..." -chmod +x "$HOME/miniconda.sh" -"$HOME/miniconda.sh" -b -p "${HOME}/Miniconda3" -export PATH=${HOME}/Miniconda3/bin:$PATH - -# create conda environment -conda create -n hnn --yes python=${PYTHON_VERSION} pip openmpi scipy numpy matplotlib pyqtgraph pyopengl psutil -source activate hnn && echo "activated conda HNN environment" - -# conda is faster to install nlopt -conda install -y -n hnn -c conda-forge nlopt - -pip install NEURON flake8 pytest pytest-cov coverage coveralls mne - -echo "Install finished" diff --git a/scripts/setup-travis-osx.sh b/scripts/setup-travis-osx.sh new file mode 100755 index 000000000..fc3b123b5 --- /dev/null +++ b/scripts/setup-travis-osx.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# we use start_download from utils.sh +source "$DIR/utils.sh" + +URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh" +FILENAME="$HOME/miniconda.sh" +start_download "$FILENAME" "$URL" + +echo "Installing miniconda..." +bash "$HOME/miniconda.sh" -b -p "${HOME}/Miniconda3" +source "$HOME/Miniconda3/etc/profile.d/conda.sh" + +# create conda environment +conda env create -f environment.yml +conda install -y -n hnn openmpi mpi4py +# conda is faster to install nlopt +conda install -y -n hnn -c conda-forge nlopt + +conda activate hnn +pip install hnn-core pyqt5 \ No newline at end of file diff --git a/scripts/setup-travis-windows.sh b/scripts/setup-travis-windows.sh index 5508fb027..d85a83612 100755 --- a/scripts/setup-travis-windows.sh +++ b/scripts/setup-travis-windows.sh @@ -1,20 +1,13 @@ #!/bin/bash set -e -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" - -# we use wait_for_pid and start_download from utils.sh -source "$DIR/utils.sh" - -[[ $LOGFILE ]] || LOGFILE="hnn_travis.log" - -echo "Installing Ubuntu WSL..." -powershell.exe -ExecutionPolicy Bypass -File ./scripts/setup-travis-wsl.ps1 & -WSL_PID=$! +echo "Installing Microsoft MPI" +powershell -command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/microsoft/Microsoft-MPI/releases/download/v10.1.1/msmpisetup.exe', 'msmpisetup.exe')" && \ + ./msmpisetup.exe -unattend && \ + rm -f msmpisetup.exe -# prepare for installing msys2 -[[ ! -f C:/tools/msys64/msys2_shell.cmd ]] && rm -rf C:/tools/msys64 -choco uninstall -y mingw +echo "Running HNN Windows install script..." +powershell.exe -ExecutionPolicy Bypass -File ./installer/windows/hnn-windows.ps1 # enable windows remoting service to log in as a different user to run tests powershell -Command 'Start-Service -Name WinRM' > /dev/null @@ -44,43 +37,3 @@ else echo "No user home directory created at $TEST_USER_DIR" exit 2 fi - -echo "Installing Microsoft MPI" -powershell -command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/microsoft/Microsoft-MPI/releases/download/v10.1.1/msmpisetup.exe', 'msmpisetup.exe')" && \ - ./msmpisetup.exe -unattend && \ - rm -f msmpisetup.exe - -echo "Running HNN Windows install script..." -powershell.exe -ExecutionPolicy Bypass -File ./installer/windows/hnn-windows.ps1 -# add miniconda python to the path -export PATH=$PATH:$HOME/Miniconda3/Scripts -export PATH=$HOME/Miniconda3/envs/hnn/:$PATH -export PATH=$HOME/Miniconda3/envs/hnn/Scripts:$PATH -export PATH=$HOME/Miniconda3/envs/hnn/Library/bin:$PATH - -echo "Installing msys2 with choco..." -choco upgrade --no-progress -y msys2 &> /dev/null - -echo "Downloading VcXsrv..." -URL="https://downloads.sourceforge.net/project/vcxsrv/vcxsrv/1.20.8.1/vcxsrv-64.1.20.8.1.installer.exe" -FILENAME="$HOME/vcxsrv-64.1.20.8.1.installer.exe" -start_download "$FILENAME" "$URL" > /dev/null - -echo "Installing VcXsrv..." -cmd //c "$HOME/vcxsrv-64.1.20.8.1.installer.exe /S" - -# get opengl32.dll from mesa -# this is needed to be able to start vcxsrv -export msys2='cmd //C RefreshEnv.cmd ' -export msys2+='& set MSYS=winsymlinks:nativestrict ' -export msys2+='& C:\\tools\\msys64\\msys2_shell.cmd -defterm -no-start' -export mingw64="$msys2 -mingw64 -full-path -here -c "\"\$@"\" --" -export msys2+=" -msys2 -c "\"\$@"\" --" -$msys2 pacman --sync --noconfirm --needed mingw-w64-x86_64-mesa - -echo "Downloading python test packages..." -pip download flake8 pytest pytest-cov coverage coveralls mne - -echo "Waiting for WSL install to finish..." -NAME="installing WSL" -wait_for_pid "${WSL_PID}" "$NAME" || script_fail diff --git a/scripts/setup-travis-wsl.ps1 b/scripts/setup-travis-wsl.ps1 index cd0d4a2f7..790751808 100644 --- a/scripts/setup-travis-wsl.ps1 +++ b/scripts/setup-travis-wsl.ps1 @@ -2,7 +2,7 @@ $ErrorActionPreference = "Stop" Set-Location C:\Users\travis Enable-WindowsOptionalFeature -Online -FeatureName Microsoft-Windows-Subsystem-Linux -Write-Host "Downloading Ubuntu image..." +Write-Host "Downloading Ubuntu WSL image..." Invoke-WebRequest -Uri https://aka.ms/wsl-ubuntu-1804 -OutFile Ubuntu.appx -UseBasicParsing Write-Host "Finished downloading Ubuntu image. Extracting..." @@ -13,20 +13,20 @@ $userenv = [System.Environment]::GetEnvironmentVariable("Path", "User"); [System Write-Host "Configuring Ubuntu WSL..." & .\Ubuntu\Ubuntu1804.exe install --root -# add hnn_user +# This creates "hnn_user" which will be the default user in WSL & wsl -- bash -ec "groupadd hnn_group && useradd -m -b /home/ -g hnn_group hnn_user && adduser hnn_user sudo && echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && chsh -s /bin/bash hnn_user" -# copy hnn dir to hnn_user homedir and change permissions +# Copy hnn source (from Travis clone) to hnn_user homedir and change permissions & wsl -- bash -ec "cp -r build/jonescompneurolab/hnn /home/hnn_user/ && chown -R hnn_user: /home/hnn_user && apt-get update && apt-get install -y dos2unix" -# run future commands as hnn_user +# Now all future commands can be run as hnn_user & .\Ubuntu\Ubuntu1804.exe config --default-user hnn_user # remove windows newlines & wsl -- bash -ec "dos2unix /home/hnn_user/hnn/scripts/* /home/hnn_user/hnn/installer/ubuntu/hnn-ubuntu.sh /home/hnn_user/hnn/installer/docker/hnn_envs" Write-Host "Installing HNN in Ubuntu WSL..." -& wsl -- bash -ec "cd /home/hnn_user/hnn && source scripts/utils.sh && export LOGFILE=ubuntu_install.log && TRAVIS_TESTING=1 installer/ubuntu/hnn-ubuntu.sh || script_fail" +& wsl -- bash -ec "cd /home/hnn_user/hnn && source scripts/utils.sh && export LOGFILE=ubuntu_install.log && installer/ubuntu/hnn-ubuntu.sh || script_fail" if (!$?) { exit 1 diff --git a/scripts/setup-travis-wsl.sh b/scripts/setup-travis-wsl.sh new file mode 100755 index 000000000..b3de2b8f8 --- /dev/null +++ b/scripts/setup-travis-wsl.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# we use wait_for_pid and start_download from utils.sh +source "$DIR/utils.sh" + +echo "Installing Ubuntu WSL..." +powershell.exe -ExecutionPolicy Bypass -File ./scripts/setup-travis-wsl.ps1 & +WSL_PID=$! + +# Note: unable to get VcXServe to start, always missing some dll. First libcrypto, then +# api-ms-win-core-delayload-l1-1-0.dll + +echo "Downloading VcXsrv..." +URL="https://downloads.sourceforge.net/project/vcxsrv/vcxsrv/1.20.8.1/vcxsrv-64.1.20.8.1.installer.exe" +FILENAME="$HOME/vcxsrv-64.1.20.8.1.installer.exe" +start_download "$FILENAME" "$URL" > /dev/null + +echo "Installing VcXsrv..." +cmd //c "$HOME/vcxsrv-64.1.20.8.1.installer.exe /S" + +echo "Starting VcXsrv..." +# note: do not try messing with quotes and escape characters here. you will +# regret it and the time wasted cannot be regained +cmd //c "C:\\PROGRA~1\\VcXsrv\vcxsrv.exe -wgl -multiwindow" + +echo "Waiting for WSL install to finish..." +NAME="installing WSL" +wait_for_pid "${WSL_PID}" "$NAME" || script_fail \ No newline at end of file diff --git a/seg3d.py b/seg3d.py deleted file mode 100644 index 9a2549db9..000000000 --- a/seg3d.py +++ /dev/null @@ -1,91 +0,0 @@ -# from pt3d of section, calculate the 3d location of segment centers, or -# a list of 3d points for each segment where the first and last 3-d points -# are the segment boundaries -# written by mh - -from neuron import h - -def segment_centers_3d (sec): - ''' return list of x, y, z tuples for centers of non-zero area segments ''' - # vector of arc values of segment centers - arcseg = h.Vector([seg.x for seg in sec]) - return interpolate(pt3d(sec), arcseg) - -def segment_points_3d (sec): - ''' return list of nseg length list of x,y,z tuples for segments - first and last xyz tuple are the ends of the segment ''' - # vector of arc values of segment edges (nseg+1) - arcseg = h.Vector(sec.nseg + 1).indgen().div(sec.nseg) - n = int(sec.n3d()) - axyz = pt3d(sec) - xyz = interpolate(axyz, arcseg) - # organize into [[(x,y,z)]*nseg] but remove end identical points in [(x,y,z)] - ret = [] - j = 0 # index into axyz vectors; - for iseg in range(int(sec.nseg)): - segitem = [] - a1 = arcseg[iseg] # proximal edge - a2 = arcseg[iseg + 1] # distal edge - segitem.append(xyz[iseg]) - while j < n and axyz[0][j] < a2: # insert the ones > a1 and less than a2 - a, pt = axyz[0][j], (axyz[1][j], axyz[2][j], axyz[3][j]) - if a > a1: - segitem.append(pt) - j += 1 - segitem.append(xyz[iseg + 1]) - ret.append(segitem) - return ret - -def pt3d (sec): - ''' return list of arc, x, y, z vectors from pt3d info of sec ''' - n = int(sec.n3d()) - #list of 3d point vectors - sec.push() - axyz = [h.Vector([f(i) for i in range(n)]) for f in [h.arc3d, h.x3d, h.y3d, h.z3d]] - h.pop_section() - axyz[0].div(sec.L) - return axyz - -def interpolate (axyz, arcvec): - ''' return list of x, y, z tuples at the arcvec locations ''' - #interpolate onto arcvec - xyz = [v.c().interpolate(arcvec, axyz[0]) for v in axyz[1:]] - return [(xyz[0][i],xyz[1][i],xyz[2][i]) for i in range(len(xyz[0]))] - -def drawsec (sec): - # draw original 3d points (x,y values). Not using Shape because of origin issues - g = h.Graph(0) - g.view(2) - n = int(sec.n3d()) - g.beginline(1, 4) - for i in range(n): - g.line(sec.x3d(i), sec.y3d(i)) - return g - -def test_segment_centers (sec, g): - xyz = segment_centers_3d(sec) - for x in xyz: - #print (x) - g.mark(x[0], x[1], 'O', 10, 2, 1) - g.exec_menu("View = plot") - -def test_segment_points (sec, g): - xyzsegs = segment_points_3d(sec) - for iseg, xyzseg in enumerate(xyzsegs): - color = 4 + iseg%2 #blue, green - g.beginline(color, 1) - for x,y,z in xyzseg: - g.line(x, y) - -if __name__ == '__main__': - # load pyramidal of neurondemo - from neuron import gui - h.load_file(h.neuronhome() + '/demo/pyramid.nrn') - s = h.dendrite_1[8] #proximal apical - s.nseg = 5 - #h.load_file(h.neuronhome() + '/demo/pyramid.ses') - - g = drawsec(s) - test_segment_points(s, g) - test_segment_centers(s, g) - diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..59d909c2f --- /dev/null +++ b/setup.py @@ -0,0 +1,41 @@ +#! /usr/bin/env python + +from setuptools import setup, find_packages + +descr = """Human Neocortical Neurosolver""" + +DISTNAME = 'hnn' +DESCRIPTION = descr +MAINTAINER = 'Blake Caldwell' +MAINTAINER_EMAIL = 'blake_caldwell@brown.edu' +URL = '' +LICENSE = 'Brown CS License' +DOWNLOAD_URL = 'http://github.com/jonescompneurolab/hnn' +VERSION = '1.4.0' + +if __name__ == "__main__": + setup(name=DISTNAME, + maintainer=MAINTAINER, + maintainer_email=MAINTAINER_EMAIL, + description=DESCRIPTION, + license=LICENSE, + url=URL, + version=VERSION, + download_url=DOWNLOAD_URL, + long_description=open('README.md').read(), + classifiers=[ + 'Intended Audience :: Science/Research', + 'License :: OSI Approved', + 'Programming Language :: Python', + 'Topic :: Scientific/Engineering', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX', + 'Operating System :: Unix', + 'Operating System :: MacOS', + ], + platforms='any', + packages=find_packages(), + package_data={'hnn': + ['../param/*.param']}, + install_requires=['hnn-core'] + ) diff --git a/simdat.py b/simdat.py deleted file mode 100644 index c8fccdd3f..000000000 --- a/simdat.py +++ /dev/null @@ -1,760 +0,0 @@ -import os -from PyQt5.QtWidgets import QMenu, QSizePolicy -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.figure import Figure -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -import matplotlib.gridspec as gridspec -import numpy as np -from math import ceil -from conf import dconf -import conf -import spikefn -from paramrw import usingOngoingInputs, usingEvokedInputs, usingPoissonInputs, usingTonicInputs, find_param, quickgetprm, countEvokedInputs, ExpParams -from scipy import signal -from gutils import getscreengeom - -# dconf has settings from hnn.cfg -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: plt.rcParams['font.size'] = dconf['fontsize'] = 10 - -debug = dconf['debug'] - -ddat = {} # current simulation data -dfile = {} # data file information for current simulation - -lsimdat = [] # list of simulation data -lsimidx = 0 # index into lsimdat - -initial_ddat = {} -optdat = [] # list of optimization data - -def updatelsimdat(paramf,dpl): - # update lsimdat with paramf and dipole dpl - # but if the specific sim already run put dipole at that location in list - global lsimdat,lsimidx - # while len(lsimdat)>0 and lsimidx!=len(lsimdat)-1: lsimdat.pop() # redos popped - found = False - for i,l in enumerate(lsimdat): - if l[0] == paramf: - lsimdat[i][1] = dpl - found = True - break - if not found: lsimdat.append([paramf,dpl]) # if not found, append to end of the list - lsimidx = len(lsimdat) - 1 # current simulation index - - -def updateoptdat(paramf,dpl): - global optdat - - optdat.append([paramf,dpl]) - -def rmse (a1, a2): - # return root mean squared error between a1, a2; assumes same lengths, sampling rates - len1,len2 = len(a1),len(a2) - sz = min(len1,len2) - if debug: print('len1:',len1,'len2:',len2,'ty1:',type(a1),'ty2:',type(a2)) - return np.sqrt(((a1[0:sz] - a2[0:sz]) ** 2).mean()) - -def readdpltrials (basedir,ntrial): - # read dipole data files for individual trials - if debug: print('in readdpltrials',basedir,ntrial) - ldpl = [] - for i in range(ntrial): - fn = os.path.join(basedir,'dpl_'+str(i)+'.txt') - if not os.path.exists(fn): break - ldpl.append(readtxt(fn)) - if debug: print('loaded ', fn) - - if len(ldpl) < ntrial and ntrial > 1: - print("Warning: only read data for %d trials" % len(ldpl)) - - return ldpl - -def getinputfiles (paramf): - # get a dictionary of input files based on simulation parameter file paramf - global dfile,basedir - dfile = {} - basedir = os.path.join(dconf['datdir'],paramf.split(os.path.sep)[-1].split('.param')[0]) - # print('basedir:',basedir) - dfile['dpl'] = os.path.join(basedir,'dpl.txt') - dfile['spec'] = os.path.join(basedir,'rawspec.npz') - dfile['spk'] = os.path.join(basedir,'spk.txt') - dfile['outparam'] = os.path.join(basedir,'param.txt') - return dfile - -def readtxt (fn, silent=False): - contents = [] - - try: - contents = np.loadtxt(fn) - except OSError: - if not silent: - print('Warning: could not read file:', fn) - except ValueError: - if not silent: - print('Warning: error reading data from:', fn) - - return contents - -def updatedat (paramf): - # update data dictionary (ddat) from the param file - - global basedir - if debug: print('paramf:',paramf) - getinputfiles(paramf) - - for k in ['dpl','spk']: - if k in ddat: - del ddat[k] - silent = not os.path.exists(basedir) - ddat[k] = readtxt(dfile[k], silent) - if len(ddat[k]) == 0: - del ddat[k] - - if not 'dpl' in ddat or not 'spk' in ddat: - raise ValueError - - ddat['dpltrials'] = readdpltrials(basedir,quickgetprm(paramf,'N_trials',int)) - - if os.path.isfile(dfile['spec']): - ddat['spec'] = np.load(dfile['spec']) - else: - ddat['spec'] = None - -def getscalefctr (paramf): - # get dipole scaling factor parameter value from paramf file - try: - xx = quickgetprm(paramf,'dipole_scalefctr',float) - if type(xx) == float: return xx - except: - pass - if 'dipole_scalefctr' in dconf: - return dconf['dipole_scalefctr'] - return 30e3 - -def drawraster (): - # draw raster to standalone matplotlib figure - for debugging (not used in main HNN GUI) - if 'spk' in ddat: - # print('spk shape:',ddat['spk'].shape) - plt.ion() - plt.figure() - for pair in ddat['spk']: - plt.plot([pair[0]],[pair[1]],'ko',markersize=10) - plt.xlabel('Time (ms)',fontsize=dconf['fontsize']); plt.ylabel('ID',fontsize=dconf['fontsize']) - -def calcerr (ddat, tstop, tstart=0.0): - # calculates RMSE error from ddat dictionary - NSig = errtot = 0.0; lerr = [] - ddat['errtot']=None; ddat['lerr']=None - for fn,dat in ddat['dextdata'].items(): - shp = dat.shape - - exp_times = dat[:,0] - sim_times = ddat['dpl'][:,0] - - # do tstart and tstop fall within both datasets? - # if not, use the closest data point as the new tstop/tstart - for tseries in [exp_times, sim_times]: - if tstart < tseries[0]: - tstart = tseries[0] - if tstop > tseries[-1]: - tstop = tseries[-1] - - # make sure start and end times are valid for both dipoles - exp_start_index = (np.abs(exp_times - tstart)).argmin() - exp_end_index = (np.abs(exp_times - tstop)).argmin() - exp_length = exp_end_index - exp_start_index - - sim_start_index = (np.abs(sim_times - tstart)).argmin() - sim_end_index = (np.abs(sim_times - tstop)).argmin() - sim_length = sim_end_index - sim_start_index - - for c in range(1,shp[1],1): - dpl1 = ddat['dpl'][sim_start_index:sim_end_index,1] - dpl2 = dat[exp_start_index:exp_end_index,c] - - if (sim_length > exp_length): - # downsample simulation timeseries to match exp data - dpl1 = signal.resample(dpl1, exp_length) - elif (sim_length < exp_length): - # downsample exp timeseries to match simulation data - dpl2 = signal.resample(dpl2, sim_length) - err0 = np.sqrt(((dpl1 - dpl2) ** 2).mean()) - lerr.append(err0) - errtot += err0 - #print('RMSE: ',err0) - NSig += 1 - if not NSig == 0.0: - errtot /= NSig - #print('Avg. RMSE:' + str(round(errtot,2))) - ddat['errtot'] = errtot - ddat['lerr'] = lerr - return lerr, errtot - -def weighted_rmse(ddat, tstop, weights, tstart=0.0): - from numpy import sqrt - from scipy import signal - - # calculates RMSE error from ddat dictionary - NSig = errtot = 0.0; lerr = [] - ddat['werrtot']=None; ddat['lerr']=None - for fn,dat in ddat['dextdata'].items(): - shp = dat.shape - exp_times = dat[:,0] - sim_times = ddat['dpl'][:,0] - - # do tstart and tstop fall within both datasets? - # if not, use the closest data point as the new tstop/tstart - for tseries in [exp_times, sim_times]: - if tstart < tseries[0]: - tstart = tseries[0] - if tstop > tseries[-1]: - tstop = tseries[-1] - - # make sure start and end times are valid for both dipoles - exp_start_index = (np.abs(exp_times - tstart)).argmin() - exp_end_index = (np.abs(exp_times - tstop)).argmin() - exp_length = exp_end_index - exp_start_index - - sim_start_index = (np.abs(sim_times - tstart)).argmin() - sim_end_index = (np.abs(sim_times - tstop)).argmin() - sim_length = sim_end_index - sim_start_index - - weight = weights[sim_start_index:sim_end_index] - - for c in range(1,shp[1],1): - dpl1 = ddat['dpl'][sim_start_index:sim_end_index,1] - dpl2 = dat[exp_start_index:exp_end_index,c] - - if (sim_length > exp_length): - # downsample simulation timeseries to match exp data - dpl1 = signal.resample(dpl1, exp_length) - weight = signal.resample(weight, exp_length) - indices = np.where(weight < 1e-4) - weight[indices] = 0 - elif (sim_length < exp_length): - # downsample exp timeseries to match simulation data - dpl2 = signal.resample(dpl2, sim_length) - - err0 = np.sqrt((weight * ((dpl1 - dpl2) ** 2)).sum()/weight.sum()) - lerr.append(err0) - errtot += err0 - #print('RMSE: ',err0) - NSig += 1 - errtot /= NSig - #print('Avg. RMSE:' + str(round(errtot,2))) - ddat['werrtot'] = errtot - ddat['wlerr'] = lerr - return lerr, errtot - - -class SIMCanvas (FigureCanvas): - # matplotlib/pyqt-compatible canvas for drawing simulation & external data - # based on https://pythonspot.com/en/pyqt5-matplotlib/ - - def __init__ (self, paramf, parent=None, width=5, height=4, dpi=40, optMode=False, title='Simulation Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - - self.title = title - self.lextdatobj = [] # external data object - self.clridx = 5 # index for next color for drawing external data - self.lpatch = [mpatches.Patch(color='black', label='Sim.')] # legend for dipole signals - self.setParent(parent) - self.gui = parent - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.initaxes() - self.G = gridspec.GridSpec(10,1) - - global initial_ddat, optdat - self.optMode = optMode - if not optMode: - initial_ddat = {} - optdat = [] - self.plot() - - def initaxes (self): - # initialize the axes - self.axdist = self.axprox = self.axdipole = self.axspec = self.axpois = None - - def plotinputhist (self, xl, dinty): - """ plot input histograms - xl = x axis limits - dinty = dict of input types used, determines how many/which axes created/displayed - """ - - extinputs = None - plot_distribs = False - - sim_tstop = quickgetprm(self.paramf,'tstop',float) - sim_dt = quickgetprm(self.paramf,'dt',float) - num_step = ceil(sim_tstop / sim_dt) + 1 - times = np.linspace(0, sim_tstop, num_step) - - try: - extinputs = spikefn.ExtInputs(dfile['spk'], dfile['outparam']) - extinputs.add_delay_times() - dinput = extinputs.inputs - except ValueError: - dinput = self.getInputDistrib() - plot_distribs = True - - if len(dinput['dist']) <= 0 and len(dinput['prox']) <= 0 and \ - len(dinput['evdist']) <= 0 and len(dinput['evprox']) <= 0 and \ - len(dinput['pois']) <= 0: - if debug: print('all hists 0!') - return False - - self.hist=hist={x:None for x in ['feed_dist','feed_prox','feed_evdist','feed_evprox','feed_pois']} - - hasPois = len(dinput['pois']) > 0 and dinty['Poisson'] # this ensures synaptic weight > 0 - - gRow = 0 - self.axdist = self.axprox = self.axpois = None # axis objects - - # check poisson inputs, create subplot - if hasPois: - self.axpois = self.figure.add_subplot(self.G[gRow,0]) - gRow += 1 - - # check distal inputs, create subplot - if (len(dinput['dist']) > 0 and dinty['OngoingDist']) or \ - (len(dinput['evdist']) > 0 and dinty['EvokedDist']): - self.axdist = self.figure.add_subplot(self.G[gRow,0]) - gRow+=1 - - # check proximal inputs, create subplot - if (len(dinput['prox']) > 0 and dinty['OngoingProx']) or \ - (len(dinput['evprox']) > 0 and dinty['EvokedProx']): - self.axprox = self.figure.add_subplot(self.G[gRow,0]) - gRow+=1 - - - # check input types provided in simulation - if extinputs is not None and self.hassimdata(): # only valid param.txt file after sim was run - if debug: - print(len(dinput['dist']),len(dinput['prox']),len(dinput['evdist']),len(dinput['evprox']),len(dinput['pois'])) - - if hasPois: # any Poisson inputs? - extinputs.plot_hist(self.axpois,'pois',times,'auto',xl,color='k',hty='step',lw=self.gui.linewidth+1) - - if len(dinput['dist']) > 0 and dinty['OngoingDist']: # dinty condition ensures synaptic weight > 0 - extinputs.plot_hist(self.axdist,'dist',times,'auto',xl,color='g',lw=self.gui.linewidth+1) - - if len(dinput['prox']) > 0 and dinty['OngoingProx']: # dinty condition ensures synaptic weight > 0 - extinputs.plot_hist(self.axprox,'prox',times,'auto',xl,color='r',lw=self.gui.linewidth+1) - - if len(dinput['evdist']) > 0 and dinty['EvokedDist']: # dinty condition ensures synaptic weight > 0 - extinputs.plot_hist(self.axdist,'evdist',times,'auto',xl,color='g',hty='step',lw=self.gui.linewidth+1) - - if len(dinput['evprox']) > 0 and dinty['EvokedProx']: # dinty condition ensures synaptic weight > 0 - extinputs.plot_hist(self.axprox,'evprox',times,'auto',xl,color='r',hty='step',lw=self.gui.linewidth+1) - elif plot_distribs: - if len(dinput['evprox']) > 0 and dinty['EvokedProx']: # dinty condition ensures synaptic weight > 0 - prox_tot = np.zeros(len(dinput['evprox'][0][0])) - for prox in dinput['evprox']: - prox_tot += prox[1] - plot = self.axprox.plot(dinput['evprox'][0][0],prox_tot,color='r',lw=self.gui.linewidth,label='evprox distribution') - self.axprox.set_xlim(dinput['evprox'][0][0][0],dinput['evprox'][0][0][-1]) - if len(dinput['evdist']) > 0 and dinty['EvokedDist']: # dinty condition ensures synaptic weight > 0 - dist_tot = np.zeros(len(dinput['evdist'][0][0])) - for dist in dinput['evdist']: - dist_tot += dist[1] - plot = self.axdist.plot(dinput['evdist'][0][0],dist_tot,color='g',lw=self.gui.linewidth,label='evdist distribution') - self.axprox.set_xlim(dinput['evdist'][0][0][0],dinput['evdist'][0][0][-1]) - - ymax = 0 - for ax in [self.axpois, self.axdist, self.axprox]: - if not ax is None: - if ax.get_ylim()[1] > ymax: - ymax = ax.get_ylim()[1] - - if ymax == 0: - if debug: print('all hists None!') - return False - else: - for ax in [self.axpois, self.axdist, self.axprox]: - if not ax is None: - ax.set_ylim(0,ymax) - if self.axdist: - self.axdist.invert_yaxis() - for ax in [self.axpois,self.axdist,self.axprox]: - if ax: - ax.set_xlim(xl) - ax.legend(loc=1) # legend in upper right - return True,gRow - - def clearaxes (self): - # clear the figures axes - for ax in self.figure.get_axes(): - if ax: - ax.cla() - - - def getNTrials (self): - # get the number of trials - N_trials = 1 - try: - xx = quickgetprm(self.paramf,'N_trials',int) - if type(xx) == int: N_trials = xx - except: - pass - return N_trials - - def getNPyr (self): - # get the number of pyramidal neurons used in the simulation - try: - x = quickgetprm(self.paramf,'N_pyr_x',int) - y = quickgetprm(self.paramf,'N_pyr_y',int) - if type(x)==int and type(y)==int: - return int(x * y * 2) - except: - return 0 - - def getInputDistrib (self): - import scipy.stats as stats - - dinput = {'evprox': [], 'evdist': [], 'prox': [], 'dist': [], 'pois': []} - try: - sim_tstop = quickgetprm(self.paramf,'tstop',float) - sim_dt = quickgetprm(self.paramf,'dt',float) - except FileNotFoundError: - return dinput - - num_step = ceil(sim_tstop / sim_dt) + 1 - times = np.linspace(0, sim_tstop, num_step) - ltprox, ltdist = self.getEVInputTimes() - for prox in ltprox: - pdf = stats.norm.pdf(times, prox[0], prox[1]) - dinput['evprox'].append((times,pdf)) - for dist in ltdist: - pdf = stats.norm.pdf(times, dist[0], dist[1]) - dinput['evdist'].append((times,pdf)) - return dinput - - def getEVInputTimes (self): - # get the evoked input times - nprox, ndist = countEvokedInputs(self.paramf) - ltprox, ltdist = [], [] - for i in range(nprox): - input_mu = quickgetprm(self.paramf,'t_evprox_' + str(i+1), float) - input_sigma = quickgetprm(self.paramf,'sigma_t_evprox_' + str(i+1), float) - ltprox.append((input_mu, input_sigma)) - for i in range(ndist): - input_mu = quickgetprm(self.paramf,'t_evdist_' + str(i+1), float) - input_sigma = quickgetprm(self.paramf,'sigma_t_evdist_' + str(i+1), float) - ltdist.append((input_mu, input_sigma)) - return ltprox, ltdist - - def drawEVInputTimes (self, ax, yl, h=0.1, hw=15, hl=15): - # draw the evoked input times using arrows - ltprox, ltdist = self.getEVInputTimes() - yrange = abs(yl[1] - yl[0]) - #print('drawEVInputTimes:',yl,yrange,h,hw,hl,h*yrange,-h*yrange,yl[0]+h*yrange,yl[1]-h*yrange) - for tt in ltprox: ax.arrow(tt[0],yl[0],0,h*yrange,fc='r',ec='r', head_width=hw,head_length=hl)#head_length=w,head_width=1.)#w/4)#length_includes_head=True, - for tt in ltdist: ax.arrow(tt[0],yl[1],0,-h*yrange,fc='g',ec='g',head_width=hw,head_length=hl)#head_length=w,head_width=1.)#w/4) - - def getInputs (self): - """ get a dictionary of input types used in simulation - with distal/proximal specificity for evoked,ongoing inputs - """ - - dinty = {'Evoked':False,'Ongoing':False,'Poisson':False,'Tonic':False,'EvokedDist':False,\ - 'EvokedProx':False,'OngoingDist':False,'OngoingProx':False} - - try: - dinty['Evoked'] = usingEvokedInputs(self.paramf) - dinty['EvokedDist'] = usingEvokedInputs(self.paramf, lsuffty = ['_evdist_']) - dinty['EvokedProx'] = usingEvokedInputs(self.paramf, lsuffty = ['_evprox_']) - dinty['Ongoing'] = usingOngoingInputs(self.paramf) - dinty['OngoingDist'] = usingOngoingInputs(self.paramf, lty = ['_dist']) - dinty['OngoingProx'] = usingOngoingInputs(self.paramf, lty = ['_prox']) - dinty['Poisson'] = usingPoissonInputs(self.paramf) - dinty['Tonic'] = usingTonicInputs(self.paramf) - except FileNotFoundError: - pass - - return dinty - - def getnextcolor (self): - # get next color for external data (colors selected in order) - self.clridx += 5 - if self.clridx > 100: self.clridx = 5 - return self.clridx - - def plotextdat (self, recalcErr=True): - if not 'dextdata' in ddat or len(ddat['dextdata']) == 0: - return - - lerr = None - errtot = None - initial_err = None - # plot 'external' data (e.g. from experiment/other simulation) - hassimdata = self.hassimdata() # has the simulation been run yet? - if hassimdata: - if recalcErr: - calcerr(ddat, ddat['dpl'][-1,0]) # recalculate/save the error? - - try: - lerr, errtot = ddat['lerr'], ddat['errtot'] - - if self.optMode: - initial_err = initial_ddat['errtot'] - except KeyError: - pass - - - if self.axdipole is None: - self.axdipole = self.figure.add_subplot(self.G[0:-1,0]) # dipole - xl = (0.0,1.0) - yl = (-0.001,0.001) - else: - xl = self.axdipole.get_xlim() - yl = self.axdipole.get_ylim() - - cmap=plt.get_cmap('nipy_spectral') - csm = plt.cm.ScalarMappable(cmap=cmap); - csm.set_clim((0,100)) - - self.clearlextdatobj() # clear annotation objects - - ddx = 0 - for fn,dat in ddat['dextdata'].items(): - shp = dat.shape - clr = csm.to_rgba(self.getnextcolor()) - c = min(shp[1],1) - self.lextdatobj.append(self.axdipole.plot(dat[:,0],dat[:,c],color=clr,linewidth=self.gui.linewidth+1)) - xl = ((min(xl[0],min(dat[:,0]))),(max(xl[1],max(dat[:,0])))) - yl = ((min(yl[0],min(dat[:,c]))),(max(yl[1],max(dat[:,c])))) - fx = int(shp[0] * float(c) / shp[1]) - if lerr: - tx,ty=dat[fx,0],dat[fx,c] - txt='RMSE: %.2f' % round(lerr[ddx],2) - if not self.optMode: - self.lextdatobj.append(self.axdipole.annotate(txt,xy=(dat[0,0],dat[0,c]),xytext=(tx,ty),color=clr,fontweight='bold')) - self.lpatch.append(mpatches.Patch(color=clr, label=fn.split(os.path.sep)[-1].split('.txt')[0])) - ddx+=1 - - self.axdipole.set_xlim(xl) - self.axdipole.set_ylim(yl) - - if self.lextdatobj and self.lpatch: - self.lextdatobj.append(self.axdipole.legend(handles=self.lpatch, loc=2)) - - if errtot: - tx,ty=0,0 - if self.optMode and initial_err: - clr = 'black' - txt='RMSE: %.2f' % round(initial_err,2) - self.annot_avg = self.axdipole.annotate(txt,xy=(0,0),xytext=(0.005,0.005),textcoords='axes fraction',color=clr,fontweight='bold') - clr = 'gray' - txt='RMSE: %.2f' % round(errtot,2) - self.annot_avg = self.axdipole.annotate(txt,xy=(0,0),xytext=(0.86,0.005),textcoords='axes fraction',color=clr,fontweight='bold') - else: - clr = 'black' - txt='Avg. RMSE: %.2f' % round(errtot,2) - self.annot_avg = self.axdipole.annotate(txt,xy=(0,0),xytext=(0.005,0.005),textcoords='axes fraction',color=clr,fontweight='bold') - - if not hassimdata: # need axis labels - self.axdipole.set_xlabel('Time (ms)',fontsize=dconf['fontsize']) - self.axdipole.set_ylabel('Dipole (nAm)',fontsize=dconf['fontsize']) - myxl = self.axdipole.get_xlim() - if myxl[0] < 0.0: self.axdipole.set_xlim((0.0,myxl[1]+myxl[0])) - - def hassimdata (self): - # check if any simulation data available in ddat dictionary - return 'dpl' in ddat - - def hasinitoptdata (self): - # check if any simulation data available in ddat dictionary - return 'dpl' in initial_ddat - - def clearlextdatobj (self): - # clear list of external data objects - for o in self.lextdatobj: - try: - o.set_visible(False) - except: - o[0].set_visible(False) - del self.lextdatobj - self.lextdatobj = [] # reset list of external data objects - self.lpatch = [] # reset legend - self.clridx = 5 # reset index for next color for drawing external data - - if self.optMode: - self.lpatch.append(mpatches.Patch(color='grey', label='Optimization')) - self.lpatch.append(mpatches.Patch(color='black', label='Initial')) - elif self.hassimdata(): - self.lpatch.append(mpatches.Patch(color='black', label='Simulation')) - if hasattr(self,'annot_avg'): - self.annot_avg.set_visible(False) - del self.annot_avg - - def plotsimdat (self): - # plot the simulation data - - self.gRow = 0 - bottom = 0.0 - - only_create_axes = False - if not os.path.isfile(self.paramf): - only_create_axes = True - DrawSpec = False - xl = (0.0, 1.0) - else: - # setup the figure axis for drawing the dipole signal - dinty = self.getInputs() - - # try loading data. ignore failures - try: - updatedat(self.paramf) - loaded_dat = True - except ValueError: - loaded_dat = False - pass - - xl = (0.0, quickgetprm(self.paramf,'tstop',float)) - if dinty['Ongoing'] or dinty['Evoked'] or dinty['Poisson']: - xo = self.plotinputhist(xl, dinty) - if xo: - self.gRow = xo[1] - - # whether to draw the specgram - should draw if user saved it or have ongoing, poisson, or tonic inputs - DrawSpec = loaded_dat and \ - 'spec' in ddat and \ - (find_param(dfile['outparam'],'save_spec_data') or dinty['Ongoing'] or dinty['Poisson'] or dinty['Tonic']) - - if DrawSpec: # dipole axis takes fewer rows if also drawing specgram - self.axdipole = self.figure.add_subplot(self.G[self.gRow:5,0]) # dipole - bottom = 0.08 - else: - self.axdipole = self.figure.add_subplot(self.G[self.gRow:-1,0]) # dipole - - yl = (-0.001,0.001) - self.axdipole.set_ylim(yl) - self.axdipole.set_xlim(xl) - - left = 0.08 - w,h=getscreengeom() - if w < 2800: left = 0.1 - self.figure.subplots_adjust(left=left,right=0.99,bottom=bottom,top=0.99,hspace=0.1,wspace=0.1) # reduce padding - - if only_create_axes: - return - - try: - updatedat(self.paramf) - except ValueError: - if 'dpl' not in ddat: - # failed to load dipole data, nothing more to plot - return - - ds = None - xl = (0,ddat['dpl'][-1,0]) - dt = ddat['dpl'][1,0] - ddat['dpl'][0,0] - - # get spectrogram if it exists, then adjust axis limits but only if drawing spectrogram - if DrawSpec: - if ddat['spec'] is not None: - ds = ddat['spec'] # spectrogram - xl = (ds['time'][0],ds['time'][-1]) # use specgram time limits - else: - DrawSpec = False - - sampr = 1e3/dt # dipole sampling rate - sidx, eidx = int(sampr*xl[0]/1e3), int(sampr*xl[1]/1e3) # use these indices to find dipole min,max - - N_trials = self.getNTrials() - if debug: print('simdat: N_trials:',N_trials) - - yl = [0,0] - yl[0] = min(yl[0],np.amin(ddat['dpl'][sidx:eidx,1])) - yl[1] = max(yl[1],np.amax(ddat['dpl'][sidx:eidx,1])) - - if not self.optMode: - # skip for optimization - for lsim in lsimdat: # plot average dipoles from prior simulations - olddpl = lsim[1] - if debug: print('olddpl has shape ',olddpl.shape,len(olddpl[:,0]),len(olddpl[:,1])) - self.axdipole.plot(olddpl[:,0],olddpl[:,1],'--',color='black',linewidth=self.gui.linewidth) - - if N_trials>1 and dconf['drawindivdpl'] and len(ddat['dpltrials']) > 0: # plot dipoles from individual trials - for dpltrial in ddat['dpltrials']: - self.axdipole.plot(dpltrial[:,0],dpltrial[:,1],color='gray',linewidth=self.gui.linewidth) - yl[0] = min(yl[0],dpltrial[sidx:eidx,1].min()) - yl[1] = max(yl[1],dpltrial[sidx:eidx,1].max()) - - if conf.dconf['drawavgdpl'] or N_trials <= 1: - # this is the average dipole (across trials) - # it's also the ONLY dipole when running a single trial - self.axdipole.plot(ddat['dpl'][:,0],ddat['dpl'][:,1],'k',linewidth=self.gui.linewidth+1) - yl[0] = min(yl[0],ddat['dpl'][sidx:eidx,1].min()) - yl[1] = max(yl[1],ddat['dpl'][sidx:eidx,1].max()) - else: - for idx, opt in enumerate(optdat): - optdpl = opt[1] - if idx == len(optdat) - 1: - # only show the last optimization - self.axdipole.plot(optdpl[:,0],optdpl[:,1],'k',color='gray',linewidth=self.gui.linewidth+1) - yl[0] = min(yl[0],optdpl[sidx:eidx,1].min()) - yl[1] = max(yl[1],optdpl[sidx:eidx,1].max()) - - if self.hasinitoptdata(): - # show initial dipole in dotted black line - self.axdipole.plot(initial_ddat['dpl'][:,0],initial_ddat['dpl'][:,1],'--',color='black',linewidth=self.gui.linewidth) - yl[0] = min(yl[0],initial_ddat['dpl'][sidx:eidx,1].min()) - yl[1] = max(yl[1],initial_ddat['dpl'][sidx:eidx,1].max()) - - scalefctr = getscalefctr(self.paramf) - NEstPyr = int(self.getNPyr() * scalefctr) - - if NEstPyr > 0: - self.axdipole.set_ylabel(r'Dipole (nAm $\times$ '+str(scalefctr)+')\nFrom Estimated '+str(NEstPyr)+' Cells',fontsize=dconf['fontsize']) - else: - self.axdipole.set_ylabel(r'Dipole (nAm $\times$ '+str(scalefctr)+')\n',fontsize=dconf['fontsize']) - self.axdipole.set_xlim(xl); self.axdipole.set_ylim(yl) - - if DrawSpec: # - if debug: print('ylim is : ', np.amin(ddat['dpl'][sidx:eidx,1]),np.amax(ddat['dpl'][sidx:eidx,1])) - - p_exp = ExpParams(self.paramf, debug=debug) - if len(p_exp.expmt_groups) > 0: - expmt_group = p_exp.expmt_groups[0] - else: - expmt_group = None - p = p_exp.return_pdict(expmt_group, 0) - - gRow = 6 - self.axspec = self.figure.add_subplot(self.G[gRow:10,0]); # specgram - cax = self.axspec.imshow(ds['TFR'],extent=(ds['time'][0],ds['time'][-1],ds['freq'][-1],ds['freq'][0]),aspect='auto',origin='upper',cmap=plt.get_cmap(p['spec_cmap'])) - self.axspec.set_ylabel('Frequency (Hz)',fontsize=dconf['fontsize']) - self.axspec.set_xlabel('Time (ms)',fontsize=dconf['fontsize']) - self.axspec.set_xlim(xl) - self.axspec.set_ylim(ds['freq'][-1],ds['freq'][0]) - cbaxes = self.figure.add_axes([0.6, 0.49, 0.3, 0.005]) - cb = plt.colorbar(cax, cax = cbaxes, orientation='horizontal') # horizontal to save space - else: - self.axdipole.set_xlabel('Time (ms)',fontsize=dconf['fontsize']) - - def plotarrows (self): - # run after scales have been updated - xl = self.axdipole.get_xlim() - yl = self.axdipole.get_ylim() - - dinty = self.getInputs() - if dinty['Evoked']: - self.drawEVInputTimes(self.axdipole,yl,0.1,(xl[1]-xl[0])*.02,(yl[1]-yl[0])*.02)#15.0) - - def plot (self, recalcErr=True): - self.clearaxes() - plt.close(self.figure) - self.figure.clf() - self.axdipole = None - - self.plotsimdat() # creates self.axdipole - self.plotextdat(recalcErr) - self.plotarrows() - - self.draw() diff --git a/specfn.py b/specfn.py deleted file mode 100644 index 9582b694a..000000000 --- a/specfn.py +++ /dev/null @@ -1,1130 +0,0 @@ -# specfn.py - Average time-frequency energy representation using Morlet wavelet method -# -# v 1.10.2-py35 -# rev 2017-02-21 (SL: fixed an issue with indexing) -# last major: (SL: more comments on the units of Morlet Spec) - -import os -import sys -import numpy as np -import scipy.signal as sps -import matplotlib.pyplot as plt -import paramrw -import fileio as fio -import multiprocessing as mp -from neuron import h as nrn - -import fileio as fio -import currentfn -import dipolefn -import spikefn -import axes_create as ac -from conf import dconf - -# MorletSpec class based on a time vec tvec and a time series vec tsvec -class MorletSpec(): - def __init__(self, tvec, tsvec, fparam, f_max=None, p_dict=None, tmin = 50.0, f_min = 1.): - # Save variable portion of fdata_spec as identifying attribute - # self.name = fdata_spec - - # Import dipole data and remove extra dimensions from signal array. - self.tvec = tvec - self.tsvec = tsvec - - # function is called this way because paramrw.read() returns 2 outputs - if p_dict is None: - self.p_dict = paramrw.read(fparam)[1] - else: - self.p_dict = p_dict - - self.f_min = f_min - - # maximum frequency of analysis - # Add 1 to ensure analysis is inclusive of maximum frequency - if not f_max: - self.f_max = self.p_dict['f_max_spec'] + 1 - else: - self.f_max = f_max + 1 - - # cutoff time in ms - self.tmin = tmin - - # truncate these vectors appropriately based on tmin - if self.p_dict['tstop'] > self.tmin: - # must be done in this order! timeseries first! - self.tsvec = self.tsvec[self.tvec >= self.tmin] - self.tvec = self.tvec[self.tvec >= self.tmin] - - # Check that tstop is greater than tmin - if self.p_dict['tstop'] > self.tmin: - # Array of frequencies over which to sort - self.f = np.arange(self.f_min, self.f_max) - - # Number of cycles in wavelet (>5 advisable) - self.width = 7. - - # Calculate sampling frequency - self.fs = 1000. / self.p_dict['dt'] - - # Generate Spec data - self.TFR = self.__traces2TFR() - - # Add time vector as first row of TFR data - # self.TFR = np.vstack([self.timevec, self.TFR]) - - else: - print("tstop not greater than %4.2f ms. Skipping wavelet analysis." % self.tmin) - - # externally callable save function - def save(self, fdata_spec): - write(fdata_spec, self.timevec, self.freqvec, self.TFR) - - # plots spec to axis - def plot_to_ax(self, ax_spec, dt): - # pc = ax.imshow(self.TFR, extent=[xmin, xmax, self.freqvec[-1], self.freqvec[0]], aspect='auto', origin='upper') - pc = ax_spec.imshow(self.TFR, aspect='auto', origin='upper', cmap=plt.get_cmap(self.p_dict['spec_cmap'])) - - return pc - - # get time and freq of max spectral power - def max(self): - print("Warning: you are using max() in MorletSpec(). It should be changed from == to np.isclose()") - max_spec = self.TFR.max() - - t_mask = (self.TFR==max_spec).sum(axis=0) - t_at_max = self.tvec[t_mask == 1] - - f_mask = (self.TFR==max_spec).sum(axis=1) - f_at_max = self.f[f_mask == 1] - - return np.array((max_spec, t_at_max, f_at_max)) - - # also creates self.timevec - def __traces2TFR(self): - self.S_trans = self.tsvec.transpose() - # self.S_trans = self.S.transpose() - - # range should probably be 0 to len(self.S_trans) - # shift tvec to reflect change - # this is in ms - self.t = 1000. * np.arange(1, len(self.S_trans)+1) / self.fs + self.tmin - self.p_dict['dt'] - - # preallocation - B = np.zeros((len(self.f), len(self.S_trans))) - - if self.S_trans.ndim == 1: - for j in range(0, len(self.f)): - s = sps.detrend(self.S_trans[:]) - - # += is used here because these were zeros and now it's adding the solution - B[j, :] += self.__energyvec(self.f[j], s) - # B[j,:] = B[j,:] + self.__energyvec(self.freqvec[j], self.__lnr50(s)) - - return B - - # this code doesn't return anything presently ... - else: - for i in range(0, self.S_trans.shape[0]): - for j in range(0, len(self.f)): - s = sps.detrend(self.S_trans[i,:]) - B[j,:] += self.__energyvec(self.f[j], s) - # B[j,:] = B[j,:] + self.__energyvec(self.freqvec[j], self.__lnr50(s)) - - # calculate the morlet wavelet for central frequency f - def __morlet(self, f, t): - """ Morlet's wavelet for frequency f and time t - Wavelet normalized so total energy is 1 - f: specific frequency - y: final units are 1/s - """ - # sf in Hz - sf = f / self.width - - # st in s - st = 1. / (2. * np.pi * sf) - - # A in 1 / s - A = 1. / (st * np.sqrt(2.*np.pi)) - - # units: 1/s * (exp (s**2 / s**2)) * exp( 1/ s * s) - y = A * np.exp(-t**2. / (2. * st**2.)) * np.exp(1.j * 2. * np.pi * f * t) - - return y - - # notch filter for UK - def __lnr50(self, s): - """ - presently unused - Line noise reduction (50 Hz) the amplitude and phase of the line notch is estimate. - A sinusoid with these characterisitics is then subtracted from the signal. - s: signal - """ - fNoise = 50. - tv = np.arange(0,len(s)) / self.fs - - if np.ndim(s) == 1: - Sc = np.zeros(s.shape) - Sft = self.__ft(s[:], fNoise) - Sc[:] = s[:] - abs(Sft) * np.cos(2. * np.pi * fNoise * tv - np.angle(Sft)) - - return Sc - - else: - s = s.transpose() - Sc = np.zeros(s.shape) - - for k in range(0, len(s)): - Sft = ft(s[k,:], fNoise) - Sc[k,:] = s[k,:] - abs(Sft) * np.cos(2. * np.pi * fNoise * tv - np.angle(Sft)) - - return Sc.tranpose() - - def __ft(self, s, f): - tv = np.arange(0,len(s)) / self.fs - tmp = np.exp(1.j*2. * np.pi * f * tv) - S = 2 * sum(s * tmp) / len(s) - - return S - - # Return an array containing the energy as function of time for freq f - def __energyvec(self, f, s): - """ Final units of y: signal units squared. For instance, a signal of Am would have Am^2 - The energy is calculated using Morlet's wavelets - f: frequency - s: signal - """ - dt = 1. / self.fs - sf = f / self.width - st = 1. / (2. * np.pi * sf) - - t = np.arange(-3.5*st, 3.5*st, dt) - - # calculate the morlet wavelet for this frequency - # units of m are 1/s - m = self.__morlet(f, t) - - # convolve wavelet with signal - y = sps.fftconvolve(s, m) - - # take the power ... - y = (2. * abs(y) / self.fs)**2. - i_lower = int(np.ceil(len(m) / 2.)) - i_upper = int(len(y) - np.floor(len(m) / 2.)+1) - y = y[i_lower:i_upper] - - return y - -# calculates a phase locking value between 2 time series via morlet wavelets -class PhaseLock(): - """ Based on 4Dtools (deprecated) MATLAB code - Might be a newer version in fieldtrip - """ - def __init__(self, tsarray1, tsarray2, fparam, f_max=60.): - # Save time-series arrays as self variables - # ohhhh. Do not use 1-indexed keys of a dict! - self.ts = { - 1: tsarray1, - 2: tsarray2, - } - - # Get param dict - self.p = paramrw.read(fparam)[1] - - # Set frequecies over which to sort - self.f = 1. + np.arange(0., f_max, 1.) - - # Set width of Morlet wavelet (>= 5 suggested) - self.width = 7. - - # Calculate sampling frequency - self.fs = 1000. / self.p['dt'] - - self.data = self.__traces2PLS() - - def __traces2PLS(self): - # Not sure what's going on here... - # nshuffle = 200; - nshuffle = 1; - - # Construct timevec - tvec = np.arange(1, self.ts[1].shape[1]) / self.fs - - # Prellocated arrays - # Check sizes - B = np.zeros((self.f.size, self.ts[1].shape[1])) - Bstat = np.zeros((self.f.size, self.ts[1].shape[1])) - Bplf = np.zeros((self.f.size, self.ts[1].shape[1])) - - # Do the analysis - for i, freq in enumerate(self.f): - print('%i Hz' % freq) - - # Get phase of signals for given freq - # Check sizes - B1 = self.__phasevec(freq, num_ts=1) - B2 = self.__phasevec(freq, num_ts=2) - - # Potential conflict here - # Check size - B[i, :] = np.mean(B1 / B2, axis=0) - B[i, :] = abs(B[i, :]) - - # Randomly shuffle B2 - for j in range(0, nshuffle): - # Check size - idxShuffle = np.random.permutation(B2.shape[0]) - B2shuffle = B2[idxShuffle, :] - - Bshuffle = np.mean(B1 / B2shuffle, axis=0) - Bplf[i, :] += Bshuffle - - idxSign = (abs(B[i, :]) > abs(Bshuffle)) - Bstat[i, idxSign] += 1 - - # Final calculation of Bstat, Bplf - Bstat = 1. - Bstat / nshuffle - Bplf /= nshuffle - - # Store data - return { - 't': tvec, - 'f': self.f, - 'B': B, - 'Bstat': Bstat, - 'Bplf': Bplf, - } - - def __phasevec(self, f, num_ts=1): - """ should num_ts here be 0, as an index? - """ - dt = 1. / self.fs - sf = f / self.width - st = 1. / (2. * np.pi * sf) - - # create a time vector for the morlet wavelet - t = np.arange(-3.5*st, 3.5*st+dt, dt) - m = self.__morlet(f, t) - - y = np.array([]) - - for k in range(0, self.ts[num_ts].shape[0]): - if k == 0: - s = sps.detrend(self.ts[num_ts][k, :]) - y = np.array([sps.fftconvolve(s, m)]) - - else: - # convolve kth time series with morlet wavelet - # might as well let return valid length (not implemented) - y_tmp = sps.fftconvolve(self.ts[num_ts][k, :], m) - y = np.vstack((y, y_tmp)) - - # Change 0s to 1s to avoid division by 0 - # l is an index - # y is now complex, so abs(y) is the complex absolute value - l = (abs(y) == 0) - y[l] = 1. - - # normalize phase values and return 1s to zeros - y = y / abs(y) - y[l] = 0 - y = y[:, np.ceil(len(m)/2.)-1:y.shape[1]-np.floor(len(m)/2.)] - - return y - - def __morlet(self, f, t): - """ Calculate the morlet wavelet - """ - sf = f / self.width - st = 1. / (2. * np.pi * sf) - A = 1. / np.sqrt(st*np.sqrt(np.pi)) - - y = A * np.exp(-t**2./(2.*st**2.)) * np.exp(1.j*2.*np.pi*f*t) - - return y - -# functions on the aggregate spec data -class Spec(): - def __init__(self, fspec, spec_cmap='jet', dtype='dpl'): - # save dtype - self.dtype = dtype - - # save details of file - # may be better ways of doing this... - self.fspec = fspec - print('Spec: fspec:',fspec) - try: - self.expmt = fspec.split('/')[6].split('.')[0] - except: - self.expmt = '' - self.fname = 'spec.npz' # fspec.split('/')[-1].split('-spec')[0] - - self.spec_cmap = spec_cmap - - # parse data - self.__parse_f(fspec) - - # parses the specific data file - def __parse_f(self, fspec): - data_spec = np.load(fspec, allow_pickle=True) - - if self.dtype == 'dpl': - self.spec = {} - - # Try to load aggregate spec data - try: - self.spec['agg'] = { - 't': data_spec['t_agg'], - 'f': data_spec['f_agg'], - 'TFR': data_spec['TFR_agg'], - } - - except KeyError: - # Try loading aggregate spec data using old keys - try: - self.spec['agg'] = { - 't': data_spec['time'], - 'f': data_spec['freq'], - 'TFR': data_spec['TFR'], - } - except KeyError: - print("No aggregate spec data found. Don't use fns that require it...") - - # Try loading Layer specific data - try: - self.spec['L2'] = { - 't': data_spec['t_L2'], - 'f': data_spec['f_L2'], - 'TFR': data_spec['TFR_L2'], - } - - self.spec['L5'] = { - 't': data_spec['t_L5'], - 'f': data_spec['f_L5'], - 'TFR': data_spec['TFR_L5'], - } - - except KeyError: - print("All or some layer data is missing. Don't use fns that require it...") - - # Try loading periodigram data - try: - self.spec['pgram'] = { - 'p': data_spec['p_pgram'], - 'f': data_spec['f_pgram'], - } - - except KeyError: - try: - self.spec['pgram'] = { - 'p': data_spec['pgram_p'], - 'f': data_spec['pgram_f'], - } - except KeyError: - print("No periodigram data found. Don't use fns that require it...") - - # Try loading aggregate max spectral data - try: - self.spec['max_agg'] = { - 'p': data_spec['max_agg'][0], - 't': data_spec['max_agg'][1], - 'f': data_spec['max_agg'][2], - } - - except KeyError: - print("No aggregate max spectral data found. Don't use fns that require it...") - - elif self.dtype == 'current': - self.spec = { - 'L2': { - 't': data_spec['t_L2'], - 'f': data_spec['f_L2'], - 'TFR': data_spec['TFR_L2'], - }, - - 'L5': { - 't': data_spec['t_L5'], - 'f': data_spec['f_L5'], - 'TFR': data_spec['TFR_L5'], - } - } - - # Truncate t, f, and TFR for a specific layer over specified t and f intervals - # Be warned: MODIFIES THE CLASS INTERNALLY - def truncate(self, layer, t_interval, f_interval): - self.spec[layer] = self.truncate_ext(layer, t_interval, f_interval) - - # Truncate t, f, and TFR for a specific layer over specified t and f intervals - # Only returns truncated values. DOES NOT MODIFY THE CLASS INTERNALLY - def truncate_ext(self, layer, t_interval, f_interval): - # set f_max and f_min - if f_interval is None: - f_min = self.spec[layer]['f'][0] - f_max = self.spec[layer]['f'][-1] - - else: - f_min, f_max = f_interval - - # create an f_mask for the bounds of f, inclusive - f_mask = (self.spec[layer]['f']>=f_min) & (self.spec[layer]['f']<=f_max) - - # do the same for t - if t_interval is None: - t_min = self.spec[layer]['t'][0] - t_max = self.spec[layer]['t'][-1] - - else: - t_min, t_max = t_interval - - t_mask = (self.spec[layer]['t']>=t_min) & (self.spec[layer]['t']<=t_max) - - # use the masks truncate these appropriately - TFR_fcut = self.spec[layer]['TFR'][f_mask, :] - TFR_tfcut = TFR_fcut[:, t_mask] - - f_fcut = self.spec[layer]['f'][f_mask] - t_tcut = self.spec[layer]['t'][t_mask] - - return { - 't': t_tcut, - 'f': f_fcut, - 'TFR': TFR_tfcut, - } - - # find the max spectral power over specified time and frequency intervals - def max(self, layer, t_interval=None, f_interval=None, f_sort=None): - # If f_sort not provided, sort over all frequencies - if not f_sort: - f_sort = (self.spec['agg']['f'][0], self.spec['agg']['f'][-1]) - - # If f_sort is -1, assume upper abound is highest frequency - elif f_sort[1] < 0: - f_sort[1] = self.spec['agg']['f'][-1] - - # Only continue if absolute max of spectral power occurs at f in range of f_sorted - # Add +1 to f_sort[0] so range is inclusive - if self.spec['max_agg']['f'] not in np.arange(f_sort[0], f_sort[1]+1): - print("%s's absolute max spectral pwr does not occur between %i-%i Hz." %(self.fname, f_sort[0], f_sort[1])) - - else: - print("Warning: you are using max() in Spec(). It should be changed from == to np.isclose()") - # truncate data based on specified intervals - dcut = self.truncate_ext(layer, t_interval, f_interval) - - # find the max power over this new range - pwr_max = dcut['TFR'].max() - max_mask = (dcut['TFR'] == pwr_max) - - # find the t and f at max - # these are slightly crude and do not allow for the possibility of multiple maxes (rare?) - t_at_max = dcut['t'][max_mask.sum(axis=0) == 1][0] - f_at_max = dcut['f'][max_mask.sum(axis=1) == 1][0] - - # if f_interval provided and lower bound is not zero, set pd_at_max with lower bound: - # otherwise set it based on f_at_max - if f_interval and f_interval[0] > 0: - pd_at_max = 1000./f_interval[0] - else: - pd_at_max = 1000./f_at_max - - t_start = t_at_max - pd_at_max - t_end = t_at_max + pd_at_max - - # output structure - data_max = { - 'fname': self.fname, - 'pwr': pwr_max, - 't_int': [t_start, t_end], - 't_at_max': t_at_max, - 'f_at_max': f_at_max, - } - - return data_max - - # Averages spectral power over specified time interval for specified frequencies - def stationary_avg(self, layer='agg', t_interval=None, f_interval=None): - print("Warning: you are using stationary_avg() in Spec(). It should be changed from == to np.isclose()") - - # truncate data based on specified intervals - dcut = self.truncate_ext(layer, t_interval, f_interval) - - # avg TFR pwr over time - # axis = 1 sums over columns - pwr_avg = dcut['TFR'].sum(axis=1) / len(dcut['t']) - - # Get max pwr and freq at which max pwr occurs - pwr_max = pwr_avg.max() - f_at_max = dcut['f'][pwr_avg == pwr_max] - - return { - 'p_avg': pwr_avg, - 'p_max': pwr_max, - 'f_max': f_at_max, - 'freq': dcut['f'], - 'expmt': self.expmt, - } - - def plot_TFR(self, ax, layer='agg', xlim=None, ylim=None): - # truncate data based on specifed xlim and ylim - # xlim is a time interval - # ylim is a frequency interval - dcut = self.truncate_ext(layer, xlim, ylim) - - # Update xlim to have values guaranteed to exist - xlim_new = (dcut['t'][0], dcut['t'][-1]) - xmin, xmax = xlim_new - - # Update ylim to have values guaranteed to exist - ylim_new = (dcut['f'][0], dcut['f'][-1]) - ymin, ymax = ylim_new - - # set extent of plot - # order is ymax, ymin so y-axis is inverted - extent_xy = [xmin, xmax, ymax, ymin] - - # plot - im = ax.imshow(dcut['TFR'], extent=extent_xy, aspect='auto', origin='upper', cmap=plt.get_cmap(self.spec_cmap)) - - return im - - def plot_pgram(self, ax, f_max=None): - # If f_max is not supplied, set it to highest freq of aggregate analysis - if f_max is None: - f_max = self.spec['agg']['f'][-1] - - # plot - ax.plot(self.spec['pgram']['f'], self.spec['pgram']['p']) - ax.set_xlim((0., f_max)) - -# core class for frequency analysis assuming stationary time series -class Welch(): - def __init__(self, t_vec, ts_vec, dt): - # assign data internally - self.t_vec = t_vec - self.ts_vec = ts_vec - self.dt = dt - self.units = 'tsunits^2' - - # only assign length if same - if len(self.t_vec) == len(self.ts_vec): - self.N = len(ts_vec) - - else: - # raise an exception for real sometime in the future, for now just say something - print("in specfn.Welch(), your lengths don't match! Something will fail!") - - # in fact, this will fail (see above) - # self.N_fft = self.__nextpow2(self.N) - - # grab the dt (in ms) and calc sampling frequency - self.fs = 1000. / self.dt - - # calculate the actual Welch - self.f, self.P = sps.welch(self.ts_vec, self.fs, window='hanning', nperseg=self.N, noverlap=0, nfft=self.N, return_onesided=True, scaling='spectrum') - - # simple plot to an axis - def plot_to_ax(self, ax, f_max=80.): - ax.plot(self.f, self.P) - ax.set_xlim((0., f_max)) - - def scale(self, scalefactor): - self.P *= scalefactor - self.units += ' x%3.4e' % scalefactor - - # return the next power of 2 generally for a given L - def __nextpow2(self, L): - n = 2 - # j = 1 - while n < L: - # j += 1 - n *= 2 - - return n - # return n, j - -# general spec write function -def write(fdata_spec, t_vec, f_vec, TFR): - np.savez_compressed(fdata_spec, time=t_vec, freq=f_vec, TFR=TFR) - -# general spec read function -def read(fdata_spec, type='dpl'): - if type == 'dpl': - data_spec = np.load(fdata_spec) - return data_spec - - elif type == 'current': - # split this up into 2 spec types - data_spec = np.load(fdata_spec) - spec_L2 = { - 't': data_spec['t_L2'], - 'f': data_spec['f_L2'], - 'TFR': data_spec['TFR_L2'], - } - - spec_L5 = { - 't': data_spec['t_L5'], - 'f': data_spec['f_L5'], - 'TFR': data_spec['TFR_L5'], - } - - return spec_L2, spec_L5 - -# average spec data for a given set of files -def average(fname, fspec_list, spec_cmap='jet'): - for fspec in fspec_list: - print(fspec) - # load spec data - spec = Spec(fspec, spec_cmap) - - # if this is first file, copy spec data structure wholesale to x - if fspec is fspec_list[0]: - x = spec.spec - - # else, iterate through spec data and add to x_agg - # there might be a more 'pythonic' way of doing this... - else: - for subdict in x: - for key in x[subdict]: - x[subdict][key] += spec.spec[subdict][key] - - # poor man's mean - for subdict in x: - for key in x[subdict]: - x[subdict][key] /= len(fspec_list) - - # save data - # if max_agg is a key in x, assume all keys are present - # else, assume only aggregate data is present - # Terrible way to save due to how np.savez_compressed works (i.e. must specify key=value) - if 'max_agg' in x.keys(): - max_agg = (x['max_agg']['p'], x['max_agg']['t'], x['max_agg']['f']) - # max_agg = (x['max_agg']['p_at_max'], x['max_agg']['t_at_max'], x['max_agg']['f_at_max']) - - np.savez_compressed(fname, t_agg=x['agg']['t'], f_agg=x['agg']['f'], TFR_agg=x['agg']['TFR'], t_L2=x['L2']['t'], f_L2=x['L2']['f'], TFR_L2=x['L2']['TFR'], t_L5=x['L5']['t'], f_L5=x['L5']['f'], TFR_L5=x['L5']['TFR'], max_agg=max_agg, pgram_p=x['pgram']['p'], pgram_f=x['pgram']['f']) - - else: - np.savez_compressed(fname, t_agg=x['agg']['t'], f_agg=x['agg']['f'], TFR_agg=x['agg']['TFR']) - -# spectral plotting kernel should be simpler and take just a file name and an axis handle -def pspec_ax(ax_spec, fspec, xlim, spec_cmap='jet', layer=None): - """ Spectral plotting kernel for ONE simulation run - ax_spec is the axis handle. fspec is the file name - """ - # read is a function in this file to read the fspec - data_spec = read(fspec) - - if layer in (None, 'agg'): - TFR = data_spec['TFR'] - - if 'f' in data_spec.keys(): - f = data_spec['f'] - else: - f = data_spec['freq'] - - else: - TFR_layer = 'TFR_%s' % layer - f_layer = 'f_%s' % layer - - if TFR_layer in data_spec.keys(): - TFR = data_spec[TFR_layer] - f = data_spec[f_layer] - - else: - print(data_spec.keys()) - - extent_xy = [xlim[0], xlim[1], f[-1], 0.] - pc = ax_spec.imshow(TFR, extent=extent_xy, aspect='auto', origin='upper', cmap=plt.get_cmap(spec_cmap)) - [vmin, vmax] = pc.get_clim() - # print(np.min(TFR), np.max(TFR)) - # print(vmin, vmax) - # ax_spec.colorbar(pc, ax=ax_spec) - - # return (vmin, vmax) - return pc - -# find max spectral power and associated time/freq for individual file -def specmax(fspec, opts): - print("Warning: you are using specmax(). It should be changed from == to np.isclose()") - # opts is a dict that includes t_interval and f_interval - # grab name of file - fname = fspec.split('/')[-1].split('-spec')[0] - - # load spec data - data = read(fspec) - - # grab the min and max f - f_min, f_max = opts['f_interval'] - - # set f_max and f_min - if f_max < 0: - f_max = data['freq'][-1] - - if f_min < 0: - f_min = data['freq'][0] - - # create an f_mask for the bounds of f, inclusive - f_mask = (data['freq']>=f_min) & (data['freq']<=f_max) - - # do the same for t - t_min, t_max = opts['t_interval'] - if t_max < 0: - t_max = data['time'][-1] - - if t_min < 0: - t_min = data['time'][0] - - t_mask = (data['time']>=t_min) & (data['time']<=t_max) - - # use the masks truncate these appropriately - TFR_fcut = data['TFR'][f_mask, :] - TFR_tfcut = TFR_fcut[:, t_mask] - - f_fcut = data['freq'][f_mask] - t_tcut = data['time'][t_mask] - - # find the max power over this new range - # the max_mask is for the entire TFR - pwr_max = TFR_tfcut.max() - max_mask = (TFR_tfcut == pwr_max) - - # find the t and f at max - # these are slightly crude and do not allow for the possibility of multiple maxes (rare?) - t_at_max = t_tcut[max_mask.sum(axis=0) == 1] - f_at_max = f_fcut[max_mask.sum(axis=1) == 1] - - pd_at_max = 1000. / f_at_max - t_start = t_at_max - pd_at_max - t_end = t_at_max + pd_at_max - - # output structure - data_max = { - 'fname': fname, - 'pwr': pwr_max, - 't_int': [t_start, t_end], - 't_at_max': t_at_max, - 'f_at_max': f_at_max, - } - - return data_max - -# return the max spectral power (simple) for a series of files -def spec_max(ddata, expmt_group, layer='agg'): - # grab the spec list, assumes it exists - list_spec = ddata.file_match(expmt_group, 'rawspec') - - # really only perform these actions if there are items in the list - if len(list_spec): - # simple prealloc - val_max = np.zeros(len(list_spec)) - - # iterate through list_spec - i = 0 - for fspec in list_spec: - data_spec = read(fspec) - - # for now only do the TFR for the aggregate data - val_max[i] = np.max(data_spec['TFR']) - i += 1 - - return spec_max - -# common function to generate spec if it appears to be missing -def generate_missing_spec(ddata, f_max=40): - # just check first expmt_group - expmt_group = ddata.expmt_groups[0] - - # list of spec data - l_spec = ddata.file_match(expmt_group, 'rawspec') - - # if this list is empty, assume it is everywhere and run the analysis function - if not l_spec: - opts = { - 'type': 'dpl_laminar', - 'f_max': f_max, - 'save_data': 1, - 'runtype': 'parallel', - } - analysis_typespecific(ddata, opts) - - else: - # this is currently incorrect, it should actually return the data that has been referred to - # as spec_results. such a function to properly get this without analysis (eg. reader to this data) - # should exist - spec = [] - - # do the one for current, too. Might as well at this point - l_speccurrent = ddata.file_match(expmt_group, 'rawspeccurrent') - - if not l_speccurrent: - p_exp = paramrw.ExpParams(ddata.fparam) - opts = { - 'type': 'current', - 'f_max': 90., - 'save_data': 1, - 'runtype': 'parallel', - } - analysis_typespecific(ddata, opts) - else: - spec_current = [] - -# Kernel for spec analysis of current data -# necessary for parallelization -def spec_current_kernel(fparam, fts, fspec, f_max): - I_syn = currentfn.SynapticCurrent(fts) - - # Generate spec results - spec_L2 = MorletSpec(I_syn.t, I_syn.I_soma_L2Pyr, fparam, f_max) - spec_L5 = MorletSpec(I_syn.t, I_syn.I_soma_L5Pyr, fparam, f_max) - - # Save spec data - np.savez_compressed(fspec, t_L2=spec_L2.t, f_L2=spec_L2.f, TFR_L2=spec_L2.TFR, t_L5=spec_L5.t, f_L5=spec_L5.f, TFR_L5=spec_L5.TFR) - -# Kernel for spec analysis of dipole data -# necessary for parallelization -def spec_dpl_kernel(fparam, fts, fspec, f_max): - dpl = dipolefn.Dipole(fts) - dpl.units = 'nAm' - - # Do the conversion prior to generating these spec - # dpl.convert_fAm_to_nAm() - - # Generate various spec results - spec_agg = MorletSpec(dpl.t, dpl.dpl['agg'], fparam, f_max) - spec_L2 = MorletSpec(dpl.t, dpl.dpl['L2'], fparam, f_max) - spec_L5 = MorletSpec(dpl.t, dpl.dpl['L5'], fparam, f_max) - - # Get max spectral power data - # for now, only doing this for agg - max_agg = spec_agg.max() - - # Generate periodogram resutls - p_dict = paramrw.read(fparam)[1] - pgram = Welch(dpl.t, dpl.dpl['agg'], p_dict['dt']) - - # Save spec results - np.savez_compressed(fspec, time=spec_agg.t, freq=spec_agg.f, TFR=spec_agg.TFR, max_agg=max_agg, t_L2=spec_L2.t, f_L2=spec_L2.f, TFR_L2=spec_L2.TFR, t_L5=spec_L5.t, f_L5=spec_L5.f, TFR_L5=spec_L5.TFR, pgram_p=pgram.P, pgram_f=pgram.f) - -def analysis_simp (opts, fparam, fdpl, fspec): - opts_run = {'type': 'dpl_laminar', - 'f_max': 100., - 'save_data': 0, - 'runtype': 'parallel', - } - if opts: - for key, val in opts.items(): - if key in opts_run.keys(): opts_run[key] = val - spec_dpl_kernel(fparam, fdpl, fspec, opts_run['f_max']) - -# Does spec analysis for all files in simulation directory -# ddata comes from fileio -def analysis_typespecific(ddata, opts=None): - # def analysis_typespecific(ddata, p_exp, opts=None): - # 'opts' input are the options in a dictionary - # if opts is defined, then make it well formed - # the valid keys of opts are in list_opts - opts_run = { - 'type': 'dpl_laminar', - 'f_max': 100., - 'save_data': 0, - 'runtype': 'parallel', - } - # check if opts is supplied - if opts: - # assume opts is a dict - # iterate through provided opts and assign if the key is present - # otherwise, ignore - for key, val in opts.items(): - if key in opts_run.keys(): - opts_run[key] = val - # preallocate lists for use below - list_param, list_ts, list_spec = [], [], [] - - # aggregrate all files from individual expmts into lists - expmt_group = ddata.expmt_groups[0] - # get the list of params - # returns an alpha SORTED list - # add to list of all param files - param_tmp = ddata.file_match(expmt_group, 'param') - print('param_tmp:',param_tmp) - list_param.extend(param_tmp) - # get exp prefix for each trial in this expmt group - list_exp_prefix = [fio.strip_extprefix(fparam) for fparam in param_tmp] - # get the list of dipoles and create spec output filenames - if opts_run['type'] in ('dpl', 'dpl_laminar'): - list_ts.extend(ddata.file_match(expmt_group, 'rawdpl')) - list_spec.extend([ddata.create_filename(expmt_group, 'rawspec', exp_prefix) for exp_prefix in list_exp_prefix]) - elif opts_run['type'] == 'current': - list_ts.extend(ddata.file_match(expmt_group, 'rawcurrent')) - list_spec.extend(ddata.create_filename(expmt_group, 'rawspeccurrent', list_exp_prefix[-1])) - # create list of spec output names - # this is sorted because of file_match - # exp_prefix_list = [fio.strip_extprefix(fparam) for fparam in list_param] - - # perform analysis on all runs from all exmpts at same time - if opts_run['type'] == 'current': - # list_spec.extend([ddata.create_filename(expmt_group, 'rawspeccurrent', exp_prefix) for exp_prefix in exp_prefix_list]) - if opts_run['runtype'] == 'parallel': - pl = mp.Pool() - for fparam, fts, fspec in zip(list_param, list_ts, list_spec): - pl.apply_async(spec_current_kernel, (fparam, fts, fspec, opts_run['f_max'])) - pl.close() - pl.join() - elif opts_run['runtype'] == 'debug': - for fparam, fts, fspec in zip(list_param, list_ts, list_spec): - spec_current_kernel(fparam, fts, fspec, opts_run['f_max']) - elif opts_run['type'] == 'dpl_laminar': - # these should be OUTPUT filenames that are being generated - # list_spec.extend([ddata.create_filename(expmt_group, 'rawspec', exp_prefix) for exp_prefix in exp_prefix_list]) - # also in this case, the original spec results will be overwritten - # and replaced by laminar specific ones and aggregate ones - # in this case, list_ts is a list of dipole - if opts_run['runtype'] == 'parallel': - pl = mp.Pool() - for fparam, fts, fspec in zip(list_param, list_ts, list_spec): - pl.apply_async(spec_dpl_kernel, (fparam, fts, fspec, opts_run['f_max'])) - pl.close() - pl.join() - elif opts_run['runtype'] == 'debug': - # spec_results_L2 and _L5 - for fparam, fts, fspec in zip(list_param, list_ts, list_spec): - spec_dpl_kernel(fparam, fts, fspec, opts_run['f_max']) - # else: - # print('Type %s not recognized. Try again later.' %(opts_run['type'])) - -# returns spec results *only* for a given experimental group -def from_expmt(spec_result_list, expmt_group): - return [spec_result for spec_result in spec_result_list if expmt_group in spec_result.name] - -# Averages spec power over time, returning an array of average pwr per frequency -def specpwr_stationary_avg(fspec): - print("Warning: you are using specpwr_stationary_avg(). It should be changed from == to np.isclose()") - - # Load data from file - data_spec = np.load(fspec) - - timevec = data_spec['time'] - freqvec = data_spec['freq'] - TFR = data_spec['TFR'] - - # get experiment name - expmt = fspec.split('/')[6].split('.')[0] - - # axis = 1 sums over columns - pwr_avg = TFR.sum(axis=1) / len(timevec) - pwr_max = pwr_avg.max() - f_at_max = freqvec[pwr_avg == pwr_max] - - return { - 'p_avg': pwr_avg, - 'p_max': pwr_max, - 'f_max': f_at_max, - 'freq': freqvec, - 'expmt': expmt, - } - -def specpwr_stationary(t, f, TFR): - print("Warning: you are using specpwr_stationary(). It should be changed from == to np.isclose()") - - # aggregate sum of power of all calculated frequencies - p = TFR.sum(axis=1) - - # calculate max power - p_max = p.max() - - # calculate max f - f_max = f[p == p_max] - - return { - 'p': p, - 'f': f, - 'p_max': p_max, - 'f_max': f_max, - } - -def calc_stderror(data_list): - # np.std returns standard deviation - # axis=0 performs standard deviation over rows - error_vec = np.std(data_list, axis=0) - - return error_vec - -def pfreqpwr_with_hist(file_name, freqpwr_result, f_spk, gid_dict, p_dict, key_types): - f = ac.FigFreqpwrWithHist() - f.ax['hist'].hold(True) - - xmin = 50. - xmax = p_dict['tstop'] - - f.ax['freqpwr'].plot(freqpwr_result['freq'], freqpwr_result['avgpwr']) - - # grab alpha feed data. spikes_from_file() from spikefn.py - s_dict = spikefn.spikes_from_file(gid_dict, f_spk) - - # check for existance of alpha feed keys in s_dict. - s_dict = spikefn.alpha_feed_verify(s_dict, p_dict) - - # Account for possible delays - s_dict = spikefn.add_delay_times(s_dict, p_dict) - - # set number of bins (150 bins/1000ms) - bins = 150. * (xmax - xmin) / 1000. - hist_data = [] - - # Proximal feed - hist_data.extend(f.ax['hist'].hist(s_dict['alpha_feed_prox'].spike_list, bins, range=[xmin, xmax], color='red', label='Proximal feed')[0]) - - # Distal feed - hist_data.extend(f.ax['hist'].hist(s_dict['alpha_feed_dist'].spike_list, bins, range=[xmin, xmax], color='green', label='Distal feed')[0]) - - # set hist axis props - f.set_hist_props(hist_data) - - # axis labels - f.ax['freqpwr'].set_xlabel('freq (Hz)') - f.ax['freqpwr'].set_ylabel('power') - f.ax['hist'].set_xlabel('time (ms)') - f.ax['hist'].set_ylabel('# spikes') - - # create title - title_str = ac.create_title(p_dict, key_types) - f.f.suptitle(title_str) - # title_str = [key + ': %2.1f' % p_dict[key] for key in key_types['dynamic_keys']] - - f.savepng(file_name) - f.close() - -def pmaxpwr(file_name, results_list, fparam_list): - f = ac.FigStd() - f.ax0.hold(True) - - # instantiate lists for storing x and y data - x_data = [] - y_data = [] - - # plot points - for result, fparam in zip(results_list, fparam_list): - p = paramrw.read(fparam)[1] - - x_data.append(p['f_input_prox']) - y_data.extend(result['freq_at_max']) - - f.ax0.plot(x_data[-1], y_data[-1], 'kx') - - # add trendline - fit = np.polyfit(x_data, y_data, 1) - fit_fn = np.poly1d(fit) - - f.ax0.plot(x_data, fit_fn(x_data), 'k-') - - # Axis stuff - f.ax0.set_xlabel('Proximal/Distal Input Freq (Hz)') - f.ax0.set_ylabel('Freq at which max avg power occurs (Hz)') - - f.save(file_name) - -if __name__ == '__main__': - x = np.arange(0, 10.1, 0.1) - s1 = np.array([np.sin(x)]) - s2 = np.array([np.sin(2*x)]) - dt = 0.1 - - p = PhaseLock(s1, s2, dt) diff --git a/spikefn.py b/spikefn.py deleted file mode 100644 index 9f0acf232..000000000 --- a/spikefn.py +++ /dev/null @@ -1,464 +0,0 @@ -# spikefn.py - dealing with spikes -# -# v 1.10.0-py35 -# rev 2016-05-01 (SL: minor) -# last major: (SL: toward python3) - -import fileio as fio -import numpy as np -import scipy.signal as sps -import matplotlib.pyplot as plt -import itertools as it -import os -import paramrw - -# meant as a class for ONE cell type -class Spikes(): - def __init__ (self, s_all, ranges): - self.r = ranges - self.spike_list = self.filter(s_all) - self.N_cells = len(self.r) - self.N_spikingcells = len(self.spike_list) - # this is set externally - self.tick_marks = [] - - # returns spike_list, a list of lists of spikes. - # Each list corresponds to a cell, counted by range - def filter (self, s_all): - spike_list = [] - if len(s_all) > 0: - for ri in self.r: - srange = s_all[s_all[:, 1] == ri][:, 0] - srange[srange.argsort()] - spike_list.append(srange) - - return spike_list - - # simple return of all spikes *or* each spike indexed i in every list - def collapse_all (self, i=None): - if i == 'None': - spk_all = [] - for spk_list in self.spike_list: - spk_all.extend(spk_list) - else: - spk_all = [spk_list[i] for spk_list in self.spike_list if spk_list] - return spk_all - - # uses self.collapse_all() and returns unique spike times - def unique_all (self, i=None): - spk_all = self.collapse_all(i) - return np.unique(spk_all) - - # plot psth - def ppsth (self, a): - # flatten list of spikes - s_agg = np.array(list(it.chain.from_iterable(self.spike_list))) - # plot histogram to axis 'a' - bins = hist_bin_opt(s_agg, 1) - a.hist(s_agg, bins, normed=True, facecolor='g', alpha=0.75) - -# Class to handle extinput event times -class ExtInputs (Spikes): - # class for external inputs - extracts gids and times - def __init__ (self, fspk, fparam, evoked=False): - # load gid and param dicts - try: - self.gid_dict, self.p_dict = paramrw.read(fparam) - except OSError: - raise ValueError - self.evoked = evoked - # parse evoked prox and dist input gids from gid_dict - # print('getting evokedinput gids') - self.gid_evprox, self.gid_evdist = self.__get_evokedinput_gids() - # print('got evokedinput gids') - # parse ongoing prox and dist input gids from gid_dict - self.gid_prox, self.gid_dist = self.__get_extinput_gids() - # poisson input gids - #print('getting pois input gids') - self.gid_pois = self.__get_poisinput_gids() - # self.inputs is dict of input times with keys 'prox' and 'dist' - self.inputs = self.__get_extinput_times(fspk) - - def __get_extinput_gids (self): - # Determine if both feeds exist in this sim - # If they do, self.gid_dict['extinput'] has length 2 - # If so, first gid is guaraneteed to be prox feed, second to be dist feed - if len(self.gid_dict['extinput']) == 2: - return self.gid_dict['extinput'] - # Otherwise, only one feed exists in this sim - # Must use param file to figure out which one... - elif len(self.gid_dict['extinput']) > 0: - if self.p_dict['t0_input_prox'] < self.p_dict['tstop']: - return self.gid_dict['extinput'][0], None - elif self.p_dict['t0_input_dist'] < self.p_dict['tstop']: - return None, self.gid_dict['extinput'][0] - else: - return None, None - - def __get_poisinput_gids (self): - # get Poisson input gids - gids = [] - if len(self.gid_dict['extpois']) > 0: - if self.p_dict['t0_pois'] < self.p_dict['tstop']: - gids = np.array(self.gid_dict['extpois']) - self.pois_gid_range = (min(gids),max(gids)) - return gids - - def countevinputs (self, ty): - # count number of evoked inputs - n = 0 - for k in self.gid_dict.keys(): - if k.startswith(ty) and len(self.gid_dict[k]) > 0: n += 1 - return n - - def countevprox (self): return self.countevinputs('evprox') - def countevdist (self): return self.countevinputs('evdist') - - def __get_evokedinput_gids (self): - gid_prox,gid_dist=None,None - nprox,ndist = self.countevprox(), self.countevdist() - #print('__get_evokedinput_gids keys:',self.gid_dict.keys(),'nprox:',nprox,'ndist:',ndist) - if nprox > 0: - gid_prox = [] - for i in range(nprox): - if len(self.gid_dict['evprox'+str(i+1)]) > 0: - l = list(self.gid_dict['evprox'+str(i+1)]) - for x in l: gid_prox.append(x) - gid_prox = np.array(gid_prox) - self.evprox_gid_range = (min(gid_prox),max(gid_prox)) - if ndist > 0: - gid_dist = [] - for i in range(ndist): - if len(self.gid_dict['evdist'+str(i+1)]) > 0: - l = list(self.gid_dict['evdist'+str(i+1)]) - for x in l: gid_dist.append(x) - gid_dist = np.array(gid_dist) - self.evdist_gid_range = (min(gid_dist),max(gid_dist)) - return gid_prox, gid_dist - - def unique_times (self,s_all,lidx): - self.r = [x for x in lidx] - lfilttime = self.filter(s_all); ltime = [] - for arr in lfilttime: - for time in arr: - ltime.append(time) - return np.array(list(set(ltime))) - - def get_times (self, gid, s_all): - # self.filter() inherited from Spikes() - # self.r weirdness is necessary to use self.filter() - # i.e. self.r must exist and be a list to execute self.filter() - self.r = [gid] - return self.filter(s_all)[0] - - def __get_extinput_times (self, fspk): - # load all spike times from file - s_all = np.loadtxt(open(fspk, 'rb')) - if len(s_all) == 0: - # couldn't read spike times - raise ValueError - - inputs = {k:np.array([]) for k in ['prox','dist','evprox','evdist','pois']} - if self.gid_prox is not None: inputs['prox'] = self.get_times(self.gid_prox,s_all) - if self.gid_dist is not None: inputs['dist'] = self.get_times(self.gid_dist,s_all) - if self.gid_evprox is not None: inputs['evprox'] = self.unique_times(s_all, self.gid_evprox) - if self.gid_evdist is not None: inputs['evdist'] = self.unique_times(s_all, self.gid_evdist) - if self.gid_pois is not None: inputs['pois'] = self.unique_times(s_all, self.gid_pois) - return inputs - - # gid associated with evoked input - def is_evoked_gid (self,gid): - if len(self.inputs['evprox']) > 0: - if self.evprox_gid_range[0] <= gid <= self.evprox_gid_range[1]: - return True - if len(self.inputs['evdist']) > 0: - if self.evdist_gid_range[0] <= gid <= self.evdist_gid_range[1]: - return True - return False - - # check if gid is associated with a proximal input - def is_prox_gid (self, gid): - if gid == self.gid_prox: return True - if len(self.inputs['evprox']) > 0: - return self.evprox_gid_range[0] <= gid <= self.evprox_gid_range[1] - return False - - # check if gid is associated with a distal input - def is_dist_gid (self, gid): - if gid == self.gid_dist: return True - if len(self.inputs['evdist']) > 0: - return self.evdist_gid_range[0] <= gid <= self.evdist_gid_range[1] - return False - - # check if gid is associated with a Poisson input - def is_pois_gid (self, gid): - try: - if len(self.inputs['pois']) > 0: - return self.pois_gid_range[0] <= gid <= self.pois_gid_range[1] - except: - pass - return False - - def truncate_ext (self, dtype, t_int): - if dtype == 'prox' or dtype == 'dist': - tmask = (self.inputs[dtype] >= t_int[0]) & (self.inputs[dtype] <= t_int[1]) - return self.inputs[dtype][tmask] - if dtype == 'env': - tmask = (self.inputs['t'] >= t_int[0]) & (self.inputs['t'] <= t_int[1]) - return [self.inputs[dtype][tmask], self.inputs['t'][tmask]] - - def add_delay_times (self): - # if prox delay to both layers is the same, add it to the prox input times - if self.p_dict['input_prox_A_delay_L2'] == self.p_dict['input_prox_A_delay_L5']: - self.inputs['prox'] += self.p_dict['input_prox_A_delay_L2'] - # if dist delay to both layers is the same, add it to the dist input times - if self.p_dict['input_dist_A_delay_L2'] == self.p_dict['input_dist_A_delay_L5']: - self.inputs['dist'] += self.p_dict['input_dist_A_delay_L2'] - - def get_envelope (self, tvec, feed='dist', bins=150): - h_range = (tvec[0], tvec[-1]) - hist, edges = np.histogram(self.inputs[feed], bins=bins, range=h_range) - centers = edges[0:bins] + np.diff(edges) / 2. - num = len(tvec) - env, t = sps.resample(hist, num, t=centers) - self.inputs['env'] = env - self.inputs['t'] = t - - # extinput is either 'dist' or 'prox' - def plot_hist (self, ax, extinput, tvec, bins='auto', xlim=None, color='green', hty='bar',lw=4): - if bins is 'auto': - bins = hist_bin_opt(self.inputs[extinput], 1) - if not xlim: - xlim = (0., p_dict['tstop']) - if len(self.inputs[extinput]): - #print("plot_hist bins:",bins,type(bins)) - hist = ax.hist(self.inputs[extinput], bins, range=xlim, color=color, label=extinput, histtype=hty,linewidth=lw) - ax.set_xticklabels([]) - ax.tick_params(bottom=False, left=False) - else: - hist = None - return hist - -# filters spike dict s_dict for keys that start with str_startswith -def filter_spike_dict (s_dict, str_startswith): - """ easy enough to modify for future conditions - just fix associated functions - """ - s_filt = {} - for key, val in s_dict.items(): - if key.startswith(str_startswith): - s_filt[key] = val - return s_filt - -# weird bin counting function -def bin_count(bins_per_second, tinterval): return bins_per_second * tinterval / 1000. - -# splits ext random feeds (of type exttype) by supplied cell type -def split_extrand(s, gid_dict, celltype, exttype): - gid_cell = gid_dict[celltype] - gid_exttype_start = gid_dict[exttype][0] - gid_exttype_cell = [gid + gid_exttype_start for gid in gid_dict[celltype]] - return Spikes(s, gid_exttype_cell) - -# histogram bin optimization -def hist_bin_opt(x, N_trials): - """ Shimazaki and Shinomoto, Neural Comput, 2007 - """ - bin_checks = np.arange(80, 300, 10) - # bin_checks = np.linspace(150, 300, 16) - costs = np.zeros(len(bin_checks)) - i = 0 - # this might be vectorizable in np - for n_bins in bin_checks: - # use np.histogram to do the numerical minimization - pdf, bin_edges = np.histogram(x, n_bins) - # calculate bin width - # some discrepancy here but should be fine - w_bin = np.unique(np.diff(bin_edges)) - if len(w_bin) > 1: w_bin = w_bin[0] - # calc mean and var - kbar = np.mean(pdf) - kvar = np.var(pdf) - # calc cost - costs[i] = (2.*kbar - kvar) / (N_trials * w_bin)**2. - i += 1 - # find the bin size corresponding to a minimization of the costs - bin_opt_list = bin_checks[costs.min() == costs] - bin_opt = bin_opt_list[0] - return bin_opt - -# "purely" from files, this is the new way to replace the old way -def spikes_from_file(fparam, fspikes): - gid_dict, _ = paramrw.read(fparam) - # cell list - requires cell to start with L2/L5 - src_list = [] - src_extinput_list = [] - src_unique_list = [] - # fill in 2 lists from the keys - for key in gid_dict.keys(): - if key.startswith('L2_') or key.startswith('L5_'): - src_list.append(key) - elif key == 'extinput': - src_extinput_list.append(key) - else: - src_unique_list.append(key) - # check to see if there are spikes in here, otherwise return an empty array - if os.stat(fspikes).st_size: - s = np.loadtxt(open(fspikes, 'rb')) - else: - s = np.array([], dtype='float64') - # get the skeleton s_dict from the cell_list - s_dict = dict.fromkeys(src_list) - # iterate through just the src keys - for key in s_dict.keys(): - # sort of a hack to separate extgauss - s_dict[key] = Spikes(s, gid_dict[key]) - # figure out its extgauss feed - newkey_gauss = 'extgauss_' + key - s_dict[newkey_gauss] = split_extrand(s, gid_dict, key, 'extgauss') - # figure out its extpois feed - newkey_pois = 'extpois_' + key - s_dict[newkey_pois] = split_extrand(s, gid_dict, key, 'extpois') - # do the keys in unique list - for key in src_unique_list: s_dict[key] = Spikes(s, gid_dict[key]) - # Deal with alpha feeds (extinputs) - # order guaranteed by order of inputs in p_ext in paramrw - # and by details of gid creation in class_net - # A little kludgy to deal with the fact that one might not exist - if len(gid_dict['extinput']) > 1: - s_dict['alpha_feed_prox'] = Spikes(s, [gid_dict['extinput'][0]]) - s_dict['alpha_feed_dist'] = Spikes(s, [gid_dict['extinput'][1]]) - else: - # not sure why this is done here - # handle the extinput: this is a LIST! - s_dict['extinput'] = [Spikes(s, [gid]) for gid in gid_dict['extinput']] - return s_dict - -# from the supplied key name, return a marker style -def get_markerstyle(key): - markerstyle = '' - # ext now same color, not ideal yet - # if 'L2' in key: - # markerstyle += 'k' - # elif 'L5' in key: - # markerstyle += 'b' - # short circuit this by putting extgauss first ... cheap. - if 'extgauss' in key: - markerstyle += 'k.' - elif 'extpois' in key: - markerstyle += 'k.' - elif 'pyramidal' in key: - markerstyle += 'k.' - elif 'basket' in key: - markerstyle += 'r|' - return markerstyle - -# spike_png plots spikes based on input dict -def spike_png(a, s_dict): - # new spikepng function: - # receive lists of cell spikes and the gid dict for now - # parse spikes file by cell type - # output all cell spikes - # get the length of s - new way - N_total = 0 - for key in s_dict.keys(): N_total += s_dict[key].N_cells - # 2 added to this in order to pad the y_ticks off the x axis and top - # e_ticks starts at 1 for padding - # i_ticks ends at -1 for padding - y_ticks = np.linspace(0, 1, N_total + 2) - # Turn the hold on - a.hold(True) - # define start point - tick_start = 1 - # sort the keys by alpha: consistency in names will lead to consistent behavior here - # reverse=True because _basket comes before _pyramidal, and the spikes plot bottom up - key_list = [key for key in s_dict.keys()] - key_list.sort(reverse=True) - # for key in s_dict.keys(): - for key in key_list: - # print key, s_dict[key].spike_list - s_dict[key].tick_marks = y_ticks[tick_start:tick_start+s_dict[key].N_cells] - tick_start += s_dict[key].N_cells - markerstyle = get_markerstyle(key) - # There must be congruency between lines in spike_list and the number of ticks - i = 0 - for spk_cell in s_dict[key].spike_list: - # a.plot(np.array([451.6]), e_ticks[i] * np.ones(1), 'k.', markersize=2.5) - # print len(s_dict[key].tick_marks), len(spk_cell) - a.plot(spk_cell, s_dict[key].tick_marks[i] * np.ones(len(spk_cell)), markerstyle, markeredgewidth=1, markersize=1.5) - i += 1 - a.set_ylim([0, 1]) - a.grid() - -# Add synaptic delays to alpha input times if applicable: -def add_delay_times(s_dict, p_dict): - # Only add delays if delay is same for L2 and L5 - # Proximal feed - # if L5 delay is -1, has same delays as L2 - # if p_dict['input_prox_A_delay_L5'] == -1: - # s_dict['alpha_feed_prox'].spike_list = [num+p_dict['input_prox_A_delay_L2'] for num in s_dict['alpha_feed_prox'].spike_list] - # else, check to see if delays are the same anyway - # else: - if s_dict['alpha_feed_prox'].spike_list and p_dict['input_prox_A_delay_L2'] == p_dict['input_prox_A_delay_L5']: - s_dict['alpha_feed_prox'].spike_list = [num+p_dict['input_prox_A_delay_L2'] for num in s_dict['alpha_feed_prox'].spike_list] - # Distal - # if L5 delay is -1, has same delays as L2 - # if p_dict['input_dist_A_delay_L5'] == -1: - # s_dict['alpha_feed_dist'].spike_list = [num+p_dict['input_dist_A_delay_L2'] for num in s_dict['alpha_feed_dist'].spike_list] - # else, check to see if delays are the same anyway - # else: - if s_dict['alpha_feed_dist'].spike_list and p_dict['input_dist_A_delay_L2'] == p_dict['input_dist_A_delay_L5']: - s_dict['alpha_feed_dist'].spike_list = [num+p_dict['input_dist_A_delay_L2'] for num in s_dict['alpha_feed_dist'].spike_list] - return s_dict - -# Checks for existance of alpha feed keys in s_dict. -def alpha_feed_verify(s_dict, p_dict): - """ If they do not exist, then simulation used one or no feeds. Creates keys accordingly - """ - # check for existance of keys. If exist, do nothing - if 'alpha_feed_prox' and 'alpha_feed_dist' in s_dict.keys(): - pass - # if they do not exist, create them and add proper data - else: - # if proximal feed's t0 < tstop, it exists and data is stored in s_dict['extinputs']. - # distal feed does not exist and gets empty list - if p_dict['t0_input_prox'] < p_dict['tstop']: - s_dict['alpha_feed_prox'] = s_dict['extinput'] - # make object on the fly with attribute 'spike_list' - # A little hack-y - s_dict['alpha_feed_dist'] = type('emptyspike', (object,), {'spike_list': np.array([])}) - # if distal feed's t0 < tstop, it exists and data is stored in s_dict['extinputs']. - # Proximal feed does not exist and gets empty list - elif p_dict['t0_input_dist'] < p_dict['tstop']: - s_dict['alpha_feed_prox'] = type('emptyspike', (object,), {'spike_list': np.array([])}) - s_dict['alpha_feed_dist'] = s_dict['extinput'] - # if neither had t0 < tstop, neither exists and both get empty list - else: - s_dict['alpha_feed_prox'] = type('emptyspike', (object,), {'spike_list': np.array([])}) - s_dict['alpha_feed_dist'] = type('emptyspike', (object,), {'spike_list': np.array([])}) - return s_dict - -# input histogram on 2 axes -def pinput_hist(a0, a1, s_list0, s_list1, n_bins, xlim): - hists = { - 'prox': a0.hist(s_list0, n_bins, color='red', label='Proximal input', alpha=0.75), - 'dist': a1.hist(s_list1, n_bins, color='green', label='Distal input', alpha=0.75), - } - # assumes these axes are inverted and figure it out - ylim_max = 2*np.max([a0.get_ylim()[1], a1.get_ylim()[1]]) + 1 - # set the ylims here - a0.set_ylim((0, ylim_max)) - a1.set_ylim((0, ylim_max)) - a0.set_xlim(xlim) - a1.set_xlim(xlim) - a1.invert_yaxis() - return hists - -def pinput_hist_onesided(a0, s_list, n_bins): - hists = { - 'prox': a0.hist(s_list, n_bins, color='k', label='Proximal input', alpha=0.75), - } - return hists - -if __name__ == '__main__': - pass diff --git a/tests/test_compare_hnn.py b/tests/test_compare_hnn.py deleted file mode 100644 index 732155679..000000000 --- a/tests/test_compare_hnn.py +++ /dev/null @@ -1,65 +0,0 @@ -import os.path as op - -from numpy import loadtxt -from numpy.testing import assert_allclose - -from mne.utils import _fetch_file - - -def test_hnn(): - """Test to check that HNN produces consistent results""" - # small snippet of data on data branch for now. To be deleted - # later. Data branch should have only commit so it does not - # pollute the history. - from subprocess import Popen, PIPE - import shlex - import os - import sys - - ntrials = 3 - paramf = op.join('param', 'default.param') - - nrniv_str = 'nrniv -python -nobanner' - cmd = nrniv_str + ' ' + sys.executable + ' run.py ' + paramf \ - + ' ntrial ' + str(ntrials) - - # Split the command into shell arguments for passing to Popen - cmdargs = shlex.split(cmd, posix="win" not in sys.platform) - - # Start the simulation - proc = Popen(cmdargs, stdin=PIPE, stdout=PIPE, stderr=PIPE, - cwd=os.getcwd(), universal_newlines=True) - out, err = proc.communicate() - - # print all messages (including error messages) - print('STDOUT', out) - print('STDERR', err) - - for trial in range(ntrials): - print("Checking data for trial %d" % trial) - if 'SYSTEM_USER_DIR' in os.environ: - basedir = os.environ['SYSTEM_USER_DIR'] - else: - basedir = os.path.expanduser('~') - dirname = op.join(basedir, 'hnn_out', 'data', 'default') - - data_dir = ('https://raw.githubusercontent.com/jonescompneurolab/' - 'hnn/test_data/') - for data_type in ['dpl', 'rawdpl', 'i']: - sys.stdout.write("%s..." % data_type) - - fname = "%s_%d.txt" % (data_type, trial) - data_url = op.join(data_dir, fname) - if not op.exists(fname): - _fetch_file(data_url, fname) - - print("comparing %s" % fname) - pr = loadtxt(op.join(dirname, fname)) - master = loadtxt(fname) - - assert_allclose(pr[:, 1], master[:, 1], rtol=1e-8, atol=0) - if data_type in ['dpl', 'rawdpl', 'i']: - assert_allclose(pr[:, 2], master[:, 2], rtol=1e-8, atol=0) - if data_type in ['dpl', 'rawdpl']: - assert_allclose(pr[:, 3], master[:, 3], rtol=1e-8, atol=0) - print("done") diff --git a/tests/test_view_windows.py b/tests/test_view_windows.py deleted file mode 100644 index 87f111970..000000000 --- a/tests/test_view_windows.py +++ /dev/null @@ -1,73 +0,0 @@ -import os.path as op -import os -import sys -import shlex -from subprocess import Popen, PIPE - -from mne.utils import _fetch_file - - -def fetch_file(fname): - data_dir = ('https://raw.githubusercontent.com/jonescompneurolab/' - 'hnn/test_data/') - - data_url = op.join(data_dir, fname) - if not op.exists(fname): - _fetch_file(data_url, fname) - - -def view_window(code_fname, paramf, data_fname=None): - """Test to check that viewer displays without error""" - - nrniv_str = 'nrniv -python -nobanner' - cmd = nrniv_str + ' ' + sys.executable + ' ' + code_fname + ' ' + \ - paramf - if data_fname is not None: - cmd += ' ' + data_fname - - # Windows will fail to load the correct Qt plugin when launched with nrniv - # This is a temporary fix until separate windows are no longer launched - # as different processes - basedir = os.path.expanduser('~') - plugin_dir = op.join(basedir, 'Miniconda3', 'envs', 'hnn', 'Library', - 'plugins', 'platforms') - os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_dir - - # Split the command into shell arguments for passing to Popen - cmdargs = shlex.split(cmd, posix="win" not in sys.platform) - - # Start the simulation - proc = Popen(cmdargs, stdin=PIPE, stdout=PIPE, stderr=PIPE, - cwd=os.getcwd(), universal_newlines=True) - out, err = proc.communicate() - - # print all messages (including error messages) - print('STDOUT', out) - print('STDERR', err) - - if proc.returncode != 0: - raise RuntimeError("Running command %s failed" % cmd) - - -def test_view_rast(): - fname = 'spk.txt' - fetch_file(fname) - paramf = op.join('param', 'default.param') - view_window('visrast.py', paramf, fname) - - -def test_view_dipole(): - fname = 'dpl.txt' - fetch_file(fname) - paramf = op.join('param', 'default.param') - view_window('visdipole.py', paramf, fname) - - -def test_view_psd(): - paramf = op.join('param', 'default.param') - view_window('vispsd.py', paramf) - - -def test_view_spec(): - paramf = op.join('param', 'default.param') - view_window('visspec.py', paramf) diff --git a/visdipole.py b/visdipole.py deleted file mode 100644 index dad52ea0d..000000000 --- a/visdipole.py +++ /dev/null @@ -1,141 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import pylab as plt -import matplotlib.gridspec as gridspec -from DataViewGUI import DataViewGUI -from neuron import h -from run import net -import paramrw -from filt import boxfilt, hammfilt -import spikefn -from math import ceil -from simdat import readdpltrials -from conf import dconf - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: dconf['fontsize'] = 10 - -tstop = -1; ntrial = 1; scalefctr = 30e3; dplpath = ''; paramf = '' -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.txt'): - dplpath = sys.argv[i] - elif sys.argv[i].endswith('.param'): - paramf = sys.argv[i] - scalefctr = paramrw.find_param(paramf,'dipole_scalefctr') - if type(scalefctr)!=float and type(scalefctr)!=int: scalefctr=30e3 - tstop = paramrw.find_param(paramf,'tstop') - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - -basedir = os.path.join(dconf['datdir'],paramf.split(os.path.sep)[-1].split('.param')[0]) - -ddat = {} -ddat['dpltrials'] = readdpltrials(basedir,ntrial) -try: - ddat['dpl'] = np.loadtxt(os.path.join(basedir,'dpl.txt')) -except: - print('Could not load',dplpath) - quit() - -class DipoleCanvas (FigureCanvas): - - def __init__ (self, paramf, index, parent=None, width=12, height=10, dpi=120, title='Dipole Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - self.title = title - self.setParent(parent) - self.gui = parent - self.index = index - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.plot() - - def clearaxes (self): - try: - for ax in self.lax: - ax.set_yticks([]) - ax.cla() - except: - pass - - def drawdipole (self, fig): - - gdx = 311 - - ltitle = ['Layer 2/3', 'Layer 5', 'Aggregate'] - - white_patch = mpatches.Patch(color='white', label='Average') - gray_patch = mpatches.Patch(color='gray', label='Individual') - lpatch = [] - - if len(ddat['dpltrials']) > 0: lpatch = [white_patch,gray_patch] - - yl = [1e9,-1e9] - - for i in [2,3,1]: - yl[0] = min(yl[0],ddat['dpl'][:,i].min()) - yl[1] = max(yl[1],ddat['dpl'][:,i].max()) - if len(ddat['dpltrials']) > 0: # plot dipoles from individual trials - for dpltrial in ddat['dpltrials']: - yl[0] = min(yl[0],dpltrial[:,i].min()) - yl[1] = max(yl[1],dpltrial[:,i].max()) - - yl = tuple(yl) - - self.lax = [] - - for i,title in zip([2, 3, 1],ltitle): - ax = fig.add_subplot(gdx) - self.lax.append(ax) - - if i == 1: ax.set_xlabel('Time (ms)'); - - lw = self.gui.linewidth - if self.index != 0: lw = self.gui.linewidth + 2 - - if len(ddat['dpltrials']) > 0: # plot dipoles from individual trials - for ddx,dpltrial in enumerate(ddat['dpltrials']): - if self.index == 0 or (self.index > 0 and ddx == self.index-1): - ax.plot(dpltrial[:,0],dpltrial[:,i],color='gray',linewidth=lw) - - # average dipole (across trials) - if self.index == 0: ax.plot(ddat['dpl'][:,0],ddat['dpl'][:,i],'w',linewidth=self.gui.linewidth+2) - - ax.set_ylabel(r'(nAm $\times$ '+str(scalefctr)+')') - if tstop != -1: ax.set_xlim((0,tstop)) - ax.set_ylim(yl) - - if i == 2 and len(ddat['dpltrials']) > 0: ax.legend(handles=lpatch) - - ax.set_facecolor('k') - ax.grid(True) - ax.set_title(title) - - gdx += 1 - - self.figure.subplots_adjust(bottom=0.06, left=0.06, right=1.0, top=0.97, wspace=0.1, hspace=0.09) - - def plot (self): - self.drawdipole(self.figure) - self.draw() - - if "TRAVIS_TESTING" in os.environ and os.environ["TRAVIS_TESTING"] == "1": - print("Exiting gracefully with TRAVIS_TESTING=1") - qApp.quit() - exit(0) - -if __name__ == '__main__': - app = QApplication(sys.argv) - ex = DataViewGUI(DipoleCanvas,paramf,ntrial,'Dipole Viewer') - sys.exit(app.exec_()) diff --git a/vislfp.py b/vislfp.py deleted file mode 100644 index 9b85c1e11..000000000 --- a/vislfp.py +++ /dev/null @@ -1,299 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import pylab as plt -import matplotlib.gridspec as gridspec -from DataViewGUI import DataViewGUI -from neuron import h -from run import net -import paramrw -from filt import boxfilt, hammfilt, lowpass -import spikefn -from math import ceil -from conf import dconf -from specfn import MorletSpec - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] - -debug = True - -tstop = -1; ntrial = 1; maxlfp = 0; scalefctr = 30e3; lfppath = ''; paramf = ''; laminar = False -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.txt'): - lfppath = sys.argv[i] - elif sys.argv[i].endswith('.param'): - paramf = sys.argv[i] - scalefctr = paramrw.find_param(paramf,'dipole_scalefctr') - if type(scalefctr)!=float and type(scalefctr)!=int: scalefctr=30e3 - tstop = paramrw.find_param(paramf,'tstop') - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - -basedir = os.path.join(dconf['datdir'],paramf.split(os.path.sep)[-1].split('.param')[0]) - -ddat = {}; tvec = None; dspec = None - -def readLFPs (basedir, ntrial): - if debug: print('readLFPs') - ddat = {'lfp':{}} - lfile = os.listdir(basedir) - maxlfp = 0; tvec = None - if debug: print('readLFPs:',lfile) - for f in lfile: - if f.count('lfp_') > 0 and f.endswith('.txt'): - lf = f.split('.txt')[0].split('_') - if debug: print('readLFPs: lf=',lf,'ntrial=',ntrial) - if ntrial > 1: - trial = int(lf[1]) - nlfp = int(lf[2]) - else: - trial = 0 - nlfp = int(lf[1]) - maxlfp = max(nlfp,maxlfp) - if debug: print('readLFPs:',trial,nlfp,maxlfp) - fullpath = os.path.join(basedir,f) - if debug: print('readLFPs: fullpath=',fullpath) - try: - k2 = (trial,nlfp) - #print('k2:',k2) - ddat['lfp'][k2] = np.loadtxt(fullpath) - if tvec is None: tvec = ddat['lfp'][k2][:,0] - except: - print('exception!') - print('readLFPs:',ddat['lfp'].keys()) - #print('ddat:',ddat,maxlfp) - return ddat, maxlfp, tvec - -# lowpass filter the items in lfps. lfps is a list or numpy array of LFPs arranged spatially by row -def getlowpass (lfps,sampr,maxf): - return np.array([lowpass(lfp,maxf,df=sampr,zerophase=True) for lfp in lfps]) - -# gets 2nd spatial derivative of voltage as approximation of CSD. -# performs lowpass filter on voltages before taking spatial derivative -# input dlfp is dictionary of LFP voltage time-series keyed by (trial, electrode) -# output dCSD is keyed by trial -def getCSD (dlfp,sampr,minf=0.1,maxf=300.0): - if debug: print('getCSD:','sampr=',sampr,'ntrial=',ntrial,'maxlfp=',maxlfp) - dCSD = {} - for trial in range(ntrial): - if debug: print('trial:',trial) - lfps = [dlfp[(trial,i)][:,1] for i in range(maxlfp+1)] - datband = getlowpass(lfps,sampr,maxf) - dCSD[trial] = -np.diff(datband,n=2,axis=0) # now each row is an electrode -- CSD along electrodes - return dCSD - -try: - ddat, maxlfp, tvec = readLFPs(basedir,ntrial) - if maxlfp > 1: laminar = True - ddat['spec'] = {} - waveprm = {'f_max_spec':40.0,'dt':tvec[1]-tvec[0],'tstop':tvec[-1]} - minwavet = 50.0 - sampr = 1e3 / (tvec[1]-tvec[0]) - - if laminar: - print('getting CSD') - ddat['CSD'] = getCSD(ddat['lfp'],sampr) - if ntrial > 1: - ddat['avgCSD'] = np.zeros(ddat['CSD'][1].shape) - for i in range(ntrial): ddat['avgCSD'] += ddat['CSD'][i] - ddat['avgCSD']/=float(ntrial) - - print('Extracting Wavelet spectrogram(s).') - for i in range(maxlfp+1): - for trial in range(ntrial): - ddat['spec'][(trial,i)] = MorletSpec(tvec, ddat['lfp'][(trial,i)][:,1],None,None,waveprm,minwavet) - if ntrial > 1: - if debug: print('here') - davglfp = {}; davgspec = {} - for i in range(maxlfp+1): - if debug: print(i,maxlfp,list(ddat['lfp'].keys())[0]) - davglfp[i] = np.zeros(len(ddat['lfp'][list(ddat['lfp'].keys())[0]]),) - try: - ms = ddat['spec'][(0,0)] - if debug: print('shape',ms.TFR.shape,ms.tmin,ms.f[0],ms.f[-1]) - davgspec[i] = [np.zeros(ms.TFR.shape), ms.tmin, ms.f] - except: - print('err in davgspec[i]=') - for trial in range(ntrial): - davglfp[i] += ddat['lfp'][(trial,i)][:,1] - davgspec[i][0] += ddat['spec'][(trial,i)].TFR - davglfp[i] /= float(ntrial) - davgspec[i][0] /= float(ntrial) - ddat['avglfp'] = davglfp - ddat['avgspec'] = davgspec -except: - print('Could not load LFPs') - quit() - -def getnorm (yin): - yout = yin - min(yin) - return yout / max(yout) - -def getrngfctroff (dat): - yrng = [max(dat[i,:])-min(dat[i,:]) for i in range(dat.shape[0])] - mxrng = np.amax(yrng) - yfctr = [yrng[i]/mxrng for i in range(len(yrng))] - yoff = [maxlfp - 1 - (i + 1) for i in range(len(yrng))] - return yrng,yfctr,yoff - -class LFPCanvas (FigureCanvas): - - def __init__ (self, paramf, index, parent=None, width=12, height=10, dpi=120, title='LFP Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - self.title = title - self.setParent(parent) - self.index = index - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.drawwavelet = True - - # get spec_cmap - p_exp = paramrw.ExpParams(self.paramf, 0) - if len(p_exp.expmt_groups) > 0: - expmt_group = p_exp.expmt_groups[0] - else: - expmt_group = None - p = p_exp.return_pdict(expmt_group, 0) - self.spec_cmap = p['spec_cmap'] - - self.plot() - - def clearaxes (self): - try: - for ax in self.lax: - ax.set_yticks([]) - ax.cla() - except: - pass - - def drawCSD (self, fig, G): - ax = fig.add_subplot(G[:,2]) - ax.set_yticks([]) - lw = 2; clr = 'k' - if ntrial > 1: - if self.index == 0: - cax = ax.imshow(ddat['avgCSD'],extent=[0, tstop, 0, maxlfp-1], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap),interpolation='None') - # overlay the time-series - yrng,yfctr,yoff = getrngfctroff(ddat['avgCSD']) - for i in range(ddat['avgCSD'].shape[0]): - y = yfctr[i] * getnorm(ddat['avgCSD'][i,:]) + yoff[i] - ax.plot(tvec,y,clr,linewidth=lw) - else: - cax = ax.imshow(ddat['CSD'][self.index-1],extent=[0, tstop, 0, maxlfp-1], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap),interpolation='None') - # overlay the time-series - yrng,yfctr,yoff = getrngfctroff(ddat['CSD'][self.index-1]) - for i in range(ddat['CSD'][self.index-1].shape[0]): - y = yfctr[i] * getnorm(ddat['CSD'][self.index-1][i,:]) + yoff[i] - ax.plot(tvec,y,clr,linewidth=lw) - else: - # draw CSD as image; blue/red corresponds to excit/inhib - cax = ax.imshow(ddat['CSD'][0],extent=[0, tstop, 0, 15], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap),interpolation='None') - # overlay the time-series - yrng,yfctr,yoff = getrngfctroff(ddat['CSD'][0]) - for i in range(ddat['CSD'][0].shape[0]): - y = yfctr[i] * getnorm(ddat['CSD'][0][i,:]) + yoff[i] - ax.plot(tvec,y,clr,linewidth=lw) - cbaxes = fig.add_axes([0.69, 0.88, 0.005, 0.1]) - fig.colorbar(cax, cax=cbaxes, orientation='vertical') - ax.set_xlim((minwavet,tstop)); ax.set_ylim((0,maxlfp-1)) - - def drawLFP (self, fig): - - if laminar: - nrow = maxlfp+1 - ncol = 3 - ltitle = ['' for x in range(nrow*ncol)] - else: - nrow = (maxlfp+1) * 2 - ncol = 1 - ltitle = ['LFP'+str(x) for x in range(nrow)] - - G = gridspec.GridSpec(nrow,ncol) - - white_patch = mpatches.Patch(color='white', label='Average') - gray_patch = mpatches.Patch(color='gray', label='Individual') - lpatch = [] - - if debug: print('ntrial:',ntrial) - - if ntrial > 1: lpatch = [white_patch,gray_patch] - - yl = [1e9,-1e9] - - minx = 100 - - for i in [1]: # this gets min,max LFP values - # print('ddat[lfp].keys():',ddat['lfp'].keys()) - for k in ddat['lfp'].keys(): - yl[0] = min(yl[0],ddat['lfp'][k][minx:-1,i].min()) - yl[1] = max(yl[1],ddat['lfp'][k][minx:-1,i].max()) - - yl = tuple(yl) # y-axis range - - self.lax = [] - - for nlfp in range(maxlfp+1): - title = ltitle[nlfp] - - if laminar: ax = fig.add_subplot(G[nlfp, 0]) - else: ax = fig.add_subplot(G[nlfp*2]) - - self.lax.append(ax) - - if self.index == 0: # draw all along with average - if ntrial > 1: clr = 'gray' - else: clr = 'white' - for i in range(ntrial): ax.plot(tvec,ddat['lfp'][(i,nlfp)][:,1],color=clr,linewidth=2) - if ntrial > 1: - ax.plot(tvec,ddat['avglfp'][nlfp],'w',linewidth=3) - if nlfp == 0: ax.legend(handles=lpatch) - else: # draw individual trial - ax.plot(tvec,ddat['lfp'][(self.index-1,nlfp)][:,1],color='white',linewidth=2) - - if not laminar: ax.set_ylabel(r'$\mu V$') - if tstop != -1: ax.set_xlim((minwavet,tstop)) - ax.set_ylim(yl) - - ax.set_facecolor('k'); ax.grid(True); ax.set_title(title) - - # plot wavelet spectrogram - if laminar: ax = fig.add_subplot(G[nlfp, 1]) - else: ax = fig.add_subplot(G[nlfp*2+1]) - self.lax.append(ax) - if self.index == 0: - if ntrial > 1: - TFR,tmin,F = ddat['avgspec'][nlfp] - ax.imshow(TFR, extent=[tmin, tvec[-1], F[-1], F[0]], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap)) - else: - ms = ddat['spec'][(0,nlfp)] - ax.imshow(ms.TFR, extent=[ms.tmin, tvec[-1], ms.f[-1], ms.f[0]], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap)) - else: - ms = ddat['spec'][(self.index-1,nlfp)] - ax.imshow(ms.TFR, extent=[ms.tmin, tvec[-1], ms.f[-1], ms.f[0]], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap)) - ax.set_xlim(minwavet,tvec[-1]) - if nlfp == maxlfp: ax.set_xlabel('Time (ms)') - if not laminar: ax.set_ylabel('Frequency (Hz)'); - - if laminar: self.drawCSD(fig, G) - - self.figure.subplots_adjust(bottom=0.04, left=0.04, right=1.0, top=0.99, wspace=0.1, hspace=0.01) - - def plot (self): - self.drawLFP(self.figure) - self.draw() - -if __name__ == '__main__': - app = QApplication(sys.argv) - ex = DataViewGUI(LFPCanvas,paramf,ntrial,'LFP Viewer') - sys.exit(app.exec_()) diff --git a/visnet.py b/visnet.py deleted file mode 100644 index 3a01a9d6d..000000000 --- a/visnet.py +++ /dev/null @@ -1,231 +0,0 @@ -import sys, os -import pyqtgraph as pg -from pyqtgraph.Qt import QtCore, QtGui -#from pyqtgraph.graphicsItems.AxisItem import * -import pyqtgraph.opengl as gl -import pyqtgraph as pg -import numpy as np - -from morphology import shapeplot, getshapecoords -from mpl_toolkits.mplot3d import Axes3D -import pylab as plt -from neuron import h -from L5_pyramidal import L5Pyr -from L2_pyramidal import L2Pyr -from L2_basket import L2Basket -from L5_basket import L5Basket -from run import net - -drawallcells = True # False -cell = net.cells[-1] - -# colors for the different cell types -dclr = {'L2_pyramidal' : 'g', L2Pyr: (0.,1.,0.,0.6), - 'L5_pyramidal' : 'r', L5Pyr: (1.,0.,0.,0.6), - 'L2_basket' : 'k', L2Basket: (1.,1.,1.,0.6), - 'L5_basket' : 'b', L5Basket: (0.,0.,1.,0.6)} - -def getcellpos (net,ty): - lx,ly = [],[] - for cell in net.cells: - if type(cell) == ty: - lx.append(cell.pos[0]) - ly.append(cell.pos[1]) - return lx,ly - -def cellsecbytype (ty): - lss = [] - for cell in net.cells: - if type(cell) == ty: - ls = cell.get_sections() - for s in ls: lss.append(s) - return lss - -def getdrawsec (ncells=1,ct=L2Pyr): - global cell - if drawallcells: return list(h.allsec()) - ls = [] - nfound = 0 - for c in net.cells: - if type(c) == ct: - cell = c - lss = c.get_sections() - for s in lss: ls.append(s) - nfound += 1 - if nfound >= ncells: break - return ls - -dsec = {} -for ty in [L2Pyr, L5Pyr, L2Basket, L5Basket]: dsec[ty] = cellsecbytype(ty) -dlw = {L2Pyr:1, L5Pyr:1,L2Basket:4,L5Basket:4} -whichdraw = [L2Pyr, L2Basket, L5Pyr, L5Basket] - -lsecnames = cell.get_section_names() - -def get3dinfo (sidx,eidx): - llx,lly,llz,lldiam = [],[],[],[] - for i in range(sidx,eidx,1): - lx,ly,lz,ldiam = net.cells[i].get3dinfo() - llx.append(lx); lly.append(ly); llz.append(lz); lldiam.append(ldiam) - return llx,lly,llz,lldiam - -llx,lly,llz,lldiam = get3dinfo(0,len(net.cells)) - -def countseg (ls): return sum([s.nseg for s in ls]) - -defclr = 'k'; selclr = 'r' -useGL = True -fig = None - -def drawcellspylab3d (): - global shapeax,fig - plt.ion(); fig = plt.figure() - shapeax = plt.subplot(111, projection='3d') - #shapeax.set_xlabel('X',fontsize=24); shapeax.set_ylabel('Y',fontsize=24); shapeax.set_zlabel('Z',fontsize=24) - shapeax.set_xticks([]); shapeax.set_yticks([]); shapeax.set_zticks([]) - shapeax.view_init(elev=105,azim=-71) - shapeax.grid(False) - lshapelines = [] - for ty in whichdraw: - ls = dsec[ty] - lshapelines.append(shapeplot(h,shapeax,sections=ls,cvals=[dclr[ty] for i in range(countseg(ls))],lw=dlw[ty])) - return lshapelines - -if not useGL: drawcellspylab3d() - -def onclick(event): - try: - print('button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % - (event.button, event.x, event.y, event.xdata, event.ydata)) - except: - pass - -def setcolor (ls,clr): - for l in ls: l.set_color(clr) - -# click on section event handler - not used for network -def onpick (event): - print('onpick') - thisline = event.artist - c = thisline.get_color() - idx = -1 - setcolor(shapelines,defclr) - for idx,l in enumerate(shapelines): - if l == thisline: - break - try: - print('idx is ', idx, 'selected',lsecnames[idx]) - xdata = thisline.get_xdata() - ydata = thisline.get_ydata() - ind = event.ind - points = tuple(zip(xdata[ind], ydata[ind])) - print('onpick points:', points) - if c == defclr: - thisline.set_color(selclr) - else: - thisline.set_color(defclr) - print(ind) - #print(dir(thisline)) - except: - pass - -def setcallbacks (): - if useGL: return [] - lcid = [] - if False: lcid.append(fig.canvas.mpl_connect('button_press_event', onclick)) - if not drawallcells: lcid.append(fig.canvas.mpl_connect('pick_event', onpick)) - return lcid - -lcid = setcallbacks() - -# -def drawinputs2d (cell,clr,ax): - for lsrc in [cell.ncfrom_L2Pyr, cell.ncfrom_L2Basket, cell.ncfrom_L5Pyr, cell.ncfrom_L5Basket]: - for src in lsrc: - precell = src.precell() - ax.plot([precell.pos[0],cell.pos[0]],[precell.pos[1],cell.pos[1]],clr) - -# -def drawconn2d (): - plt.figure() - ax = plt.gca() - """ - loc = np.array(net.pos_dict['L2_basket']) - plot(loc[:,0],loc[:,1],'ko',markersize=14) - loc = np.array(net.pos_dict['L2_pyramidal']) - plot(loc[:,0],loc[:,1],'ro',markersize=14) - loc = np.array(net.pos_dict['L2_basket']) - plot(loc[:,0],loc[:,1],'bo',markersize=10) - """ - lx = [cell.pos[0] for cell in net.cells] - ly = [cell.pos[1] for cell in net.cells] - ax.plot(lx,ly,'ko',markersize=14) - """ - self.ncfrom_L2Pyr = [] - self.ncfrom_L2Basket = [] - self.ncfrom_L5Pyr = [] - self.ncfrom_L5Basket = [] - """ - for cell in net.cells: - drawinputs2d(cell,'r',ax) - break - -# -def drawinputs3d (cell,clr,widg,width=2.0): - for lsrc in [cell.ncfrom_L2Pyr, cell.ncfrom_L2Basket, cell.ncfrom_L5Pyr, cell.ncfrom_L5Basket]: - for src in lsrc: - precell = src.precell() - pts = np.vstack([[precell.pos[0]*100,cell.pos[0]*100],[precell.pos[2],cell.pos[2]],[precell.pos[1]*100,cell.pos[1]*100]]).transpose() - plt = gl.GLLinePlotItem(pos=pts, color=clr, width=width, antialias=True, mode='lines') - widg.addItem(plt) - -# -def drawconn3d (widg,width=2.0,clr=(1.0,0.0,0.0,0.5)): - i = 0 - for cell in net.cells: - drawinputs3d(cell,clr,widg,width) - i += 1 - #if i > 20: break - -def drawcells3dgl (ty,widget,width=2.2): - for cell in net.cells: - if type(cell) != ty: continue - lx,ly,lz = getshapecoords(h,cell.get_sections()) - pts = np.vstack([lx,ly,lz]).transpose() - plt = gl.GLLinePlotItem(pos=pts, color=dclr[type(cell)], width=width, antialias=True, mode='lines') - #plt.showGrid(x=True,y=True) - widget.addItem(plt) - #axis = pg.AxisItem(orientation='bottom') - #print(dir(axis)) - #print(dir(widget)) - #print(widget.getViewport()) - #axis.linkToView(axis.getViewBox())#widget.getViewport()) - #widget.addItem(pg.AxisItem(orientation='bottom')) - -def drawallcells3dgl (wcells): - drawcells3dgl(L5Pyr,wcells,width=15.0) - drawcells3dgl(L2Pyr,wcells,width=15.0) - drawcells3dgl(L5Basket,wcells,width=40.0) - drawcells3dgl(L2Basket,wcells,width=40.0) - wcells.opts['distance'] = 4320.9087386478195 - wcells.opts['elevation']=105 - wcells.opts['azimuth']=-71 - wcells.opts['fov'] = 90 - wcells.setWindowTitle('Network Visualization') - -if __name__ == '__main__': - app = QtGui.QApplication([]) - widg = gl.GLViewWidget() - for s in sys.argv: - if s == 'cells': - drawallcells3dgl(widg) - if s == 'Econn': - drawconn3d(widg,clr=(1.0,0.0,0.0,0.25)) - if s == 'Iconn': - drawconn3d(widg,clr=(0.0,0.0,1.0,0.25)) - #app.axis = axis = pg.AxisItem(orientation='bottom') - #app.pqg_plot_item.showAxis('bottom',True) - widg.show() - if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'): - QtGui.QApplication.instance().exec_() - diff --git a/vispsd.py b/vispsd.py deleted file mode 100644 index 269fce5dd..000000000 --- a/vispsd.py +++ /dev/null @@ -1,261 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import pylab as plt -import matplotlib.gridspec as gridspec -from DataViewGUI import DataViewGUI -from neuron import h -from run import net -import paramrw -from filt import boxfilt, hammfilt -import spikefn -from math import ceil, sqrt -from specfn import MorletSpec -from conf import dconf - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: dconf['fontsize'] = 10 - -ntrial = 1; specpath = ''; paramf = '' -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.txt'): - specpath = sys.argv[i] - elif sys.argv[i].endswith('.param'): - paramf = sys.argv[i] - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - -basedir = os.path.join(dconf['datdir'],paramf.split(os.path.sep)[-1].split('.param')[0]) -print('basedir:',basedir) - -ddat = {} -try: - specpath = os.path.join(basedir,'rawspec.npz') - print('specpath',specpath) - ddat['spec'] = np.load(specpath) -except: - print('Could not load',specpath) - quit() - -# assumes column 0 is time, rest of columns are time-series -def extractpsd (dat, fmax=120.0): - print('extractpsd',dat.shape) - lpsd = [] - tvec = dat[:,0] - dt = tvec[1] - tvec[0] - tstop = tvec[-1] - prm = {'f_max_spec':fmax,'dt':dt,'tstop':tstop} - for col in range(1,dat.shape[1],1): - ms = MorletSpec(tvec,dat[:,col],None,None,prm) - lpsd.append(np.mean(ms.TFR,axis=1)) - return ms.f, np.array(lpsd) - -class PSDCanvas (FigureCanvas): - def __init__ (self, paramf, index, parent=None, width=12, height=10, dpi=120, title='PSD Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - self.title = title - self.setParent(parent) - self.gui = parent - self.index = index - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.invertedhistax = False - self.G = gridspec.GridSpec(10,1) - self.plot() - - def drawpsd (self, dspec, fig, G, ltextra=''): - - lax = [] - - lkF = ['f_L2', 'f_L5', 'f_L2'] - lkS = ['TFR_L2', 'TFR_L5', 'TFR'] - - plt.ion() - - gdx = 311 - - ltitle = ['Layer 2/3', 'Layer 5', 'Aggregate'] - - yl = [1e9,-1e9] - - for i in [0,1,2]: - ddat['avg'+str(i)] = avg = np.mean(dspec[lkS[i]],axis=1) - ddat['std'+str(i)] = std = np.std(dspec[lkS[i]],axis=1) / sqrt(dspec[lkS[i]].shape[1]) - yl[0] = min(yl[0],np.amin(avg-std)) - yl[1] = max(yl[1],np.amax(avg+std)) - - yl = tuple(yl) - xl = (dspec['f_L2'][0],dspec['f_L2'][-1]) - - for i,title in zip([0, 1, 2],ltitle): - ax = fig.add_subplot(gdx) - lax.append(ax) - - if i == 2: ax.set_xlabel('Frequency (Hz)'); - - ax.plot(dspec[lkF[i]],np.mean(dspec[lkS[i]],axis=1),color='w',linewidth=self.gui.linewidth+2) - avg = ddat['avg'+str(i)] - std = ddat['std'+str(i)] - ax.plot(dspec[lkF[i]],avg-std,color='gray',linewidth=self.gui.linewidth) - ax.plot(dspec[lkF[i]],avg+std,color='gray',linewidth=self.gui.linewidth) - - ax.set_ylim(yl) - ax.set_xlim(xl) - - ax.set_facecolor('k') - ax.grid(True) - ax.set_title(title) - ax.set_ylabel(r'$nAm^2$') - - gdx += 1 - return lax - - - def clearaxes (self): - try: - for ax in self.lax: - ax.set_yticks([]) - ax.cla() - except: - pass - - def clearlextdatobj (self): - if hasattr(self,'lextdatobj'): - for o in self.lextdatobj: - try: - o.set_visible(False) - except: - o[0].set_visible(False) - del self.lextdatobj - - def plotextdat (self, lF, lextpsd, lextfiles): # plot 'external' data (e.g. from experiment/other simulation) - - print('len(lax)',len(self.lax)) - - self.lextdatobj = [] - white_patch = mpatches.Patch(color='white', label='Simulation') - self.lpatch = [white_patch] - - ax = self.lax[2] # plot on agg - - yl = ax.get_ylim() - - cmap=plt.get_cmap('nipy_spectral') - csm = plt.cm.ScalarMappable(cmap=cmap); - csm.set_clim((0,100)) - - for f,lpsd,fname in zip(lF,lextpsd,lextfiles): - print(fname,len(f),lpsd.shape) - clr = csm.to_rgba(int(np.random.RandomState().uniform(5,101,1))) - avg = np.mean(lpsd,axis=0) - std = np.std(lpsd,axis=0) / sqrt(lpsd.shape[1]) - self.lextdatobj.append(ax.plot(f,avg,color=clr,linewidth=self.gui.linewidth+2)) - self.lextdatobj.append(ax.plot(f,avg-std,'--',color=clr,linewidth=self.gui.linewidth)) - self.lextdatobj.append(ax.plot(f,avg+std,'--',color=clr,linewidth=self.gui.linewidth)) - yl = ((min(yl[0],min(avg))),(max(yl[1],max(avg)))) - new_patch = mpatches.Patch(color=clr, label=fname.split(os.path.sep)[-1].split('.txt')[0]) - self.lpatch.append(new_patch) - - ax.set_ylim(yl) - self.lextdatobj.append(ax.legend(handles=self.lpatch)) - - def plot (self): - #self.clearaxes() - #plt.close(self.figure) - if self.index == 0: - self.lax = self.drawpsd(ddat['spec'],self.figure, self.G, ltextra='All Trials') - else: - specpathtrial = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'rawspec_'+str(self.index)+'.npz') - if 'spec'+str(self.index) not in ddat: - ddat['spec'+str(self.index)] = np.load(specpath) - self.lax=self.drawpsd(ddat['spec'+str(self.index)],self.figure, self.G, ltextra='Trial '+str(self.index)); - - self.figure.subplots_adjust(bottom=0.06, left=0.06, right=0.98, top=0.97, wspace=0.1, hspace=0.09) - - self.draw() - -class PSDViewGUI (DataViewGUI): - def __init__ (self,CanvasType,paramf,ntrial,title): - super(PSDViewGUI,self).__init__(CanvasType,paramf,ntrial,title) - self.addLoadDataActions() - self.lF = [] # frequencies associated with external data psd - self.lextpsd = [] # external data psd - self.lextfiles = [] # external data files - - if "TRAVIS_TESTING" in os.environ and os.environ["TRAVIS_TESTING"] == "1": - print("Exiting gracefully with TRAVIS_TESTING=1") - qApp.quit() - exit(0) - - def addLoadDataActions (self): - loadDataFile = QAction(QIcon.fromTheme('open'), 'Load data file.', self) - loadDataFile.setShortcut('Ctrl+D') - loadDataFile.setStatusTip('Load data file.') - loadDataFile.triggered.connect(self.loadDisplayData) - - clearDataFileAct = QAction(QIcon.fromTheme('close'), 'Clear data file.', self) - clearDataFileAct.setShortcut('Ctrl+C') - clearDataFileAct.setStatusTip('Clear data file.') - clearDataFileAct.triggered.connect(self.clearDataFile) - - self.fileMenu.addAction(loadDataFile) - self.fileMenu.addAction(clearDataFileAct) - - def loadDisplayData (self): - extdataf,dat = self.loadDataFileDialog() - if not extdataf: return - try: - f, lpsd = extractpsd(dat) - self.printStat('Extracted PSDs from ' + extdataf) - self.lextpsd.append(lpsd) - self.lextfiles.append(extdataf) - self.lF.append(f) - except: - self.printStat('Could not extract PSDs from ' + extdataf) - - try: - if len(self.lextpsd) > 0: - self.printStat('Plotting ext data PSDs.') - self.m.plotextdat(self.lF,self.lextpsd,self.lextfiles) - self.m.draw() # make sure new lines show up in plot - self.printStat('') - except: - self.printStat('Could not plot data from ' + extdataf) - - def loadDataFileDialog (self): - fn = QFileDialog.getOpenFileName(self, 'Open file', 'data') - if fn[0]: - try: - extdataf = os.path.abspath(fn[0]) # data file - dat = np.loadtxt(extdataf) - self.printStat('Loaded data in ' + extdataf + '. Extracting PSDs.') - return extdataf,dat - except: - self.printStat('Could not load data in ' + fn[0]) - return None,None - return None,None - - def clearDataFile (self): - self.m.clearlextdatobj() - self.lextpsd = [] - self.lextfiles = [] - self.lF = [] - self.m.draw() - - -if __name__ == '__main__': - app = QApplication(sys.argv) - ex = PSDViewGUI(PSDCanvas,paramf,ntrial,'PSD Viewer') - sys.exit(app.exec_()) - diff --git a/visrast.py b/visrast.py deleted file mode 100644 index 49cbbee7c..000000000 --- a/visrast.py +++ /dev/null @@ -1,400 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox, QInputDialog -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import pylab as plt -import matplotlib.gridspec as gridspec -from neuron import h -from run import net -import paramrw -from filt import boxfilt, hammfilt -import spikefn -from math import ceil -from conf import dconf -from gutils import getmplDPI - -#plt.rcParams['lines.markersize'] = 15 -plt.rcParams['lines.linewidth'] = 1 -rastmarksz = 5 # raster dot size -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: plt.rcParams['font.size'] = dconf['fontsize'] = 10 - -# colors for the different cell types -dclr = {'L2_pyramidal' : 'g', - 'L5_pyramidal' : 'r', - 'L2_basket' : 'w', - 'L5_basket' : 'b'} - -ntrial = 1; tstop = -1; outparamf = spkpath = paramf = ''; EvokedInputs = OngoingInputs = PoissonInputs = False; - -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.txt'): - spkpath = sys.argv[i] - elif sys.argv[i].endswith('.param'): - paramf = sys.argv[i] - tstop = paramrw.quickgetprm(paramf,'tstop',float) - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - EvokedInputs = paramrw.usingEvokedInputs(paramf) - OngoingInputs = paramrw.usingOngoingInputs(paramf) - PoissonInputs = paramrw.usingPoissonInputs(paramf) - outparamf = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'param.txt') - -extinputs = spikefn.ExtInputs(spkpath, outparamf) -extinputs.add_delay_times() - -alldat = {} - -ncell = len(net.cells) - -binsz = 5.0 -smoothsz = 0 # no smoothing - -bDrawHist = True # whether to draw histograms (spike counts per time) - -# adjust input gids for display purposes -def adjustinputgid (extinputs, gid): - if gid == extinputs.gid_prox: - return 0 - elif gid == extinputs.gid_dist: - return 1 - elif extinputs.is_prox_gid(gid): - return 2 - elif extinputs.is_dist_gid(gid): - return 3 - return gid - -def getdspk (fn): - ddat = {} - try: - ddat['spk'] = np.loadtxt(fn) - except: - print('Could not load',fn) - quit() - dspk = {'Cell':([],[],[]),'Input':([],[],[])} - dhist = {} - for ty in dclr.keys(): dhist[ty] = [] - haveinputs = False - for (t,gid) in ddat['spk']: - ty = net.gid_to_type(gid) - if ty in dclr: - dspk['Cell'][0].append(t) - dspk['Cell'][1].append(gid) - dspk['Cell'][2].append(dclr[ty]) - dhist[ty].append(t) - else: - dspk['Input'][0].append(t) - dspk['Input'][1].append(adjustinputgid(extinputs, gid)) - if extinputs.is_prox_gid(gid): - dspk['Input'][2].append('r') - elif extinputs.is_dist_gid(gid): - dspk['Input'][2].append('g') - else: - dspk['Input'][2].append('orange') - haveinputs = True - for ty in dhist.keys(): - dhist[ty] = np.histogram(dhist[ty],range=(0,tstop),bins=int(tstop/binsz)) - if smoothsz > 0: - #dhist[ty] = boxfilt(dhist[ty][0],smoothsz) - dhist[ty] = hammfilt(dhist[ty][0],smoothsz) - else: - dhist[ty] = dhist[ty][0] - return dspk,haveinputs,dhist - -def drawhist (dhist,fig,G): - ax = fig.add_subplot(G[-4:-1,:]) - fctr = 1.0 - if ntrial > 0: fctr = 1.0 / ntrial - for ty in dhist.keys(): - ax.plot(np.arange(binsz/2,tstop+binsz/2,binsz),dhist[ty]*fctr,dclr[ty],linestyle='--') - ax.set_xlim((0,tstop)) - ax.set_ylabel('Cell Spikes') - return ax - -invertedax = False - -def drawrast (dspk, fig, G, sz=8): - global invertedax - lax = [] - lk = ['Cell'] - row = 0 - - if haveinputs: - lk.append('Input') - lk.reverse() - - dinput = extinputs.inputs - - for i,k in enumerate(lk): - if k == 'Input': # input spiking - - bins = ceil(150. * tstop / 1000.) # bins needs to be an int - - haveEvokedDist = (EvokedInputs and len(dinput['evdist'])>0) - haveOngoingDist = (OngoingInputs and len(dinput['dist'])>0) - haveEvokedProx = (EvokedInputs and len(dinput['evprox'])>0) - haveOngoingProx = (OngoingInputs and len(dinput['prox'])>0) - - if haveEvokedDist or haveOngoingDist: - ax = fig.add_subplot(G[row:row+2,:]); row += 2 - lax.append(ax) - if haveEvokedDist: extinputs.plot_hist(ax,'evdist',0,bins,(0,tstop),color='g',hty='step') - if haveOngoingDist: extinputs.plot_hist(ax,'dist',0,bins,(0,tstop),color='g') - ax.invert_yaxis() - ax.set_ylabel('Distal Input') - - if haveEvokedProx or haveOngoingProx: - ax2 = fig.add_subplot(G[row:row+2,:]); row += 2 - lax.append(ax2) - if haveEvokedProx: extinputs.plot_hist(ax2,'evprox',0,bins,(0,tstop),color='r',hty='step') - if haveOngoingProx: extinputs.plot_hist(ax2,'prox',0,bins,(0,tstop),color='r') - ax2.set_ylabel('Proximal Input') - - if PoissonInputs and len(dinput['pois']): - axp = fig.add_subplot(G[row:row+2,:]); row += 2 - lax.append(axp) - extinputs.plot_hist(axp,'pois',0,bins,(0,tstop),color='orange') - axp.set_ylabel('Poisson Input') - - else: # local circuit neuron spiking - - endrow = -1 - if bDrawHist: endrow = -4 - - ax = fig.add_subplot(G[row:endrow,:]) - lax.append(ax) - - ax.scatter(dspk[k][0],dspk[k][1],c=dspk[k][2],s=sz**2) - ax.set_ylabel(k + ' ID') - white_patch = mpatches.Patch(color='white', label='L2/3 Basket') - green_patch = mpatches.Patch(color='green', label='L2/3 Pyr') - red_patch = mpatches.Patch(color='red', label='L5 Pyr') - blue_patch = mpatches.Patch(color='blue', label='L5 Basket') - ax.legend(handles=[white_patch,green_patch,blue_patch,red_patch],loc='best') - ax.set_ylim((-1,ncell+1)) - ax.invert_yaxis() - - return lax - -class SpikeCanvas (FigureCanvas): - def __init__ (self, paramf, index, parent=None, width=12, height=10, dpi=120, title='Spike Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - self.title = title - self.setParent(parent) - self.index = index - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.invertedhistax = False - self.G = gridspec.GridSpec(16,1) - self.plot() - - def clearaxes (self): - try: - for ax in self.lax: - ax.set_yticks([]) - ax.cla() - except: - pass - - def loadspk (self,idx): - global haveinputs,extinputs - if idx in alldat: return - alldat[idx] = {} - if idx == 0: - try: - extinputs = spikefn.ExtInputs(spkpath, outparamf) - except ValueError: - print("Error: could not load spike timings from %s" % spkpath) - return - extinputs.add_delay_times() - dspk,haveinputs,dhist = getdspk(spkpath) - alldat[idx]['dspk'] = dspk - alldat[idx]['haveinputs'] = haveinputs - alldat[idx]['dhist'] = dhist - alldat[idx]['extinputs'] = extinputs - else: - spkpathtrial = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'spk_'+str(self.index-1)+'.txt') - dspktrial,haveinputs,dhisttrial = getdspk(spkpathtrial) # show spikes from first trial - try: - extinputs = spikefn.ExtInputs(spkpathtrial, outparamf) - except ValueError: - print("Error: could not load spike timings from %s" % spkpath) - return - extinputs.add_delay_times() - alldat[idx]['dspk'] = dspktrial - alldat[idx]['haveinputs'] = haveinputs - alldat[idx]['dhist'] = dhisttrial - alldat[idx]['extinputs'] = extinputs - - def plot (self): - global haveinputs,extinputs - - self.loadspk(self.index) - - idx = self.index - dspk = alldat[idx]['dspk'] - haveinputs = alldat[idx]['haveinputs'] - dhist = alldat[idx]['dhist'] - extinputs = alldat[idx]['extinputs'] - - self.lax = drawrast(dspk,self.figure, self.G, rastmarksz) - - if bDrawHist: self.lax.append(drawhist(dhist,self.figure,self.G)) - - for ax in self.lax: - ax.set_facecolor('k') - ax.grid(True) - if tstop != -1: ax.set_xlim((0,tstop)) - - if idx == 0: self.lax[0].set_title('All Trials') - else: self.lax[0].set_title('Trial '+str(self.index)) - - self.lax[-1].set_xlabel('Time (ms)'); - - self.figure.subplots_adjust(bottom=0.0, left=0.06, right=1.0, top=0.97, wspace=0.1, hspace=0.09) - - self.draw() - -class SpikeGUI (QMainWindow): - def __init__ (self): - global dfile, ddat, paramf - super().__init__() - self.initUI() - - if "TRAVIS_TESTING" in os.environ and os.environ["TRAVIS_TESTING"] == "1": - print("Exiting gracefully with TRAVIS_TESTING=1") - qApp.quit() - exit(0) - - def initMenu (self): - exitAction = QAction(QIcon.fromTheme('exit'), 'Exit', self) - exitAction.setShortcut('Ctrl+Q') - exitAction.setStatusTip('Exit HNN Spike Viewer.') - exitAction.triggered.connect(qApp.quit) - - menubar = self.menuBar() - fileMenu = menubar.addMenu('&File') - menubar.setNativeMenuBar(False) - fileMenu.addAction(exitAction) - - viewMenu = menubar.addMenu('&View') - drawHistAction = QAction('Toggle Histograms',self) - drawHistAction.setStatusTip('Toggle Histogram Drawing.') - drawHistAction.triggered.connect(self.toggleHist) - viewMenu.addAction(drawHistAction) - changeFontSizeAction = QAction('Change Font Size',self) - changeFontSizeAction.setStatusTip('Change Font Size.') - changeFontSizeAction.triggered.connect(self.changeFontSize) - viewMenu.addAction(changeFontSizeAction) - changeLineWidthAction = QAction('Change Line Width',self) - changeLineWidthAction.setStatusTip('Change Line Width.') - changeLineWidthAction.triggered.connect(self.changeLineWidth) - viewMenu.addAction(changeLineWidthAction) - changeMarkerSizeAction = QAction('Change Marker Size',self) - changeMarkerSizeAction.setStatusTip('Change Marker Size.') - changeMarkerSizeAction.triggered.connect(self.changeMarkerSize) - viewMenu.addAction(changeMarkerSizeAction) - - - def toggleHist (self): - global bDrawHist - bDrawHist = not bDrawHist - self.initCanvas() - self.m.plot() - - def changeFontSize (self): - i, okPressed = QInputDialog.getInt(self, "Set Font Size","Font Size:", plt.rcParams['font.size'], 1, 100, 1) - if okPressed: - plt.rcParams['font.size'] = dconf['fontsize'] = i - self.initCanvas() - self.m.plot() - - def changeLineWidth (self): - i, okPressed = QInputDialog.getInt(self, "Set Line Width","Line Width:", plt.rcParams['lines.linewidth'], 1, 20, 1) - if okPressed: - plt.rcParams['lines.linewidth'] = i - self.initCanvas() - self.m.plot() - - def changeMarkerSize (self): - global rastmarksz - i, okPressed = QInputDialog.getInt(self, "Set Marker Size","Font Size:", rastmarksz, 1, 100, 1) - if okPressed: - rastmarksz = i - self.initCanvas() - self.m.plot() - - def initCanvas (self): - try: # to avoid memory leaks remove any pre-existing widgets before adding new ones - self.grid.removeWidget(self.m) - self.grid.removeWidget(self.toolbar) - self.m.setParent(None) - self.toolbar.setParent(None) - self.m = self.toolbar = None - except: - pass - self.m = SpikeCanvas(paramf, self.index, parent = self, width=12, height=10, dpi=getmplDPI()) - # this is the Navigation widget - # it takes the Canvas widget and a parent - self.toolbar = NavigationToolbar(self.m, self) - self.grid.addWidget(self.toolbar, 0, 0, 1, 4); - self.grid.addWidget(self.m, 1, 0, 1, 4); - - def initUI (self): - self.initMenu() - self.statusBar() - self.setGeometry(300, 300, 1300, 1100) - self.setWindowTitle('Spike Viewer - ' + paramf) - self.grid = grid = QGridLayout() - self.index = 0 - self.initCanvas() - self.cb = QComboBox(self) - self.grid.addWidget(self.cb,2,0,1,4) - - if ntrial > 1: - self.cb.addItem('Show All Trials') - for i in range(ntrial): - self.cb.addItem('Show Trial ' + str(i+1)) - else: - self.cb.addItem('All Trials') - self.cb.activated[int].connect(self.onActivated) - - # need a separate widget to put grid on - widget = QWidget(self) - widget.setLayout(grid) - self.setCentralWidget(widget); - - try: self.setWindowIcon(QIcon(os.path.join('res','icon.png'))) - except: pass - - self.show() - - def onActivated(self, idx): - if idx != self.index: - self.index = idx - if self.index == 0: - self.statusBar().showMessage('Loading data from all trials.') - else: - self.statusBar().showMessage('Loading data from trial ' + str(self.index) + '.') - self.m.index = self.index - self.initCanvas() - self.m.plot() - self.statusBar().showMessage('') - -if __name__ == '__main__': - - app = QApplication(sys.argv) - ex = SpikeGUI() - sys.exit(app.exec_()) - - diff --git a/visspec.py b/visspec.py deleted file mode 100644 index 4100e150c..000000000 --- a/visspec.py +++ /dev/null @@ -1,291 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import pylab as plt -import matplotlib.gridspec as gridspec -from DataViewGUI import DataViewGUI -from specfn import MorletSpec -from conf import dconf -import simdat -from simdat import readdpltrials -import paramrw -from paramrw import quickgetprm - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] - -ntrial = 1; paramf = '' -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.txt'): - specpath = sys.argv[i] - elif sys.argv[i].endswith('.param'): - paramf = sys.argv[i] - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - -basedir = os.path.join(dconf['datdir'],paramf.split(os.path.sep)[-1].split('.param')[0]) -#print('basedir:',basedir,'paramf:',paramf,'ntrial:',ntrial) - -# assumes column 0 is time, rest of columns are time-series -def extractspec (dat, fmax=40.0): - global ntrial - #print('extractspec',dat.shape) - lspec = [] - tvec = dat[:,0] - dt = tvec[1] - tvec[0] - tstop = tvec[-1] - - prm = {'f_max_spec':fmax,'dt':dt,'tstop':tstop} - - if dat.shape[1] > 2: - for col in range(1,dat.shape[1],1): - ms = MorletSpec(tvec,dat[:,col],None,None,prm) - lspec.append(ms) - else: - ms = MorletSpec(tvec,dat[:,1],None,None,prm) - lspec.append(ms) - - ntrial = len(lspec) - - if ntrial > 1: - avgdipole = np.mean(dat[:,1:-1],axis=1) - else: - avgdipole = dat[:,1] - - avgspec = MorletSpec(tvec,avgdipole,None,None,prm) # !!should fix to average of individual spectrograms!! - - ltfr = [ms.TFR for ms in lspec] - npspec = np.array(ltfr) - avgspec.TFR = np.mean(npspec,axis=0)#,axis=0) - - return ms.f, lspec, avgdipole, avgspec - -def loaddat (fname): - try: - if fname.endswith('.txt'): - dat = np.loadtxt(fname) - print('Loaded data in ' + fname + '. Extracting Spectrograms.') - return dat - elif fname.endswith('.param'): - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - basedir = os.path.join(dconf['datdir'],paramf.split(os.path.sep)[-1].split('.param')[0]) - #simdat.updatedat(paramf) - #return paramf,simdat.ddat - if ntrial > 1: - ddat = readdpltrials(basedir,quickgetprm(paramf,'N_trials',int)) - #print('read dpl trials',ddat[0].shape) - dout = np.zeros((ddat[0].shape[0],1+ntrial)) - #print('set dout shape',dout.shape) - dout[:,0] = ddat[0][:,0] - for i in range(ntrial): - dout[:,i+1] = ddat[i][:,1] - return dout - else: - ddat = np.loadtxt(os.path.join(basedir,'dpl.txt')) - #print('ddat.shape:',ddat.shape) - dout = np.zeros((ddat.shape[0],2)) - #print('dout.shape:',dout.shape) - dout[:,0] = ddat[:,0] - dout[:,1] = ddat[:,1] - return dout - except: - print('Could not load data in ' + fname) - return None - return None - -class SpecCanvas (FigureCanvas): - def __init__ (self, paramf, index, parent=None, width=12, height=10, dpi=120, title='Spectrogram Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - self.title = title - self.setParent(parent) - self.gui = parent - self.index = index - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.invertedhistax = False - self.G = gridspec.GridSpec(10,1) - self.dat = [] - self.lextspec = [] - self.lax = [] - self.avgdipole = [] - self.avgspec = [] - - # get spec_cmap - p_exp = paramrw.ExpParams(self.paramf, 0) - if len(p_exp.expmt_groups) > 0: - expmt_group = p_exp.expmt_groups[0] - else: - expmt_group = None - p = p_exp.return_pdict(expmt_group, 0) - self.spec_cmap = p['spec_cmap'] - - self.plot() - - def clearaxes (self): - try: - for ax in self.lax: - ax.set_yticks([]) - ax.cla() - except: - pass - - def clearlextdatobj (self): - if hasattr(self,'lextdatobj'): - for o in self.lextdatobj: - try: - o.set_visible(False) - except: - o[0].set_visible(False) - del self.lextdatobj - - def drawspec (self, dat, lspec, sdx, avgdipole, avgspec, fig, G, ltextra=''): - if len(lspec) == 0: return - - plt.ion() - - gdx = 211 - - ax = fig.add_subplot(gdx) - lax = [ax] - tvec = dat[:,0] - dt = tvec[1] - tvec[0] - tstop = tvec[-1] - - if sdx == 0: - for i in range(1,dat.shape[1],1): - #print('sdx is 0',dat.shape,i) - ax.plot(tvec, dat[:,i],linewidth=self.gui.linewidth,color='gray') - ax.plot(tvec,avgdipole,linewidth=self.gui.linewidth+1,color='black') - else: - ax.plot(dat[:,0], dat[:,sdx],linewidth=self.gui.linewidth+1,color='gray') - - ax.set_xlim(tvec[0],tvec[-1]) - ax.set_ylabel('Dipole (nAm)') - - gdx = 212 - - ax = fig.add_subplot(gdx) - - #print('sdx:',sdx,avgspec.TFR.shape) - - if sdx==0: ms = avgspec - else: ms = lspec[sdx-1] - #print('ms.TFR.shape:',ms.TFR.shape) - - ax.imshow(ms.TFR, extent=[tvec[0], tvec[-1], ms.f[-1], ms.f[0]], aspect='auto', origin='upper',cmap=plt.get_cmap(self.spec_cmap)) - - ax.set_xlim(tvec[0],tvec[-1]) - ax.set_xlabel('Time (ms)') - ax.set_ylabel('Frequency (Hz)'); - - lax.append(ax) - - return lax - - def plot (self): - ltextra = 'Trial '+str(self.index) - if self.index == 0: ltextra = 'All Trials' - self.lax = self.drawspec(self.dat, self.lextspec,self.index, self.avgdipole, self.avgspec, self.figure, self.G, ltextra=ltextra) - self.figure.subplots_adjust(bottom=0.06, left=0.06, right=0.98, top=0.97, wspace=0.1, hspace=0.09) - self.draw() - -class SpecViewGUI (DataViewGUI): - def __init__ (self,CanvasType,paramf,ntrial,title): - self.lF = [] # frequencies associated with external data spec - self.lextspec = [] # external data spec - self.lextfiles = [] # external data files - self.dat = None - self.avgdipole = [] - self.avgspec = [] - super(SpecViewGUI,self).__init__(CanvasType,paramf,ntrial,title) - self.addLoadDataActions() - #print('paramf:',paramf) - if len(paramf): - self.loadDisplayData(paramf) - - if "TRAVIS_TESTING" in os.environ and os.environ["TRAVIS_TESTING"] == "1": - print("Exiting gracefully with TRAVIS_TESTING=1") - qApp.quit() - exit(0) - - def initCanvas (self): - super(SpecViewGUI,self).initCanvas() - self.m.lextspec = self.lextspec - self.m.dat = self.dat - self.m.avgdipole = self.avgdipole - self.m.avgspec = self.avgspec - - def addLoadDataActions (self): - loadDataFile = QAction(QIcon.fromTheme('open'), 'Load data.', self) - loadDataFile.setShortcut('Ctrl+D') - loadDataFile.setStatusTip('Load experimental (.txt) / simulation (.param) data.') - loadDataFile.triggered.connect(self.loadDisplayData) - - clearDataFileAct = QAction(QIcon.fromTheme('close'), 'Clear data.', self) - clearDataFileAct.setShortcut('Ctrl+C') - clearDataFileAct.setStatusTip('Clear data.') - clearDataFileAct.triggered.connect(self.clearDataFile) - - self.fileMenu.addAction(loadDataFile) - self.fileMenu.addAction(clearDataFileAct) - - def loadDisplayData (self, fname=None): - if fname is None or fname is False: - fname = QFileDialog.getOpenFileName(self, 'Open .param or .txt file', 'data') - fname = os.path.abspath(fname[0]) - if not fname: return - dat = loaddat(fname) - self.dat = dat - try: - try: - fmax = quickgetprm(paramf,'f_max_spec',float) - except: - fmax = 40. - f, lspec, avgdipole, avgspec = extractspec(dat,fmax=fmax) - self.ntrial = len(lspec) - self.updateCB() - self.printStat('Extracted ' + str(len(lspec)) + ' spectrograms from ' + fname) - self.lextspec = lspec - self.lextfiles.append(fname) - self.avgdipole = avgdipole - self.avgspec = avgspec - self.lF.append(f) - except: - self.printStat('Could not extract Spectrograms from ' + fname) - - try: - if len(self.lextspec) > 0: - self.printStat('Plotting Spectrograms.') - self.m.lextspec = self.lextspec - self.m.dat = self.dat - self.m.avgspec = self.avgspec - self.m.avgdipole = self.avgdipole - self.m.plot() - self.m.draw() # make sure new lines show up in plot - self.printStat('') - except: - self.printStat('Could not plot data from ' + fname) - - def clearDataFile (self): - self.m.clearlextdatobj() - self.lextspec = [] - self.lextfiles = [] - self.lF = [] - self.m.draw() - - -if __name__ == '__main__': - app = QApplication(sys.argv) - ex = SpecViewGUI(SpecCanvas,paramf,ntrial,'Spectrogram Viewer') - sys.exit(app.exec_()) - diff --git a/visvolt.py b/visvolt.py deleted file mode 100644 index a8058170d..000000000 --- a/visvolt.py +++ /dev/null @@ -1,233 +0,0 @@ -import sys, os -from PyQt5.QtWidgets import QMainWindow, QAction, qApp, QApplication, QToolTip, QPushButton, QFormLayout -from PyQt5.QtWidgets import QMenu, QSizePolicy, QMessageBox, QWidget, QFileDialog, QComboBox, QTabWidget -from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QDialog, QGridLayout, QLineEdit, QLabel -from PyQt5.QtWidgets import QCheckBox, QInputDialog -from PyQt5.QtGui import QIcon, QFont, QPixmap -from PyQt5.QtCore import QCoreApplication, QThread, pyqtSignal, QObject, pyqtSlot -from PyQt5 import QtCore -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import pylab as plt -import matplotlib.gridspec as gridspec -from neuron import h -from run import net -import paramrw -import pickle -from conf import dconf -from gutils import getmplDPI - -if dconf['fontsize'] > 0: plt.rcParams['font.size'] = dconf['fontsize'] -else: dconf['fontsize'] = 10 - -# colors for the different cell types -dclr = {'L2_pyramidal' : 'g', - 'L5_pyramidal' : 'r', - 'L2_basket' : 'w', - 'L5_basket' : 'b'} - -ntrial = 1; tstop = -1; outparamf = voltpath = paramf = ''; - -maxperty = 10 # how many cells of a type to draw - -for i in range(len(sys.argv)): - if sys.argv[i].endswith('.param'): - paramf = sys.argv[i] - tstop = paramrw.quickgetprm(paramf,'tstop',float) - ntrial = paramrw.quickgetprm(paramf,'N_trials',int) - outparamf = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'param.txt') - elif sys.argv[i] == 'maxperty': - maxperty = int(sys.argv[i]) - -if ntrial <= 1: - voltpath = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'vsoma.pkl') -else: - voltpath = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'vsoma_1.pkl') - -class VoltCanvas (FigureCanvas): - def __init__ (self, paramf, index, parent=None, width=12, height=10, dpi=120, title='Voltage Viewer'): - FigureCanvas.__init__(self, Figure(figsize=(width, height), dpi=dpi)) - self.title = title - self.setParent(parent) - self.gui = parent - self.index = index - self.invertedax = False - FigureCanvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) - FigureCanvas.updateGeometry(self) - self.paramf = paramf - self.G = gridspec.GridSpec(10,1) - self.plot() - - def drawvolt (self, dvolt, fig, G, sz=8, ltextra=''): - row = 0 - ax = fig.add_subplot(G[row:-1,:]) - lax = [ax] - dcnt = {} # counts number of times cell of a type drawn - vtime = dvolt['vtime'] - yoff = 0 - # print(dvolt.keys()) - for gid,it in dvolt.items(): - ty,vsoma = it[0],it[1] - # print('ty:',ty,'gid:',gid) - if type(gid) != int: continue - if ty not in dcnt: dcnt[ty] = 1 - if dcnt[ty] > maxperty: continue - #ax.plot(vtime, -vsoma + yoff, dclr[ty], linewidth = self.gui.linewidth) - ax.plot(vtime, -vsoma + yoff, dclr[ty], linewidth = self.gui.linewidth) - yoff += max(vsoma) - min(vsoma) - dcnt[ty] += 1 - - white_patch = mpatches.Patch(color='white', label='L2/3 Basket') - green_patch = mpatches.Patch(color='green', label='L2/3 Pyr') - red_patch = mpatches.Patch(color='red', label='L5 Pyr') - blue_patch = mpatches.Patch(color='blue', label='L5 Basket') - ax.legend(handles=[white_patch,green_patch,blue_patch,red_patch]) - - if not self.invertedax: - ax.set_ylim(ax.get_ylim()[::-1]) - self.invertedax = True - #if not self.invertedax: - # ax.invert_yaxis() - # self.invertedax = True - - ax.set_yticks([]) - - ax.set_facecolor('k') - ax.grid(True) - if tstop != -1: ax.set_xlim((0,tstop)) - if i ==0: ax.set_title(ltextra) - ax.set_xlabel('Time (ms)'); - - self.figure.subplots_adjust(bottom=0.01, left=0.01, right=0.99, top=0.99, wspace=0.1, hspace=0.09) - - return lax - - def plot (self): - if self.index == 0: - if ntrial == 1: - dvolt = pickle.load(open(voltpath,'rb')) - else: - dvolt = pickle.load(open(voltpath,'rb')) - self.lax = self.drawvolt(dvolt,self.figure, self.G, 5, ltextra='All Trials') - else: - voltpathtrial = os.path.join(dconf['datdir'],paramf.split('.param')[0].split(os.path.sep)[-1],'vsoma_'+str(self.index)+'.pkl') - dvolttrial = pickle.load(open(voltpathtrial,'rb')) - self.lax=self.drawvolt(dvolttrial,self.figure, self.G, 5, ltextra='Trial '+str(self.index)); - self.draw() - -class VoltGUI (QMainWindow): - def __init__ (self): - global dfile, ddat, paramf - super().__init__() - self.fontsize = dconf['fontsize'] - self.linewidth = plt.rcParams['lines.linewidth'] = 1 - self.markersize = plt.rcParams['lines.markersize'] = 5 - self.initUI() - - def changeFontSize (self): - i, okPressed = QInputDialog.getInt(self, "Set Font Size","Font Size:", plt.rcParams['font.size'], 1, 100, 1) - if okPressed: - self.fontsize = plt.rcParams['font.size'] = dconf['fontsize'] = i - self.initCanvas() - self.m.plot() - - def changeLineWidth (self): - i, okPressed = QInputDialog.getInt(self, "Set Line Width","Line Width:", plt.rcParams['lines.linewidth'], 1, 20, 1) - if okPressed: - self.linewidth = plt.rcParams['lines.linewidth'] = i - self.initCanvas() - self.m.plot() - - def changeMarkerSize (self): - i, okPressed = QInputDialog.getInt(self, "Set Marker Size","Font Size:", self.markersize, 1, 100, 1) - if okPressed: - self.markersize = plt.rcParams['lines.markersize'] = i - self.initCanvas() - self.m.plot() - - def initMenu (self): - exitAction = QAction(QIcon.fromTheme('exit'), 'Exit', self) - exitAction.setShortcut('Ctrl+Q') - exitAction.setStatusTip('Exit HNN Volt Viewer.') - exitAction.triggered.connect(qApp.quit) - - menubar = self.menuBar() - fileMenu = menubar.addMenu('&File') - menubar.setNativeMenuBar(False) - fileMenu.addAction(exitAction) - - viewMenu = menubar.addMenu('&View') - changeFontSizeAction = QAction('Change Font Size',self) - changeFontSizeAction.setStatusTip('Change Font Size.') - changeFontSizeAction.triggered.connect(self.changeFontSize) - viewMenu.addAction(changeFontSizeAction) - changeLineWidthAction = QAction('Change Line Width',self) - changeLineWidthAction.setStatusTip('Change Line Width.') - changeLineWidthAction.triggered.connect(self.changeLineWidth) - viewMenu.addAction(changeLineWidthAction) - changeMarkerSizeAction = QAction('Change Marker Size',self) - changeMarkerSizeAction.setStatusTip('Change Marker Size.') - changeMarkerSizeAction.triggered.connect(self.changeMarkerSize) - viewMenu.addAction(changeMarkerSizeAction) - - - def initCanvas (self): - self.invertedax = False - try: # to avoid memory leaks remove any pre-existing widgets before adding new ones - self.grid.removeWidget(self.m) - self.grid.removeWidget(self.toolbar) - self.m.setParent(None) - self.toolbar.setParent(None) - self.m = self.toolbar = None - except: - pass - self.m = VoltCanvas(paramf, self.index, parent = self, width=12, height=10, dpi=getmplDPI()) - # this is the Navigation widget - # it takes the Canvas widget and a parent - self.toolbar = NavigationToolbar(self.m, self) - self.grid.addWidget(self.toolbar, 0, 0, 1, 4); - self.grid.addWidget(self.m, 1, 0, 1, 4); - - def initUI (self): - self.initMenu() - self.statusBar() - self.setGeometry(300, 300, 1300, 1100) - self.setWindowTitle('Volt Viewer - ' + paramf) - self.grid = grid = QGridLayout() - self.index = 0 - self.initCanvas() - self.cb = QComboBox(self) - self.grid.addWidget(self.cb,2,0,1,4) - - for i in range(ntrial): self.cb.addItem('Trial ' + str(i+1)) - self.cb.activated[int].connect(self.onActivated) - - # need a separate widget to put grid on - widget = QWidget(self) - widget.setLayout(grid) - self.setCentralWidget(widget); - - try: self.setWindowIcon(QIcon(os.path.join('res','icon.png'))) - except: pass - - self.show() - - def onActivated(self, idx): - if idx != self.index: - self.index = idx - self.statusBar().showMessage('Loading data from trial ' + str(self.index+1) + '.') - self.m.index = self.index - self.initCanvas() - self.m.plot() - self.statusBar().showMessage('') - -if __name__ == '__main__': - - app = QApplication(sys.argv) - ex = VoltGUI() - sys.exit(app.exec_()) -