diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26d04f24..c697fbbd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,8 @@ +ci: + autofix_prs: false repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v5.0.0 hooks: - id: check-docstring-first - id: end-of-file-fixer @@ -8,7 +10,7 @@ repos: exclude: ^\.napari-hub/.* - id: check-yaml # checks for correct yaml syntax for github actions ex. - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.7 + rev: v0.11.9 hooks: - id: ruff args: [--fix] diff --git a/README.md b/README.md index b33210ac..66049071 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,10 @@ [![tests](https://github.com/funkelab/motile_tracker/workflows/tests/badge.svg)](https://github.com/funkelab/motile_tracker/actions) [![codecov](https://codecov.io/gh/funkelab/motile_tracker/branch/main/graph/badge.svg)](https://codecov.io/gh/funkelab/motile_tracker) -[![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/motile_tracker)](https://napari-hub.org/plugins/motile_tracker) The full documentation of the plugin can be found [here](https://funkelab.github.io/motile_tracker/). -A plugin for tracking with [motile](https://github.com/funkelab/motile) in napari. +An application for interactive tracking with [motile](https://github.com/funkelab/motile) Motile is a library that makes it easy to solve tracking problems using optimization by framing the task as an Integer Linear Program (ILP). See the motile [documentation](https://funkelab.github.io/motile) @@ -16,7 +15,7 @@ for more details on the concepts and method. ## Installation -This plugin depends on [motile](https://github.com/funkelab/motile), which in +This application depends on [motile](https://github.com/funkelab/motile), which in turn depends on gurobi and ilpy. These dependencies must be installed with conda before installing the plugin with pip. @@ -25,6 +24,21 @@ conda before installing the plugin with pip. conda install -c conda-forge -c funkelab -c gurobi ilpy pip install motile-tracker +## Running Motile Tracker + +To run the application: +* activate the conda environment created in the [Installation Step](#installation) + + conda activate motile-tracker + +* Run: + + python -m motile_tracker + +or + + motile_tracker + ## Issues If you encounter any problems, please diff --git a/conda_config.yml b/conda_config.yml index dcc2121d..0a5cee13 100644 --- a/conda_config.yml +++ b/conda_config.yml @@ -3,6 +3,5 @@ channels: - conda-forge - funkelab - gurobi - - defaults dependencies: - ilpy diff --git a/docs/source/conf.py b/docs/source/conf.py index 123f7f2c..a0767f63 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,5 @@ project = "Motile Tracker" -copyright = "2024, Howard Hughes Medical Institute" +copyright = "2024, Howard Hughes Medical Institute" # noqa: A001 author = "Caroline Malin-Mayor" extensions = [ diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 4b406cb1..a24ea958 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -74,7 +74,16 @@ is incorporating the detection and linking corrections into the optimization tas Each ``Tracking Run`` will be stored in the ``Results List`` widget. These are the runs that are stored in memory - if you run tracking multiple times with different inputs or parameters, you can click back and forth -between the results here. Here you can also save any runs that you want to store for later. +between the results here. Here you can also save any runs that you want to store for later, +or export the tracks to a csv file. If your input was a Labels layer, the +``node_id`` will be determined by segmentation label id. If your original segmentation +repeated labels across time, the application will relabel them all to be unique, and +the new label id will be used as the node id. +If your input was a Points layer, the ``node_id`` is simply the index of the +node in the list of points. +Note: This does not save the output segmentation. If you want to save +the relabeled segmentation, you can do so through napari by selecting the +layer and then selecting ``File``-> ``Save selected layers`` Deleting runs you do not want to keep viewing is a good idea, since these are stored in memory. Runs that were saved in previous sessions do not appear here until you load them from disk with the ``Load Tracks`` button. The tracking results can also be visualized as a lineage tree. diff --git a/docs/source/motile.rst b/docs/source/motile.rst index 6fc66253..221b3770 100644 --- a/docs/source/motile.rst +++ b/docs/source/motile.rst @@ -47,19 +47,6 @@ The ``Run Viewer`` contains the following information: - The ``Graph of solver gap``, which is mostly for debugging purposes. The solver gap is an optimization value that should decrease at each iteration. - The run settings, including ``Hyperparameters``, ``Costs``, and ``Attribute weights``. -- The ``Save run`` button. This button will take you to a file dialog to save the - whole run, so that if you close napari and re-open it, you can load the run - and see the results. -- The ``Export tracks to CSV`` button, which will take you to a file dialog for saving - a csv file containing the tracks. If your input was a Labels layer, the - ``node_id`` will be determined by segmentation label id. If your original segmentation - repeated labels across time, the application will relabel them all to be unique, and - the new label id will be used as the node id. - If your input was a Points layer, the ``node_id`` is simply the index of the - node in the list of points. - - Note: This does not save the output segmentation. If you want to save - the relabeled segmentation, you can do so through napari by selecting the - layer and then selecting ``File``-> ``Save selected layers`` - The ``Back to editing`` button, which will return you to the ``Run Editor`` in its previous state. - The ``Edit this run`` button. This button will take you back to the ``Run Editor``, diff --git a/docs/source/tree_view.rst b/docs/source/tree_view.rst index c0c24710..4d9f93db 100644 --- a/docs/source/tree_view.rst +++ b/docs/source/tree_view.rst @@ -20,11 +20,20 @@ Please visit :doc:`key bindings ` page for a complete list of avai Viewing Externally Generated Tracks *********************************** It is also possible to view tracks that were not created from the motile widget using -the synchronized Tree View and napari layers. This is not accessible from the UI, so -you will need to make a python script to create a Tracks object and load it into the -viewer. +the synchronized Tree View and napari layers. To do so, navigate to the ``Results List`` tab and select ``External tracks from CSV`` in the dropdown menu at the bottom of the widgets, and click ``Load``. +A pop up menu will allow you to select a CSV file and map its columns to the required default attributes and optional additional attributes. You may also provide the accompanying segmentation and specify scaling information. -A `SolutionTracks object`_ contains a networkx graph representing the tracking result, and optionally +The following columns have to be selected: + +- time: representing the position of the object in the time dimension. +- x: x centroid coordinate of the object. +- y: y centroid coordinate of the object. +- z (optional): z centroid coordinate of the object, if it is a 3D object. +- id: unique id of the object. +- parent_id: id of the directly connected predecessor (parent) of the object. Should be empty if the object is at the start of a lineage. +- seg_id: label value in the segmentation image data (if provided) that corresponds to the object id. + +From this, a `SolutionTracks object`_ is generated, containing a networkx graph representing the tracking result, and optionally a segmentation. The networkx graph is directed, with nodes representing detections and edges going from a detection in time t to the same object in t+n (edges go forward in time). Nodes must have an attribute representing time, by default named "time" but a different name @@ -32,9 +41,7 @@ can be stored in the ``Tracks.time_attr`` attribute. Nodes must also have one or representing position. The default way of storing positions on nodes is an attribute called "pos" containing a list of position values, but dimensions can also be stored in separate attributes (e.g. "x" and "y", each with one value). The name or list of names of the position attributes -should be specified in ``Tracks.pos_attr``. If you want to view tracks by area of the nodes, -you will also need to store the area of the corresponding segmentation on the nodes of the graph -in an ``area`` attribute. +should be specified in ``Tracks.pos_attr``. If a segmentation is provided but no ``area`` attribute, it will be computed automatically. The segmentation is expected to be a numpy array with time as the first dimension, followed by the position dimensions in the same order as the ``Tracks.pos_attr``. The segmentation @@ -43,7 +50,7 @@ motile_toolbox called ensure_unique_labels that relabels a segmentation to be un across time if needed. If a segmentation is provided, the node ids in the graph should match label id of the corresponding segmentation. -An example script that loads a tracks object from a CSV and segmentation array is provided in `scripts/view_external_tracks.csv`. Once you have a Tracks object in the format described above, +An example script that loads a tracks object from a CSV and segmentation array is provided in `scripts/view_external_tracks.py`. Once you have a Tracks object in the format described above, the following lines will view it in the Tree View and create synchronized napari layers (Points, Labels, and Tracks) to visualize the provided tracks.:: diff --git a/pyproject.toml b/pyproject.toml index 640e3c97..4ba0edfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ authors = [ ] classifiers = [ "Development Status :: 5 - Production/Stable", - "Framework :: napari", "Intended Audience :: Developers", "Intended Audience :: End Users/Desktop", "Intended Audience :: Science/Research", @@ -29,7 +28,9 @@ classifiers = [ ] dependencies =[ - "napari[all]", + "finn-viewer>=0.2", + "funtracks", + "appdirs", "numpy", "magicgui", "qtpy", @@ -38,23 +39,24 @@ dependencies =[ "motile_toolbox == 0.4.0", "pydantic", "tifffile[all]", + "tqdm", + "dask[array]>=2021.10.0", "fonticon-fontawesome6", "pyqtgraph", - "lxml_html_clean", # only to deal with napari dependencies being broken ] [project.optional-dependencies] -testing =["napari", "pyqt5", "pytest", "pytest-cov", "pytest-qt"] +testing =["pyqt5", "pytest", "pytest-cov", "pytest-qt"] docs = ["myst-parser", "sphinx", "sphinx-autoapi", "sphinx_rtd_theme", "sphinxcontrib-video"] dev = ["ruff", "pre-commit"] all = ["motile-tracker[testing,docs,dev]"] -[project.entry-points."napari.manifest"] -motile-tracker = "motile_tracker:napari.yaml" - [project.urls] "Bug Tracker" = "https://github.com/funkelab/motile_tracker/issues" "Documentation" ="https://funkelab.github.io/motile_tracker/" +[project.scripts] +motile_tracker = "motile_tracker.__main__:main" + [tool.setuptools_scm] [tool.ruff] @@ -87,7 +89,6 @@ unfixable = [ "B905", # currently adds strict=False to zips. Should add strict=True (manually) ] - [tool.ruff.lint.per-file-ignores] "scripts/*.py" = ["F"] diff --git a/scripts/run_hela.py b/scripts/run_hela.py index 7d9a3b7f..038b1716 100644 --- a/scripts/run_hela.py +++ b/scripts/run_hela.py @@ -4,9 +4,10 @@ import napari import zarr from appdirs import AppDirs +from napari.utils.theme import _themes + from motile_tracker.application_menus import MainApp from motile_tracker.data_views import TreeWidget -from napari.utils.theme import _themes logging.basicConfig( level=logging.INFO, diff --git a/scripts/test_edge_selection.py b/scripts/test_edge_selection.py deleted file mode 100644 index ea64113e..00000000 --- a/scripts/test_edge_selection.py +++ /dev/null @@ -1,152 +0,0 @@ -from rtree import index -from scipy.spatial import KDTree -import numpy as np - -def build_tree(cylinders): - """ - Build a Tree for 3D cylinders. - - :param cylinders: List of cylinders, where each cylinder is represented as ((x1, y1, z1), (x2, y2, z2), radius). - :return: KDTree object. - """ - p = index.Property() - p.dimension = 3 - idx = index.Index(properties=p) - bboxes = [] - - for i, cylinder in enumerate(cylinders): - p1, p2, radius = cylinder - # Calculate the bounding box of the cylinder - axis_vector = np.array(p2) - np.array(p1) - axis_length = np.linalg.norm(axis_vector) - - if axis_length == 0: - continue # Skip degenerate cylinders - - axis_vector /= axis_length # Normalize the axis vector - orthogonal_vector = np.array([1, 0, 0]) if axis_vector[0] < 0.9 else np.array([0, 1, 0]) - ortho1 = np.cross(axis_vector, orthogonal_vector) - ortho2 = np.cross(axis_vector, ortho1) - - corner_points = [p1, p2] - for ortho in [ortho1, ortho2]: - for sign in [-1, 1]: - direction = ortho * radius * sign - corner_points.extend([p1 + direction, p2 + direction]) - - min_bounds = np.min(corner_points, axis=0) - max_bounds = np.max(corner_points, axis=0) - bbox = tuple(min_bounds) + tuple(max_bounds) - bboxes += [bbox] - - # Insert the bounding box into the R-tree - idx.insert(i, bbox) - - # Build and return the KD-Tree - return idx, bboxes - -# Example usage: -# cylinders = [((0,0,0), (1,1,1), 0.5), ((2,2,2), (3,3,3), 0.3)] -# tree = build_kdtree(cylinders) - - -# TODO dont hardcode graph -graph = viewer.layers[-1].data - -cylinder_radius = 5 -cylinders = [] -for edge in graph.get_edges(): - (start_node, end_node) = edge[0] - # cylinders += [(graph.get_coordinates()[start_node], graph.get_coordinates()[end_node], cylinder_radius)] - coords = graph.get_coordinates(edge[0]) - - cylinders += [(coords[0,:], coords[1,:], cylinder_radius)] - -tree, bboxes = build_tree(cylinders) - - - - -def ray_intersects_cylinder(ray_origin, ray_direction, cylinder, tolerance=1e-6): - """ - Check if a ray intersects with a cylinder. - - :param ray_origin: Origin of the ray (x, y, z). - :param ray_direction: Direction of the ray (dx, dy, dz). - :param cylinder: Cylinder defined as ((x1, y1, z1), (x2, y2, z2), radius). - :param tolerance: Numerical tolerance for the intersection test. - :return: Boolean indicating if there is an intersection. - """ - p1, p2, radius = cylinder - d = np.array(p2) - np.array(p1) - m = np.array(ray_origin) - np.array(p1) - n = np.array(ray_direction) - - md = np.dot(m, d) - nd = np.dot(n, d) - dd = np.dot(d, d) - - # Coefficients for the quadratic equation - a = dd * np.dot(n, n) - nd * nd - b = dd * np.dot(n, m) - nd * md - c = dd * np.dot(m, m) - md * md - radius * radius * dd - - # If a is approximately zero, the ray is parallel to the cylinder axis - if abs(a) < tolerance: - return False - - # Quadratic formula discriminant - discr = b * b - a * c - - # If discriminant is negative, no real roots - no intersection - if discr < 0: - return False - - # Ray intersects cylinder - return True - -def distance_to_bbox(ray_origin, bbox): - """ - Calculate the distance from a point to the closest point on a bounding box. - - :param ray_origin: Origin of the ray (x, y, z). - :param bbox: Bounding box represented as a tuple (xmin, ymin, zmin, xmax, ymax, zmax). - :return: Distance from the ray origin to the closest point on the bounding box. - """ - xmin, ymin, zmin, xmax, ymax, zmax = bbox - closest_point = np.maximum(np.minimum(ray_origin, [xmax, ymax, zmax]), [xmin, ymin, zmin]) - return np.linalg.norm(ray_origin - closest_point) - - -def query_ray_intersection(ray_origin, ray_direction, cylinders, rtree_index, bboxes): - """ - Query the R-tree for a ray intersection with cylinders. - - :param ray_origin: Origin of the ray (x, y, z). - :param ray_direction: Direction of the ray (dx, dy, dz). - :param cylinders: List of cylinders. - :param rtree_index: R-tree index object. - :return: Index of the intersecting cylinder or None. - """ - # Define a large bounding box along the ray for querying the R-tree - ray_point_far = np.array(ray_origin) + np.array(ray_direction) * 10000 # Arbitrary large number - bbox = tuple(np.minimum(ray_origin, ray_point_far)) + tuple(np.maximum(ray_origin, ray_point_far)) - - # Query the R-tree for intersecting bounding boxes - candidates = list(rtree_index.intersection(bbox)) - candidates.sort(key=lambda idx: distance_to_bbox(ray_origin, bboxes[idx])) - - for idx in candidates: - if ray_intersects_cylinder(ray_origin, ray_direction, cylinders[idx]): - return idx # Return the first intersecting cylinder's index - - return None # No intersection found - -# Example usage: -ray_origin = (2, 367, 514) -ray_direction = (0, 1, 0) -result = query_ray_intersection(ray_origin, ray_direction, cylinders, tree, bboxes) - -(result, cylinders[result]) - - diff --git a/scripts/view_external_tracks.py b/scripts/view_external_tracks.py index c77ff893..15cb83cb 100644 --- a/scripts/view_external_tracks.py +++ b/scripts/view_external_tracks.py @@ -1,8 +1,10 @@ import napari +import pandas as pd + from motile_tracker.application_menus import MainApp from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer from motile_tracker.example_data import Fluo_N2DL_HeLa -from motile_tracker.utils.load_tracks import tracks_from_csv +from motile_tracker.import_export.load_tracks import tracks_from_df if __name__ == "__main__": # load the example data @@ -10,8 +12,27 @@ segmentation_arr = labels_layer_info[0] # the segmentation ids in this file correspond to the segmentation ids in the # example segmentation data, loaded above - csvfile = "hela_example_tracks.csv" - tracks = tracks_from_csv(csvfile, segmentation_arr) + csvfile = "scripts/hela_example_tracks.csv" + selected_columns = { + "time": "t", + "y": "y", + "x": "x", + "id": "id", + "parent_id": "parent_id", + "seg_id": "id", + } + + df = pd.read_csv(csvfile) + + # Create new columns for each feature based on the original column values + for feature, column in selected_columns.items(): + df[feature] = df[column] + + tracks = tracks_from_df( + df=df, + segmentation=segmentation_arr, + scale=[1, 1, 1], + ) viewer = napari.Viewer() raw_data, raw_kwargs, _ = raw_layer_info diff --git a/src/motile_tracker/__main__.py b/src/motile_tracker/__main__.py new file mode 100644 index 00000000..958932c6 --- /dev/null +++ b/src/motile_tracker/__main__.py @@ -0,0 +1,22 @@ +import sys + +import finn +from finn.track_application_menus.main_app import MainApp + +from motile_tracker.motile.menus.motile_widget import MotileWidget + + +def main(): + # Auto-load the motile tracker + viewer = finn.Viewer() + main_app = MainApp(viewer) + motile_widget = MotileWidget(viewer) + main_app.menu_widget.tabwidget.addTab(motile_widget, "Track with Motile") + viewer.window.add_dock_widget(main_app) + + # Start finn event loop + finn.run() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/motile_tracker/application_menus/__init__.py b/src/motile_tracker/application_menus/__init__.py deleted file mode 100644 index d47f2e4d..00000000 --- a/src/motile_tracker/application_menus/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .main_app import MainApp # noqa -from .editing_menu import EditingMenu # noqa -from .menu_widget import MenuWidget # noqa diff --git a/src/motile_tracker/application_menus/editing_menu.py b/src/motile_tracker/application_menus/editing_menu.py deleted file mode 100644 index a20ad873..00000000 --- a/src/motile_tracker/application_menus/editing_menu.py +++ /dev/null @@ -1,100 +0,0 @@ -import napari -from qtpy.QtWidgets import ( - QGroupBox, - QPushButton, - QVBoxLayout, - QWidget, -) - -from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer - - -class EditingMenu(QWidget): - def __init__(self, viewer: napari.Viewer): - super().__init__() - - self.tracks_viewer = TracksViewer.get_instance(viewer) - self.tracks_viewer.selected_nodes.list_updated.connect(self.update_buttons) - layout = QVBoxLayout() - - node_box = QGroupBox("Edit Node(s)") - node_box.setMaximumHeight(60) - node_box_layout = QVBoxLayout() - - self.delete_node_btn = QPushButton("Delete [D]") - self.delete_node_btn.clicked.connect(self.tracks_viewer.delete_node) - self.delete_node_btn.setEnabled(False) - # self.split_node_btn = QPushButton("Set split [S]") - # self.split_node_btn.clicked.connect(self.tracks_viewer.set_split_node) - # self.split_node_btn.setEnabled(False) - # self.endpoint_node_btn = QPushButton("Set endpoint [E]") - # self.endpoint_node_btn.clicked.connect(self.tracks_viewer.set_endpoint_node) - # self.endpoint_node_btn.setEnabled(False) - # self.linear_node_btn = QPushButton("Set linear [C]") - # self.linear_node_btn.clicked.connect(self.tracks_viewer.set_linear_node) - # self.linear_node_btn.setEnabled(False) - - node_box_layout.addWidget(self.delete_node_btn) - # node_box_layout.addWidget(self.split_node_btn) - # node_box_layout.addWidget(self.endpoint_node_btn) - # node_box_layout.addWidget(self.linear_node_btn) - - node_box.setLayout(node_box_layout) - - edge_box = QGroupBox("Edit Edge(s)") - edge_box.setMaximumHeight(100) - edge_box_layout = QVBoxLayout() - - self.delete_edge_btn = QPushButton("Break [B]") - self.delete_edge_btn.clicked.connect(self.tracks_viewer.delete_edge) - self.delete_edge_btn.setEnabled(False) - self.create_edge_btn = QPushButton("Add [A]") - self.create_edge_btn.clicked.connect(self.tracks_viewer.create_edge) - self.create_edge_btn.setEnabled(False) - - edge_box_layout.addWidget(self.delete_edge_btn) - edge_box_layout.addWidget(self.create_edge_btn) - - edge_box.setLayout(edge_box_layout) - - self.undo_btn = QPushButton("Undo (Z)") - self.undo_btn.clicked.connect(self.tracks_viewer.undo) - - self.redo_btn = QPushButton("Redo (R)") - self.redo_btn.clicked.connect(self.tracks_viewer.redo) - - layout.addWidget(node_box) - layout.addWidget(edge_box) - layout.addWidget(self.undo_btn) - layout.addWidget(self.redo_btn) - - self.setLayout(layout) - self.setMaximumHeight(300) - - def update_buttons(self): - """Set the buttons to enabled/disabled depending on the currently selected nodes""" - - n_selected = len(self.tracks_viewer.selected_nodes) - if n_selected == 0: - self.delete_node_btn.setEnabled(False) - # self.split_node_btn.setEnabled(False) - # self.endpoint_node_btn.setEnabled(False) - # self.linear_node_btn.setEnabled(False) - self.delete_edge_btn.setEnabled(False) - self.create_edge_btn.setEnabled(False) - - elif n_selected == 2: - self.delete_node_btn.setEnabled(True) - # self.split_node_btn.setEnabled(True) - # self.endpoint_node_btn.setEnabled(True) - # self.linear_node_btn.setEnabled(True) - self.delete_edge_btn.setEnabled(True) - self.create_edge_btn.setEnabled(True) - - else: - self.delete_node_btn.setEnabled(True) - # self.split_node_btn.setEnabled(True) - # self.endpoint_node_btn.setEnabled(True) - # self.linear_node_btn.setEnabled(True) - self.delete_edge_btn.setEnabled(False) - self.create_edge_btn.setEnabled(False) diff --git a/src/motile_tracker/application_menus/main_app.py b/src/motile_tracker/application_menus/main_app.py deleted file mode 100644 index f95111f3..00000000 --- a/src/motile_tracker/application_menus/main_app.py +++ /dev/null @@ -1,26 +0,0 @@ -import napari -from qtpy.QtWidgets import ( - QVBoxLayout, - QWidget, -) - -from motile_tracker.data_views.views.tree_view.tree_widget import TreeWidget - -from .menu_widget import MenuWidget - - -class MainApp(QWidget): - """Combines the different tracker widgets for faster dock arrangement""" - - def __init__(self, viewer: napari.Viewer): - super().__init__() - - menu_widget = MenuWidget(viewer) - tree_widget = TreeWidget(viewer) - - viewer.window.add_dock_widget(tree_widget, area="bottom", name="Tree View") - - layout = QVBoxLayout() - layout.addWidget(menu_widget) - - self.setLayout(layout) diff --git a/src/motile_tracker/application_menus/menu_widget.py b/src/motile_tracker/application_menus/menu_widget.py deleted file mode 100644 index f0af726e..00000000 --- a/src/motile_tracker/application_menus/menu_widget.py +++ /dev/null @@ -1,32 +0,0 @@ -import napari -from qtpy.QtWidgets import QScrollArea, QTabWidget, QVBoxLayout - -from motile_tracker.application_menus.editing_menu import EditingMenu -from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer -from motile_tracker.motile.menus.motile_widget import MotileWidget - - -class MenuWidget(QScrollArea): - """Combines the different tracker menus into tabs for cleaner UI""" - - def __init__(self, viewer: napari.Viewer): - super().__init__() - - tracks_viewer = TracksViewer.get_instance(viewer) - - motile_widget = MotileWidget(viewer) - editing_widget = EditingMenu(viewer) - - tabwidget = QTabWidget() - - tabwidget.addTab(motile_widget, "Track with Motile") - tabwidget.addTab(editing_widget, "Edit Tracks") - tabwidget.addTab(tracks_viewer.tracks_list, "Results List") - - layout = QVBoxLayout() - layout.addWidget(tabwidget) - - self.setWidget(tabwidget) - self.setWidgetResizable(True) - - self.setLayout(layout) diff --git a/src/motile_tracker/data_model/__init__.py b/src/motile_tracker/data_model/__init__.py deleted file mode 100644 index c1aac86b..00000000 --- a/src/motile_tracker/data_model/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .tracks import Tracks # noqa -from .solution_tracks import SolutionTracks # noqa -from .node_type import NodeType # noqa -from .tracks_controller import TracksController # noqa diff --git a/src/motile_tracker/data_model/action_history.py b/src/motile_tracker/data_model/action_history.py deleted file mode 100644 index b3dc5c6f..00000000 --- a/src/motile_tracker/data_model/action_history.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .actions import TracksAction - - -class ActionHistory: - """An action history implementing the ideas from this blog: - https://github.com/zaboople/klonk/blob/master/TheGURQ.md - Essentially, if you go back and change something after undo-ing, you can always get - back to every state if you undo far enough (instead of throwing out - the undone actions) - """ - - def __init__(self): - self.undo_stack: list[TracksAction] = [] # list of actions that can be undone - self.redo_stack: list[TracksAction] = [] # list of actions that can be redone - - @property - def undo_pointer(self): - return len(self.undo_stack) - len(self.redo_stack) - 1 - - def add_new_action(self, action: TracksAction) -> None: - """Add a newly performed action to the history. - Args: - action (TracksAction): The new action to be added to the history. - """ - if len(self.redo_stack) > 0: - # add all the redo stuff to the undo stack, so that both the originial and - # inverse are on the stack - self.undo_stack.extend(self.redo_stack) - self.redo_stack = [] - self.undo_stack.append(action) - - def undo(self) -> bool: - """Undo the last performed action - - Returns: - bool: True if an action was undone, and False - if there was no previous action to undo. - """ - if self.undo_pointer < 0: - return False - else: - action = self.undo_stack[self.undo_pointer] - inverse = action.inverse() - self.redo_stack.append(inverse) - return True - - def redo(self) -> bool: - """Redo the last undone action - - Returns: - bool: True if an action was redone, and False - if there was no undone action to redo. - """ - if len(self.redo_stack) == 0: - return False - else: - action = self.redo_stack.pop(-1) - # apply the inverse but don't save it - # (the original is already on the undo stack) - action.inverse() - return True diff --git a/src/motile_tracker/data_model/actions.py b/src/motile_tracker/data_model/actions.py deleted file mode 100644 index 6ee5ea1c..00000000 --- a/src/motile_tracker/data_model/actions.py +++ /dev/null @@ -1,327 +0,0 @@ -"""This module contains all the low level actions used to control a Tracks object. -Low level actions should control these aspects of Tracks: - - adding/removing nodes and edges to/from the segmentation and graph - - Updating the segmentation and graph attributes that are controlled by the segmentation. - Currently, position and area for nodes, and IOU for edges. - - Keeping track of information needed to undo the given action. For removing a node, - this means keeping track of the incident edges that were removed, along with their - attributes. - -The low level actions do not contain application logic, such as manipulating track ids, -or validation of "allowed" actions. -The actions should work on candidate graphs as well as solution graphs. -Action groups can be constructed to represent application-level actions constructed -from many low-level actions. -""" - -from __future__ import annotations - -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr - -from .solution_tracks import SolutionTracks -from .tracks import Attrs, Edge, Node, SegMask, Tracks - - -class TracksAction: - def __init__(self, tracks: Tracks): - """An modular change that can be applied to the given Tracks. The tracks must - be passed in at construction time so that metadata needed to invert the action - can be extracted. - The change should be applied in the init function. - - Args: - tracks (Tracks): The tracks that this action will edit - """ - self.tracks = tracks - - def inverse(self) -> TracksAction: - """Get the inverse of this action. Calling this function does undo the action, - since the change is applied in the action constructor. - - Raises: - NotImplementedError: if the inverse is not implemented in the subclass - - Returns: - TracksAction: An action that un-does this action, bringing the tracks - back to the exact state it had before applying this action. - """ - raise NotImplementedError("Inverse not implemented") - - -class ActionGroup(TracksAction): - def __init__( - self, - tracks: Tracks, - actions: list[TracksAction], - ): - """A group of actions that is also an action, used to modify the given tracks. - This is useful for creating composite actions from the low-level actions. - Composite actions can contain application logic and can be un-done as a group. - - Args: - tracks (Tracks): The tracks that this action will edit - actions (list[TracksAction]): A list of actions contained within the group, - in the order in which they should be executed. - """ - super().__init__(tracks) - self.actions = actions - - def inverse(self) -> ActionGroup: - actions = [action.inverse() for action in self.actions[::-1]] - return ActionGroup(self.tracks, actions) - - -class AddNodes(TracksAction): - """Action for adding new nodes. If a segmentation should also be added, the - pixels for each node should be provided. The label to set the pixels will - be taken from the node id. The existing pixel values are assumed to be - zero - you must explicitly update any other segmentations that were overwritten - using an UpdateNodes action if you want to be able to undo the action. - """ - - def __init__( - self, - tracks: Tracks, - nodes: list[Node], - attributes: Attrs, - pixels: list[SegMask] | None = None, - ): - """Create an action to add new nodes, with optional segmentation - - Args: - tracks (Tracks): The Tracks to add the nodes to - nodes (Node): A list of node ids - attributes (Attrs): Includes times and optionally positions - pixels (list[SegMask] | None, optional): The segmentations associated with each node. - Defaults to None. - """ - super().__init__(tracks) - self.nodes = nodes - user_attrs = attributes.copy() - self.times = attributes.get(NodeAttr.TIME.value, None) - if NodeAttr.TIME.value in attributes: - del user_attrs[NodeAttr.TIME.value] - self.positions = attributes.get(tracks.pos_attr, None) - if tracks.pos_attr in attributes: - del user_attrs[tracks.pos_attr] - self.pixels = pixels - self.attributes = user_attrs - self._apply() - - def inverse(self): - """Invert the action to delete nodes instead""" - return DeleteNodes(self.tracks, self.nodes) - - def _apply(self): - """Apply the action, and set segmentation if provided in self.pixels""" - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.nodes) - self.tracks.add_nodes( - self.nodes, self.times, self.positions, attrs=self.attributes - ) - - -class DeleteNodes(TracksAction): - """Action of deleting existing nodes - If the tracks contain a segmentation, this action also constructs a reversible - operation for setting involved pixels to zero - """ - - def __init__( - self, tracks: Tracks, nodes: list[Node], pixels: list[SegMask] | None = None - ): - super().__init__(tracks) - self.nodes = nodes - self.attributes = { - NodeAttr.TIME.value: self.tracks.get_times(nodes), - self.tracks.pos_attr: self.tracks.get_positions(nodes), - NodeAttr.TRACK_ID.value: self.tracks._get_nodes_attr( - nodes, NodeAttr.TRACK_ID.value - ), - } - self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels - self._apply() - - def inverse(self): - """Invert this action, and provide inverse segmentation operation if given""" - - return AddNodes(self.tracks, self.nodes, self.attributes, pixels=self.pixels) - - def _apply(self): - """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be removed - by this operation - Steps: - - For each node - set pixels to 0 if self.pixels is provided - - Remove nodes from graph - """ - if self.pixels is not None: - self.tracks.set_pixels( - self.pixels, - [0] * len(self.pixels), - ) - - self.tracks.remove_nodes(self.nodes) - - -class UpdateNodeSegs(TracksAction): - """Action for updating the segmentation associated with nodes. Cannot mix adding - and removing pixels from segmentation: the added flag applies to all nodes""" - - def __init__( - self, - tracks: Tracks, - nodes: list[Node], - pixels: list[SegMask], - added: bool = True, - ): - """ - Args: - tracks (Tracks): The tracks to update the segmenatations for - nodes (list[Node]): The nodes with updated segmenatations - pixels (list[SegMask]): The pixels that were updated for each node - added (bool, optional): If the provided pixels were added (True) or deleted - (False) from all nodes. Defaults to True. Cannot mix adding and deleting - pixels in one action. - """ - super().__init__(tracks) - self.nodes = nodes - self.pixels = pixels - self.added = added - self._apply() - - def inverse(self): - """Restore previous attributes""" - return UpdateNodeSegs( - self.tracks, - self.nodes, - pixels=self.pixels, - added=not self.added, - ) - - def _apply(self): - """Set new attributes""" - self.tracks.update_segmentations(self.nodes, self.pixels, self.added) - - -class UpdateNodeAttrs(TracksAction): - """Action for user updates to node attributes. Cannot update protected - attributes (time, area, track id), as these are controlled by internal application - logic.""" - - def __init__( - self, - tracks: Tracks, - nodes: list[Node], - attrs: Attrs, - ): - """ - Args: - tracks (Tracks): The tracks to update the node attributes for - nodes (list[Node]): The nodes to update the attributes for - attrs (Attrs): A mapping from attribute name to list of new attribute values - for the given nodes. - - Raises: - ValueError: If a protected attribute is in the given attribute mapping. - """ - super().__init__(tracks) - protected_attrs = [ - tracks.time_attr, - NodeAttr.AREA.value, - NodeAttr.TRACK_ID.value, - ] - for attr in attrs: - if attr in protected_attrs: - raise ValueError(f"Cannot update attribute {attr} manually") - self.nodes = nodes - self.prev_attrs = { - attr: self.tracks._get_nodes_attr(nodes, attr) for attr in attrs - } - self.new_attrs = attrs - self._apply() - - def inverse(self): - """Restore previous attributes""" - return UpdateNodeAttrs( - self.tracks, - self.nodes, - self.prev_attrs, - ) - - def _apply(self): - """Set new attributes""" - for attr, values in self.new_attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) - - -class AddEdges(TracksAction): - """Action for adding new edges""" - - def __init__(self, tracks: Tracks, edges: list[Edge]): - super().__init__(tracks) - self.edges = edges - self._apply() - - def inverse(self): - """Delete edges""" - return DeleteEdges(self.tracks, self.edges) - - def _apply(self): - """ - Steps: - - add each edge to the graph. Assumes all edges are valid (they should be checked at this point already) - """ - self.tracks.add_edges(self.edges) - - -class DeleteEdges(TracksAction): - """Action for deleting edges""" - - def __init__(self, tracks: Tracks, edges: list[Edge]): - super().__init__(tracks) - self.edges = edges - self._apply() - - def inverse(self): - """Restore edges and their attributes""" - return AddEdges(self.tracks, self.edges) - - def _apply(self): - """Steps: - - Remove the edges from the graph - """ - self.tracks.remove_edges(self.edges) - - -class UpdateTrackID(TracksAction): - def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): - """ - Args: - tracks (Tracks): The tracks to update - start_node (Node): The node ID of the first node in the track. All successors - with the same track id as this node will be updated. - track_id (int): The new track id to assign. - """ - super().__init__(tracks) - self.start_node = start_node - self.old_track_id = self.tracks.get_track_id(start_node) - self.new_track_id = track_id - self._apply() - - def inverse(self) -> TracksAction: - """Restore the previous track_id""" - return UpdateTrackID(self.tracks, self.start_node, self.old_track_id) - - def _apply(self): - """Assign a new track id to the track starting with start_node.""" - old_track_id = self.tracks.get_track_id(self.start_node) - curr_node = self.start_node - while self.tracks.get_track_id(curr_node) == old_track_id: - # update the track id - self.tracks.set_track_id(curr_node, self.new_track_id) - # getting the next node (picks one if there are two) - successors = list(self.tracks.graph.successors(curr_node)) - if len(successors) == 0: - break - curr_node = successors[0] diff --git a/src/motile_tracker/data_model/node_type.py b/src/motile_tracker/data_model/node_type.py deleted file mode 100644 index 4b5f0b2f..00000000 --- a/src/motile_tracker/data_model/node_type.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import Enum - - -class NodeType(Enum): - """Types of nodes in the track graph. Currently used for standardizing - visualization. All nodes are exactly one type. - """ - - SPLIT = "SPLIT" - END = "END" - CONTINUE = "CONTINUE" diff --git a/src/motile_tracker/data_model/solution_tracks.py b/src/motile_tracker/data_model/solution_tracks.py deleted file mode 100644 index 4864a016..00000000 --- a/src/motile_tracker/data_model/solution_tracks.py +++ /dev/null @@ -1,166 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import networkx as nx -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr - -from .tracks import Tracks - -if TYPE_CHECKING: - from collections.abc import Iterable - from pathlib import Path - - import numpy as np - - from .tracks import Attrs, Node - - -class SolutionTracks(Tracks): - """Difference from Tracks: every node must have a track_id""" - - def __init__( - self, - graph: nx.DiGraph, - segmentation: np.ndarray | None = None, - time_attr: str = NodeAttr.TIME.value, - pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, - scale: list[float] | None = None, - ndim: int | None = None, - ): - super().__init__( - graph, - segmentation=segmentation, - time_attr=time_attr, - pos_attr=pos_attr, - scale=scale, - ndim=ndim, - ) - self.max_track_id: int - self._initialize_track_ids() - - @classmethod - def from_tracks(cls, tracks: Tracks): - return cls( - tracks.graph, - segmentation=tracks.segmentation, - time_attr=tracks.time_attr, - pos_attr=tracks.pos_attr, - scale=tracks.scale, - ndim=tracks.ndim, - ) - - @property - def node_id_to_track_id(self) -> dict[Node, int]: - return nx.get_node_attributes(self.graph, NodeAttr.TRACK_ID.value) - - def get_next_track_id(self) -> int: - """Return the next available track_id and update self.max_track_id""" - computed_max = max(self.node_id_to_track_id.values()) - if self.max_track_id < computed_max: - self.max_track_id = computed_max - self.max_track_id = self.max_track_id + 1 - return self.max_track_id - - def get_track_id(self, node) -> int: - track_id = int( - self._get_node_attr(node, NodeAttr.TRACK_ID.value, required=True) - ) - return track_id - - def set_track_id(self, node: Node, value: int): - old_track_id = self.get_track_id(node) - self.track_id_to_node[old_track_id].remove(node) - self._set_node_attr(node, NodeAttr.TRACK_ID.value, value) - if value not in self.track_id_to_node: - self.track_id_to_node[value] = [] - self.track_id_to_node[value].append(node) - - def _initialize_track_ids(self): - self.max_track_id = 0 - self.track_id_to_node = {} - - if self.graph.number_of_nodes() != 0: - if len(self.node_id_to_track_id) < self.graph.number_of_nodes(): - # not all nodes have a track id: reassign - self._assign_tracklet_ids() - else: - self.max_track_id = max(self.node_id_to_track_id.values()) - for node, track_id in self.node_id_to_track_id.items(): - if track_id not in self.track_id_to_node: - self.track_id_to_node[track_id] = [] - self.track_id_to_node[track_id].append(node) - - def _assign_tracklet_ids(self): - """Add a track_id attribute to a graph by removing division edges, - assigning one id to each connected component. - Also sets the max_track_id and initializes a dictionary from track_id to nodes - """ - graph_copy = self.graph.copy() - - parents = [node for node, degree in self.graph.out_degree() if degree >= 2] - intertrack_edges = [] - - # Remove all intertrack edges from a copy of the original graph - for parent in parents: - daughters = [child for p, child in self.graph.out_edges(parent)] - for daughter in daughters: - graph_copy.remove_edge(parent, daughter) - intertrack_edges.append((parent, daughter)) - - track_id = 1 - for tracklet in nx.weakly_connected_components(graph_copy): - nx.set_node_attributes( - self.graph, - {node: {NodeAttr.TRACK_ID.value: track_id} for node in tracklet}, - ) - self.track_id_to_node[track_id] = list(tracklet) - track_id += 1 - self.max_track_id = track_id - 1 - - def export_tracks(self, outfile: Path | str): - """Export the tracks from this run to a csv with the following columns: - t,[z],y,x,id,parent_id,track_id - Cells without a parent_id will have an empty string for the parent_id. - Whether or not to include z is inferred from self.ndim - """ - header = ["t", "z", "y", "x", "id", "parent_id", "track_id"] - if self.ndim == 3: - header = [header[0]] + header[2:] # remove z - with open(outfile, "w") as f: - f.write(",".join(header)) - for node_id in self.graph.nodes(): - parents = list(self.graph.predecessors(node_id)) - parent_id = "" if len(parents) == 0 else parents[0] - track_id = self.get_track_id(node_id) - time = self.get_time(node_id) - position = self.get_position(node_id) - row = [ - time, - *position, - node_id, - parent_id, - track_id, - ] - f.write("\n") - f.write(",".join(map(str, row))) - - def add_nodes( - self, - nodes: Iterable[Node], - times: Iterable[int], - positions: np.ndarray | None = None, - attrs: Attrs | None = None, - ): - # overriding add_nodes to add new nodes to the track_id_to_node mapping - super().add_nodes(nodes, times, positions, attrs) - for node, track_id in zip(nodes, attrs[NodeAttr.TRACK_ID.value], strict=True): - if track_id not in self.track_id_to_node: - self.track_id_to_node[track_id] = [] - self.track_id_to_node[track_id].append(node) - - def remove_nodes(self, nodes: Iterable[Node]): - # overriding remove_nodes to remove nodes from the track_id_to_node mapping - for node in nodes: - self.track_id_to_node[self.get_track_id(node)].remove(node) - super().remove_nodes(nodes) diff --git a/src/motile_tracker/data_model/tracks.py b/src/motile_tracker/data_model/tracks.py deleted file mode 100644 index e745675b..00000000 --- a/src/motile_tracker/data_model/tracks.py +++ /dev/null @@ -1,706 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import Iterable, Mapping, Sequence -from typing import ( - TYPE_CHECKING, - Any, - Optional, - TypeAlias, -) - -import networkx as nx -import numpy as np -from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr -from motile_toolbox.candidate_graph.iou import _compute_ious -from psygnal import Signal -from skimage import measure - -if TYPE_CHECKING: - from pathlib import Path - -AttrValue: TypeAlias = Any -Node: TypeAlias = int -Edge: TypeAlias = tuple[Node, Node] -AttrValues: TypeAlias = Sequence[AttrValue] -Attrs: TypeAlias = Mapping[str, AttrValues] -SegMask: TypeAlias = tuple[np.ndarray, ...] - - -class Tracks: - """A set of tracks consisting of a graph and an optional segmentation. - The graph nodes represent detections and must have a time attribute and - position attribute. Edges in the graph represent links across time. - - Attributes: - graph (nx.DiGraph): A graph with nodes representing detections and - and edges representing links across time. - segmentation (Optional(np.ndarray)): An optional segmentation that - accompanies the tracking graph. If a segmentation is provided, - the node ids in the graph must match the segmentation labels. - Defaults to None. - time_attr (str): The attribute in the graph that specifies the time - frame each node is in. - pos_attr (str | tuple[str] | list[str]): The attribute in the graph - that specifies the position of each node. Can be a single attribute - that holds a list, or a list of attribute keys. - - For bulk operations on attributes, a KeyError will be raised if a node or edge - in the input set is not in the graph. All operations before the error node will - be performed, and those after will not. - """ - - refresh = Signal(Optional[str]) - GRAPH_FILE = "graph.json" - SEG_FILE = "seg.npy" - ATTRS_FILE = "attrs.json" - - def __init__( - self, - graph: nx.DiGraph, - segmentation: np.ndarray | None = None, - time_attr: str = NodeAttr.TIME.value, - pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, - scale: list[float] | None = None, - ndim: int | None = None, - ): - self.graph = graph - self.segmentation = segmentation - self.time_attr = time_attr - self.pos_attr = pos_attr - self.scale = scale - self.ndim = self._compute_ndim(segmentation, scale, ndim) - - def get_positions( - self, nodes: Iterable[Node], incl_time: bool = False - ) -> np.ndarray: - """Get the positions of nodes in the graph. Optionally include the - time frame as the first dimension. Raises an error if any of the nodes - are not in the graph. - - Args: - node (Iterable[Node]): The node ids in the graph to get the positions of - incl_time (bool, optional): If true, include the time as the - first element of each position array. Defaults to False. - - Returns: - np.ndarray: A N x ndim numpy array holding the positions, where N is the - number of nodes passed in - """ - if isinstance(self.pos_attr, tuple | list): - positions = np.stack( - [ - self._get_nodes_attr(nodes, dim, required=True) - for dim in self.pos_attr - ], - axis=1, - ) - else: - positions = np.array( - self._get_nodes_attr(nodes, self.pos_attr, required=True) - ) - - if incl_time: - times = np.array(self._get_nodes_attr(nodes, self.time_attr, required=True)) - positions = np.c_[times, positions] - - return positions - - def get_position(self, node: Node, incl_time=False) -> list: - return self.get_positions([node], incl_time=incl_time)[0].tolist() - - def set_positions( - self, - nodes: Iterable[Node], - positions: np.ndarray, - incl_time: bool = False, - ): - """Set the location of nodes in the graph. Optionally include the - time frame as the first dimension. Raises an error if any of the nodes - are not in the graph. - - Args: - nodes (Iterable[node]): The node ids in the graph to set the location of. - positions (np.ndarray): An (ndim, num_nodes) shape array of positions to set. - f incl_time is true, time is the first column and is included in ndim. - incl_time (bool, optional): If true, include the time as the - first column of the position array. Defaults to False. - """ - if not isinstance(positions, np.ndarray): - positions = np.array(positions) - if incl_time: - self.set_times(nodes, positions[:, 0].tolist()) - positions = positions[:, 1:] - - if isinstance(self.pos_attr, tuple | list): - for idx, attr in enumerate(self.pos_attr): - self._set_nodes_attr(nodes, attr, positions[:, idx].tolist()) - else: - self._set_nodes_attr(nodes, self.pos_attr, positions.tolist()) - - def set_position(self, node: Node, position: list, incl_time=False): - self.set_positions( - [node], np.expand_dims(np.array(position), axis=0), incl_time=incl_time - ) - - def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: - return self._get_nodes_attr(nodes, self.time_attr, required=True) - - def get_time(self, node: Node) -> int: - """Get the time frame of a given node. Raises an error if the node - is not in the graph. - - Args: - node (Any): The node id to get the time frame for - - Returns: - int: The time frame that the node is in - """ - return int(self.get_times([node])[0]) - - def set_times(self, nodes: Iterable[Node], times: Iterable[int]): - times = [int(t) for t in times] - self._set_nodes_attr(nodes, self.time_attr, times) - - def set_time(self, node: Any, time: int): - """Set the time frame of a given node. Raises an error if the node - is not in the graph. - - Args: - node (Any): The node id to set the time frame for - time (int): The time to set - - """ - self.set_times([node], [int(time)]) - - def add_nodes( - self, - nodes: Iterable[Node], - times: Iterable[int], - positions: np.ndarray | None = None, - attrs: Attrs | None = None, - ): - """Add a set of nodes to the tracks object. Includes computing node attributes - (position, area) from the segmentation if there is one. Does not include setting - the segmentation pixels - assumes this is already done. - - Args: - nodes (Iterable[Node]): node ids to add - times (Iterable[int]): times of nodes to add - positions (np.ndarray | None, optional): The positions to set for each node, - if no segmentation is present. If segmentation is present, these provided - values will take precedence over the computed centroids. Defaults to None. - attrs (Attrs | None, optional): The additional attributes to add to each node. - Defaults to None. - - Raises: - ValueError: If neither positions nor segmentations are provided - """ - if attrs is None: - attrs = {} - self.graph.add_nodes_from(nodes) - self.set_times(nodes, times) - final_pos: np.ndarray - if self.segmentation is not None: - computed_attrs = self._compute_node_attrs(nodes, times) - if positions is None: - final_pos = np.array(computed_attrs[NodeAttr.POS.value]) - else: - final_pos = positions - attrs[NodeAttr.AREA.value] = computed_attrs[NodeAttr.AREA.value] - elif positions is None: - raise ValueError("Must provide positions or segmentation and ids") - else: - final_pos = positions - - self.set_positions(nodes, final_pos) - for attr, values in attrs.items(): - self._set_nodes_attr(nodes, attr, values) - - def add_node( - self, - node: Node, - time: int, - position: Sequence | None = None, - attrs: Attrs | None = None, - ): - """Add a node to the graph. Will update the internal mappings and generate the - segmentation-controlled attributes if there is a segmentation present. - The segmentation should have been previously updated, otherwise the - attributes will not update properly. - - Args: - node (Node): The node id to add - time (int): the time frame of the node to add - position (Sequence | None): The spatial position of the node (excluding time). - Can be None if it should be automatically detected from the segmentation. - Either segmentation or position must be provided. Defaults to None. - attrs (Attrs | None, optional): The additional attributes to add to node. - Defaults to None. - """ - pos = np.expand_dims(position, axis=0) if position is not None else None - attributes: dict[str, Sequence[Any]] | None = ( - {key: [val] for key, val in attrs.items()} if attrs is not None else None - ) - self.add_nodes([node], [time], positions=pos, attrs=attributes) - - def remove_nodes(self, nodes: Iterable[Node]): - self.graph.remove_nodes_from(nodes) - - def remove_node(self, node: Node): - """Remove the node from the graph. - Does not update the segmentation if present. - - Args: - node (Node): The node to remove from the graph - """ - self.remove_nodes([node]) - - def add_edges(self, edges: Iterable[Edge]): - attrs: dict[str, Sequence[Any]] = {} - attrs.update(self._compute_edge_attrs(edges)) - for idx, edge in enumerate(edges): - for node in edge: - if not self.graph.has_node(node): - raise KeyError( - f"Cannot add edge {edge}: endpoint {node} not in graph yet" - ) - self.graph.add_edge( - edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} - ) - - def add_edge(self, edge: Edge): - self.add_edges([edge]) - - def remove_edges(self, edges: Iterable[Edge]): - for edge in edges: - self.remove_edge(edge) - - def remove_edge(self, edge: Edge): - if self.graph.has_edge(*edge): - self.graph.remove_edge(*edge) - else: - raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") - - def get_areas(self, nodes: Iterable[Node]) -> Sequence[int | None]: - """Get the area/volume of a given node. Raises a KeyError if the node - is not in the graph. Returns None if the given node does not have an Area - attribute. - - Args: - node (Node): The node id to get the area/volume for - - Returns: - int: The area/volume of the node - """ - return self._get_nodes_attr(nodes, NodeAttr.AREA.value) - - def get_area(self, node: Node) -> int | None: - """Get the area/volume of a given node. Raises a KeyError if the node - is not in the graph. Returns None if the given node does not have an Area - attribute. - - Args: - node (Node): The node id to get the area/volume for - - Returns: - int: The area/volume of the node - """ - return self.get_areas([node])[0] - - def get_ious(self, edges: Iterable[Edge]): - return self._get_edges_attr(edges, EdgeAttr.IOU.value) - - def get_iou(self, edge: Edge): - return self._get_edge_attr(edge, EdgeAttr.IOU.value) - - def get_pixels(self, nodes: list[Node]) -> list[tuple[np.ndarray, ...]] | None: - """Get the pixels corresponding to each node in the nodes list. - - Args: - nodes (list[Node]): A list of node to get the values for. - - Returns: - list[tuple[np.ndarray, ...]] | None: A list of tuples, where each tuple - represents the pixels for one of the input nodes, or None if the segmentation - is None. The tuple will have length equal to the number of segmentation - dimensions, and can be used to index the segmentation. - """ - if self.segmentation is None: - return None - pix_list = [] - for node in nodes: - time = self.get_time(node) - loc_pixels = np.nonzero(self.segmentation[time] == node) - time_array = np.ones_like(loc_pixels[0]) * time - pix_list.append((time_array, *loc_pixels)) - return pix_list - - def set_pixels( - self, pixels: Iterable[tuple[np.ndarray, ...]], values: Iterable[int | None] - ): - """Set the given pixels in the segmentation to the given value. - - Args: - pixels (Iterable[tuple[np.ndarray]]): The pixels that should be set, - formatted like the output of np.nonzero (each element of the tuple - represents one dimension, containing an array of indices in that dimension). - Can be used to directly index the segmentation. - value (Iterable[int | None]): The value to set each pixel to - """ - if self.segmentation is None: - raise ValueError("Cannot set pixels when segmentation is None") - for pix, val in zip(pixels, values, strict=False): - if val is None: - raise ValueError("Cannot set pixels to None value") - self.segmentation[pix] = val - - def update_segmentations( - self, nodes: Iterable[Node], pixels: Iterable[SegMask], added: bool = True - ) -> None: - """Updates the segmentation of the given nodes. Also updates the - auto-computed attributes of the nodes and incident edges. - """ - times = self.get_times(nodes) - values = ( - nodes - if added - else [ - 0, - ] - * len(nodes) - ) - self.set_pixels(pixels, values) - computed_attrs = self._compute_node_attrs(nodes, times) - positions = np.array(computed_attrs[NodeAttr.POS.value]) - self.set_positions(nodes, positions) - self._set_nodes_attr( - nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] - ) - - incident_edges = list(self.graph.in_edges(nodes)) + list( - self.graph.out_edges(nodes) - ) - for edge in incident_edges: - new_edge_attrs = self._compute_edge_attrs([edge]) - self._set_edge_attributes([edge], new_edge_attrs) - - def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): - """Update the attributes for given nodes""" - - for idx, node in enumerate(nodes): - if node in self.graph: - for key, values in attributes.items(): - self.graph.nodes[node][key] = values[idx] - else: - print(f"Node {node} not found in the graph.") - - def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None: - """Set the edge attributes for the given edges. Attributes should already exist - (although adding will work in current implementation, they cannot currently be - removed) - - Args: - edges (list[Edge]): A list of edges to set the attributes for - attributes (Attributes): A dictionary of attribute name -> numpy array, - where the length of the arrays matches the number of edges. - Attributes should already exist: this function will only - update the values. - """ - for idx, edge in enumerate(edges): - if self.graph.has_edge(*edge): - for key, value in attributes.items(): - self.graph.edges[edge][key] = value[idx] - else: - print(f"Edge {edge} not found in the graph.") - - def save(self, directory: Path): - """Save the tracks to the given directory. - Currently, saves the graph as a json file in networkx node link data format, - saves the segmentation as a numpy npz file, and saves the time and position - attributes and scale information in an attributes json file. - - Args: - directory (Path): The directory to save the tracks in. - """ - self._save_graph(directory) - if self.segmentation is not None: - self._save_seg(directory) - self._save_attrs(directory) - - def _save_graph(self, directory: Path): - """Save the graph to file. Currently uses networkx node link data - format (and saves it as json). - - Args: - directory (Path): The directory in which to save the graph file. - """ - graph_file = directory / self.GRAPH_FILE - graph_data = nx.node_link_data(self.graph) - - def convert_np_types(data): - """Recursively convert numpy types to native Python types.""" - - if isinstance(data, dict): - return {key: convert_np_types(value) for key, value in data.items()} - elif isinstance(data, list): - return [convert_np_types(item) for item in data] - elif isinstance(data, np.ndarray): - return data.tolist() # Convert numpy arrays to Python lists - elif isinstance(data, np.integer): - return int(data) # Convert numpy integers to Python int - elif isinstance(data, np.floating): - return float(data) # Convert numpy floats to Python float - else: - return ( - data # Return the data as-is if it's already a native Python type - ) - - graph_data = convert_np_types(graph_data) - with open(graph_file, "w") as f: - json.dump(graph_data, f) - - def _save_seg(self, directory: Path): - """Save a segmentation as a numpy array using np.save. In the future, - could be changed to use zarr or other file types. - - Args: - directory (Path): The directory in which to save the segmentation - """ - out_path = directory / self.SEG_FILE - np.save(out_path, self.segmentation) - - def _save_attrs(self, directory: Path): - """Save the time_attr, pos_attr, and scale in a json file in the given directory. - - Args: - directory (Path): The directory in which to save the attributes - """ - out_path = directory / self.ATTRS_FILE - attrs_dict = { - "time_attr": self.time_attr, - "pos_attr": self.pos_attr - if not isinstance(self.pos_attr, np.ndarray) - else self.pos_attr.tolist(), - "scale": self.scale - if not isinstance(self.scale, np.ndarray) - else self.scale.tolist(), - "ndim": self.ndim, - } - with open(out_path, "w") as f: - json.dump(attrs_dict, f) - - @classmethod - def load(cls, directory: Path, seg_required=False) -> Tracks: - """Load a Tracks object from the given directory. Looks for files - in the format generated by Tracks.save. - - Args: - directory (Path): The directory containing tracks to load - seg_required (bool, optional): If true, raises a FileNotFoundError if the - segmentation file is not present in the directory. Defaults to False. - - Returns: - Tracks: A tracks object loaded from the given directory - """ - graph_file = directory / cls.GRAPH_FILE - graph = cls._load_graph(graph_file) - - seg_file = directory / cls.SEG_FILE - seg = cls._load_seg(seg_file, seg_required=seg_required) - - attrs_file = directory / cls.ATTRS_FILE - attrs = cls._load_attrs(attrs_file) - - return cls(graph, seg, **attrs) - - @staticmethod - def _load_graph(graph_file: Path) -> nx.DiGraph: - """Load the graph from the given json file. Expects networkx node_link_graph - formatted json. - - Args: - graph_file (Path): The json file to load into a networkx graph - - Raises: - FileNotFoundError: If the file does not exist - - Returns: - nx.DiGraph: A networkx graph loaded from the file. - """ - if graph_file.is_file(): - with open(graph_file) as f: - json_graph = json.load(f) - return nx.node_link_graph(json_graph, directed=True) - else: - raise FileNotFoundError(f"No graph at {graph_file}") - - @staticmethod - def _load_seg(seg_file: Path, seg_required: bool = False) -> np.ndarray | None: - """Load a segmentation from a file. If the file doesn't exist, either return - None or raise a FileNotFoundError depending on the seg_required flag. - - Args: - seg_file (Path): The npz file to load. - seg_required (bool, optional): If true, raise a FileNotFoundError if the - segmentation is not present. Defaults to False. - - Returns: - np.ndarray | None: The segmentation array, or None if it wasn't present and - seg_required was False. - """ - if seg_file.is_file(): - return np.load(seg_file) - elif seg_required: - raise FileNotFoundError(f"No segmentation at {seg_file}") - else: - return None - - @staticmethod - def _load_attrs(attrs_file: Path) -> dict: - if attrs_file.is_file(): - with open(attrs_file) as f: - return json.load(f) - else: - raise FileNotFoundError(f"No attributes at {attrs_file}") - - @classmethod - def delete(cls, directory: Path): - # Lets be safe and remove the expected files and then the directory - (directory / cls.GRAPH_FILE).unlink() - (directory / cls.SEG_FILE).unlink() - (directory / cls.ATTRS_FILE).unlink() - directory.rmdir() - - def _compute_ndim( - self, - seg: np.ndarray | None, - scale: list[float] | None, - provided_ndim: int | None, - ): - seg_ndim = seg.ndim if seg is not None else None - scale_ndim = len(scale) if scale is not None else None - ndims = [seg_ndim, scale_ndim, provided_ndim] - ndims = [d for d in ndims if d is not None] - if len(ndims) == 0: - raise ValueError( - "Cannot compute dimensions from segmentation or scale: please provide ndim argument" - ) - ndim = ndims[0] - if not all(d == ndim for d in ndims): - raise ValueError( - f"Dimensions from segmentation {seg_ndim}, scale {scale_ndim}, and ndim {provided_ndim} must match" - ) - return ndim - - def _set_node_attr(self, node: Node, attr: NodeAttr, value: Any): - if isinstance(value, np.ndarray): - value = list(value) - self.graph.nodes[node][attr] = value - - def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): - for node, value in zip(nodes, values, strict=False): - if isinstance(value, np.ndarray): - value = list(value) - self.graph.nodes[node][attr] = value - - def _get_node_attr(self, node: Node, attr: str, required: bool = False): - if required: - return self.graph.nodes[node][attr] - else: - return self.graph.nodes[node].get(attr, None) - - def _get_nodes_attr(self, nodes: Iterable[Node], attr: str, required: bool = False): - return [self._get_node_attr(node, attr, required=required) for node in nodes] - - def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - self.graph.edge[edge][attr] = value - - def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): - for edge, value in zip(edges, values, strict=False): - self.graph.edges[edge][attr] = value - - def _get_edge_attr(self, edge: Edge, attr: str, required: bool = False): - if required: - return self.graph.edges[edge][attr] - else: - return self.graph.edges[edge].get(attr, None) - - def _get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): - return [self._get_edge_attr(edge, attr, required=required) for edge in edges] - - def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> Attrs: - """Get the segmentation controlled node attributes (area and position) - from the segmentation with label based on the node id in the given time point. - - Args: - nodes (Iterable[int]): The node ids to query the current segmentation for - time (int): The time frames of the current segmentation to query - - Returns: - dict[str, int]: A dictionary containing the attributes that could be - determined from the segmentation. It will be empty if self.segmentation - is None. If self.segmentation exists but node id is not present in time, - area will be 0 and position will be None. If self.segmentation - exists and node id is present in time, area and position will be included. - """ - if self.segmentation is None: - return {} - - attrs: dict[str, list[Any]] = { - NodeAttr.POS.value: [], - NodeAttr.AREA.value: [], - } - for node, time in zip(nodes, times, strict=False): - seg = self.segmentation[time] == node - pos_scale = self.scale[1:] if self.scale is not None else None - area = np.sum(seg) - if pos_scale is not None: - area *= np.prod(pos_scale) - # only include the position if the segmentation was actually there - pos = ( - measure.centroid(seg, spacing=pos_scale) - if area > 0 - else np.array( - [ - None, - ] - * (self.ndim - 1) - ) - ) - attrs[NodeAttr.AREA.value].append(area) - attrs[NodeAttr.POS.value].append(pos) - attrs[NodeAttr.POS.value] = np.array(attrs[NodeAttr.POS.value]) - return attrs - - def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: - """Get the segmentation controlled edge attributes (IOU) - from the segmentations associated with the endpoints of the edge. - The endpoints should already exist and have associated segmentations. - - Args: - edge (Edge): The edge to compute the segmentation-based attributes from - - Returns: - dict[str, int]: A dictionary containing the attributes that could be - determined from the segmentation. It will be empty if self.segmentation - is None or if self.segmentation exists but the endpoint segmentations - are not found. - """ - if self.segmentation is None: - return {} - - attrs: dict[str, list[Any]] = {EdgeAttr.IOU.value: []} - for edge in edges: - source, target = edge - source_time = self.get_time(source) - target_time = self.get_time(target) - - source_arr = self.segmentation[source_time] == source - target_arr = self.segmentation[target_time] == target - - iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) - iou = 0 if len(iou_list) == 0 else iou_list[0][2] - - attrs[EdgeAttr.IOU.value].append(iou) - return attrs diff --git a/src/motile_tracker/data_model/tracks_controller.py b/src/motile_tracker/data_model/tracks_controller.py deleted file mode 100644 index 8bcf1dc2..00000000 --- a/src/motile_tracker/data_model/tracks_controller.py +++ /dev/null @@ -1,602 +0,0 @@ -from collections.abc import Iterable - -import numpy as np -from motile_toolbox.candidate_graph import NodeAttr -from napari.utils.notifications import show_info, show_warning -from qtpy.QtWidgets import QMessageBox - -from .action_history import ActionHistory -from .actions import ( - ActionGroup, - AddEdges, - AddNodes, - DeleteEdges, - DeleteNodes, - TracksAction, - UpdateNodeAttrs, - UpdateNodeSegs, - UpdateTrackID, -) -from .solution_tracks import SolutionTracks -from .tracks import Attrs, Node, SegMask - - -class TracksController: - """A set of high level functions to change the data model. - All changes to the data should go through this API. - """ - - def __init__(self, tracks: SolutionTracks): - self.tracks = tracks - self.action_history = ActionHistory() - self.node_id_counter = 1 - - def add_nodes( - self, - attributes: Attrs, - pixels: list[SegMask] | None = None, - ) -> None: - """Calls the _add_nodes function to add nodes. Calls the refresh signal when finished. - - Args: - attributes (Attrs): dictionary containing at least time and position attributes - pixels (list[SegMask] | None, optional): The pixels associated with each node, - if a segmentation is present. Defaults to None. - """ - action, nodes = self._add_nodes(attributes, pixels) - self.action_history.add_new_action(action) - self.tracks.refresh.emit(nodes[0] if nodes else None) - - def _get_pred_and_succ( - self, track_id: int, time: int - ) -> tuple[Node | None, Node | None]: - """Get the last node with the given track id before time, and the first node - with the track id after time, if any. Does not assume that a node with - the given track_id and time is already in tracks, but it can be. - - Args: - track_id (int): The track id to search for - time (int): The time point to find the immediate predecessor and successor - for - - Returns: - tuple[Node | None, Node | None]: The last node before time with the given - track id, and the first node after time with the given track id, - or Nones if there are no such nodes. - """ - if ( - track_id not in self.tracks.track_id_to_node - or len(self.tracks.track_id_to_node[track_id]) == 0 - ): - return None, None - candidates = self.tracks.track_id_to_node[track_id] - candidates.sort(key=lambda n: self.tracks.get_time(n)) - - pred = None - succ = None - for cand in candidates: - if self.tracks.get_time(cand) < time: - pred = cand - elif self.tracks.get_time(cand) > time: - succ = cand - break - return pred, succ - - def _confirm_remove_division_edges(self) -> bool: - """Spawn a dialog box to ask the user if they want to break an upstream division - event or not. - - Returns: - bool: True if the upstream division edges should be removed to make room for - the new node in the track, False if the user wants to cancel the operation. - """ - msg = QMessageBox() - msg.setWindowTitle("Delete existing division?") - msg.setText( - "Painting a label with this track id involves breaking an upstream division event. Proceed?" - ) - msg.setIcon(QMessageBox.Information) - - # Set both OK and Cancel buttons - msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) - - # Execute the message box and catch the result - result = msg.exec_() - - # Check which button was clicked - return result == QMessageBox.Ok - - def _add_nodes( - self, - attributes: Attrs, - pixels: list[SegMask] | None = None, - ) -> tuple[TracksAction, list[Node]]: - """Add nodes to the graph. Includes all attributes and the segmentation. - Will return the actions needed to add the nodes, and the node ids generated for the - new nodes. - If there is a segmentation, the attributes must include: - - time - - node_id - - track_id - If there is not a segmentation, the attributes must include: - - time - - pos - - track_id - - Logic of the function: - - remove edges (when we add a node in a track between two nodes - connected by a skip edge) - - add the nodes - - add edges (to connect each node to its immediate - predecessor and successor with the same track_id, if any) - - Args: - attributes (Attrs): dictionary containing at least time and track id, - and either node_id (if pixels are provided) or position (if not) - pixels (list[SegMask] | None): A list of pixels associated with the node, - or None if there is no segmentation. These pixels will be updated - in the tracks.segmentation, set to the new node id - """ - if NodeAttr.TIME.value not in attributes: - raise ValueError( - f"Cannot add nodes without times. Please add {NodeAttr.TIME.value} attribute" - ) - if NodeAttr.TRACK_ID.value not in attributes: - raise ValueError( - f"Cannot add nodes without track ids. Please add {NodeAttr.TRACK_ID.value} attribute" - ) - - times = attributes[NodeAttr.TIME.value] - track_ids = attributes[NodeAttr.TRACK_ID.value] - if pixels is not None: - nodes = attributes["node_id"] - else: - nodes = self._get_new_node_ids(len(times)) - actions = [] - - # remove skip edges that will be replaced by new edges after adding nodes - edges_to_remove = [] - for time, track_id in zip(times, track_ids, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None and succ is not None: - edges_to_remove.append((pred, succ)) - - # Find and remove edges to nodes with different track_ids (upstream division events) - if track_id in self.tracks.track_id_to_node: - track_id_nodes = self.tracks.track_id_to_node[track_id] - for node in track_id_nodes: - if ( - self.tracks._get_node_attr(node, NodeAttr.TIME.value) <= time - and self.tracks.graph.out_degree(node) == 2 - ): # there is an upstream division event here - if self._confirm_remove_division_edges(): - for succ in self.tracks.graph.successors(node): - edges_to_remove.append((node, succ)) - else: - show_info("Action canceled by user") - self.tracks.refresh.emit() - return - - if len(edges_to_remove) > 0: - actions.append(DeleteEdges(self.tracks, edges_to_remove)) - - # add nodes - actions.append( - AddNodes( - tracks=self.tracks, - nodes=nodes, - attributes=attributes, - pixels=pixels, - ) - ) - - # add in edges to preds and succs with the same track id - edges_to_add = set() # make it a set to avoid double adding edges when you add - # two nodes next to each other in the same track - for node, time, track_id in zip(nodes, times, track_ids, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None: - edges_to_add.add((pred, node)) - if succ is not None: - edges_to_add.add((node, succ)) - actions.append(AddEdges(self.tracks, list(edges_to_add))) - - return ActionGroup(self.tracks, actions), nodes - - def delete_nodes(self, nodes: Iterable[Node]) -> None: - """Calls the _delete_nodes function and then emits the refresh signal - - Args: - nodes (Iterable[Node]): array of node_ids to be deleted - """ - - action = self._delete_nodes(nodes) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def _delete_nodes( - self, nodes: Iterable[Node], pixels: list[SegMask] | None = None - ) -> TracksAction: - """Delete the nodes provided by the array from the graph but maintain successor - track_ids. Reconnect to the nearest predecessor and/or nearest successor - on the same track, if any. - - Function logic: - - delete all edges incident to the nodes - - delete the nodes - - add edges to preds and succs of nodes if they have the same track id - - update track ids if we removed a division by deleting the dge - - Args: - nodes (np.ndarray): array of node_ids to be deleted - """ - actions = [] - - # find all the edges that should be deleted (no duplicates) and put them in a single action - # also keep track of which deletions removed a division, and save the sibling nodes so we can - # update the track ids - edges_to_delete = set() - new_track_ids = [] - for node in nodes: - for pred in self.tracks.graph.predecessors(node): - edges_to_delete.add((pred, node)) - # determine if we need to relabel any tracks - siblings = list(self.tracks.graph.successors(pred)) - if len(siblings) == 2: - # need to relabel the track id of the sibling to match the pred because - # you are implicitly deleting a division - siblings.remove(node) - sib = siblings[0] - # check if the sibling is also deleted, because then relabeling is not needed - if sib not in nodes: - new_track_id = self.tracks.get_track_id(pred) - new_track_ids.append((sib, new_track_id)) - for succ in self.tracks.graph.successors(node): - edges_to_delete.add((node, succ)) - if len(edges_to_delete) > 0: - actions.append(DeleteEdges(self.tracks, list(edges_to_delete))) - - if len(new_track_ids) > 0: - for node, track_id in new_track_ids: - actions.append(UpdateTrackID(self.tracks, node, track_id)) - - track_ids = [self.tracks.get_track_id(node) for node in nodes] - times = self.tracks.get_times(nodes) - # remove nodes - actions.append(DeleteNodes(self.tracks, nodes, pixels=pixels)) - - # find all the skip edges to be made (no duplicates or intermediates to nodes - # that are deleted) and put them in a single action - skip_edges = set() - for track_id, time in zip(track_ids, times, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None and succ is not None: - skip_edges.add((pred, succ)) - if len(skip_edges) > 0: - actions.append(AddEdges(self.tracks, list(skip_edges))) - - return ActionGroup(self.tracks, actions=actions) - - def _update_node_segs( - self, - nodes: Iterable[Node], - pixels: list[SegMask], - added=False, - ) -> TracksAction: - """Update the segmentation and segmentation-managed attributes for - a set of nodes. - - Args: - nodes (Iterable[Node]): The nodes to update - pixels (list[SegMask]): The pixels for each node that were edited - added (bool, optional): If the pixels were added to the nodes (True) - or deleted (False). Defaults to False. Cannot mix adding and removing - pixels in one call. - - Returns: - TracksAction: _description_ - """ - return UpdateNodeSegs(self.tracks, nodes, pixels, added=added) - - def add_edges(self, edges: np.ndarray[int]) -> None: - """Add edges to the graph. Also update the track ids and - corresponding segmentations if applicable - - Args: - edges (np.array[int]): An Nx2 array of N edges, each with source and target - node ids - """ - make_valid_actions = [] - for edge in edges: - is_valid, valid_action = self.is_valid(edge) - if not is_valid: - # warning was printed with details in is_valid call - return - if valid_action is not None: - make_valid_actions.append(valid_action) - main_action = self._add_edges(edges) - if len(make_valid_actions) > 0: - make_valid_actions.append(main_action) - action = ActionGroup(self.tracks, make_valid_actions) - else: - action = main_action - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def update_node_attrs(self, nodes: Iterable[Node], attributes: Attrs): - """Update the user provided node attributes (not the managed attributes). - Also adds the action to the history and emits the refresh signal. - - Args: - nodes (Iterable[Node]): The nodes to update the attributes for - attributes (Attrs): A mapping from user-provided attributes to values for - each node. - """ - action = self._update_node_attrs(nodes, attributes) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def _update_node_attrs( - self, nodes: Iterable[Node], attributes: Attrs - ) -> TracksAction: - """Update the user provided node attributes (not the managed attributes). - - Args: - nodes (Iterable[Node]): The nodes to update the attributes for - attributes (Attrs): A mapping from user-provided attributes to values for - each node. - - Returns: A TracksAction object that performed the update - """ - return UpdateNodeAttrs(self.tracks, nodes, attributes) - - def _add_edges(self, edges: np.ndarray[int]) -> TracksAction: - """Add edges and attributes to the graph. Also update the track ids of the target - node tracks and potentially sibling tracks. - - Args: - edges (np.array[int]): An Nx2 array of N edges, each with source and target - node ids - - Returns: - A TracksAction containing all edits performed in this call - """ - actions = [] - for edge in edges: - out_degree = self.tracks.graph.out_degree(edge[0]) - if out_degree == 0: # joining two segments - # assign the track id of the source node to the target and all out - # edges until end of track - new_track_id = self.tracks.get_track_id(edge[0]) - actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # creating a division - # assign a new track id to existing child - successor = next(iter(self.tracks.graph.successors(edge[0]))) - actions.append( - UpdateTrackID( - self.tracks, successor, self.tracks.get_next_track_id() - ) - ) - else: - raise RuntimeError( - f"Expected degree of 0 or 1 before adding edge, got {out_degree}" - ) - - actions.append(AddEdges(self.tracks, edges)) - return ActionGroup(self.tracks, actions) - - def is_valid(self, edge) -> tuple[bool, TracksAction | None]: - """Check if this edge is valid. - Criteria: - - not horizontal - - not existing yet - - no merges - - no triple divisions - - new edge should be the shortest possible connection between two nodes, given their track_ids. - (no skipping/bypassing any nodes of the same track_id). Check if there are any nodes of the same source or target track_id between source and target - - Args: - edge (np.ndarray[(int, int)]: edge to be validated - Returns: - True if the edge is valid, false if invalid""" - - # make sure that the node2 is downstream of node1 - time1 = self.tracks.get_time(edge[0]) - time2 = self.tracks.get_time(edge[1]) - - if time1 > time2: - edge = (edge[1], edge[0]) - time1, time2 = time2, time1 - action = None - # do all checks - # reject if edge already exists - if self.tracks.graph.has_edge(edge[0], edge[1]): - show_warning("Edge is rejected because it exists already.") - return False, action - - # reject if edge is horizontal - elif self.tracks.get_time(edge[0]) == self.tracks.get_time(edge[1]): - show_warning("Edge is rejected because it is horizontal.") - return False, action - - # reject if target node already has an incoming edge - elif self.tracks.graph.in_degree(edge[1]) > 0: - msg = QMessageBox() - msg.setWindowTitle("Delete existing edge?") - msg.setText( - "Creating this edge involves breaking an existing incoming edge to the target node. Proceed?" - ) - msg.setIcon(QMessageBox.Information) - - # Set both OK and Cancel buttons - msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) - - # Execute the message box and catch the result - result = msg.exec_() - - # Check which button was clicked - if result == QMessageBox.Ok: - print("User clicked OK") - - # identify incoming edge in the target node and insert a delete action - pred = next(self.tracks.graph.predecessors(edge[1])) - action = self._delete_edges(edges=np.array([[pred, edge[1]]])) - - elif result == QMessageBox.Cancel: - show_warning( - "Edge is rejected because merges are currently not allowed." - ) - return False, action - - elif self.tracks.graph.out_degree(edge[0]) > 1: - show_warning( - "Edge is rejected because triple divisions are currently not allowed." - ) - return False, action - - elif time2 - time1 > 1: - track_id2 = self.tracks.graph.nodes[edge[1]][NodeAttr.TRACK_ID.value] - # check whether there are already any nodes with the same track id between source and target (shortest path between equal track_ids rule) - for t in range(time1 + 1, time2): - nodes = [ - n - for n, attr in self.tracks.graph.nodes(data=True) - if attr.get(self.tracks.time_attr) == t - and attr.get(NodeAttr.TRACK_ID.value) == track_id2 - ] - if len(nodes) > 0: - show_warning("Please connect to the closest node") - return False, action - - # all checks passed! - return True, action - - def delete_edges(self, edges: np.ndarray): - """Delete edges from the graph. - - Args: - edges (np.ndarray): The Nx2 array of edges to be deleted - """ - - for edge in edges: - # First check if the to be deleted edges exist - if not self.tracks.graph.has_edge(edge[0], edge[1]): - show_warning("Cannot delete non-existing edge!") - return - action = self._delete_edges(edges) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def _delete_edges(self, edges: np.ndarray) -> ActionGroup: - actions = [DeleteEdges(self.tracks, edges)] - for edge in edges: - out_degree = self.tracks.graph.out_degree(edge[0]) - if out_degree == 0: # removed a normal (non division) edge - new_track_id = self.tracks.get_next_track_id() - actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # removed a division edge - sibling = next(self.tracks.graph.successors(edge[0])) - new_track_id = self.tracks.get_track_id(edge[0]) - actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) - else: - raise RuntimeError( - f"Expected degree of 0 or 1 after removing edge, got {out_degree}" - ) - return ActionGroup(self.tracks, actions) - - def update_segmentations( - self, - to_remove: list[Node], # (node_ids, pixels) - to_update_smaller: list[tuple], # (node_id, pixels) - to_update_bigger: list[tuple], # (node_id, pixels) - to_add: list[tuple], # (node_id, track_id, pixels) - current_timepoint: int, - ) -> None: - """Handle a change in the segmentation mask, checking for node addition, deletion, and attribute updates. - Args: - updated_pixels (list[(tuple(np.ndarray, np.ndarray, np.ndarray), np.ndarray, int)]): - list holding the operations that updated the segmentation (directly from - the napari labels paint event). - Each element in the list consists of a tuple of np.ndarrays representing - indices for each dimension, an array of the previous values, and an array - or integer representing the new value(s) - current_timepoint (int): the current time point in the viewer, used to set the selected node. - """ - actions = [] - node_to_select = None - - if len(to_remove) > 0: - nodes = [node_id for node_id, _ in to_remove] - pixels = [pixels for _, pixels in to_remove] - actions.append(self._delete_nodes(nodes, pixels=pixels)) - if len(to_update_smaller) > 0: - nodes = [node_id for node_id, _ in to_update_smaller] - pixels = [pixels for _, pixels in to_update_smaller] - actions.append(self._update_node_segs(nodes, pixels, added=False)) - if len(to_update_bigger) > 0: - nodes = [node_id for node_id, _ in to_update_bigger] - pixels = [pixels for _, pixels in to_update_bigger] - actions.append(self._update_node_segs(nodes, pixels, added=True)) - if len(to_add) > 0: - nodes = [node for node, _, _ in to_add] - pixels = [pix for _, _, pix in to_add] - track_ids = [ - val if val is not None else self.tracks.get_next_track_id() - for _, val, _ in to_add - ] - times = [pix[0][0] for pix in pixels] - attributes = { - NodeAttr.TRACK_ID.value: track_ids, - NodeAttr.TIME.value: times, - "node_id": nodes, - } - - result = self._add_nodes(attributes=attributes, pixels=pixels) - if result is None: - return - else: - action, nodes = result - - actions.append(action) - - # if this is the time point where the user added a node, select the new node - if current_timepoint in times: - index = times.index(current_timepoint) - node_to_select = nodes[index] - - action_group = ActionGroup(self.tracks, actions) - self.action_history.add_new_action(action_group) - self.tracks.refresh.emit(node_to_select) - - def undo(self) -> None: - """Obtain the action to undo from the history, and invert""" - if self.action_history.undo(): - self.tracks.refresh.emit() - else: - show_info("No more actions to undo") - - def redo(self) -> None: - """Obtain the action to redo from the history""" - if self.action_history.redo(): - self.tracks.refresh.emit() - else: - show_info("No more actions to redo") - - def _get_new_node_ids(self, n: int) -> list[Node]: - """Get a list of new node ids for creating new nodes. - They will be unique from all existing nodes, but have no other guarantees. - - Args: - n (int): The number of new node ids to return - - Returns: - list[Node]: A list of new node ids. - """ - ids = [self.node_id_counter + i for i in range(n)] - self.node_id_counter += n - for idx, _id in enumerate(ids): - while self.tracks.graph.has_node(_id): - _id = self.node_id_counter - self.node_id_counter += 1 - ids[idx] = _id - return ids diff --git a/src/motile_tracker/data_views/__init__.py b/src/motile_tracker/data_views/__init__.py deleted file mode 100644 index 88986c1b..00000000 --- a/src/motile_tracker/data_views/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .views.tree_view.tree_widget import TreeWidget # noqa -from .views.layers.track_graph import TrackGraph # noqa -from .views.layers.track_labels import TrackLabels # noqa -from .views.layers.track_points import TrackPoints # noqa - -from .views_coordinator.node_selection_list import NodeSelectionList # noqa -from .views_coordinator.tracks_viewer import TracksViewer # noqa diff --git a/src/motile_tracker/data_views/views/__init__.py b/src/motile_tracker/data_views/views/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/motile_tracker/data_views/views/layers/__init__.py b/src/motile_tracker/data_views/views/layers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/motile_tracker/data_views/views/layers/track_graph.py b/src/motile_tracker/data_views/views/layers/track_graph.py deleted file mode 100644 index 66386b86..00000000 --- a/src/motile_tracker/data_views/views/layers/track_graph.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import annotations - -import copy -from typing import TYPE_CHECKING - -import napari -import numpy as np - -if TYPE_CHECKING: - from motile_tracker.data_model.solution_tracks import SolutionTracks - from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer - - -def update_napari_tracks( - tracks: SolutionTracks, -): - """Function to take a networkx graph with assigned track_ids and return the data needed to add to - a napari tracks layer. - - Args: - tracks (SolutionTracks): tracks that have track_ids and have a tree structure - - Returns: - data: array (N, D+1) - Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first - axis is the integer ID of the track. D is either 3 or 4 for planar - or volumetric timeseries respectively. - graph: dict {int: list} - Graph representing associations between tracks. Dictionary defines the - mapping between a track ID and the parents of the track. This can be - one (the track has one parent, and the parent has >=1 child) in the - case of track splitting, or more than one (the track has multiple - parents, but only one child) in the case of track merging. - """ - - ndim = tracks.ndim - 1 - graph = tracks.graph - napari_data = np.zeros((graph.number_of_nodes(), ndim + 2)) - napari_edges = {} - - parents = [node for node, degree in graph.out_degree() if degree >= 2] - intertrack_edges = [] - - # Remove all intertrack edges from a copy of the original graph - graph_copy = graph.copy() - for parent in parents: - daughters = [child for _, child in graph.out_edges(parent)] - for daughter in daughters: - graph_copy.remove_edge(parent, daughter) - intertrack_edges.append((parent, daughter)) - - for index, node in enumerate(graph.nodes(data=True)): - node_id, data = node - location = tracks.get_position(node_id) - napari_data[index] = [ - tracks.get_track_id(node_id), - tracks.get_time(node_id), - *location, - ] - - for parent, child in intertrack_edges: - parent_track_id = tracks.get_track_id(parent) - child_track_id = tracks.get_track_id(child) - if child_track_id in napari_edges: - napari_edges[child_track_id].append(parent_track_id) - else: - napari_edges[child_track_id] = [parent_track_id] - - return napari_data, napari_edges - - -class TrackGraph(napari.layers.Tracks): - """Extended tracks layer that holds the track information and emits and responds - to dynamics visualization signals""" - - def __init__( - self, - name: str, - tracks_viewer: TracksViewer, - ): - self.tracks_viewer = tracks_viewer - track_data, track_edges = update_napari_tracks( - self.tracks_viewer.tracks, - ) - - super().__init__( - data=track_data, - graph=track_edges, - name=name, - tail_length=3, - color_by="track_id", - ) - - self.colormaps_dict["track_id"] = self.tracks_viewer.colormap - self.tracks_layer_graph = copy.deepcopy(self.graph) # for restoring graph later - self.colormap = "turbo" # just to 'refresh' the track_id colormap, we do not actually use turbo - - def _refresh(self): - """Refreshes the displayed tracks based on the graph in the current tracks_viewer.tracks""" - - track_data, track_edges = update_napari_tracks( - self.tracks_viewer.tracks, - ) - - self.data = track_data - self.graph = track_edges - self.tracks_layer_graph = copy.deepcopy(self.graph) - self.colormaps_dict["track_id"] = self.tracks_viewer.colormap - self.colormap = "turbo" # just to 'refresh' the track_id colormap, we do not actually use turbo - - def update_track_visibility(self, visible: list[int] | str) -> None: - """Optionally show only the tracks of a current lineage""" - - if visible == "all": - self.track_colors[:, 3] = 1 - self.graph = self.tracks_layer_graph - else: - track_id_mask = np.isin( - self.properties["track_id"], - visible, - ) - self.graph = { - key: self.tracks_layer_graph[key] - for key in visible - if key in self.tracks_layer_graph - } - - self.track_colors[:, 3] = 0 - self.track_colors[track_id_mask, 3] = 1 - if len(self.graph.items()) == 0: - self.display_graph = False # empty dicts to not trigger update (bug?) so disable the graph entirely as a workaround - else: - self.display_graph = True diff --git a/src/motile_tracker/data_views/views/layers/track_labels.py b/src/motile_tracker/data_views/views/layers/track_labels.py deleted file mode 100644 index e20e05a9..00000000 --- a/src/motile_tracker/data_views/views/layers/track_labels.py +++ /dev/null @@ -1,460 +0,0 @@ -from __future__ import annotations - -import random -from typing import TYPE_CHECKING - -import napari -import numpy as np -from napari.utils import DirectLabelColormap -from napari.utils.action_manager import action_manager -from napari.utils.notifications import show_info, show_warning -from napari.utils.translations import trans - -if TYPE_CHECKING: - from napari.utils.events import Event - - from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer - -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr - - -def new_label(layer: TrackLabels): - """A function to override the default napari labels new_label function. - Must be registered (see end of this file)""" - layer.events.selected_label.disconnect(layer._ensure_valid_label) - _new_label(layer, new_track_id=True) - layer.events.selected_label.connect(layer._ensure_valid_label) - - -def _new_label(layer: TrackLabels, new_track_id=True): - """A function to get a new label for a given TrackLabels layer. Should properly - go on the class, but needs to be registered to override the default napari function - in the action manager. This helper is abstracted out because we want to do the same - thing without making a new track id in the layer, and with the new track id in the - overriden action. - - Args: - layer (TrackLabels): A TrackLabels layer from which get a new label for drawing a - new segmentation. Updates the selected_label attribute. - new_track_id (bool, optional): If you should also generate a new track id and set - it to the selected_track attribute. Defaults to True. - """ - - if isinstance(layer.data, np.ndarray): - new_selected_label = np.max(layer.data) + 1 - if layer.selected_label == new_selected_label: - show_info( - trans._( - "Current selected label is not being used. You will need to use it first " - "to be able to set the current select label to the next one available", - ) - ) - else: - if new_track_id: - new_selected_track = layer.tracks_viewer.tracks.get_next_track_id() - layer.selected_track = new_selected_track - layer.selected_label = new_selected_label - layer.colormap.color_dict[new_selected_label] = ( - layer.tracks_viewer.colormap.map(layer.selected_track) - ) - layer.colormap = DirectLabelColormap( - color_dict=layer.colormap.color_dict - ) # to refresh, otherwise you paint with a transparent label until you release the mouse - else: - show_info( - trans._("Calculating empty label on non-numpy array is not supported") - ) - - -class TrackLabels(napari.layers.Labels): - """Extended labels layer that holds the track information and emits - and responds to dynamics visualization signals""" - - @property - def _type_string(self) -> str: - return "labels" # to make sure that the layer is treated as labels layer for saving - - def __init__( - self, - viewer: napari.Viewer, - data: np.array, - name: str, - opacity: float, - scale: tuple, - tracks_viewer: TracksViewer, - ): - self.tracks_viewer = tracks_viewer - self.selected_track = None - colormap = self._get_colormap() - - super().__init__( - data=data, - name=name, - opacity=opacity, - colormap=colormap, - scale=scale, - ) - - self.viewer = viewer - - # Key bindings (should be specified both on the viewer (in tracks_viewer) - # and on the layer to overwrite napari defaults) - self.bind_key("q")(self.tracks_viewer.toggle_display_mode) - self.bind_key("a")(self.tracks_viewer.create_edge) - self.bind_key("d")(self.tracks_viewer.delete_node) - self.bind_key("Delete")(self.tracks_viewer.delete_node) - self.bind_key("b")(self.tracks_viewer.delete_edge) - # self.bind_key("s")(self.tracks_viewer.set_split_node) - # self.bind_key("e")(self.tracks_viewer.set_endpoint_node) - # self.bind_key("c")(self.tracks_viewer.set_linear_node) - self.bind_key("z")(self.tracks_viewer.undo) - self.bind_key("r")(self.tracks_viewer.redo) - - # Connect click events to node selection - @self.mouse_drag_callbacks.append - def click(_, event): - if ( - event.type == "mouse_press" - and self.mode == "pan_zoom" - and not ( - self.tracks_viewer.mode == "lineage" - and self.viewer.dims.ndisplay == 3 - ) - ): # disable selecting in lineage mode in 3D - label = self.get_value( - event.position, - view_direction=event.view_direction, - dims_displayed=event.dims_displayed, - world=True, - ) - - if ( - label is not None - and label != 0 - and self.colormap.map(label)[-1] != 0 - ): # check opacity (=visibility) in the colormap - append = "Shift" in event.modifiers - self.tracks_viewer.selected_nodes.add(label, append) - - # Listen to paint events and changing the selected label - self.events.paint.connect(self._on_paint) - self.tracks_viewer.selected_nodes.list_updated.connect( - self.update_selected_label - ) - self.events.selected_label.connect(self._ensure_valid_label) - self.events.mode.connect(self._check_mode) - self.viewer.dims.events.current_step.connect(self._ensure_valid_label) - - def _get_colormap(self) -> DirectLabelColormap: - """Get a DirectLabelColormap that maps node ids to their track ids, and then - uses the tracks_viewer.colormap to map from track_id to color. - - Returns: - DirectLabelColormap: A map from node ids to colors based on track id - """ - tracks = self.tracks_viewer.tracks - if tracks is not None: - nodes = list(tracks.graph.nodes()) - track_ids = [tracks.get_track_id(node) for node in nodes] - colors = [self.tracks_viewer.colormap.map(tid) for tid in track_ids] - else: - nodes = [] - colors = [] - return DirectLabelColormap( - color_dict={ - **dict(zip(nodes, colors, strict=True)), - None: [0, 0, 0, 0], - } - ) - - def _check_mode(self): - """Check if the mode is valid and call the ensure_valid_label function""" - - self.events.mode.disconnect( - self._check_mode - ) # here disconnecting the event listener is still necessary because self.mode = paint triggers the event internally and it is not blocked with event.blocker() - if self.mode == "polygon": - show_info( - trans._( - "Please use the paint tool to update the label", - ) - ) - self.mode = "paint" - - self._ensure_valid_label() - self.events.mode.connect(self._check_mode) - - def redo(self): - """Overwrite the redo functionality of the labels layer and invoke redo action on the tracks_viewer.tracks_controller first""" - - self.tracks_viewer.redo() - - def undo(self): - """Overwrite undo function and invoke undo action on the tracks_viewer.tracks_controller""" - - self.tracks_viewer.undo() - - def _parse_paint_event(self, event_val): - """_summary_ - - Args: - event_val (list[tuple]): A list of paint "atoms" generated by the labels layer. - Each atom is a 3-tuple of arrays containing: - - a numpy multi-index, pointing to the array elements that were - changed (a tuple with len ndims) - - the values corresponding to those elements before the change - - the value after the change - Returns: - tuple(int, list[tuple]): The new value, and a list of node update actions - defined by the time point and node update item - Each "action" is a 2-tuple containing: - - a numpy multi-index, pointing to the array elements that were - changed (a tuple with len ndims) - - the value before the change - """ - - new_value = event_val[-1][-1] - ndim = len(event_val[-1][0]) - concatenated_indices = tuple( - np.concatenate([ev[0][dim] for ev in event_val]) for dim in range(ndim) - ) - concatenated_values = np.concatenate([ev[1] for ev in event_val]) - old_values = np.unique(concatenated_values) - actions = [] - for old_value in old_values: - mask = concatenated_values == old_value - indices = tuple(concatenated_indices[dim][mask] for dim in range(ndim)) - time_points = np.unique(indices[0]) - for time in time_points: - time_mask = indices[0] == time - actions.append( - (tuple(indices[dim][time_mask] for dim in range(ndim)), old_value) - ) - return new_value, actions - - def _revert_paint(self, event): - """Revert a paint event after it fails validation (no motile tracker Actions have - been created). This keeps the view synced with the backend data. - """ - super().undo() - - def _on_paint(self, event): - """Listen to the paint event and check which track_ids have changed""" - - with self.events.selected_label.blocker(): - current_timepoint = self.viewer.dims.current_step[ - 0 - ] # also pass on the current time point to know which node to select later - new_value, updated_pixels = self._parse_paint_event(event.value) - # updated_pixels is a list of tuples. Each tuple is (indices, old_value) - to_delete = [] # (node_ids, pixels) - to_update_smaller = [] # (node_id, pixels) - to_update_bigger = [] # (node_id, pixels) - to_add = [] # (node_id, track_id, pixels) - for pixels, old_value in updated_pixels: - ndim = len(pixels) - if old_value == 0: - continue - time = pixels[0][0] - removed_node = old_value - assert ( - removed_node is not None - ), f"Node with label {old_value} in time {time} was not found" - # check if all pixels of old_value are removed - if np.sum(self.data[time] == old_value) == 0: - to_delete.append((removed_node, pixels)) - else: - to_update_smaller.append((removed_node, pixels)) - if new_value != 0: - all_pixels = tuple( - np.concatenate([pixels[dim] for pixels, _ in updated_pixels]) - for dim in range(ndim) - ) - for _ in np.unique(all_pixels[0]): - existing_node = self.tracks_viewer.tracks.graph.has_node(new_value) - if existing_node: - to_update_bigger.append((new_value, all_pixels)) - else: - to_add.append((new_value, self.selected_track, all_pixels)) - - if len(to_delete) > 0 and len(to_add) > 0: - show_warning( - "This paint or fill operation completely replaced one label with a new label. This is currently not supported." - " If you want to update the track id of the node, please edit the edges directly instead." - ) - self._revert_paint(event) - self.refresh() - return - self.tracks_viewer.tracks_controller.update_segmentations( - to_delete, - to_update_smaller, - to_update_bigger, - to_add, - current_timepoint, - ) - - def _refresh(self): - """Refresh the data in the labels layer""" - self.data = self.tracks_viewer.tracks.segmentation - self.colormap = self._get_colormap() - self.refresh() - - def update_label_colormap(self, visible: list[int] | str) -> None: - """Updates the opacity of the label colormap to highlight the selected label - and optionally hide cells not belonging to the current lineage - - Visible is a list of visible node id - """ - with self.events.selected_label.blocker(): - highlighted = self.tracks_viewer.selected_nodes - - # update the opacity of the cyclic label colormap values according to whether nodes are visible/invisible/highlighted - if visible == "all": - self.colormap.color_dict = { - key: np.array( - [ - *value[:-1], - 0.6 if key is not None and key != 0 else value[-1], - ], - dtype=np.float32, - ) - for key, value in self.colormap.color_dict.items() - } - - else: - self.colormap.color_dict = { - key: np.array([*value[:-1], 0], dtype=np.float32) - for key, value in self.colormap.color_dict.items() - } - for node in visible: - # find the index in the colormap - self.colormap.color_dict[node][-1] = 0.6 - - for node in highlighted: - self.colormap.color_dict[node][-1] = 1 # full opacity - - self.colormap = DirectLabelColormap( - color_dict=self.colormap.color_dict - ) # create a new colormap from the updated colors (otherwise it does not refresh) - - def new_colormap(self): - """Override existing function to generate new colormap on tracks_viewer and - emit refresh signal to update colors in all layers/widgets""" - - self.tracks_viewer.colormap = napari.utils.colormaps.label_colormap( - 49, - seed=random.uniform(0, 1), - background_value=0, - ) - self.tracks_viewer._refresh() - - def update_selected_label(self): - """Update the selected label in the labels layer""" - - self.events.selected_label.disconnect(self._ensure_valid_label) - if len(self.tracks_viewer.selected_nodes) > 0: - self.selected_label = int(self.tracks_viewer.selected_nodes[0]) - self.events.selected_label.connect(self._ensure_valid_label) - - def _ensure_valid_label(self, event: Event | None = None): - """Make sure a valid label is selected, because it is not allowed to paint with a - label that already exists at a different timepoint. - Scenarios: - 1. If a node with the selected label value (node id) exists at a different time point, - check if there is any node with the same track_id at the current time point - 1.a if there is a node with the same track id, select that one, so that it can be used to update an existing node - 1.b if there is no node with the same track id, create a new node id and paint with the track_id of the selected label. - This can be used to add a new node with the same track id at a time point where it does not (yet) exist (anymore). - 2. if there is no existing node with this value in the graph, it is assume that you want to add a node with the current track id - Retrieve the track_id from self.current_track_id and use it to find if there are any nodes of this track id - at current time point - 3. If no node with this label exists yet, it is valid and can be used to start a new track id. - Therefore, create a new node id and map a new color. Add it to the dictionary. - 4. If a node with the label exists at the current time point, it is valid and can be used to update the existing node in a paint event. No action is needed""" - - if self.tracks_viewer.tracks is not None and self.mode in ( - "fill", - "paint", - "erase", - "pick", - ): - self.events.selected_label.disconnect(self._ensure_valid_label) - - current_timepoint = self.viewer.dims.current_step[0] - # if a node with the given label is already in the graph - if self.tracks_viewer.tracks.graph.has_node(self.selected_label): - # Update the track id - self.selected_track = self.tracks_viewer.tracks._get_node_attr( - self.selected_label, NodeAttr.TRACK_ID.value - ) - existing_time = self.tracks_viewer.tracks._get_node_attr( - self.selected_label, NodeAttr.TIME.value - ) - if existing_time == current_timepoint: - # we are changing the existing node. This is fine - pass - else: - # if there is already a node in that track in this frame, edit that instead - edit = False - if ( - self.selected_track - in self.tracks_viewer.tracks.track_id_to_node - ): - for node in self.tracks_viewer.tracks.track_id_to_node[ - self.selected_track - ]: - if ( - self.tracks_viewer.tracks._get_node_attr( - node, NodeAttr.TIME.value - ) - == current_timepoint - ): - self.selected_label = int(node) - edit = True - break - - if not edit: - # use a new label, but the same track id - _new_label(self, new_track_id=False) - self.colormap = DirectLabelColormap( - color_dict=self.colormap.color_dict - ) - - # the current node does not exist in the graph. - # Use the current selected_track as the track id (will be a new track if a new label was found with "m") - # Check that the track id is not already in this frame. - else: - # if there is already a node in that track in this frame, edit that instead - edit = False - if self.selected_track in self.tracks_viewer.tracks.track_id_to_node: - for node in self.tracks_viewer.tracks.track_id_to_node[ - self.selected_track - ]: - if ( - self.tracks_viewer.tracks._get_node_attr( - node, NodeAttr.TIME.value - ) - == current_timepoint - ): - self.selected_label = int(node) - edit = True - break - - self.events.selected_label.connect(self._ensure_valid_label) - - @napari.layers.Labels.n_edit_dimensions.setter - def n_edit_dimensions(self, n_edit_dimensions): - # Overriding the setter to disable editing in time dimension - if n_edit_dimensions > self.tracks_viewer.tracks.ndim - 1: - n_edit_dimensions = self.tracks_viewer.tracks.ndim - 1 - self._n_edit_dimensions = n_edit_dimensions - self.events.n_edit_dimensions() - - -# This is to override the default napari function to get a new label for the labels layer -action_manager.register_action( - name="napari:new_label", - command=new_label, - keymapprovider=TrackLabels, - description="", -) diff --git a/src/motile_tracker/data_views/views/layers/track_points.py b/src/motile_tracker/data_views/views/layers/track_points.py deleted file mode 100644 index c4ec0722..00000000 --- a/src/motile_tracker/data_views/views/layers/track_points.py +++ /dev/null @@ -1,232 +0,0 @@ -from __future__ import annotations - -import math -from typing import TYPE_CHECKING - -import napari -import numpy as np -from motile_toolbox.candidate_graph import NodeAttr -from napari.utils.notifications import show_info - -from motile_tracker.data_model import NodeType, Tracks - -if TYPE_CHECKING: - from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer - - -class TrackPoints(napari.layers.Points): - """Extended points layer that holds the track information and emits and - responds to dynamics visualization signals - """ - - @property - def _type_string(self) -> str: - return "points" # to make sure that the layer is treated as points layer for saving - - def __init__( - self, - name: str, - tracks_viewer: TracksViewer, - ): - self.tracks_viewer = tracks_viewer - self.nodes = list(tracks_viewer.tracks.graph.nodes) - self.node_index_dict = {node: idx for idx, node in enumerate(self.nodes)} - - points = self.tracks_viewer.tracks.get_positions(self.nodes, incl_time=True) - track_ids = [ - self.tracks_viewer.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value] - for node in self.nodes - ] - colors = [self.tracks_viewer.colormap.map(track_id) for track_id in track_ids] - symbols = self.get_symbols( - self.tracks_viewer.tracks, self.tracks_viewer.symbolmap - ) - - self.default_size = 5 - - super().__init__( - data=points, - name=name, - symbol=symbols, - face_color=colors, - size=self.default_size, - properties={ - "node_id": self.nodes, - "track_id": track_ids, - }, # TODO: use features - border_color=[1, 1, 1, 1], - blending="translucent_no_depth", - ) - - # Key bindings (should be specified both on the viewer (in tracks_viewer) - # and on the layer to overwrite napari defaults) - self.bind_key("q")(self.tracks_viewer.toggle_display_mode) - self.bind_key("a")(self.tracks_viewer.create_edge) - self.bind_key("d")(self.tracks_viewer.delete_node) - self.bind_key("Delete")(self.tracks_viewer.delete_node) - self.bind_key("b")(self.tracks_viewer.delete_edge) - # self.bind_key("s")(self.tracks_viewer.set_split_node) - # self.bind_key("e")(self.tracks_viewer.set_endpoint_node) - # self.bind_key("c")(self.tracks_viewer.set_linear_node) - self.bind_key("z")(self.tracks_viewer.undo) - self.bind_key("r")(self.tracks_viewer.redo) - - # Connect to click events to select nodes - @self.mouse_drag_callbacks.append - def click(layer, event): - if event.type == "mouse_press": - # is the value passed from the click event? - point_index = layer.get_value( - event.position, - view_direction=event.view_direction, - dims_displayed=event.dims_displayed, - world=True, - ) - if point_index is not None: - node_id = self.nodes[point_index] - append = "Shift" in event.modifiers - self.tracks_viewer.selected_nodes.add(node_id, append) - - # listen to updates of the data - self.events.data.connect(self._update_data) - - # listen to updates in the selected data (from the point selection tool) - # to update the nodes in self.tracks_viewer.selected_nodes - self.selected_data.events.items_changed.connect(self._update_selection) - - def set_point_size(self, size: int) -> None: - """Sets a new default point size""" - - self.default_size = size - self._refresh() - - def _refresh(self): - """Refresh the data in the points layer""" - - self.events.data.disconnect( - self._update_data - ) # do not listen to new events until updates are complete - self.nodes = list(self.tracks_viewer.tracks.graph.nodes) - - self.node_index_dict = {node: idx for idx, node in enumerate(self.nodes)} - - track_ids = [ - self.tracks_viewer.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value] - for node in self.nodes - ] - self.data = self.tracks_viewer.tracks.get_positions(self.nodes, incl_time=True) - self.symbol = self.get_symbols( - self.tracks_viewer.tracks, self.tracks_viewer.symbolmap - ) - self.face_color = [ - self.tracks_viewer.colormap.map(track_id) for track_id in track_ids - ] - self.properties = {"node_id": self.nodes, "track_id": track_ids} - self.size = self.default_size - self.border_color = [1, 1, 1, 1] - - self.events.data.connect( - self._update_data - ) # reconnect listening to update events - - def _create_node_attrs(self, new_point: np.array) -> tuple[np.array, dict]: - """Create attributes for a new node at given time point""" - - t = int(new_point[0]) - track_id = self.tracks_viewer.tracks.get_next_track_id() - area = 0 - - attributes = { - NodeAttr.POS.value: np.array([new_point[1:]]), - NodeAttr.TIME.value: np.array([t]), - NodeAttr.TRACK_ID.value: np.array([track_id]), - NodeAttr.AREA.value: np.array([area]), - } - return attributes - - def _update_data(self, event): - """Calls the tracks controller with to update the data in the Tracks object and dispatch the update""" - - if event.action == "added": - # we only want to allow this update if there is no seg layer - if self.tracks_viewer.tracking_layers.seg_layer is None: - new_point = event.value[-1] - attributes = self._create_node_attrs(new_point) - self.tracks_viewer.tracks_controller.add_nodes(attributes) - else: - show_info( - "Mixed point and segmentation nodes not allowed: add points by drawing on segmentation layer" - ) - self._refresh() - - if event.action == "removed": - self.tracks_viewer.tracks_controller.delete_nodes( - self.tracks_viewer.selected_nodes._list - ) - - if event.action == "changed": - # we only want to allow this update if there is no seg layer - if self.tracks_viewer.tracking_layers.seg_layer is None: - positions = [] - node_ids = [] - for ind in self.selected_data: - point = self.data[ind] - pos = point[1:] - positions.append(pos) - node_id = self.properties["node_id"][ind] - node_ids.append(node_id) - - attributes = {NodeAttr.POS.value: positions} - self.tracks_viewer.tracks_controller.update_node_attrs( - node_ids, attributes - ) - else: - self._refresh() # refresh to move points back where they belong - - def _update_selection(self): - """Replaces the list of selected_nodes with the selection provided by the user""" - - selected_points = self.selected_data - self.tracks_viewer.selected_nodes.reset() - for point in selected_points: - node_id = self.nodes[point] - self.tracks_viewer.selected_nodes.add(node_id, True) - - def get_symbols(self, tracks: Tracks, symbolmap: dict[NodeType, str]) -> list[str]: - statemap = { - 0: NodeType.END, - 1: NodeType.CONTINUE, - 2: NodeType.SPLIT, - } - symbols = [symbolmap[statemap[degree]] for _, degree in tracks.graph.out_degree] - return symbols - - def update_point_outline(self, visible: list[int] | str) -> None: - """Update the outline color of the selected points and visibility according to display mode - - Args: - visible (list[int] | str): A list of track ids, or "all" - """ - # filter out the non-selected tracks if in lineage mode - if visible == "all": - self.shown[:] = True - else: - indices = np.where(np.isin(self.properties["track_id"], visible))[ - 0 - ].tolist() - self.shown[:] = False - self.shown[indices] = True - - # set border color for selected item - self.border_color = [1, 1, 1, 1] - self.size = self.default_size - for node in self.tracks_viewer.selected_nodes: - index = self.node_index_dict[node] - self.border_color[index] = ( - 0, - 1, - 1, - 1, - ) - self.size[index] = math.ceil(self.default_size + 0.3 * self.default_size) - self.refresh() diff --git a/src/motile_tracker/data_views/views/layers/tracks_layer_group.py b/src/motile_tracker/data_views/views/layers/tracks_layer_group.py deleted file mode 100644 index 67c30b1f..00000000 --- a/src/motile_tracker/data_views/views/layers/tracks_layer_group.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import napari - -from motile_tracker.data_model.tracks import Tracks - -from .track_graph import TrackGraph -from .track_labels import TrackLabels -from .track_points import TrackPoints - -if TYPE_CHECKING: - from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer - - -class TracksLayerGroup: - def __init__( - self, - viewer: napari.Viewer, - tracks: Tracks, - name: str, - tracks_viewer: TracksViewer, - ): - self.viewer = viewer - self.tracks_viewer = tracks_viewer - self.tracks = tracks - self.name = name - self.tracks_layer: TrackGraph | None = None - self.points_layer: TrackPoints | None = None - self.seg_layer: TrackLabels | None = None - - def set_tracks(self, tracks, name): - self.remove_napari_layers() - self.tracks = tracks - self.name = name - # Create new layers - if self.tracks is not None and self.tracks.segmentation is not None: - self.seg_layer = TrackLabels( - viewer=self.viewer, - data=self.tracks.segmentation, - name=self.name + "_seg", - opacity=0.9, - scale=self.tracks.scale, - tracks_viewer=self.tracks_viewer, - ) - else: - self.seg_layer = None - - if ( - self.tracks is not None - and self.tracks.graph is not None - and self.tracks.graph.number_of_nodes() != 0 - ): - self.tracks_layer = TrackGraph( - name=self.name + "_tracks", - tracks_viewer=self.tracks_viewer, - ) - - self.points_layer = TrackPoints( - name=self.name + "_points", - tracks_viewer=self.tracks_viewer, - ) - else: - self.tracks_layer = None - self.points_layer = None - self.add_napari_layers() - - def remove_napari_layer(self, layer: napari.layers.Layer | None) -> None: - """Remove a layer from the napari viewer, if present""" - if layer and layer in self.viewer.layers: - self.viewer.layers.remove(layer) - - def remove_napari_layers(self) -> None: - """Remove all tracking layers from the viewer""" - self.remove_napari_layer(self.tracks_layer) - self.remove_napari_layer(self.seg_layer) - self.remove_napari_layer(self.points_layer) - - def add_napari_layers(self) -> None: - """Add new tracking layers to the viewer""" - if self.tracks_layer is not None: - self.viewer.add_layer(self.tracks_layer) - if self.seg_layer is not None: - self.viewer.add_layer(self.seg_layer) - if self.points_layer is not None: - self.viewer.add_layer(self.points_layer) - - def _refresh(self) -> None: - """Refresh the tracking layers with new tracks info""" - if self.tracks_layer is not None: - self.tracks_layer._refresh() - if self.seg_layer is not None: - self.seg_layer._refresh() - if self.points_layer is not None: - self.points_layer._refresh() - - def update_visible(self, visible_tracks: list[int], visible_nodes: list[int]): - if self.seg_layer is not None: - self.seg_layer.update_label_colormap(visible_nodes) - if self.points_layer is not None: - self.points_layer.update_point_outline(visible_tracks) - if self.tracks_layer is not None: - self.tracks_layer.update_track_visibility(visible_tracks) - - def center_view(self, node): - """Adjust the current_step and camera center of the viewer to jump to the node - location, if the node is not already in the field of view""" - - if self.seg_layer is None or self.seg_layer.mode == "pan_zoom": - location = self.tracks.get_positions([node], incl_time=True)[0].tolist() - assert ( - len(location) == self.viewer.dims.ndim - ), f"Location {location} does not match viewer number of dims {self.viewer.dims.ndim}" - - step = list(self.viewer.dims.current_step) - for dim in self.viewer.dims.not_displayed: - step[dim] = int( - location[dim] + 0.5 - ) # use the world location, since the 'step' in viewer.dims.range - # already in world units - self.viewer.dims.current_step = step - - # check whether the new coordinates are inside or outside the field of view, - # then adjust the camera if needed - example_layer = ( - self.points_layer - ) # the points layer is always in world units, - # because it directly reads the scaled coordinates. Therefore, no rescaling - # is necessary to compute the camera center - corner_coordinates = example_layer.corner_pixels - - # check which dimensions are shown, the first dimension is displayed on the - # x axis, and the second on the y_axis - dims_displayed = self.viewer.dims.displayed - - # Note: This centering does not work in 3D. What we should do instead is take - # the view direction vector, start at the point, and move backward along the - # vector a certain amount to put the point in view. - # Note #2: Points already does centering when you add the first point, and it - # works in 3D. We can look at that to see what logic they use. - - # self.viewer.dims.displayed_order - x_dim = dims_displayed[-1] - y_dim = dims_displayed[-2] - - # find corner pixels for the displayed axes - _min_x = corner_coordinates[0][x_dim] - _max_x = corner_coordinates[1][x_dim] - _min_y = corner_coordinates[0][y_dim] - _max_y = corner_coordinates[1][y_dim] - - # check whether the node location falls within the corner spatial range - if not ( - (location[x_dim] > _min_x and location[x_dim] < _max_x) - and (location[y_dim] > _min_y and location[y_dim] < _max_y) - ): - camera_center = self.viewer.camera.center - - # set the center y and x to the center of the node, by using the index of the - # currently displayed dimensions - self.viewer.camera.center = ( - camera_center[0], - location[y_dim], - # camera center is calculated in scaled coordinates, and the optional - # labels layer is scaled by the layer.scale attribute - location[x_dim], - ) diff --git a/src/motile_tracker/data_views/views/tree_view/__init__.py b/src/motile_tracker/data_views/views/tree_view/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/motile_tracker/data_views/views/tree_view/flip_axes_widget.py b/src/motile_tracker/data_views/views/tree_view/flip_axes_widget.py deleted file mode 100644 index 8721e77b..00000000 --- a/src/motile_tracker/data_views/views/tree_view/flip_axes_widget.py +++ /dev/null @@ -1,29 +0,0 @@ -from psygnal import Signal -from qtpy.QtWidgets import QGroupBox, QPushButton, QVBoxLayout, QWidget - - -class FlipTreeWidget(QWidget): - """Widget to flip the axis of the tree view""" - - flip_tree = Signal() - - def __init__(self): - super().__init__() - - flip_layout = QVBoxLayout() - display_box = QGroupBox("Plot axes [F]") - flip_button = QPushButton("Flip") - flip_button.clicked.connect(self.flip) - flip_layout.addWidget(flip_button) - display_box.setLayout(flip_layout) - - layout = QVBoxLayout() - layout.addWidget(display_box) - self.setLayout(layout) - display_box.setMaximumWidth(90) - display_box.setMaximumHeight(82) - - def flip(self): - """Send a signal to flip the axes of the plot""" - - self.flip_tree.emit() diff --git a/src/motile_tracker/data_views/views/tree_view/navigation_widget.py b/src/motile_tracker/data_views/views/tree_view/navigation_widget.py deleted file mode 100644 index 27dc891a..00000000 --- a/src/motile_tracker/data_views/views/tree_view/navigation_widget.py +++ /dev/null @@ -1,186 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pandas as pd -from qtpy.QtWidgets import ( - QGroupBox, - QHBoxLayout, - QPushButton, - QWidget, -) - -if TYPE_CHECKING: - from motile_tracker.data_views import NodeSelectionList - - -class NavigationWidget(QWidget): - def __init__( - self, - track_df: pd.DataFrame, - lineage_df: pd.DataFrame, - view_direction: str, - selected_nodes: NodeSelectionList, - feature: str, - ): - """Widget for controlling navigation in the tree widget - - Args: - track_df (pd.DataFrame): The dataframe holding the track information - view_direction (str): The view direction of the tree widget. Options: "vertical", "horizontal". - selected_nodes (NodeSelectionList): The list of selected nodes. - feature (str): The feature currently being displayed - """ - - super().__init__() - self.track_df = track_df - self.lineage_df = lineage_df - self.view_direction = view_direction - self.selected_nodes = selected_nodes - self.feature = feature - - navigation_box = QGroupBox("Navigation [\u2b05 \u27a1 \u2b06 \u2b07]") - navigation_layout = QHBoxLayout() - left_button = QPushButton("\u2b05") - right_button = QPushButton("\u27a1") - up_button = QPushButton("\u2b06") - down_button = QPushButton("\u2b07") - - left_button.clicked.connect(lambda: self.move("left")) - right_button.clicked.connect(lambda: self.move("right")) - up_button.clicked.connect(lambda: self.move("up")) - down_button.clicked.connect(lambda: self.move("down")) - - navigation_layout.addWidget(left_button) - navigation_layout.addWidget(right_button) - navigation_layout.addWidget(up_button) - navigation_layout.addWidget(down_button) - navigation_box.setLayout(navigation_layout) - navigation_box.setMaximumWidth(250) - navigation_box.setMaximumHeight(60) - - layout = QHBoxLayout() - layout.addWidget(navigation_box) - - self.setLayout(layout) - - def move(self, direction: str) -> None: - """Move in the given direction on the tree view. Will select the next - node in that direction, based on the orientation of the widget. - - Args: - direction (str): The direction to move. Options: "up", "down", - "left", "right" - """ - if len(self.selected_nodes) == 0: - return - node_id = self.selected_nodes[0] - - if direction == "left": - if self.view_direction == "horizontal": - next_node = self.get_predecessor(node_id) - else: - next_node = self.get_next_track_node( - self.track_df, node_id, forward=False - ) - elif direction == "right": - if self.view_direction == "horizontal": - next_node = self.get_successor(node_id) - else: - next_node = self.get_next_track_node(self.track_df, node_id) - elif direction == "up": - if self.view_direction == "horizontal": - next_node = self.get_next_track_node(self.lineage_df, node_id) - if next_node is None: - next_node = self.get_next_track_node(self.track_df, node_id) - else: - next_node = self.get_predecessor(node_id) - elif direction == "down": - if self.view_direction == "horizontal": - # try navigation within the current lineage_df first - next_node = self.get_next_track_node( - self.lineage_df, node_id, forward=False - ) - # if not found, look in the whole dataframe - # to enable jumping to the next node outside the current tree view content - if next_node is None: - next_node = self.get_next_track_node( - self.track_df, node_id, forward=False - ) - else: - next_node = self.get_successor(node_id) - else: - raise ValueError( - f"Direction must be one of 'left', 'right', 'up', 'down', got {direction}" - ) - if next_node is not None: - self.selected_nodes.add(next_node) - - def get_next_track_node( - self, df: pd.DataFrame, node_id: str, forward=True - ) -> str | None: - """Get the node at the same time point in an adjacent track. - - Args: - df (pd.DataFrame): The dataframe to be used (full track_df or subset lineage_df). - node_id (str): The current node ID to get the next from. - forward (bool, optional): If true, pick the next track (right/down). - Otherwise, pick the previous track (left/up). Defaults to True. - """ - # Determine which axis to use for finding neighbors - axis_label = "area" if self.feature == "area" else "x_axis_pos" - - if df.empty: - return None - node_data = df.loc[df["node_id"] == node_id] - if node_data.empty: - return None - - # Fetch the axis value for the given node ID - axis_label_value = node_data[axis_label].iloc[0] - t = node_data["t"].iloc[0] - - if forward: - neighbors = df.loc[(df[axis_label] > axis_label_value) & (df["t"] == t)] - else: - neighbors = df.loc[(df[axis_label] < axis_label_value) & (df["t"] == t)] - if not neighbors.empty: - # Find the closest index label - closest_index_label = ( - (neighbors[axis_label] - axis_label_value).abs().idxmin() - ) - neighbor = neighbors.loc[closest_index_label, "node_id"] - return neighbor - - def get_predecessor(self, node_id: str) -> str | None: - """Get the predecessor node of the given node_id - - Args: - node_id (str): the node id to get the predecessor of - - Returns: - str | None: THe node id of the predecessor, or none if no predecessor - is found - """ - parent_id = self.track_df.loc[ - self.track_df["node_id"] == node_id, "parent_id" - ].values[0] - parent_row = self.track_df.loc[self.track_df["node_id"] == parent_id] - if not parent_row.empty: - return parent_row["node_id"].values[0] - - def get_successor(self, node_id: str) -> str | None: - """Get the successor node of the given node_id. If there are two children, - picks one arbitrarily. - - Args: - node_id (str): the node id to get the successor of - - Returns: - str | None: THe node id of the successor, or none if no successor - is found - """ - children = self.track_df.loc[self.track_df["parent_id"] == node_id] - if not children.empty: - child = children.to_dict("records")[0] - return child["node_id"] diff --git a/src/motile_tracker/data_views/views/tree_view/tree_view_feature_widget.py b/src/motile_tracker/data_views/views/tree_view/tree_view_feature_widget.py deleted file mode 100644 index 7058a45c..00000000 --- a/src/motile_tracker/data_views/views/tree_view/tree_view_feature_widget.py +++ /dev/null @@ -1,60 +0,0 @@ -from psygnal import Signal -from qtpy.QtWidgets import ( - QButtonGroup, - QGroupBox, - QHBoxLayout, - QRadioButton, - QVBoxLayout, - QWidget, -) - - -class TreeViewFeatureWidget(QWidget): - """Widget to switch between viewing all nodes versus nodes of one or more lineages in the tree widget""" - - change_feature = Signal(str) - - def __init__(self): - super().__init__() - - self.feature = "tree" - - display_box = QGroupBox("Feature [W]") - display_layout = QHBoxLayout() - button_group = QButtonGroup() - self.show_tree_radio = QRadioButton("Lineage Tree") - self.show_tree_radio.setChecked(True) - self.show_tree_radio.clicked.connect(lambda: self._set_feature("tree")) - self.show_area_radio = QRadioButton("Object size") - self.show_area_radio.clicked.connect(lambda: self._set_feature("area")) - button_group.addButton(self.show_tree_radio) - button_group.addButton(self.show_area_radio) - display_layout.addWidget(self.show_tree_radio) - display_layout.addWidget(self.show_area_radio) - display_box.setLayout(display_layout) - display_box.setMaximumWidth(250) - display_box.setMaximumHeight(60) - - layout = QVBoxLayout() - layout.addWidget(display_box) - - self.setLayout(layout) - - def _toggle_feature_mode(self, event=None) -> None: - """Toggle display mode""" - - if ( - self.show_area_radio.isEnabled - ): # if button is disabled, toggle is not allowed - if self.feature == "area": - self._set_feature("tree") - self.show_tree_radio.setChecked(True) - else: - self._set_feature("area") - self.show_area_radio.setChecked(True) - - def _set_feature(self, mode: str): - """Emit signal to change the display mode""" - - self.feature = mode - self.change_feature.emit(mode) diff --git a/src/motile_tracker/data_views/views/tree_view/tree_view_mode_widget.py b/src/motile_tracker/data_views/views/tree_view/tree_view_mode_widget.py deleted file mode 100644 index 552b77d7..00000000 --- a/src/motile_tracker/data_views/views/tree_view/tree_view_mode_widget.py +++ /dev/null @@ -1,57 +0,0 @@ -from psygnal import Signal -from qtpy.QtWidgets import ( - QButtonGroup, - QGroupBox, - QHBoxLayout, - QRadioButton, - QVBoxLayout, - QWidget, -) - - -class TreeViewModeWidget(QWidget): - """Widget to switch between viewing all nodes versus nodes of one or more lineages in the tree widget""" - - change_mode = Signal(str) - - def __init__(self): - super().__init__() - - self.mode = "all" - - display_box = QGroupBox("Display [Q]") - display_layout = QHBoxLayout() - button_group = QButtonGroup() - self.show_all_radio = QRadioButton("All cells") - self.show_all_radio.setChecked(True) - self.show_all_radio.clicked.connect(lambda: self._set_mode("all")) - self.show_lineage_radio = QRadioButton("Current lineage(s)") - self.show_lineage_radio.clicked.connect(lambda: self._set_mode("lineage")) - button_group.addButton(self.show_all_radio) - button_group.addButton(self.show_lineage_radio) - display_layout.addWidget(self.show_all_radio) - display_layout.addWidget(self.show_lineage_radio) - display_box.setLayout(display_layout) - display_box.setMaximumWidth(250) - display_box.setMaximumHeight(60) - - layout = QVBoxLayout() - layout.addWidget(display_box) - - self.setLayout(layout) - - def _toggle_display_mode(self, event=None) -> None: - """Toggle display mode""" - - if self.mode == "lineage": - self._set_mode("all") - self.show_all_radio.setChecked(True) - else: - self._set_mode("lineage") - self.show_lineage_radio.setChecked(True) - - def _set_mode(self, mode: str): - """Emit signal to change the display mode""" - - self.mode = mode - self.change_mode.emit(mode) diff --git a/src/motile_tracker/data_views/views/tree_view/tree_widget.py b/src/motile_tracker/data_views/views/tree_view/tree_widget.py deleted file mode 100644 index c7329af2..00000000 --- a/src/motile_tracker/data_views/views/tree_view/tree_widget.py +++ /dev/null @@ -1,737 +0,0 @@ -# do not put the from __future__ import annotations as it breaks the injection - -from typing import Any - -import napari -import numpy as np -import pandas as pd -import pyqtgraph as pg -from psygnal import Signal -from pyqtgraph.Qt import QtCore -from qtpy.QtCore import Qt -from qtpy.QtGui import QColor, QKeyEvent, QMouseEvent -from qtpy.QtWidgets import ( - QHBoxLayout, - QVBoxLayout, - QWidget, -) -from superqt import QCollapsible - -from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer - -from .flip_axes_widget import FlipTreeWidget -from .navigation_widget import NavigationWidget -from .tree_view_feature_widget import TreeViewFeatureWidget -from .tree_view_mode_widget import TreeViewModeWidget -from .tree_widget_utils import ( - extract_lineage_tree, - extract_sorted_tracks, -) - - -class CustomViewBox(pg.ViewBox): - selected_rect = Signal(Any) - - def __init__(self, *args, **kwds): - kwds["enableMenu"] = False - pg.ViewBox.__init__(self, *args, **kwds) - # self.setMouseMode(self.RectMode) - - ## reimplement right-click to zoom out - def mouseClickEvent(self, ev): - if ev.button() == QtCore.Qt.MouseButton.RightButton: - self.autoRange() - - def showAxRect(self, ax, **kwargs): - """Set the visible range to the given rectangle - Emits sigRangeChangedManually without changing the range. - """ - # Emit the signal without setting the range - self.sigRangeChangedManually.emit(self.state["mouseEnabled"]) - - def mouseDragEvent(self, ev, axis=None): - """Modified mouseDragEvent function to check which mouse mode to use - and to submit rectangle coordinates for selecting multiple nodes if necessary""" - - super().mouseDragEvent(ev, axis) - - # use RectMode when pressing shift - if ev.modifiers() == QtCore.Qt.ShiftModifier: - self.setMouseMode(self.RectMode) - - if ev.isStart(): - self.mouse_start_pos = self.mapSceneToView(ev.scenePos()) - elif ev.isFinish(): - rect_end_pos = self.mapSceneToView(ev.scenePos()) - rect = QtCore.QRectF(self.mouse_start_pos, rect_end_pos).normalized() - self.selected_rect.emit(rect) # emit the rectangle - ev.accept() - else: - ev.ignore() - else: - # Otherwise, set pan mode - self.setMouseMode(self.PanMode) - - -class TreePlot(pg.PlotWidget): - node_clicked = Signal(Any, bool) # node_id, append - nodes_selected = Signal(list, bool) - - def __init__(self) -> pg.PlotWidget: - """Construct the pyqtgraph treewidget. This is the actual canvas - on which the tree view is drawn. - """ - super().__init__(viewBox=CustomViewBox()) - - self.setFocusPolicy(Qt.StrongFocus) - self.setTitle("Lineage Tree") - - self._pos = [] - self.adj = [] - self.symbolBrush = [] - self.symbols = [] - self.pen = [] - self.outline_pen = [] - self.node_ids = [] - self.sizes = [] - - self.view_direction = None - self.feature = None - self.g = pg.GraphItem() - self.g.scatter.sigClicked.connect(self._on_click) - self.addItem(self.g) - self.set_view("vertical", feature="tree") - self.getViewBox().selected_rect.connect(self.select_points_in_rect) - - def select_points_in_rect(self, rect: QtCore.QRectF): - """Select all nodes in given rectangle""" - - scatter_data = self.g.scatter.data - x = scatter_data["x"] - y = scatter_data["y"] - data = scatter_data["data"] - - # Filter points that are within the rectangle - points_within_rect = [ - (x[i], y[i], data[i]) for i in range(len(x)) if rect.contains(x[i], y[i]) - ] - selected_nodes = [point[2] for point in points_within_rect] - self.nodes_selected.emit(selected_nodes, True) - - def update( - self, - track_df: pd.DataFrame, - view_direction: str, - feature: str, - selected_nodes: list[Any], - reset_view: bool | None = False, - allow_flip: bool | None = True, - ): - """Update the entire view, including the data, view direction, and - selected nodes - - Args: - track_df (pd.DataFrame): The dataframe containing the graph data - view_direction (str): The view direction - feature (str): The feature to be plotted ('tree' or 'area') - selected_nodes (list[Any]): The currently selected nodes to be highlighted - """ - self.set_data(track_df, feature) - self._update_viewed_data(view_direction) # this can be expensive - self.set_view(view_direction, feature, reset_view, allow_flip) - self.set_selection(selected_nodes, feature) - - def set_view( - self, - view_direction: str, - feature: str, - reset_view: bool | None = False, - allow_flip: bool | None = True, - ): - """Set the view direction, saving the new value as an attribute and - changing the axes labels. Shortcuts if the view direction is already - correct. Does not actually update the rendered graph (need to call - _update_viewed_data). - - Args: - view_direction (str): "horizontal" or "vertical" - feature (str): the feature being displayed, it can be 'tree' or 'area' - """ - - if view_direction == self.view_direction and feature == self.feature: - if reset_view: - self.autoRange() - return - - axis_titles = { - "time": "Time Point", - "area": "Object size in calibrated units", - "tree": "", - } - if allow_flip: - if view_direction == "vertical": - time_axis = "left" # time is on y axis - feature_axis = "bottom" - self.invertY(True) # to show tracks from top to bottom - else: - time_axis = "bottom" # time is on y axis - feature_axis = "left" - self.invertY(False) - self.setLabel(time_axis, text=axis_titles["time"]) - self.getAxis(time_axis).setStyle(showValues=True) - - self.setLabel(feature_axis, text=axis_titles[feature]) - if feature == "tree": - self.getAxis(feature_axis).setStyle(showValues=False) - else: - self.getAxis(feature_axis).setStyle(showValues=True) - self.autoRange() # not sure if this is necessary or not - - if ( - self.view_direction != view_direction - or self.feature != feature - or reset_view - ): - self.autoRange() - self.view_direction = view_direction - self.feature = feature - - def _on_click(self, _, points: np.ndarray, ev: QMouseEvent) -> None: - """Adds the selected point to the selected_nodes list. Called when - the user clicks on the TreeWidget to select nodes. - - Args: - points (np.ndarray): _description_ - ev (QMouseEvent): _description_ - """ - - modifiers = ev.modifiers() - node_id = points[0].data() - append = Qt.ShiftModifier == modifiers - self.node_clicked.emit(node_id, append) - self.setFocus() - - def set_data(self, track_df: pd.DataFrame, feature: str) -> None: - """Updates the stored pyqtgraph content based on the given dataframe. - Does not render the new information (need to call _update_viewed_data). - - Args: - track_df (pd.DataFrame): The tracks df to compute the pyqtgraph - content for. Can be all lineages or any subset of them. - feature (str): The feature to be plotted. Can either be 'tree', or 'area'. - """ - self.track_df = track_df - self._create_pyqtgraph_content(track_df, feature) - - def _update_viewed_data(self, view_direction: str): - """Set the data according to the view direction - Args: - view_direction (str): direction to plot the data, either 'horizontal' or 'vertical' - """ - self.g.scatter.setPen( - pg.mkPen(QColor(150, 150, 150)) - ) # first reset the pen to avoid problems with length mismatch between the different properties - self.g.scatter.setSize(10) - if len(self._pos) == 0 or view_direction == "vertical": - pos_data = self._pos - else: - pos_data = np.flip(self._pos, axis=1) - - self.g.setData( - pos=pos_data, - adj=self.adj, - symbol=self.symbols, - symbolBrush=self.symbolBrush, - pen=self.pen, - data=self.node_ids, - ) - self.g.scatter.setPen(self.outline_pen) - self.g.scatter.setSize(self.sizes) - - def _create_pyqtgraph_content(self, track_df: pd.DataFrame, feature: str) -> None: - """Parse the given track_df into the format that pyqtgraph expects - and save the information as attributes. - - Args: - track_df (pd.DataFrame): The dataframe containing the graph to be - rendered in the tree view. Can be all lineages or a subset. - feature (str): The feature to be plotted. Can either be 'tree' or 'area'. - """ - self._pos = [] - self.adj = [] - self.symbols = [] - self.symbolBrush = [] - self.pen = [] - self.sizes = [] - self.node_ids = [] - - if track_df is not None and not track_df.empty: - self.symbols = track_df["symbol"].to_list() - self.symbolBrush = track_df["color"].to_numpy() - if feature == "tree": - self._pos = track_df[["x_axis_pos", "t"]].to_numpy() - elif feature == "area": - self._pos = track_df[["area", "t"]].to_numpy() - self.node_ids = track_df["node_id"].to_list() - self.sizes = np.array( - [ - 8, - ] - * len(self.symbols) - ) - - valid_edges_df = track_df[track_df["parent_id"] != 0] - node_ids_to_index = { - node_id: index for index, node_id in enumerate(self.node_ids) - } - edges_df = valid_edges_df[["node_id", "parent_id"]] - self.pen = valid_edges_df["color"].to_numpy() - edges_df_mapped = edges_df.map(lambda _id: node_ids_to_index[_id]) - self.adj = edges_df_mapped.to_numpy() - - self.outline_pen = np.array( - [pg.mkPen(QColor(150, 150, 150)) for i in range(len(self._pos))] - ) - - def set_selection(self, selected_nodes: list[Any], feature: str) -> None: - """Set the provided list of nodes to be selected. Increases the size - and highlights the outline with blue. Also centers the view - if the first selected node is not visible in the current canvas. - - Args: - selected_nodes (list[Any]): A list of node ids to be selected. - feature (str): the feature that is being plotted, either 'tree' or 'area' - """ - - # reset to default size and color to avoid problems with the array lengths - self.g.scatter.setPen(pg.mkPen(QColor(150, 150, 150))) - self.g.scatter.setSize(10) - - size = ( - self.sizes.copy() - ) # just copy the size here to keep the original self.sizes intact - - outlines = self.outline_pen.copy() - axis_label = ( - "area" if feature == "area" else "x_axis_pos" - ) # check what is currently being shown, to know how to scale the view - - if len(selected_nodes) > 0: - x_values = [] - t_values = [] - for node_id in selected_nodes: - node_df = self.track_df.loc[self.track_df["node_id"] == node_id] - if not node_df.empty: - x_axis_value = node_df[axis_label].values[0] - t = node_df["t"].values[0] - - x_values.append(x_axis_value) - t_values.append(t) - - # Update size and outline - index = self.node_ids.index(node_id) - size[index] += 5 - outlines[index] = pg.mkPen(color="c", width=2) - - # Center point if a single node is selected, center range if multiple nodes are selected - if len(selected_nodes) == 1: - self._center_view(x_axis_value, t) - else: - min_x = np.min(x_values) - max_x = np.max(x_values) - min_t = np.min(t_values) - max_t = np.max(t_values) - self._center_range(min_x, max_x, min_t, max_t) - - self.g.scatter.setPen(outlines) - self.g.scatter.setSize(size) - - def _center_range(self, min_x: int, max_x: int, min_t: int, max_t: int): - """Check whether viewbox contains current range and adjust if not""" - - if self.view_direction == "horizontal": - min_x, max_x, min_t, max_t = min_t, max_t, min_x, max_x - - view_box = self.plotItem.getViewBox() - current_range = view_box.viewRange() - - x_range = current_range[0] - y_range = current_range[1] - - # Check if the new range is within the current range - if ( - x_range[0] <= min_x - and x_range[1] >= max_x - and y_range[0] <= min_t - and y_range[1] >= max_t - ): - return - else: - view_box.setRange(xRange=(min_x, max_x), yRange=(min_t, max_t)) - - def _center_view(self, center_x: int, center_y: int): - """Center the Viewbox on given coordinates""" - - if self.view_direction == "horizontal": - center_x, center_y = ( - center_y, - center_x, - ) # flip because the axes have changed in horizontal mode - - view_box = self.plotItem.getViewBox() - current_range = view_box.viewRange() - - x_range = current_range[0] - y_range = current_range[1] - - # Check if the new center is within the current range - if ( - x_range[0] <= center_x <= x_range[1] - and y_range[0] <= center_y <= y_range[1] - ): - return - - # Calculate the width and height of the current view - current_width = x_range[1] - x_range[0] - current_height = y_range[1] - y_range[0] - - # Calculate new ranges maintaining the current width and height - new_x_range = ( - center_x - current_width / 2, - center_x + current_width / 2, - ) - new_y_range = ( - center_y - current_height / 2, - center_y + current_height / 2, - ) - - view_box.setRange(xRange=new_x_range, yRange=new_y_range, padding=0) - - -class TreeWidget(QWidget): - """pyqtgraph-based widget for lineage tree visualization and navigation""" - - def __init__(self, viewer: napari.Viewer): - super().__init__() - self.track_df = pd.DataFrame() # all tracks - self.lineage_df = pd.DataFrame() # the currently viewed subset of lineages - self.graph = None - self.mode = "all" # options: "all", "lineage" - self.feature = "tree" # options: "tree", "area" - self.view_direction = "vertical" # options: "horizontal", "vertical" - - self.tracks_viewer = TracksViewer.get_instance(viewer) - self.selected_nodes = self.tracks_viewer.selected_nodes - self.selected_nodes.list_updated.connect(self._update_selected) - self.tracks_viewer.tracks_updated.connect(self._update_track_data) - - # Construct the tree view pyqtgraph widget - layout = QVBoxLayout() - - self.tree_widget: TreePlot = TreePlot() - self.tree_widget.node_clicked.connect(self.selected_nodes.add) - self.tree_widget.nodes_selected.connect(self.selected_nodes.add_list) - - # Add radiobuttons for switching between different display modes - self.mode_widget = TreeViewModeWidget() - self.mode_widget.change_mode.connect(self._set_mode) - - # Add buttons to change which feature to display - self.feature_widget = TreeViewFeatureWidget() - self.feature_widget.change_feature.connect(self._set_feature) - - # Add navigation widget - self.navigation_widget = NavigationWidget( - self.track_df, - self.lineage_df, - self.view_direction, - self.selected_nodes, - self.feature, - ) - - # Add widget to flip the axes - self.flip_widget = FlipTreeWidget() - self.flip_widget.flip_tree.connect(self._flip_axes) - - # Construct a toolbar and set main layout - panel_layout = QHBoxLayout() - panel_layout.addWidget(self.mode_widget) - panel_layout.addWidget(self.feature_widget) - panel_layout.addWidget(self.navigation_widget) - panel_layout.addWidget(self.flip_widget) - panel_layout.setSpacing(0) - panel_layout.setContentsMargins(0, 0, 0, 0) - - panel = QWidget() - panel.setLayout(panel_layout) - panel.setMaximumWidth(930) - panel.setMaximumHeight(82) - - # Make a collapsible for TreeView widgets - collapsable_widget = QCollapsible("Show/Hide Tree View Controls") - collapsable_widget.layout().setContentsMargins(0, 0, 0, 0) - collapsable_widget.layout().setSpacing(0) - collapsable_widget.addWidget(panel) - collapsable_widget.collapse(animate=False) - - layout.addWidget(collapsable_widget) - layout.addWidget(self.tree_widget) - layout.setSpacing(0) - self.setLayout(layout) - self._update_track_data(reset_view=True) - - def keyPressEvent(self, event: QKeyEvent) -> None: - """Handle key press events.""" - key_map = { - Qt.Key_Delete: self.delete_node, - Qt.Key_D: self.delete_node, - Qt.Key_A: self.create_edge, - Qt.Key_B: self.delete_edge, - Qt.Key_Z: self.undo, - Qt.Key_R: self.redo, - Qt.Key_Q: self.toggle_display_mode, - Qt.Key_W: self.toggle_feature_mode, - Qt.Key_F: self._flip_axes, - Qt.Key_X: lambda: self.set_mouse_enabled(x=True, y=False), - Qt.Key_Y: lambda: self.set_mouse_enabled(x=False, y=True), - } - - # Check if the key has a handler in the map - handler = key_map.get(event.key()) - - if handler: - handler() # Call the function bound to the key - else: - # Handle navigation (Arrow keys) - direction_map = { - Qt.Key_Left: "left", - Qt.Key_Right: "right", - Qt.Key_Up: "up", - Qt.Key_Down: "down", - } - direction = direction_map.get(event.key()) - if direction: - self.navigation_widget.move(direction) - self.tree_widget.setFocus() - - def delete_node(self): - """Delete a node.""" - self.tracks_viewer.delete_node() - - def create_edge(self): - """Create an edge.""" - self.tracks_viewer.create_edge() - - def delete_edge(self): - """Delete an edge.""" - self.tracks_viewer.delete_edge() - - def undo(self): - """Undo action.""" - self.tracks_viewer.undo() - - def redo(self): - """Redo action.""" - self.tracks_viewer.redo() - - def toggle_display_mode(self): - """Toggle display mode.""" - self.mode_widget._toggle_display_mode() - - def toggle_feature_mode(self): - """Toggle feature mode.""" - self.feature_widget._toggle_feature_mode() - - def _flip_axes(self): - """Flip the axes of the plot""" - - if self.view_direction == "horizontal": - self.view_direction = "vertical" - else: - self.view_direction = "horizontal" - - self.navigation_widget.view_direction = self.view_direction - self.tree_widget._update_viewed_data(self.view_direction) - self.tree_widget.set_view( - view_direction=self.view_direction, - feature=self.tree_widget.feature, - reset_view=False, - ) - - def set_mouse_enabled(self, x: bool, y: bool): - """Enable or disable mouse zoom scrolling in X or Y direction.""" - self.tree_widget.setMouseEnabled(x=x, y=y) - - def keyReleaseEvent(self, ev): - """Reset the mouse scrolling when releasing the X/Y key""" - - if ev.key() == Qt.Key_X or ev.key() == Qt.Key_Y: - self.tree_widget.setMouseEnabled(x=True, y=True) - - def _update_selected(self): - """Called whenever the selection list is updated. Only re-computes - the full graph information when the new selection is not in the - lineage df (and in lineage mode) - """ - - if self.mode == "lineage" and any( - node not in np.unique(self.lineage_df["node_id"].values) - for node in self.selected_nodes - ): - self._update_lineage_df() - self.tree_widget.update( - self.lineage_df, - self.view_direction, - self.feature, - self.selected_nodes, - ) - else: - self.tree_widget.set_selection(self.selected_nodes, self.feature) - - def _update_track_data(self, reset_view: bool | None = None) -> None: - """Called when the TracksViewer emits the tracks_updated signal, indicating - that a new set of tracks should be viewed. - """ - - if self.tracks_viewer.tracks is None: - self.track_df = pd.DataFrame() - self.graph = None - else: - if reset_view: - self.track_df = extract_sorted_tracks( - self.tracks_viewer.tracks, self.tracks_viewer.colormap - ) - else: - self.track_df = extract_sorted_tracks( - self.tracks_viewer.tracks, - self.tracks_viewer.colormap, - self.track_df, - ) - self.graph = self.tracks_viewer.tracks.graph - - # check whether we have area measurements and therefore should activate the area - # button - if "area" not in self.track_df.columns: - if self.feature_widget.feature == "area": - self.feature_widget._toggle_feature_mode() - self.feature_widget.show_area_radio.setEnabled(False) - else: - self.feature_widget.show_area_radio.setEnabled(True) - - # if reset_view, we got new data and want to reset display and feature before calling the plot update - if reset_view: - self.lineage_df = pd.DataFrame() - self.mode = "all" - self.mode_widget.show_all_radio.setChecked(True) - self.view_direction = "vertical" - self.feature = "tree" - self.feature_widget.show_tree_radio.setChecked(True) - allow_flip = True - else: - allow_flip = False - - # also update the navigation widget - self.navigation_widget.track_df = self.track_df - self.navigation_widget.lineage_df = self.lineage_df - - # check which view to set - if self.mode == "lineage": - self._update_lineage_df() - self.tree_widget.update( - self.lineage_df, - self.view_direction, - self.feature, - self.selected_nodes, - reset_view=reset_view, - allow_flip=allow_flip, - ) - - else: - self.tree_widget.update( - self.track_df, - self.view_direction, - self.feature, - self.selected_nodes, - reset_view=reset_view, - allow_flip=allow_flip, - ) - - def _set_mode(self, mode: str) -> None: - """Set the display mode to all or lineage view. Currently, linage - view is always horizontal and all view is always vertical. - - Args: - mode (str): The mode to set the view to. Options are "all" or "lineage" - """ - if mode not in ["all", "lineage"]: - raise ValueError(f"Mode must be 'all' or 'lineage', got {mode}") - - self.mode = mode - if mode == "all": - if self.feature == "tree": - self.view_direction = "vertical" - else: - self.view_direction = "horizontal" - df = self.track_df - elif mode == "lineage": - self.view_direction = "horizontal" - self._update_lineage_df() - df = self.lineage_df - self.navigation_widget.view_direction = self.view_direction - self.tree_widget.update( - df, self.view_direction, self.feature, self.selected_nodes - ) - - def _set_feature(self, feature: str) -> None: - """Set the feature mode to 'tree' or 'area'. For this the view is always - horizontal. - - Args: - feature (str): The feature to plot. Options are "tree" or "area" - """ - if feature not in ["tree", "area"]: - raise ValueError(f"Feature must be 'tree' or 'area', got {feature}") - - self.feature = feature - if feature == "tree" and self.mode == "all": - self.view_direction = "vertical" - else: - self.view_direction = "horizontal" - self.navigation_widget.view_direction = self.view_direction - - if self.mode == "all": - df = self.track_df - if self.mode == "lineage": - df = self.lineage_df - - self.navigation_widget.feature = self.feature - self.tree_widget.update( - df, self.view_direction, self.feature, self.selected_nodes - ) - - def _update_lineage_df(self) -> None: - """Subset dataframe to include only nodes belonging to the current lineage""" - - if len(self.selected_nodes) == 0 and not self.lineage_df.empty: - # try to restore lineage df based on previous selection, even if those nodes are now deleted. - # this is to prevent that deleting nodes will remove those lineages from the lineage view, which is confusing. - prev_visible_set = set(self.lineage_df["node_id"]) - prev_visible = [ - node for node in prev_visible_set if self.graph.has_node(node) - ] - visible = [] - for node_id in prev_visible: - visible += extract_lineage_tree(self.graph, node_id) - if set(prev_visible).issubset(visible): - break - else: - visible = [] - for node_id in self.selected_nodes: - visible += extract_lineage_tree(self.graph, node_id) - self.lineage_df = self.track_df[ - self.track_df["node_id"].isin(visible) - ].reset_index() - self.lineage_df["x_axis_pos"] = ( - self.lineage_df["x_axis_pos"].rank(method="dense").astype(int) - 1 - ) - self.navigation_widget.lineage_df = self.lineage_df diff --git a/src/motile_tracker/data_views/views/tree_view/tree_widget_utils.py b/src/motile_tracker/data_views/views/tree_view/tree_widget_utils.py deleted file mode 100644 index c24008d8..00000000 --- a/src/motile_tracker/data_views/views/tree_view/tree_widget_utils.py +++ /dev/null @@ -1,247 +0,0 @@ -from __future__ import annotations - -import napari.layers -import networkx as nx -import numpy as np -import pandas as pd -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr - -from motile_tracker.data_model import NodeType, Tracks - - -def extract_sorted_tracks( - tracks: Tracks, - colormap: napari.utils.CyclicLabelColormap, - prev_df: pd.DataFrame | None = None, -) -> pd.DataFrame | None: - """ - Extract the information of individual tracks required for constructing the pyqtgraph plot. Follows the same logic as the relabel_segmentation - function from the Motile toolbox. - - Args: - tracks (motile_tracker.core.Tracks): A tracks object containing a graph - to be converted into a dataframe. - colormap (napari.utils.CyclicLabelColormap): The colormap to use to - extract the color of each node from the track ID - prev_df (pd.DataFrame, Optional). Dataframe that holds the previous track_df, including the order of the tracks. - - Returns: - pd.DataFrame | None: data frame with all the information needed to - construct the pyqtgraph plot. Columns are: 't', 'node_id', 'track_id', - 'color', 'x', 'y', ('z'), 'index', 'parent_id', 'parent_track_id', - 'state', 'symbol', and 'x_axis_pos' - """ - if tracks is None or tracks.graph is None: - return None - - solution_nx_graph = tracks.graph - - track_list = [] - parent_mapping = [] - - # Identify parent nodes (nodes with more than one child) - parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] - end_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d == 0] - - # Make a copy of the graph and remove outgoing edges from parent nodes to isolate tracks - soln_copy = solution_nx_graph.copy() - for parent_node in parent_nodes: - out_edges = solution_nx_graph.out_edges(parent_node) - soln_copy.remove_edges_from(out_edges) - - # Process each weakly connected component as a separate track - for node_set in nx.weakly_connected_components(soln_copy): - # Sort nodes in each weakly connected component by their time attribute to ensure correct order - sorted_nodes = sorted( - node_set, - key=lambda node: tracks.get_time(node), - ) - positions = tracks.get_positions(sorted_nodes).tolist() - - # track_id and color are the same for all nodes in a node_set - parent_track_id = None - track_id = tracks.get_track_id(sorted_nodes[0]) - color = np.concatenate((colormap.map(track_id)[:3] * 255, [255])) - - for node, pos in zip(sorted_nodes, positions, strict=False): - if node in parent_nodes: - state = NodeType.SPLIT - symbol = "t1" - elif node in end_nodes: - state = NodeType.END - symbol = "x" - else: - state = NodeType.CONTINUE - symbol = "o" - - track_dict = { - "t": tracks.get_time(node), - "node_id": node, - "track_id": track_id, - "color": color, - "x": pos[-1], - "y": pos[-2], - "parent_id": 0, - "parent_track_id": 0, - "state": state, - "symbol": symbol, - } - - if tracks.get_area(node) is not None: - track_dict["area"] = tracks.get_area(node) - - if len(pos) == 3: - track_dict["z"] = pos[0] - - # Determine parent_id and parent_track_id - predecessors = list(solution_nx_graph.predecessors(node)) - if predecessors: - parent_id = predecessors[ - 0 - ] # There should be only one predecessor in a lineage tree - track_dict["parent_id"] = parent_id - - if parent_track_id is None: - parent_track_id = solution_nx_graph.nodes[parent_id][ - NodeAttr.TRACK_ID.value - ] - track_dict["parent_track_id"] = parent_track_id - - else: - parent_track_id = 0 - track_dict["parent_id"] = 0 - track_dict["parent_track_id"] = parent_track_id - - track_list.append(track_dict) - - parent_mapping.append( - {"track_id": track_id, "parent_track_id": parent_track_id, "node_id": node} - ) - - x_axis_order = sort_track_ids(parent_mapping, prev_df) - - for node in track_list: - node["x_axis_pos"] = x_axis_order.index(node["track_id"]) - - df = pd.DataFrame(track_list) - if "area" in df.columns: - df["area"] = df["area"].fillna(0) - - return df - - -def find_root(track_id: int, parent_map: dict) -> int: - """Function to find the root associated with a track by tracing its lineage""" - - # Keep traversing a track is found where parent_track_id == 0 (i.e., it's a root) - current_track = track_id - while parent_map.get(current_track) != 0: - current_track = parent_map.get(current_track) - return current_track - - -def sort_track_ids( - track_list: list[dict], prev_df: pd.DataFrame | None = None -) -> list[dict]: - """ - Sort track IDs such to maintain left-first order in the tree formed by parent-child relationships. - Used to determine the x-axis order of the tree plot. - - Args: - track_list (list): List of dictionaries with 'track_id' and 'parent_track_id'. - prev_df (pd.DataFrame, Optional). Dataframe that holds the previous track_df, including the order of the tracks. - - Returns: - list: Ordered list of track IDs for the x-axis. - """ - - roots = [node["track_id"] for node in track_list if node["parent_track_id"] == 0] - - if prev_df is not None and not prev_df.empty: - prev_roots = ( - prev_df.loc[prev_df["parent_track_id"] == 0, "track_id"].unique().tolist() - ) - new_roots = set(roots) - set( - prev_roots - ) # Detect new roots (those in the current list but not in the previous list) - - # Create mappings for fast lookup - track_id_map = { - n["track_id"]: n["node_id"] for n in track_list - } # track_id -> node_id - parent_map = { - n["track_id"]: n["parent_track_id"] for n in track_list - } # track_id -> parent_track_id - position_map = prev_df.set_index("node_id")[ - "x_axis_pos" - ].to_dict() # node_id -> x_axis_pos - - # Iterate over each new root and place it based on previous positions (to the right of its previous left neighbor) - for new_root in new_roots: - new_node_id = track_id_map.get(new_root) - # if node_id of new root does not exist in track_id_map, it is a completely new node and we can skip the rest of the code below and add the new track at the end. - if new_node_id and new_node_id in position_map: - prev_pos = position_map[ - new_node_id - ] # Get the previous position of the new root - # Find which track was on the left of this new root based on previous x_axis_pos - left_track = prev_df.loc[ - prev_df["x_axis_pos"] == prev_pos - 1, "track_id" - ].unique() - if len(left_track) > 0: - left_track_id = left_track[0] # Get the track ID of the left track - # Check if the left_track is a root or further downstream - if left_track_id not in roots: - # If the left_track is not a root, find the root associated with it - left_root = find_root(left_track_id, parent_map) - else: - # If left_track is already a root, use it as-is - left_root = left_track_id - # Find the index of the root where we need to insert the new root - left_ind = roots.index(left_root) if left_root in roots else -1 - else: - # If no left track is found, insert the new root at the beginning - left_ind = -1 - - # Remove the new root from its current position and reinsert it after the left root - roots.remove(new_root) - roots.insert(left_ind + 1, new_root) - - # Final sorted order of roots - x_axis_order = list(roots) - - # Find the children of each of the starting points, and work down the tree. - while len(roots) > 0: - children_list = [] - for track_id in roots: - children = [ - node["track_id"] - for node in track_list - if node["parent_track_id"] == track_id - ] - for i, child in enumerate(children): - [children_list.append(child)] - x_axis_order.insert(x_axis_order.index(track_id) + i, child) - roots = children_list - - return x_axis_order - - -def extract_lineage_tree(graph: nx.DiGraph, node_id: str) -> list[str]: - """Extract the entire lineage tree including horizontal relations for a given node""" - - # go up the tree to identify the root node - root_node = node_id - while True: - predecessors = list(graph.predecessors(root_node)) - if not predecessors: - break - root_node = predecessors[0] - - # extract all descendants to get the full tree - nodes = nx.descendants(graph, root_node) - - # include root - nodes.add(root_node) - - return list(nodes) diff --git a/src/motile_tracker/data_views/views_coordinator/__init__.py b/src/motile_tracker/data_views/views_coordinator/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/motile_tracker/data_views/views_coordinator/node_selection_list.py b/src/motile_tracker/data_views/views_coordinator/node_selection_list.py deleted file mode 100644 index 63af3bff..00000000 --- a/src/motile_tracker/data_views/views_coordinator/node_selection_list.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from psygnal import Signal -from PyQt5.QtCore import QObject - - -class NodeSelectionList(QObject): - """Updates the current selection (0, 1, or 2) of nodes. Sends a signal on every update. - Stores a list of node ids only.""" - - list_updated = Signal() - - def __init__(self): - super().__init__() - self._list = [] - - def add(self, item, append: bool | None = False): - """Append or replace an item to the list, depending on the number of items present and the keyboard modifiers used. Emit update signal""" - - # first check if this node was already present, if so, remove it. - if item in self._list: - self._list.remove(item) - - # single selection plus shift modifier: append to list to have two items in it - elif append: - self._list.append(item) - - # replace item in list - else: - self._list = [] - self._list.append(item) - - # emit update signal - self.list_updated.emit() - - def add_list(self, items: list, append: bool | None = False): - """Add nodes from a list and emit a single signal""" - - if append: - for item in items: - if item in self._list: - self._list.remove(item) - else: - self._list.append(item) - - else: - self._list = items - - self.list_updated.emit() - - def flip(self): - """Change the order of the items in the list""" - if len(self) == 2: - self._list = [self._list[1], self._list[0]] - - def reset(self): - """Empty list and emit update signal""" - self._list = [] - self.list_updated.emit() - - def __getitem__(self, index): - return self._list[index] - - def __len__(self): - return len(self._list) diff --git a/src/motile_tracker/data_views/views_coordinator/tracks_list.py b/src/motile_tracker/data_views/views_coordinator/tracks_list.py deleted file mode 100644 index b985fb6b..00000000 --- a/src/motile_tracker/data_views/views_coordinator/tracks_list.py +++ /dev/null @@ -1,168 +0,0 @@ -from functools import partial -from pathlib import Path -from warnings import warn - -import napari -from fonticon_fa6 import FA6S -from napari._qt.qt_resources import QColoredSVGIcon -from qtpy.QtCore import Signal -from qtpy.QtWidgets import ( - QFileDialog, - QGroupBox, - QHBoxLayout, - QLabel, - QListWidget, - QListWidgetItem, - QPushButton, - QVBoxLayout, - QWidget, -) -from superqt.fonticon import icon as qticon - -from motile_tracker.data_model import Tracks -from motile_tracker.motile.backend.motile_run import MotileRun - - -class TrackListWidget(QWidget): - """Creates or finds a TracksViewer and displays its TrackList widget. This is only used in case the user wants to open the trackslist from the plugins menu.""" - - def __init__(self, viewer: napari.Viewer): - super().__init__() - - from motile_tracker.data_views.views_coordinator.tracks_viewer import ( - TracksViewer, - ) - - tracks_viewer = TracksViewer.get_instance(viewer) - layout = QVBoxLayout() - layout.addWidget(tracks_viewer.tracks_list) - - self.setLayout(layout) - - -class TracksButton(QWidget): - # https://doc.qt.io/qt-5/qlistwidget.html#setItemWidget - # I think this means if we want static buttons we can just make the row here - # but if we want to change the buttons we need to do something more complex - # Columns: Run name, Date/time, delete btn - def __init__(self, tracks: Tracks, name: str): - super().__init__() - self.tracks = tracks - self.name = QLabel(name) - self.name.setFixedHeight(20) - delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") - self.delete = QPushButton(icon=delete_icon) - self.delete.setFixedSize(20, 20) - save_icon = qticon(FA6S.floppy_disk, color="white") - self.save = QPushButton(icon=save_icon) - self.save.setFixedSize(20, 20) - layout = QHBoxLayout() - layout.setSpacing(1) - layout.addWidget(self.name) - layout.addWidget(self.save) - layout.addWidget(self.delete) - self.setLayout(layout) - - def sizeHint(self): - hint = super().sizeHint() - hint.setHeight(30) - return hint - - -class TracksList(QGroupBox): - """Widget for holding in-memory Tracks. Emits a view_tracks signal whenever - a run is selected in the list, useful for telling the TracksViewer to display the - selected tracks. - """ - - view_tracks = Signal(Tracks, str) - - def __init__(self): - super().__init__(title="Results List") - self.file_dialog = QFileDialog() - self.file_dialog.setFileMode(QFileDialog.Directory) - self.file_dialog.setOption(QFileDialog.ShowDirsOnly, True) - - self.save_dialog = QFileDialog() - self.save_dialog.setFileMode(QFileDialog.Directory) - self.save_dialog.setOption(QFileDialog.ShowDirsOnly, True) - - self.tracks_list = QListWidget() - self.tracks_list.setSelectionMode(1) # single selection - self.tracks_list.itemSelectionChanged.connect(self._selection_changed) - - load_button = QPushButton("Load tracks") - load_button.clicked.connect(self.load_tracks) - - layout = QVBoxLayout() - layout.addWidget(self.tracks_list) - layout.addWidget(load_button) - self.setLayout(layout) - - def _selection_changed(self): - selected = self.tracks_list.selectedItems() - if selected: - tracks_button = self.tracks_list.itemWidget(selected[0]) - self.view_tracks.emit(tracks_button.tracks, tracks_button.name.text()) - - def add_tracks(self, tracks: Tracks, name: str, select=True): - """Add a run to the list and optionally select it. Will make a new - row in the list UI representing the given run. - - Note: selecting the run will also emit the selection changed event on - the list. - - Args: - tracks (Tracks): _description_ - select (bool, optional): _description_. Defaults to True. - """ - item = QListWidgetItem(self.tracks_list) - tracks_row = TracksButton(tracks, name) - self.tracks_list.setItemWidget(item, tracks_row) - item.setSizeHint(tracks_row.minimumSizeHint()) - self.tracks_list.addItem(item) - tracks_row.delete.clicked.connect(partial(self.remove_tracks, item)) - tracks_row.save.clicked.connect(partial(self.save_tracks, item)) - if select: - self.tracks_list.setCurrentRow(len(self.tracks_list) - 1) - - def save_tracks(self, item: QListWidgetItem): - """Saves a tracks object from the list. You must pass the list item that - represents the tracks, not the tracks object itself. - - Args: - item (QListWidgetItem): The list item to save. This list item - contains the TracksButton that represents a set of tracks. - """ - tracks: Tracks = self.tracks_list.itemWidget(item).tracks - if self.save_dialog.exec_(): - directory = Path(self.save_dialog.selectedFiles()[0]) - tracks.save(directory) - - def remove_tracks(self, item: QListWidgetItem): - """Remove a tracks object from the list. You must pass the list item that - represents the tracks, not the tracks object itself. - - Args: - item (QListWidgetItem): The list item to remove. This list item - contains the TracksButton that represents a set of tracks. - """ - row = self.tracks_list.indexFromItem(item).row() - self.tracks_list.takeItem(row) - - def load_tracks(self): - """Load a set of tracks from disk. The user selects the directory created - by calling save_tracks. - """ - if self.file_dialog.exec_(): - directory = Path(self.file_dialog.selectedFiles()[0]) - name = directory.stem - try: - tracks = MotileRun.load(directory) - self.add_tracks(tracks, name, select=True) - except (ValueError, FileNotFoundError): - try: - tracks = Tracks.load(directory) - self.add_tracks(tracks, name, select=True) - except (ValueError, FileNotFoundError) as e: - warn(f"Could not load tracks from {directory}: {e}", stacklevel=2) diff --git a/src/motile_tracker/data_views/views_coordinator/tracks_viewer.py b/src/motile_tracker/data_views/views_coordinator/tracks_viewer.py deleted file mode 100644 index 78f59eea..00000000 --- a/src/motile_tracker/data_views/views_coordinator/tracks_viewer.py +++ /dev/null @@ -1,246 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -import napari -import numpy as np -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr -from psygnal import Signal - -from motile_tracker.data_model import NodeType, SolutionTracks -from motile_tracker.data_model.tracks_controller import TracksController -from motile_tracker.data_views.views.layers.tracks_layer_group import TracksLayerGroup -from motile_tracker.data_views.views.tree_view.tree_widget_utils import ( - extract_lineage_tree, -) - -from .node_selection_list import NodeSelectionList -from .tracks_list import TracksList - - -class TracksViewer: - """Purposes of the TracksViewer: - - Emit signals that all widgets should use to update selection or update - the currently displayed Tracks object - - Storing the currently displayed tracks - - Store shared rendering information like colormaps (or symbol maps) - """ - - tracks_updated = Signal(Optional[bool]) - - @classmethod - def get_instance(cls, viewer=None): - if not hasattr(cls, "_instance"): - print("Making new tracking view controller") - if viewer is None: - raise ValueError("Make a viewer first please!") - cls._instance = TracksViewer(viewer) - return cls._instance - - def __init__( - self, - viewer: napari.viewer, - ): - self.viewer = viewer - self.colormap = napari.utils.colormaps.label_colormap( - 49, - seed=0.5, - background_value=0, - ) - - self.symbolmap: dict[NodeType, str] = { - NodeType.END: "x", - NodeType.CONTINUE: "disc", - NodeType.SPLIT: "triangle_up", - } - self.mode = "all" - self.tracks = None - self.visible = None - self.tracking_layers = TracksLayerGroup(self.viewer, self.tracks, "", self) - self.selected_nodes = NodeSelectionList() - self.selected_nodes.list_updated.connect(self.update_selection) - - self.tracks_list = TracksList() - self.tracks_list.view_tracks.connect(self.update_tracks) - - self.set_keybinds() - - def set_keybinds(self): - # TODO: separate and document keybinds (and maybe allow user to choose) - self.viewer.bind_key("q")(self.toggle_display_mode) - self.viewer.bind_key("a")(self.create_edge) - self.viewer.bind_key("d")(self.delete_node) - self.viewer.bind_key("Delete")(self.delete_node) - self.viewer.bind_key("b")(self.delete_edge) - # self.viewer.bind_key("s")(self.set_split_node) - # self.viewer.bind_key("e")(self.set_endpoint_node) - # self.viewer.bind_key("c")(self.set_linear_node) - self.viewer.bind_key("z")(self.undo) - self.viewer.bind_key("r")(self.redo) - - def _refresh(self, node: str | None = None, refresh_view: bool = False) -> None: - """Call refresh function on napari layers and the submit signal that tracks are updated - Restore the selected_nodes, if possible - """ - - if len(self.selected_nodes) > 0 and any( - not self.tracks.graph.has_node(node) for node in self.selected_nodes - ): - self.selected_nodes.reset() - - self.tracking_layers._refresh() - - self.tracks_updated.emit(refresh_view) - - # if a new node was added, we would like to select this one now (call this after emitting the signal, because if the node is a new node, we have to update the table in the tree widget first, or it won't be present) - if node is not None: - self.selected_nodes.add(node) - - # restore selection and/or highlighting in all napari Views (napari Views do not know about their selection ('all' vs 'lineage'), but TracksViewer does) - self.update_selection() - - def update_tracks(self, tracks: SolutionTracks, name: str) -> None: - """Stop viewing a previous set of tracks and replace it with a new one. - Will create new segmentation and tracks layers and add them to the viewer. - - Args: - tracks (motile_tracker.core.Tracks): The tracks to visualize in napari. - name (str): The name of the tracks to display in the layer names - """ - self.selected_nodes._list = [] - - if self.tracks is not None: - self.tracks.refresh.disconnect(self._refresh) - - self.tracks = tracks - self.tracks_controller = TracksController(self.tracks) - - # listen to refresh signals from the tracks - self.tracks.refresh.connect(self._refresh) - - # deactivate the input labels layer - for layer in self.viewer.layers: - if isinstance(layer, (napari.layers.Labels | napari.layers.Points)): - layer.visible = False - - self.set_display_mode("all") - self.tracking_layers.set_tracks(tracks, name) - self.selected_nodes.reset() - self.tracks_updated.emit(True) - - def toggle_display_mode(self, event=None) -> None: - """Toggle the display mode between available options""" - - if self.mode == "lineage": - self.set_display_mode("all") - else: - self.set_display_mode("lineage") - - def set_display_mode(self, mode: str) -> None: - """Update the display mode and call to update colormaps for points, labels, and tracks""" - - # toggle between 'all' and 'lineage' - if mode == "lineage": - self.mode = "lineage" - self.viewer.text_overlay.text = "Toggle Display [Q]\n Lineage" - else: - self.mode = "all" - self.viewer.text_overlay.text = "Toggle Display [Q]\n All" - - self.viewer.text_overlay.visible = True - visible_tracks = self.filter_visible_nodes() - self.tracking_layers.update_visible(visible_tracks, self.visible) - - def filter_visible_nodes(self) -> list[int]: - """Construct a list of track_ids that should be displayed""" - - if self.tracks is None or self.tracks.graph is None: - return [] - if self.mode == "lineage": - # if no nodes are selected, check which nodes were previously visible and filter those - if len(self.selected_nodes) == 0 and self.visible is not None: - prev_visible = [ - node for node in self.visible if self.tracks.graph.has_node(node) - ] - self.visible = [] - for node_id in prev_visible: - self.visible += extract_lineage_tree(self.tracks.graph, node_id) - if set(prev_visible).issubset(self.visible): - break - else: - self.visible = [] - for node in self.selected_nodes: - self.visible += extract_lineage_tree(self.tracks.graph, node) - - return list( - { - self.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value] - for node in self.visible - } - ) - else: - self.visible = "all" - return "all" - - def update_selection(self) -> None: - """Sets the view and triggers visualization updates in other components""" - - self.set_napari_view() - visible_tracks = self.filter_visible_nodes() - self.tracking_layers.update_visible(visible_tracks, self.visible) - - def set_napari_view(self) -> None: - """Adjust the current_step of the viewer to jump to the last item of the selected_nodes list""" - if len(self.selected_nodes) > 0: - node = self.selected_nodes[-1] - self.tracking_layers.center_view(node) - - def delete_node(self, event=None): - """Calls the tracks controller to delete currently selected nodes""" - - self.tracks_controller.delete_nodes(self.selected_nodes._list) - - def set_split_node(self, event=None): - print("split this node") - - def set_endpoint_node(self, event=None): - print("make this node an endpoint") - - def set_linear_node(self, event=None): - print("make this node linear") - - def delete_edge(self, event=None): - """Calls the tracks controller to delete an edge between the two currently selected nodes""" - - if len(self.selected_nodes) == 2: - node1 = self.selected_nodes[0] - node2 = self.selected_nodes[1] - - time1 = self.tracks.get_time(node1) - time2 = self.tracks.get_time(node2) - - if time1 > time2: - node1, node2 = node2, node1 - - self.tracks_controller.delete_edges(edges=np.array([[node1, node2]])) - - def create_edge(self, event=None): - """Calls the tracks controller to add an edge between the two currently selected nodes""" - - if len(self.selected_nodes) == 2: - node1 = self.selected_nodes[0] - node2 = self.selected_nodes[1] - - time1 = self.tracks.get_time(node1) - time2 = self.tracks.get_time(node2) - - if time1 > time2: - node1, node2 = node2, node1 - - self.tracks_controller.add_edges(edges=np.array([[node1, node2]])) - - def undo(self, event=None): - self.tracks_controller.undo() - - def redo(self, event=None): - self.tracks_controller.redo() diff --git a/src/motile_tracker/example_data.py b/src/motile_tracker/example_data.py deleted file mode 100644 index fb6bace9..00000000 --- a/src/motile_tracker/example_data.py +++ /dev/null @@ -1,284 +0,0 @@ -import logging -import os -import shutil -import zipfile -from pathlib import Path -from urllib.request import urlretrieve - -import numpy as np -import tifffile -import zarr -from appdirs import AppDirs -from napari.types import LayerData -from skimage.measure import regionprops - -logger = logging.getLogger(__name__) - - -def Mouse_Embryo_Membrane() -> list[LayerData]: - """Loads the Mouse Embryo Membrane raw data and segmentation data from - the appdir "user data dir". Will download it from the Zenodo DOI if it is not present already. - Returns: - list[LayerData]: An image layer of raw data and a segmentation labels - layer - """ - ds_name = "Mouse_Embryo_Membrane" - appdir = AppDirs("motile-tracker") - data_dir = Path(appdir.user_data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - raw_name = "imaging.tif" - label_name = "segmentation.tif" - return read_zenodo_dataset(ds_name, raw_name, label_name, data_dir) - - -def Fluo_N2DL_HeLa() -> list[LayerData]: - """Loads the Fluo-N2DL-HeLa 01 training raw data and silver truth from - the appdir "user data dir". Will download it from the CTC and convert it to - zarr if it is not present already. - Returns: - list[LayerData]: An image layer of 01 training raw data and a labels - layer of 01 training silver truth labels - """ - ds_name = "Fluo-N2DL-HeLa" - appdir = AppDirs("motile-tracker") - data_dir = Path(appdir.user_data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - return read_ctc_dataset(ds_name, data_dir) - - -def Fluo_N2DL_HeLa_crop() -> list[LayerData]: - """Loads the Fluo-N2DL-HeLa 01 training raw data and silver truth from - the appdir "user data dir". Will download it from the CTC and convert it to - zarr if it is not present already. - Returns: - list[LayerData]: An image layer of 01 training raw data and a labels - layer of 01 training silver truth labels - """ - ds_name = "Fluo-N2DL-HeLa" - appdir = AppDirs("motile-tracker") - data_dir = Path(appdir.user_data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - return read_ctc_dataset(ds_name, data_dir, crop_region=True) - - -def read_zenodo_dataset( - ds_name: str, raw_name: str, label_name: str, data_dir: Path -) -> list[LayerData]: - """Read a zenodo dataset (assumes pre-downloaded) - and returns a list of layer data for making napari layers - - Args: - ds_name (str): name to give to the dataset - raw_name (str): name of the file that points to the intensity data - label_name (str): name of the file that points to the segmentation data - data_dir (Path): Path to the directory containing the images - - Returns: - list[LayerData]: An image layer of raw data and a segmentation labels - layer - """ - ds_zarr = data_dir / (ds_name + ".zarr") - if not ds_zarr.exists(): - logger.info("Downloading %s", ds_name) - download_zenodo_dataset(ds_name, raw_name, label_name, data_dir) - - raw_data = zarr.open(store=ds_zarr, path="01_membrane", dimension_separator="/")[:] - raw_layer_data = (raw_data, {"name": "01_membrane"}, "image") - seg_data = zarr.open(ds_zarr, path="01_labels", dimension_separator="/")[:] - seg_layer_data = (seg_data, {"name": "01_labels"}, "labels") - return [raw_layer_data, seg_layer_data] - - -def read_ctc_dataset( - ds_name: str, data_dir: Path, crop_region=False -) -> list[LayerData]: - """Read a CTC dataset from a zarr (assumes pre-downloaded and converted) - and returns a list of layer data for making napari layers - - Args: - ds_name (str): Dataset name - data_dir (Path): Path to the directory containing the zarr - - Returns: - list[LayerData]: An image layer of 01 training raw data and a labels - layer of 01 training silver truth labels - """ - ds_zarr = data_dir / (ds_name + ".zarr") - if not ds_zarr.exists(): - logger.info("Downloading %s", ds_name) - download_ctc_dataset(ds_name, data_dir) - zarr_store = zarr.open(store=ds_zarr, mode="a") # Open in append mode ('a') - raw_data = zarr_store["01"] - seg_data = zarr_store["01_ST"] - min_y = 90 - min_x = 700 - max_y = 300 - max_x = 1040 - if crop_region: - raw_data = raw_data[:, min_y:max_y, min_x:max_x] - seg_data = seg_data[:, min_y:max_y, min_x:max_x] - else: - raw_data = raw_data[:] - seg_data = seg_data[:] - raw_layer_data = (raw_data, {"name": "01_raw"}, "image") - seg_layer_data = (seg_data, {"name": "01_ST"}, "labels") - - # Check if 'points' dataset exists in the zarr file - points_name = "points_crop" if crop_region else "points" - if "points_file" not in zarr_store: - logger.info("extracting centroids...") - centroids_list = [] - for t in range(seg_data.shape[0]): # Iterate over time frames - frame_seg = seg_data[t] - props = regionprops(frame_seg) - centroids = np.array([prop.centroid for prop in props]) - time_stamped_centroids = np.column_stack( - [np.full(centroids.shape[0], t), centroids] - ) - centroids_list.append(time_stamped_centroids) - all_centroids = np.vstack(centroids_list) - - # Save the centroids inside the zarr file under the 'points' key - zarr_store.create_dataset(points_name, data=all_centroids, overwrite=True) - logger.info("Centroids extracted and saved") - else: - # If 'points' dataset exists, load it - logger.info("points dataset found, loading...") - all_centroids = zarr_store[points_name][:] - - # Prepare points layer data for napari - points_layer_data = (all_centroids, {"name": "centroids"}, "points") - - return [raw_layer_data, seg_layer_data, points_layer_data] - - -def download_zenodo_dataset( - ds_name: str, raw_name: str, label_name: str, data_dir: Path -) -> None: - """Download a sample dataset from zenodo doi and unzip it, then delete the zip. Then convert the tiffs to - zarrs for the first training set consisting of 3D membrane intensity images and segmentation. - - Args: - ds_name (str): Name to give to the dataset - raw_name (str): Name of the file that contains the intensity data - label_name (str): Name of the file that contains the label data - data_dir (Path): The directory in which to store the data. - """ - ds_file_raw = data_dir / raw_name - ds_file_labels = data_dir / label_name - ds_zarr = data_dir / (ds_name + ".zarr") - url_raw = "https://zenodo.org/records/13903500/files/imaging.zip" - url_labels = "https://zenodo.org/records/13903500/files/segmentation.zip" - zip_filename_raw = data_dir / "imaging.zip" - zip_filename_labels = data_dir / "segmentation.zip" - - if not zip_filename_raw.is_file(): - urlretrieve(url_raw, filename=zip_filename_raw) - if not zip_filename_labels.is_file(): - urlretrieve(url_labels, filename=zip_filename_labels) - - with zipfile.ZipFile(zip_filename_raw, "r") as zip_ref: - zip_ref.extractall(data_dir) - with zipfile.ZipFile(zip_filename_labels, "r") as zip_ref: - zip_ref.extractall(data_dir) - - zip_filename_raw.unlink() - zip_filename_labels.unlink() - - convert_4d_arr_to_zarr(ds_file_raw, ds_zarr, "01_membrane") - convert_4d_arr_to_zarr(ds_file_labels, ds_zarr, "01_labels") - - -def download_ctc_dataset(ds_name: str, data_dir: Path) -> None: - """Download a dataset from the Cell Tracking Challenge - and unzip it, then delete the zip. Then convert the tiffs to - zarrs for the first training set images and silver truth. - - Args: - ds_name (str): Dataset name, according to the CTC - data_dir (Path): The directory in which to store the data. - """ - ds_dir = data_dir / ds_name - ds_zarr = data_dir / (ds_name + ".zarr") - ctc_url = f"http://data.celltrackingchallenge.net/training-datasets/{ds_name}.zip" - zip_filename = data_dir / f"{ds_name}.zip" - if not zip_filename.is_file(): - urlretrieve(ctc_url, filename=zip_filename) - with zipfile.ZipFile(zip_filename, "r") as zip_ref: - zip_ref.extractall(data_dir) - zip_filename.unlink() - - convert_to_zarr(ds_dir / "01", ds_zarr, "01") - convert_to_zarr(ds_dir / "01_ST" / "SEG", ds_zarr, "01_ST", relabel=True) - shutil.rmtree(ds_dir) - - -def convert_4d_arr_to_zarr( - tiff_file: str, zarr_path: str, zarr_group: str, relabel=False -): - """Convert 4D tiff file image data to zarr. Also deletes the tiffs! - Args: - tiff_file (str): string representing path to tif file to be converted - zarr_path (str): path to the zarr file to write the output to - zarr_group (str): group within the zarr store to write the data to - relabel (bool): if true, relabels the segmentations to be unique over time - """ - img = tifffile.imread(tiff_file) - data_shape = img.shape - data_dtype = img.dtype - - # prepare zarr - if not os.path.exists(zarr_path): - os.mkdir(zarr_path) - store = zarr.NestedDirectoryStore(zarr_path) - zarr_array = zarr.open( - store=store, - mode="w", - path=zarr_group, - shape=data_shape, - dtype=data_dtype, - ) - # save the time points to the zarr file - max_label = 0 - for t in range(img.shape[0]): - frame = img[t] - if relabel: - frame[frame != 0] += max_label - max_label = int(np.max(frame)) - zarr_array[t] = frame - os.remove(tiff_file) - - -def convert_to_zarr(tiff_path: Path, zarr_path: Path, zarr_group: str, relabel=False): - """Convert tiff file image data to zarr. Also deletes the tiffs! - Args: - tif_path (Path): Path to the directory containing the tiff files - zarr_path (Path): path to the zarr file to write the output to - zarr_group (Path): group within the zarr store to write the data to - """ - # get data dimensions - files = sorted(tiff_path.glob("*.tif")) - logger.info("%s time points found.", len(files)) - example_image = tifffile.imread(files[0]) - data_shape = (len(files), *example_image.shape) - data_dtype = example_image.dtype - # prepare zarr - zarr_path.mkdir(parents=True, exist_ok=True) - store = zarr.NestedDirectoryStore(zarr_path) - zarr_array = zarr.open( - store=store, - mode="w", - path=zarr_group, - shape=data_shape, - dtype=data_dtype, - ) - # load and save data in zarr - max_label = 0 - for t, file in enumerate(files): - frame = tifffile.imread(file) - if relabel: - frame[frame != 0] += max_label - max_label = int(np.max(frame)) - zarr_array[t] = frame - file.unlink() - tiff_path.rmdir() diff --git a/src/motile_tracker/motile/backend/motile_run.py b/src/motile_tracker/motile/backend/motile_run.py index 9d428d3a..699d5bd6 100644 --- a/src/motile_tracker/motile/backend/motile_run.py +++ b/src/motile_tracker/motile/backend/motile_run.py @@ -6,10 +6,9 @@ from typing import TYPE_CHECKING import numpy as np +from funtracks.data_model import SolutionTracks from motile_toolbox.candidate_graph.graph_attributes import NodeAttr -from motile_tracker.data_model import SolutionTracks - from .solver_params import SolverParams if TYPE_CHECKING: diff --git a/src/motile_tracker/motile/menus/motile_widget.py b/src/motile_tracker/motile/menus/motile_widget.py index 42bcf91e..91dcaf95 100644 --- a/src/motile_tracker/motile/menus/motile_widget.py +++ b/src/motile_tracker/motile/menus/motile_widget.py @@ -2,8 +2,10 @@ import logging -from napari import Viewer -from napari.utils.notifications import show_warning +from finn import Viewer +from finn.track_data_views.views_coordinator.tracks_viewer import TracksViewer +from finn.utils.notifications import show_warning +from funtracks.data_model import SolutionTracks from psygnal import Signal from qtpy.QtWidgets import ( QLabel, @@ -12,8 +14,6 @@ ) from superqt.utils import thread_worker -from motile_tracker.data_model import SolutionTracks -from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer from motile_tracker.motile.backend import MotileRun, solve from .run_editor import RunEditor @@ -23,7 +23,7 @@ class MotileWidget(QWidget): - """A widget that controls the backend components of the motile napari tracker. + """A widget that controls the backend components of the motile tracker. Recieves user input about solver parameters, runs motile, and passes results to the TrackingViewController. """ @@ -69,7 +69,6 @@ def view_run(self, tracks: SolutionTracks) -> None: self.edit_run_widget.hide() self.view_run_widget.show() else: - show_warning("Tried to view a Tracks that is not a MotileRun") self.view_run_widget.hide() def edit_run(self, run: MotileRun | None): diff --git a/src/motile_tracker/motile/menus/run_editor.py b/src/motile_tracker/motile/menus/run_editor.py index fdbaf9f0..19c9853d 100644 --- a/src/motile_tracker/motile/menus/run_editor.py +++ b/src/motile_tracker/motile/menus/run_editor.py @@ -5,7 +5,8 @@ from typing import TYPE_CHECKING from warnings import warn -import napari.layers +import dask.array as da +import finn.layers import networkx as nx import numpy as np from motile_toolbox.utils.relabel_segmentation import ensure_unique_labels @@ -21,13 +22,14 @@ QVBoxLayout, QWidget, ) +from tqdm import tqdm from motile_tracker.motile.backend import MotileRun from .params_editor import SolverParamsEditor if TYPE_CHECKING: - import napari + import finn logger = logging.getLogger(__name__) @@ -35,13 +37,13 @@ class RunEditor(QGroupBox): start_run = Signal(MotileRun) - def __init__(self, viewer: napari.Viewer): + def __init__(self, viewer: finn.Viewer): """A widget for editing run parameters and starting solving. Has to know about the viewer to get the input segmentation from the current layers. Args: - viewer (napari.Viewer): The napari viewer that the editor should + viewer (finn.Viewer): The finn viewer that the editor should get the input segmentation from. """ super().__init__(title="Run Editor") @@ -70,7 +72,7 @@ def update_labels_layers(self) -> None: prev_selection = self.layer_selection_box.currentText() self.layer_selection_box.clear() for layer in self.viewer.layers: - if isinstance(layer, napari.layers.Labels | napari.layers.Points): + if isinstance(layer, finn.layers.Labels | finn.layers.Points): self.layer_selection_box.addItem(layer.name) self.layer_selection_box.setCurrentText(prev_selection) @@ -79,9 +81,9 @@ def update_layer_selection(self) -> None: layer = self.get_input_layer() if layer is None: return - if isinstance(layer, napari.layers.Labels): + if isinstance(layer, finn.layers.Labels): enable_iou = True - elif isinstance(layer, napari.layers.Points): + elif isinstance(layer, finn.layers.Points): enable_iou = False self.solver_params_widget.iou_row.toggle_visible(enable_iou) @@ -91,7 +93,7 @@ def _labels_layer_widget(self) -> QWidget: Returns: QWidget: A dropdown select with all the labels layers in layers - and a refresh button to sync with napari. + and a refresh button to sync with finn. """ layer_group = QWidget() layer_layout = QHBoxLayout() @@ -117,12 +119,12 @@ def _labels_layer_widget(self) -> QWidget: layer_group.setLayout(layer_layout) return layer_group - def get_input_layer(self) -> napari.layers.Layer | None: + def get_input_layer(self) -> finn.layers.Layer | None: """Get the input segmentation or points in current selection in the layer dropdown. Returns: - napari.layers.Layer | None: The points or labels layer with the name + finn.layers.Layer | None: The points or labels layer with the name that is selected, or None if no layer is selected. """ layer_name = self.layer_selection_box.currentText() @@ -184,8 +186,13 @@ def get_run(self) -> MotileRun | None: if input_layer is None: warn("No input layer selected", stacklevel=2) return None - if isinstance(input_layer, napari.layers.Labels): - input_seg = input_layer.data + if isinstance(input_layer, finn.layers.Labels): + if isinstance(input_layer.data, da.core.Array): + input_seg = self._convert_da_to_np_array( + input_layer.data + ) # silently convert to in-memory array + else: + input_seg = input_layer.data ndim = input_seg.ndim if ndim > 4: raise ValueError( @@ -199,7 +206,7 @@ def get_run(self) -> MotileRun | None: input_seg = ensure_unique_labels(input_seg) input_points = None - elif isinstance(input_layer, napari.layers.Points): + elif isinstance(input_layer, finn.layers.Points): input_seg = None input_points = input_layer.data params = self.solver_params_widget.solver_params.copy() @@ -213,6 +220,24 @@ def get_run(self) -> MotileRun | None: scale=input_layer.scale, ) + def _convert_da_to_np_array(self, dask_array: da.core.Array) -> np.ndarray: + """Convert from dask array to in-memory array. + + Args: + dask_array (da.core.Array): a dask array + + Returns: + np.ndarray: data as an in-memory numpy array + """ + + stack_list = [] + for i in tqdm( + range(dask_array.shape[0]), + desc="Converting dask array to in-memory array", + ): + stack_list.append(dask_array[i].compute()) + return np.stack(stack_list, axis=0) + def emit_run(self) -> None: """Construct a run and start solving by emitting the start run signal for the main widget to connect to. If run is invalid, will diff --git a/src/motile_tracker/motile/menus/run_viewer.py b/src/motile_tracker/motile/menus/run_viewer.py index f3ae12f4..f960d93f 100644 --- a/src/motile_tracker/motile/menus/run_viewer.py +++ b/src/motile_tracker/motile/menus/run_viewer.py @@ -1,14 +1,10 @@ from __future__ import annotations from functools import partial -from pathlib import Path -from warnings import warn import pyqtgraph as pg -from fonticon_fa6 import FA6S from qtpy.QtCore import Signal from qtpy.QtWidgets import ( - QFileDialog, QGroupBox, QHBoxLayout, QLabel, @@ -17,7 +13,6 @@ QWidget, ) from superqt import QCollapsible, ensure_main_thread -from superqt.fonticon import icon as qticon from motile_tracker.motile.backend import MotileRun @@ -26,9 +21,8 @@ class RunViewer(QGroupBox): """A widget for viewing in progress or completed runs, including - the progress of the solver and the parameters. Can also save the whole - run or export the tracks to CSV. - Output tracks and segmentation are visualized separately in napari layers. + the progress of the solver and the parameters. + Output tracks and segmentation are visualized separately in finn layers. """ edit_run = Signal(object) @@ -42,15 +36,10 @@ def __init__(self): self.solver_label: QLabel self.gap_plot: pg.PlotWidget - # Define persistent file dialogs for saving and exporting - self.save_run_dialog = self._save_dialog() - self.export_tracks_dialog = self._export_tracks_dialog() - # Create layout and add subwidgets main_layout = QVBoxLayout() main_layout.addWidget(self._progress_widget()) main_layout.addWidget(self.params_widget) - main_layout.addWidget(self._save_and_export_widget()) main_layout.addWidget(self._back_to_edit_widget()) self.setLayout(main_layout) @@ -68,31 +57,6 @@ def update_run(self, run: MotileRun): self.solver_event_update() self.params_widget.new_params.emit(run.solver_params) - def _save_and_export_widget(self) -> QWidget: - """Create a widget for saving and exporting tracking results. - - Returns: - QWidget: A widget containing a save button and an export tracks - button. - """ - widget = QWidget() - layout = QHBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - - # Save button - icon = qticon(FA6S.floppy_disk, color="white") - save_run_button = QPushButton(icon=icon, text="Save run") - save_run_button.clicked.connect(self.save_run) - layout.addWidget(save_run_button) - - # create button to export tracks - export_tracks_btn = QPushButton("Export tracks to CSV") - export_tracks_btn.clicked.connect(self.export_tracks) - layout.addWidget(export_tracks_btn) - - widget.setLayout(layout) - return widget - def _back_to_edit_widget(self) -> QWidget: """Create a widget for navigating back to the run editor with different parameters. @@ -164,42 +128,6 @@ def _plot_widget(self) -> pg.PlotWidget: gap_plot.plotItem.setLabel("bottom", "Solver round", **styles) return gap_plot - def _save_dialog(self) -> QFileDialog: - save_run_dialog = QFileDialog() - save_run_dialog.setFileMode(QFileDialog.Directory) - save_run_dialog.setOption(QFileDialog.ShowDirsOnly, True) - return save_run_dialog - - def _export_tracks_dialog(self) -> QFileDialog: - export_tracks_dialog = QFileDialog() - export_tracks_dialog.setFileMode(QFileDialog.AnyFile) - export_tracks_dialog.setAcceptMode(QFileDialog.AcceptSave) - export_tracks_dialog.setDefaultSuffix("csv") - return export_tracks_dialog - - def save_run(self): - if self.save_run_dialog.exec_(): - directory = self.save_run_dialog.selectedFiles()[0] - self.run.save(directory) - - def export_tracks(self): - """Export the tracks from this run to a csv with the following columns: - t,[z],y,x,id,parent_id,[seg_id] - Cells without a parent_id will have an empty string for the parent_id. - Whether or not to include z is inferred from the length of an - arbitrary node's position attribute. If the nodes have a "seg_id" attribute, - the "seg_id" column is included. - """ - default_name = self.run._make_id() - default_name = f"{default_name}_tracks.csv" - base_path = Path(self.export_tracks_dialog.directory().path()) - self.export_tracks_dialog.selectFile(str(base_path / default_name)) - if self.export_tracks_dialog.exec_(): - outfile = self.export_tracks_dialog.selectedFiles()[0] - self.run.export_tracks(outfile) - else: - warn("Exporting aborted", stacklevel=2) - def _set_solver_label(self, status: str): message = "Solver status: " + status self.solver_label.setText(message) diff --git a/src/motile_tracker/napari.yaml b/src/motile_tracker/napari.yaml deleted file mode 100644 index b134b6a7..00000000 --- a/src/motile_tracker/napari.yaml +++ /dev/null @@ -1,46 +0,0 @@ -name: motile-tracker -display_name: Motile -# use 'hidden' to remove plugin from napari hub search results -visibility: public -# see https://napari.org/stable/plugins/manifest.html for valid categories -categories: ["Utilities"] -contributions: - commands: - - id: motile-tracker.main_app - python_name: motile_tracker.application_menus.main_app:MainApp - title: "Open the motile main application" - - id: motile-tracker.tree_widget - python_name: motile_tracker.data_views.views.tree_view.tree_widget:TreeWidget - title: "Open the lineage view widget" - - id: motile-tracker.menus_widget - python_name: motile_tracker.application_menus.menu_widget:MenuWidget - title: "Open the motile menus widget" - - id: motile-tracker.solve - python_name: motile_tracker.motile.backend.solve:solve - title: "Run motile tracking (backend only)" - - id: motile-tracker.Fluo_N2DL_HeLa - python_name: motile_tracker.example_data:Fluo_N2DL_HeLa - title: "Load Fluo-N2DL-HeLa tracking dataset" - - id: motile-tracker.Fluo_N2DL_HeLa_crop - python_name: motile_tracker.example_data:Fluo_N2DL_HeLa_crop - title: "Load Fluo-N2DL-HeLa tracking dataset (crop)" - - id: motile-tracker.Mouse_Embryo_Membrane - python_name: motile_tracker.example_data:Mouse_Embryo_Membrane - title: "Load Mouse Embryo_Membrane tracking dataset" - widgets: - - command: motile-tracker.main_app - display_name: Motile Main Widget - - command: motile-tracker.menus_widget - display_name: Motile Menus Widget - - command: motile-tracker.tree_widget - display_name: Motile Lineage View - sample_data: - - command: motile-tracker.Fluo_N2DL_HeLa - key: "Fluo-N2DL-HeLa" - display_name: "Fluo-N2DL-HeLa (2D)" - - command: motile-tracker.Fluo_N2DL_HeLa_crop - key: "Fluo-N2DL-HeLa-crop" - display_name: "Fluo-N2DL-HeLa crop (2D)" - - command: motile-tracker.Mouse_Embryo_Membrane - key: "Mouse_Embryo_Membrane" - display_name: "Mouse Embryo Membranes (3D)" diff --git a/src/motile_tracker/utils/load_tracks.py b/src/motile_tracker/utils/load_tracks.py deleted file mode 100644 index ce128932..00000000 --- a/src/motile_tracker/utils/load_tracks.py +++ /dev/null @@ -1,51 +0,0 @@ -from csv import DictReader - -import networkx as nx -import numpy as np - -from motile_tracker.data_model import SolutionTracks - - -def tracks_from_csv( - csvfile: str, segmentation: np.ndarray | None = None -) -> SolutionTracks: - """Assumes a csv similar to that created from "export tracks to csv" with columns: - t,[z],y,x,id,parent_id,[seg_id] - Cells without a parent_id will have an empty string or a -1 for the parent_id. - - Args: - csvfile (str): - path to the csv to load - segmentation (np.ndarray | None, optional): - An optional accompanying segmentation. - If provided, assumes that the seg_id column in the csv file exists and - corresponds to the label ids in the segmentation array - - Returns: - Tracks: a tracks object ready to be visualized with - TracksViewer.view_external_tracks - """ - graph = nx.DiGraph() - with open(csvfile) as f: - reader = DictReader(f) - for row in reader: - _id = int(row["id"]) - attrs = { - "pos": [float(row["y"]), float(row["x"])], - "time": int(row["t"]), - } - if "seg_id" in row: - attrs["seg_id"] = int(row["seg_id"]) - graph.add_node(_id, **attrs) - parent_id = row["parent_id"].strip() - if parent_id != "": - parent_id = int(parent_id) - if parent_id != -1: - graph.add_edge(parent_id, _id) - tracks = SolutionTracks( - graph=graph, - segmentation=segmentation, - pos_attr="pos", - time_attr="time", - ) - return tracks diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py deleted file mode 100644 index d3b6ef2e..00000000 --- a/tests/data_model/test_action_history.py +++ /dev/null @@ -1,60 +0,0 @@ -import networkx as nx - -from motile_tracker.data_model.action_history import ActionHistory -from motile_tracker.data_model.actions import AddNodes -from motile_tracker.data_model.tracks import Tracks - -# https://github.com/zaboople/klonk/blob/master/TheGURQ.md - - -def test_action_history(): - history = ActionHistory() - tracks = Tracks(nx.DiGraph(), ndim=3) - action1 = AddNodes( - tracks, nodes=[0, 1], attributes={"time": [0, 1], "pos": [[0, 1], [1, 2]]} - ) - - # empty history has no undo or redo - assert not history.undo() - assert not history.redo() - - # add an action to the history - history.add_new_action(action1) - # undo the action - assert history.undo() - assert tracks.graph.number_of_nodes() == 0 - assert len(history.undo_stack) == 1 - assert len(history.redo_stack) == 1 - assert history.undo_pointer == -1 - - # no more actions to undo - assert not history.undo() - - # redo the action - assert history.redo() - assert tracks.graph.number_of_nodes() == 2 - assert len(history.undo_stack) == 1 - assert len(history.redo_stack) == 0 - assert history.undo_pointer == 0 - - # no more actions to redo - assert not history.redo() - - # undo and then add new action - assert history.undo() - action2 = AddNodes(tracks, nodes=[10], attributes={"time": [10], "pos": [[0, 1]]}) - history.add_new_action(action2) - assert tracks.graph.number_of_nodes() == 1 - # there are 3 things on the stack: action1, action1's inverse, and action 2 - assert len(history.undo_stack) == 3 - assert len(history.redo_stack) == 0 - assert history.undo_pointer == 2 - - # undo back to after action 1 - assert history.undo() - assert history.undo() - assert tracks.graph.number_of_nodes() == 2 - - assert len(history.undo_stack) == 3 - assert len(history.redo_stack) == 2 - assert history.undo_pointer == 0 diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py deleted file mode 100644 index 537c0241..00000000 --- a/tests/data_model/test_actions.py +++ /dev/null @@ -1,129 +0,0 @@ -import networkx as nx -import numpy as np -import pytest -from motile_toolbox.candidate_graph.graph_attributes import EdgeAttr, NodeAttr -from numpy.testing import assert_array_almost_equal - -from motile_tracker.data_model import Tracks -from motile_tracker.data_model.actions import ( - AddEdges, - AddNodes, - UpdateNodeSegs, -) - - -def test_add_delete_nodes(segmentation_2d, graph_2d): - empty_graph = nx.DiGraph() - empty_seg = np.zeros_like(segmentation_2d) - tracks = Tracks(empty_graph, segmentation=empty_seg) - nodes = list(graph_2d.nodes()) - attrs = {} - attrs[NodeAttr.TIME.value] = [ - graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes - ] - attrs[NodeAttr.POS.value] = [ - graph_2d.nodes[node][NodeAttr.POS.value] for node in nodes - ] - attrs[NodeAttr.TRACK_ID.value] = [ - graph_2d.nodes[node][NodeAttr.TRACK_ID.value] for node in nodes - ] - pixels = [ - np.nonzero(segmentation_2d[time] == node_id) - for time, node_id in zip(attrs[NodeAttr.TIME.value], nodes, strict=True) - ] - pixels = [ - (np.ones_like(pix[0]) * time, *pix) - for time, pix in zip(attrs[NodeAttr.TIME.value], pixels, strict=True) - ] - add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - if NodeAttr.AREA.value not in graph_2d_data: - graph_2d_data[NodeAttr.AREA.value] = ( - 305 # hard coding the case Anniek took out for now - ) - assert data == graph_2d_data - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - del_nodes = add_nodes.inverse() - assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) - assert_array_almost_equal(tracks.segmentation, empty_seg) - - del_nodes.inverse() - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - if NodeAttr.AREA.value not in graph_2d_data: - graph_2d_data[NodeAttr.AREA.value] = None - assert data == graph_2d_data - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - -def test_update_node_segs(segmentation_2d, graph_2d): - tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - nodes = list(graph_2d.nodes()) - - # add a couple pixels to the first node - new_seg = segmentation_2d.copy() - new_seg[0][0] = 1 - nodes = [1] - - pixels = [np.nonzero(segmentation_2d != new_seg)] - action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] - != graph_2d.nodes[1][NodeAttr.POS.value] - ) - assert_array_almost_equal(tracks.segmentation, new_seg) - - inverse = action.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - assert data == graph_2d.nodes[node] - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse.inverse() - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] - != graph_2d.nodes[1][NodeAttr.POS.value] - ) - assert_array_almost_equal(tracks.segmentation, new_seg) - - -def test_add_delete_edges(graph_2d, segmentation_2d): - node_graph = nx.create_empty_copy(graph_2d, with_data=True) - tracks = Tracks(node_graph, segmentation_2d) - - edges = [[1, 2], [1, 3], [3, 4], [4, 5]] - - action = AddEdges(tracks, edges) - # TODO: What if adding an edge that already exists? - # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 - ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse = action.inverse() - assert set(tracks.graph.edges()) == set() - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert set(tracks.graph.edges()) == set(graph_2d.edges()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 - ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py deleted file mode 100644 index 986aa1c6..00000000 --- a/tests/data_model/test_solution_tracks.py +++ /dev/null @@ -1,25 +0,0 @@ -from motile_tracker.data_model import SolutionTracks - - -def test_export_to_csv(graph_2d, graph_3d, tmp_path): - tracks = SolutionTracks(graph_2d, ndim=3) - temp_file = tmp_path / "test_export_2d.csv" - tracks.export_tracks(temp_file) - with open(temp_file) as f: - lines = f.readlines() - - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header - - header = ["t", "y", "x", "id", "parent_id", "track_id"] - assert lines[0].strip().split(",") == header - - tracks = SolutionTracks(graph_3d, ndim=4) - temp_file = tmp_path / "test_export_3d.csv" - tracks.export_tracks(temp_file) - with open(temp_file) as f: - lines = f.readlines() - - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header - - header = ["t", "z", "y", "x", "id", "parent_id", "track_id"] - assert lines[0].strip().split(",") == header diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py deleted file mode 100644 index 042c82f3..00000000 --- a/tests/data_model/test_tracks.py +++ /dev/null @@ -1,192 +0,0 @@ -import networkx as nx -import numpy as np -import pytest -from motile_toolbox.candidate_graph import NodeAttr - -from motile_tracker.data_model import Tracks - - -def test_create_tracks(graph_3d, segmentation_3d): - # create empty tracks - tracks = Tracks(graph=nx.DiGraph(), ndim=3) - with pytest.raises(KeyError): - tracks.get_positions([1]) - - # create tracks with graph only - tracks = Tracks(graph=graph_3d, ndim=4) - assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] - assert tracks.get_time(1) == 0 - with pytest.raises(KeyError): - tracks.get_positions(["0"]) - - # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) - assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] - assert tracks.get_time(1) == 0 - assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] - tracks.set_time(1, 1) - assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] - - tracks_wrong_attr = Tracks( - graph=graph_3d, segmentation=segmentation_3d, time_attr="test" - ) - with pytest.raises(KeyError): # raises error at access if time is wrong - tracks_wrong_attr.get_times([1]) - - tracks_wrong_attr = Tracks(graph=graph_3d, pos_attr="test", ndim=3) - with pytest.raises(KeyError): # raises error at access if pos is wrong - tracks_wrong_attr.get_positions([1]) - - # test multiple position attrs - pos_attr = ("z", "y", "x") - for node in graph_3d.nodes(): - pos = graph_3d.nodes[node][NodeAttr.POS.value] - z, y, x = pos - del graph_3d.nodes[node][NodeAttr.POS.value] - graph_3d.nodes[node]["z"] = z - graph_3d.nodes[node]["y"] = y - graph_3d.nodes[node]["x"] = x - - tracks = Tracks(graph=graph_3d, pos_attr=pos_attr, ndim=4) - assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] - tracks.set_position(1, [55, 56, 57]) - assert tracks.get_position(1) == [55, 56, 57] - - tracks.set_position(1, [1, 50, 50, 50], incl_time=True) - assert tracks.get_time(1) == 1 - - -def test_add_remove_nodes(graph_2d, segmentation_2d): - # create empty tracks - tracks = Tracks(graph=nx.DiGraph(), ndim=3) - with pytest.raises(KeyError): - tracks.get_positions([1]) - # add a node - tracks.add_node(1, time=0, position=[0, 0, 0]) - assert tracks.get_positions([1]).tolist() == [[0, 0, 0]] - assert tracks.get_time(1) == 0 - # remove the node - tracks.remove_node(1) - with pytest.raises(KeyError): - tracks.get_positions([1]) - - # add a position-less node - with pytest.raises(ValueError): - tracks.add_node(1, time=10) - - # create tracks with segmentation - tracks = Tracks(graph=graph_2d, segmentation=segmentation_2d, scale=[1, 2, 1]) - - # removing a node - node = 3 - tracks.remove_node(node) - with pytest.raises(KeyError): - tracks.get_position(node) - - with pytest.raises(KeyError): - tracks.get_positions([node]) - # adding a node without position infers position from segmentation - tracks.add_node(node, time=1) - assert tracks.get_area(node) == 697 * 2 - - -def test_add_remove_edges(graph_2d, segmentation_2d): - # create empty tracks - tracks = Tracks(graph=nx.DiGraph(), ndim=3) - with pytest.raises(KeyError): - tracks.get_positions([1]) - # add a node - tracks.add_node(1, time=0, position=[0, 0, 0]) - assert tracks.get_positions([1]).tolist() == [[0, 0, 0]] - assert tracks.get_time(1) == 0 - # remove the node - tracks.remove_node(1) - with pytest.raises(KeyError): - tracks.get_positions([1]) - - # add an edge - with pytest.raises(KeyError): - tracks.add_edge((1, 2)) - - tracks.add_node(1, time=0, position=[0, 0, 0]) - tracks.add_node(2, time=1, position=[1, 1, 1]) - tracks.add_edge((1, 2)) - assert tracks.graph.number_of_edges() == 1 - - # create track with graph and seg - tracks = Tracks(graph=graph_2d, segmentation=segmentation_2d) - num_edges = tracks.graph.number_of_edges() - - edge = (1, 3) - iou = tracks.get_iou(edge) - tracks.remove_edge(edge) - assert tracks.graph.number_of_edges() == num_edges - 1 - tracks.add_edge(edge) - assert tracks.graph.number_of_edges() == num_edges - assert pytest.approx(tracks.get_iou(edge), abs=0.01) == iou - - edges = [(1, 3), (1, 2)] - tracks.remove_edges(edges) - assert tracks.graph.number_of_edges() == num_edges - 2 - - with pytest.raises(KeyError): - tracks.remove_edge((1, 2)) - - with pytest.raises(KeyError): - tracks.remove_edges([(1, 3), (1, 2)]) - - # with pytest.raises(ValueError): - # # TODO: what happens if you add a duplicate edge? remove a nonexisting edge? - # tracks.add_edge(edge) - - -def test_pixels_and_seg_id(graph_3d, segmentation_3d): - # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) - - # changing a segmentation id changes it in the mapping - pix = tracks.get_pixels([1]) - new_seg_id = 10 - tracks.set_pixels(pix, [new_seg_id]) - - with pytest.raises(KeyError): - tracks.get_positions(["0"]) - - -def test_update_segmentations(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - - # remove pixels from a segmentation - nodes = [1] - edge = (1, 3) - current_pix = tracks.get_pixels(nodes) - print(current_pix, segmentation_2d.ndim) - areas = tracks.get_areas(nodes) - iou = tracks.get_iou(edge) - # get the first 5 pixels of each segmentation - pix_to_remove = [ - tuple(pix[dim][0:5] for dim in range(segmentation_2d.ndim)) - for pix in current_pix - ] - tracks.update_segmentations(nodes, pix_to_remove, added=False) - - # there are 5 different pixels for each node - assert np.sum(segmentation_2d != tracks.segmentation) == len(nodes) * 5 - - # the areas have updated - for node, area in zip(nodes, areas, strict=False): - assert tracks.get_area(node) == area - 5 - - # the edge IOUs have updated - assert tracks.get_iou(edge) < iou - - # add pixels back to the segmentation - tracks.update_segmentations(nodes, pix_to_remove, added=True) - assert np.sum(segmentation_2d != tracks.segmentation) == 0 - - # the areas have updated - for node, area in zip(nodes, areas, strict=False): - assert tracks.get_area(node) == area - - # the edge IOUs have updated - assert tracks.get_iou(edge) == pytest.approx(iou, abs=0.01) diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py deleted file mode 100644 index 383883b1..00000000 --- a/tests/data_model/test_tracks_controller.py +++ /dev/null @@ -1,306 +0,0 @@ -import numpy as np -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr - -from motile_tracker.data_model.solution_tracks import SolutionTracks -from motile_tracker.data_model.tracks_controller import TracksController - - -def test__add_nodes_no_seg(graph_2d): - # add without segmentation - tracks = SolutionTracks(graph_2d, ndim=3) - controller = TracksController(tracks) - - num_edges = tracks.graph.number_of_edges() - - # start a new track with multiple nodes - attrs = { - NodeAttr.TIME.value: [0, 1], - NodeAttr.POS.value: np.array([[1, 3], [1, 3]]), - NodeAttr.TRACK_ID.value: [6, 6], - } - - action, node_ids = controller._add_nodes(attrs) - - node = node_ids[0] - assert tracks.graph.has_node(node) - assert tracks.get_position(node) == [1, 3] - assert tracks.get_track_id(node) == 6 - - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added - - # add nodes to end of existing track - attrs = { - NodeAttr.TIME.value: [2, 3], - NodeAttr.POS.value: np.array([[1, 3], [1, 3]]), - NodeAttr.TRACK_ID.value: [2, 2], - } - - action, node_ids = controller._add_nodes(attrs) - - node1 = node_ids[0] - node2 = node_ids[1] - assert tracks.get_position(node1) == [1, 3] - assert tracks.get_track_id(node1) == 2 - assert tracks.graph.has_edge(2, node1) - assert tracks.graph.has_edge(node1, node2) - - # add node to middle of existing track - attrs = { - NodeAttr.TIME.value: [3], - NodeAttr.POS.value: np.array([[1, 3]]), - NodeAttr.TRACK_ID.value: [3], - } - - action, node_ids = controller._add_nodes(attrs) - - node = node_ids[0] - assert tracks.get_position(node) == [1, 3] - assert tracks.get_track_id(node) == 3 - - assert tracks.graph.has_edge(4, node) - assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) - - -def test__add_nodes_with_seg(graph_2d, segmentation_2d): - # add with segmentation - tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) - controller = TracksController(tracks) - - num_edges = tracks.graph.number_of_edges() - - new_seg = segmentation_2d.copy() - time = 0 - track_id = 6 - node1 = 7 - node2 = 8 - new_seg[time : time + 1, 90:100, 0:4] = node1 - new_seg[time + 1 : time + 2, 90:100, 0:4] = node2 - expected_center = [94.5, 1.5] - # start a new track - attrs = { - NodeAttr.TIME.value: [time, time + 1], - NodeAttr.TRACK_ID.value: [track_id, track_id], - "node_id": [node1, node2], - } - - loc_pix = np.where(new_seg[time] == node1) - time_pix = np.ones_like(loc_pix[0]) * time - time_pix2 = np.ones_like(loc_pix[0]) * (time + 1) - pixels = [ - (time_pix, *loc_pix), - (time_pix2, *loc_pix), - ] # TODO: get time from pixels? - - action, node_ids = controller._add_nodes(attrs, pixels=pixels) - - node1, node2 = node_ids - assert tracks.get_time(node1) == 0 - assert tracks.get_position(node1) == expected_center - assert tracks.get_track_id(node1) == 6 - assert tracks.get_time(node2) == 1 - assert tracks.get_position(node2) == expected_center - assert tracks.get_track_id(node2) == 6 - assert np.sum(tracks.segmentation != new_seg) == 0 - - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added - - # add nodes to end of existing track - time = 2 - track_id = 2 - node1 = 9 - node2 = 10 - new_seg[time : time + 1, 0:10, 0:4] = node1 - new_seg[time + 1 : time + 2, 0:10, 0:4] = node2 - expected_center = [4.5, 1.5] - # start a new track - attrs = { - NodeAttr.TIME.value: [time, time + 1], - NodeAttr.TRACK_ID.value: [track_id, track_id], - "node_id": [node1, node2], - } - - loc_pix = np.where(new_seg[time] == node1) - time_pix = np.ones_like(loc_pix[0]) * time - time_pix2 = np.ones_like(loc_pix[0]) * (time + 1) - pixels = [(time_pix, *loc_pix), (time_pix2, *loc_pix)] - - action, node_ids = controller._add_nodes(attrs, pixels) - print(node_ids, pixels) - - node = node_ids[0] - assert tracks.get_position(node) == expected_center - assert tracks.get_track_id(node) == 2 - assert np.sum(tracks.segmentation != new_seg) == 0 - - assert tracks.graph.has_edge(2, node) - assert tracks.graph.has_edge(node, node_ids[1]) - - # add node to middle of existing track - time = 3 - track_id = 3 - node1 = 11 - new_seg[time, 0:10, 0:4] = node1 - expected_center = [4.5, 1.5] - attrs = { - NodeAttr.TIME.value: [time], - NodeAttr.TRACK_ID.value: [track_id], - "node_id": [node1], - } - - loc_pix = np.where(new_seg[time] == node1) - time_pix = np.ones_like(loc_pix[0]) * time - pixels = [(time_pix, *loc_pix)] - - action, node_ids = controller._add_nodes(attrs, pixels=pixels) - - node = node_ids[0] - assert tracks.get_position(node) == expected_center - assert tracks.get_track_id(node) == 3 - assert np.sum(tracks.segmentation != new_seg) == 0 - - assert tracks.graph.has_edge(4, node) - assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) - - -def test__delete_nodes_no_seg(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() - - # delete unconnected node - node = 6 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert tracks.graph.number_of_edges() == num_edges - action.inverse() - - # delete end node - node = 5 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert not tracks.graph.has_edge(4, node) - action.inverse() - - # delete continuation node - node = 4 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert not tracks.graph.has_edge(3, node) - assert not tracks.graph.has_edge(node, 5) - assert tracks.graph.has_edge(3, 5) - assert tracks.get_track_id(5) == 3 - action.inverse() - - # delete div parent - node = 1 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert not tracks.graph.has_edge(node, 2) - assert not tracks.graph.has_edge(node, 3) - action.inverse() - - # delete div child - node = 3 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert tracks.get_track_id(2) == 1 # update track id for other child - - -def test__delete_nodes_with_seg(graph_2d, segmentation_2d): - tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) - controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() - - # delete unconnected node - node = 6 - track_id = 6 - time = 4 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert tracks.graph.number_of_edges() == num_edges - action.inverse() - - # delete end node - node = 5 - track_id = 3 - time = 4 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(4, node) - action.inverse() - - # delete continuation node - node = 4 - track_id = 3 - time = 2 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(3, node) - assert not tracks.graph.has_edge(node, 5) - assert tracks.graph.has_edge(3, 5) - assert tracks.get_track_id(5) == 3 - action.inverse() - - # delete div parent - node = 1 - track_id = 1 - time = 0 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(node, 2) - assert not tracks.graph.has_edge(node, 3) - action.inverse() - - # delete div child - node = 2 - track_id = 2 - time = 1 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert tracks.get_track_id(3) == 1 # update track id for other child - assert tracks.get_track_id(5) == 1 # update track id for other child - - -def test__add_remove_edges_no_seg(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() - - # delete continuation edge - edge = (3, 4) - track_id = 3 - controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) != track_id # relabeled the rest of the track - assert tracks.graph.number_of_edges() == num_edges - 1 - - # add back in continuation edge - controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) == track_id # track id was changed back - assert tracks.graph.number_of_edges() == num_edges - - # delete division edge - edge = (1, 3) - track_id = 3 - controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal - assert tracks.get_track_id(2) == 1 # but do relabel the sibling - assert tracks.graph.number_of_edges() == num_edges - 1 - - # add back in division edge - edge = (1, 3) - track_id = 3 - controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal - assert tracks.get_track_id(2) != 1 # give sibling new id again (not necessarily 2) - assert tracks.graph.number_of_edges() == num_edges diff --git a/tests/motile_plugin/utils/test_tree_widget_utils.py b/tests/motile_plugin/utils/test_tree_widget_utils.py deleted file mode 100644 index 820fd010..00000000 --- a/tests/motile_plugin/utils/test_tree_widget_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -import napari -import pandas as pd -from motile_toolbox.visualization.napari_utils import assign_tracklet_ids - -from motile_tracker.data_model import SolutionTracks -from motile_tracker.data_views.views.tree_view.tree_widget_utils import ( - extract_sorted_tracks, -) - - -def test_track_df(graph_2d): - tracks = SolutionTracks(graph=graph_2d, ndim=3) - - assert tracks.get_area(1) == 1245 - assert tracks.get_area(2) is None - - tracks.graph, _, _ = assign_tracklet_ids(tracks.graph) - - colormap = napari.utils.colormaps.label_colormap( - 49, - seed=0.5, - background_value=0, - ) - - track_df = extract_sorted_tracks(tracks, colormap) - assert isinstance(track_df, pd.DataFrame) - assert track_df.loc[track_df["node_id"] == 1, "area"].values[0] == 1245 - assert track_df.loc[track_df["node_id"] == 2, "area"].values[0] == 0 - assert track_df["area"].notna().all()