From 66fdd0f3f6fcb77dc78bee63dc8901fecb23787b Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 22 Apr 2019 16:54:12 +0200 Subject: [PATCH] Improvements to offline tracking workflow --- NEWS.md => CHANGELOG.md | 7 ++ docs/source/userguide/7_offline.rst | 4 +- setup.py | 2 +- stytra/collectors/accumulators.py | 19 +---- stytra/hardware/video/__init__.py | 2 +- stytra/offline/track_video.py | 119 ++++++++++++++++++++-------- stytra/utilities.py | 31 ++++++++ 7 files changed, 131 insertions(+), 53 deletions(-) rename NEWS.md => CHANGELOG.md (95%) diff --git a/NEWS.md b/CHANGELOG.md similarity index 95% rename from NEWS.md rename to CHANGELOG.md index bc33500c..0435b1eb 100644 --- a/NEWS.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 0.8.7 +## Improvements +- improved workflow for offline tracking + +## Fixes +- offline tracking works on OS X + # 0.8.6 ## New features diff --git a/docs/source/userguide/7_offline.rst b/docs/source/userguide/7_offline.rst index 1475dbfe..7c7ee49e 100644 --- a/docs/source/userguide/7_offline.rst +++ b/docs/source/userguide/7_offline.rst @@ -6,11 +6,11 @@ With Stytra installed, start the offline tracking script by running: python -m stytra.offline.track_video -Choose a video file and what you want to track +Choose a video file and what you want to track. Run Stytra and adjust the tracking parameters. Please see the corresponding :ref:`documentation section` for hints. -Click the `Track video` button in the upper right corner. The progress bar on top will show the progress of the tracking. When it is done, you will get a csv file with the tracking results in the same folder and with the same name as the input video, just with an extension corresponding to the chosen output format. +Click the `Track video` button in the toolbar. The progress bar will show the progress of the tracking. When it is done, the program will exit and you will get a tracking output file. It will have the same name as the input video, just with an extension corresponding to the chosen output format. If you want to batch process multiple videos with the same parameters, running the Stytra pipeline through a script or notebook might be convenient. For this, please refer to the `notebook repository `_. diff --git a/setup.py b/setup.py index d8b96b9a..ab487b6a 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="stytra", - version="0.8.6", + version="0.8.7", author="Vilim Stih, Luigi Petrucco @portugueslab", author_email="vilim@neuro.mpg.de", license="GPLv3+", diff --git a/stytra/collectors/accumulators.py b/stytra/collectors/accumulators.py index 44340bca..2c760c6f 100644 --- a/stytra/collectors/accumulators.py +++ b/stytra/collectors/accumulators.py @@ -3,11 +3,12 @@ import numpy as np from queue import Empty import pandas as pd -import json from collections import namedtuple from bisect import bisect_right from os.path import basename +from stytra.utilities import save_df + class Accumulator(QObject): def __init__(self, experiment, name="", max_history_if_not_running = 1000): @@ -217,23 +218,11 @@ def save(self, path, format="csv"): output format, csv, feather, hdf5, json """ - outpath = path + "." + format df = self.get_dataframe() if df is None: return - - if format == "csv": - # replace True and False in csv files: - df.replace({True: 1, False: 0}).to_csv(outpath, sep=";") - elif format == "feather": - df.to_feather(outpath) - elif format == "hdf5": - df.to_hdf(outpath, "/data", complib="blosc", complevel=5) - elif format == "json": - json.dump(df.to_dict(), open(outpath, "w")) - else: - raise (NotImplementedError(format + " is not an implemented log foramt")) - return basename(outpath) + saved_filename = save_df(df, path, format) + return basename(saved_filename) class QueueDataAccumulator(DataFrameAccumulator): diff --git a/stytra/hardware/video/__init__.py b/stytra/hardware/video/__init__.py index cdee397f..411c03e4 100644 --- a/stytra/hardware/video/__init__.py +++ b/stytra/hardware/video/__init__.py @@ -62,7 +62,7 @@ class VideoSource(FrameProcess): """ - def __init__(self, rotation=False, max_mbytes_queue=100, n_consumers=1): + def __init__(self, rotation=False, max_mbytes_queue=200, n_consumers=1): """ """ super().__init__(name="camera") self.rotation = rotation diff --git a/stytra/offline/track_video.py b/stytra/offline/track_video.py index 6a5c241d..feee00d0 100644 --- a/stytra/offline/track_video.py +++ b/stytra/offline/track_video.py @@ -1,22 +1,95 @@ from pathlib import Path from stytra import Stytra from PyQt5.QtWidgets import QFileDialog, QApplication, QDialog, QPushButton,\ - QComboBox, QGridLayout, QLabel + QComboBox, QGridLayout, QLabel, QToolBar, QProgressBar, QVBoxLayout import qdarkstyle from stytra.stimulation import Protocol from stytra.stimulation.stimuli import Stimulus from stytra.experiments.fish_pipelines import pipeline_dict +from stytra.utilities import save_df import imageio import pandas as pd +import json class EmptyProtocol(Protocol): - name = "parameters" + name = "Offline" def get_stim_sequence(self): return [Stimulus(duration=5.),] +class TrackingDialog(QDialog): + def __init__(self): + super().__init__() + self.setLayout(QVBoxLayout()) + self.setWindowTitle("Tracking") + self.prog_track = QProgressBar() + self.lbl_status = QLabel() + self.layout().addWidget(self.prog_track) + self.layout().addWidget(self.lbl_status) + + +class OfflineToolbar(QToolBar): + def __init__(self, app, exp, input_path, pipeline_type): + super().__init__() + self.app = app + self.setObjectName("toolbar_offline") + self.exp = exp + self.input_path = Path(input_path) + self.pipeline_type = pipeline_type + self.output_path = self.input_path.parent / self.input_path.stem + + self.cmb_fmt = QComboBox() + self.cmb_fmt.addItems([ + "csv", "feather", "hdf5", "json"]) + + self.addAction("Track video", self.track) + self.addAction("Output format") + self.addWidget(self.cmb_fmt) + self.addSeparator() + self.addAction("Save tracking params", self.save_params) + + self.diag_track = TrackingDialog() + + def track(self): + + fileformat = self.cmb_fmt.currentText() + + self.exp.camera.kill_event.set() + reader = imageio.get_reader(str(self.input_path)) + data = [] + self.exp.window_main.stream_plot.toggle_freeze() + + output_name = str(self.output_path)+"."+fileformat + self.diag_track.show() + self.diag_track.prog_track.setMaximum(reader.get_length()) + self.diag_track.lbl_status.setText("Tracking to "+ + output_name) + + for i, frame in enumerate(reader): + data.append(self.exp.pipeline.run(frame[:, :, 0]).data) + self.diag_track.prog_track.setValue(i) + if i % 100 == 0: + self.app.processEvents() + + self.diag_track.lbl_status.setText("Saving " + + output_name) + df = pd.DataFrame.from_records(data, + columns=data[0]._fields) + save_df(df, self.output_path, fileformat) + self.diag_track.lbl_status.setText("Completed " + + output_name) + self.exp.wrap_up() + + def save_params(self): + params = self.exp.pipeline.serialize_params() + json.dump(dict(pipeline_type=self.pipeline_type, + pipeline_params=params), + open(str(self.output_path) + + "_trackingparams.json", "w")) + + class StytraLoader(QDialog): """ A quick-and-dirty monkey-patch of Stytra for easy offline tracking @@ -35,11 +108,6 @@ def __init__(self, app): self.cmb_tracking = QComboBox() self.cmb_tracking.addItems(list(pipeline_dict.keys())) - self.lbl_outformat = QLabel("Tracking output format") - self.cmb_fmt = QComboBox() - self.cmb_fmt.addItems([ - "csv", "feather", "hdf5", "json"]) - self.btn_start = QPushButton("Start stytra") self.btn_start.clicked.connect(self.run_stytra) self.btn_start.setEnabled(False) @@ -51,10 +119,7 @@ def __init__(self, app): self.layout().addWidget(self.lbl_whattrack, 1, 0) self.layout().addWidget(self.cmb_tracking, 1, 1) - self.layout().addWidget(self.lbl_outformat, 2, 0) - self.layout().addWidget(self.cmb_fmt, 2, 1) - - self.layout().addWidget(self.btn_start, 3, 0, 1, 2) + self.layout().addWidget(self.btn_start, 2, 0, 1, 2) self.stytra = None @@ -69,33 +134,19 @@ def run_stytra(self): self.stytra = Stytra(app=self.app, protocol=EmptyProtocol(), camera=dict(video_file=self.filename), tracking=dict(method=self.cmb_tracking.currentText()), - log_format=self.cmb_fmt.currentText(), exec=False) - btn_track = QPushButton("Track video") - self.stytra.exp.window_main.toolbar_control.addWidget(btn_track) - btn_track.clicked.connect(self.track) + + offline_toolbar = OfflineToolbar(self.app, + self.stytra.exp, + self.filename, + pipeline_type=self.cmb_tracking.currentText()) + + self.stytra.exp.window_main.toolbar_control.hide() + self.stytra.exp.window_main.addToolBar(offline_toolbar) + self.stytra.exp.window_display.hide() self.close() - def track(self): - assert isinstance(self.stytra, Stytra) - self.stytra.exp.camera.kill_event.set() - reader = imageio.get_reader(self.filename) - data = [] - self.stytra.exp.window_main.stream_plot.toggle_freeze() - self.stytra.exp.window_main.toolbar_control.progress_bar.setMaximum(reader.get_length()) - self.stytra.exp.window_main.toolbar_control.progress_bar.setFormat("%v / %m") - for i, frame in enumerate(reader): - data.append(self.stytra.exp.pipeline.run(frame[:, :, 0]).data) - self.stytra.exp.window_main.toolbar_control.progress_bar.setValue(i) - if i % 100 == 0: - self.app.processEvents() - df = pd.DataFrame.from_records(data, columns=data[0]._fields) - out_path = Path(self.filename) - df.to_csv(out_path.parent / (out_path.stem + ".csv")) - self.stytra.exp.wrap_up() - self.app.quit() - if __name__ == "__main__": app = QApplication([]) diff --git a/stytra/utilities.py b/stytra/utilities.py index 25ebe41f..cb0480ca 100644 --- a/stytra/utilities.py +++ b/stytra/utilities.py @@ -1,4 +1,5 @@ import datetime +import json import time from collections import OrderedDict from multiprocessing import Process, Queue @@ -248,3 +249,33 @@ def recursive_update(d, u): def reduce_to_pi(angle): """Puts an angle or array of angles inside the (-pi, pi) range""" return np.mod(angle + np.pi, 2 * np.pi) - np.pi + + +def save_df(df, path, fileformat): + """ Saves the dataframe in one of the supported formats + + Parameters + ---------- + df + path + fileformat + + Returns + ------- + + """ + outpath = Path(str(path) + "." + fileformat) + if fileformat == "csv": + # replace True and False in csv files: + df.replace({True: 1, False: 0}).to_csv(outpath, sep=";") + elif fileformat == "feather": + df.to_feather(outpath) + elif fileformat == "hdf5": + df.to_hdf(outpath, "/data", complib="blosc", complevel=5) + elif fileformat == "json": + json.dump(df.to_dict(), open(outpath, "w")) + else: + raise ( + NotImplementedError(fileformat + " is not an implemented log format")) + return outpath.name +